Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions src/stdlib/mexpr/lamlift.mc
Original file line number Diff line number Diff line change
Expand Up @@ -533,25 +533,37 @@ end
lang MExprLambdaLiftAllowSpineCapture =
MExprLambdaLift + MExprAsDecl

sem findFreeVariablesSpine : LambdaLiftState -> Expr -> LambdaLiftState
sem findFreeVariablesSpine state =
syn AllowCapture =
| AllowCapture
| DisallowCapture

sem findFreeVariablesSpine : ({ty : Type, body : Expr} -> AllowCapture) -> LambdaLiftState -> Expr -> LambdaLiftState
sem findFreeVariablesSpine shouldAllowCapture state =
| TmLet t ->
let state =
match t.body with TmLam _ then
-- NOTE(vipa, 2023-10-09): A let-bound lambda, find a solution
-- for it
let sol = findFreeVariablesInBody state _solEmpty t.body in
{state with sols = mapInsert t.ident sol state.sols}
else
-- NOTE(vipa, 2025-01-14): A normal variable along the spine,
-- don't treat it as free later
state
else switch shouldAllowCapture {ty = t.tyBody, body = t.body}
case AllowCapture _ then
-- NOTE(vipa, 2025-01-14): Not adding a variable to the
-- state means we don't add it as a parameter to functions
-- that close over it, i.e., we allow their capture.
state
case DisallowCapture _ then
-- NOTE(vipa, 2025-05-16): A normal variable that should not
-- be implicitly captured, but rather passed as an explicit
-- argument.
{state with vars = mapInsert t.ident t.tyBody state.vars}
end
in
let state =
let tyvars = concat (stripTyAll t.tyAnnot).0 (stripTyAll t.tyBody).0 in
foldl (lam acc. lam pair. {acc with tyVars = mapInsert pair.0 pair.1 acc.tyVars}) state tyvars in
let state = findFreeVariables state t.body in
findFreeVariablesSpine state t.inexpr
findFreeVariablesSpine shouldAllowCapture state t.inexpr
| tm & TmRecLets t -> recursive
let insertInitialSolution = lam state. lam binding.
let sol = findFreeVariablesInBody state _solEmpty binding.body in
Expand Down Expand Up @@ -587,23 +599,27 @@ lang MExprLambdaLiftAllowSpineCapture =
let sccs = digraphTarjan g in
let state = propagateFunNames state (reverse sccs) in
let state = foldl findFreeVariablesBinding state t.bindings in
findFreeVariablesSpine state t.inexpr
findFreeVariablesSpine shouldAllowCapture state t.inexpr
| TmExt t ->
let state = {state with sols = mapInsert t.ident _solEmpty state.sols} in
findFreeVariablesSpine state t.inexpr
findFreeVariablesSpine shouldAllowCapture state t.inexpr
| tm ->
match exprAsDecl tm with Some (decl, inexpr) then
let state = sfold_Decl_Expr findFreeVariables state decl in
findFreeVariablesSpine state inexpr
findFreeVariablesSpine shouldAllowCapture state inexpr
else findFreeVariables state tm

sem liftLambdasWithSolutionsAllowSpineCapture : Expr -> (Map Name FinalOrderedLamLiftSolution, Expr)
sem liftLambdasWithSolutionsAllowSpineCapture = | t ->
sem liftLambdasWithSolutionsMaybeAllowSpineCapture : ({ty : Type, body : Expr} -> AllowCapture) -> Expr -> (Map Name FinalOrderedLamLiftSolution, Expr)
sem liftLambdasWithSolutionsMaybeAllowSpineCapture shouldAllowCapture = | t ->
let t = nameAnonymousLambdas t in
let state = findFreeVariablesSpine emptyLambdaLiftState t in
let state = findFreeVariablesSpine shouldAllowCapture emptyLambdaLiftState t in
let t = insertFreeVariables state.sols t in
let t = liftGlobal t in
replaceCapturedParameters state.sols t

sem liftLambdasWithSolutionsAllowSpineCapture : Expr -> (Map Name FinalOrderedLamLiftSolution, Expr)
sem liftLambdasWithSolutionsAllowSpineCapture = | t ->
liftLambdasWithSolutionsMaybeAllowSpineCapture (lam. AllowCapture ()) t
end

lang TestLang =
Expand Down
Loading