Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
4d34998
Update test to expose HeapParameterization bug
keyboardDrummer May 5, 2026
7760c92
Add fix
keyboardDrummer May 5, 2026
f8acfc6
Update
keyboardDrummer May 5, 2026
39189f3
update comment
keyboardDrummer May 5, 2026
515f13d
Replace panic
keyboardDrummer May 5, 2026
4c72dc4
Fixes
keyboardDrummer May 5, 2026
9788a1b
Fix bug in ConstrainedTypeElim
keyboardDrummer May 5, 2026
60cedac
Merge branch 'main' into heapParamScopeBug
keyboardDrummer May 5, 2026
a5a6d5c
Fix test
keyboardDrummer May 5, 2026
702bf34
Update Strata/Languages/Laurel/LaurelCompilationPipeline.lean
keyboardDrummer May 5, 2026
906fe44
Fix oops
keyboardDrummer May 5, 2026
dd17265
Rename Build/ to IntermediatePrograms/ and fix trailing newline
keyboardDrummer-bot May 18, 2026
d5e90af
Add comment
keyboardDrummer May 18, 2026
66c9707
Refactor HeapParameterization to use list-returning traversal (#1192)
keyboardDrummer-bot May 20, 2026
426d53b
Remove unused mapStmtExprFlattenGoM and move wrapList to HeapParamete…
keyboardDrummer-bot May 20, 2026
e01f41f
Fix consistency check: use dbg_trace instead of aborting pipeline
keyboardDrummer-bot May 20, 2026
d9d7889
Remove dbg_trace from consistency check
keyboardDrummer-bot May 20, 2026
1d54bda
Merge remote-tracking branch 'origin/main' into heapParamScopeBug
MikaelMayer May 20, 2026
b381cc0
Factor out heap variable name constants and move intermediate program…
MikaelMayer May 20, 2026
bcf4c4a
Merge remote-tracking branch 'origin/main' into heapParamScopeBug
MikaelMayer May 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ vcs/*.smt2
*.py.ion
*.py.ion.core.st

Strata.code-workspace
Strata.code-workspace
IntermediatePrograms/
2 changes: 1 addition & 1 deletion Strata/Languages/Laurel/ConstrainedTypeElim.lean
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ private def mkWitnessProc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) :
{ name := mkId s!"$witness_{ct.name.text}"
inputs := []
outputs := []
body := .Transparent ⟨.Block [witnessInit, assert] none, src⟩
body := .Opaque [] (some ⟨.Block [witnessInit, assert] none, src⟩) []
Comment thread
MikaelMayer marked this conversation as resolved.
preconditions := []
isFunctional := false
decreases := none }
Expand Down
141 changes: 71 additions & 70 deletions Strata/Languages/Laurel/HeapParameterization.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public import Strata.Languages.Laurel.Laurel
public import Strata.Languages.Laurel.Grammar.AbstractToConcreteTreeTranslator
public import Strata.Languages.Laurel.LaurelTypes
public import Strata.Languages.Laurel.HeapParameterizationConstants
public import Strata.Languages.Laurel.MapStmtExpr
public import Strata.Util.Tactics

/-
Expand Down Expand Up @@ -260,22 +261,25 @@ Transform an expression, adding heap parameters where needed.
- `valueUsed`: whether the result value of this expression is used (affects optimization of heap-writing calls)
-/
def heapTransformExpr (heapVar : Identifier) (model: SemanticModel) (expr : StmtExprMd) (valueUsed : Bool := true) : TransformM StmtExprMd :=
recurse expr valueUsed
recurseOne expr valueUsed
where
recurse (exprMd : StmtExprMd) (valueUsed : Bool := true) : TransformM StmtExprMd := do
recurseOne (exprMd : StmtExprMd) (valueUsed : Bool := true) : TransformM StmtExprMd :=
wrapList exprMd.source <$> recurse exprMd valueUsed
termination_by (sizeOf exprMd, 1)
recurse (exprMd : StmtExprMd) (valueUsed : Bool := true) : TransformM (List StmtExprMd) := do
let ⟨expr, source⟩ := exprMd
match _h : expr with
| .Var (.Field selectTarget fieldName) => do
let some qualifiedName := resolveQualifiedFieldName model fieldName
| return ⟨ .Hole, source ⟩
| return [⟨ .Hole, source ⟩]

let valTy := (model.get fieldName).getType
let readExpr := ⟨ .StaticCall "readField" [mkMd (.Var (.Local heapVar)), selectTarget, mkMd (.StaticCall qualifiedName [])], source ⟩
-- Unwrap Box: apply the appropriate destructor
recordBoxConstructor model valTy.val
return mkMd <| .StaticCall (boxDestructorName model valTy.val) [readExpr]
return [mkMd <| .StaticCall (boxDestructorName model valTy.val) [readExpr]]
| .StaticCall callee args =>
let args' ← args.mapM (recurse ·)
let args' ← args.mapM (recurseOne ·)
let calleeReadsHeap ← readsHeap callee
let calleeWritesHeap ← writesHeap callee
if calleeWritesHeap then
Expand All @@ -284,7 +288,7 @@ where
let callWithHeap := ⟨ .Assign
[mkVarMd (.Local heapVar), mkVarMd (.Declare ⟨freshVar, computeExprType model exprMd⟩)]
(⟨ .StaticCall callee (mkMd (.Var (.Local heapVar)) :: args'), source ⟩), source ⟩
return ⟨ .Block [callWithHeap, mkMd (.Var (.Local freshVar))] none, source ⟩
return [callWithHeap, mkMd (.Var (.Local freshVar))]
else
-- Generate throwaway Declare targets for any non-heap outputs
let procOutputs := match model.get callee with
Expand All @@ -294,18 +298,18 @@ where
let extraTargets ← procOutputs.mapM fun out => do
pure (mkVarMd (.Declare ⟨← freshVarName, out.type⟩))
let allTargets := mkVarMd (.Local heapVar) :: extraTargets
return ⟨ .Assign allTargets (⟨ .StaticCall callee (mkMd (.Var (.Local heapVar)) :: args'), source ⟩), source ⟩
return [⟨ .Assign allTargets (⟨ .StaticCall callee (mkMd (.Var (.Local heapVar)) :: args'), source ⟩), source ⟩]
else if calleeReadsHeap then
return ⟨ .StaticCall callee (mkMd (.Var (.Local heapVar)) :: args'), source ⟩
return [⟨ .StaticCall callee (mkMd (.Var (.Local heapVar)) :: args'), source ⟩]
else
return ⟨ .StaticCall callee args', source ⟩
return [⟨ .StaticCall callee args', source ⟩]
| .InstanceCall callTarget callee args =>
let t ← recurse callTarget
let args' ← args.mapM (recurse ·)
return ⟨ .InstanceCall t callee args', source ⟩
let t ← recurseOne callTarget
let args' ← args.mapM (recurseOne ·)
return [⟨ .InstanceCall t callee args', source ⟩]
| .IfThenElse c t e =>
let e' ← match e with | some x => some <$> recurse x valueUsed | none => pure none
return ⟨ .IfThenElse (← recurse c) (← recurse t valueUsed) e', source ⟩
let e' ← match e with | some x => some <$> recurseOne x valueUsed | none => pure none
return [⟨ .IfThenElse (← recurseOne c) (← recurseOne t valueUsed) e', source ⟩]
| .Block stmts label =>
let n := stmts.length
let rec processStmts (idx : Nat) (remaining : List StmtExprMd) : TransformM (List StmtExprMd) := do
Expand All @@ -315,16 +319,16 @@ where
let isLast := idx == n - 1
let s' ← recurse s (isLast && valueUsed)
let rest' ← processStmts (idx + 1) rest
pure (s' :: rest')
termination_by sizeOf remaining
pure (s' ++ rest')
termination_by (sizeOf remaining, 0)
let stmts' ← processStmts 0 stmts
return ⟨ .Block stmts' label, source ⟩
return [⟨ .Block stmts' label, source ⟩]
| .While c invs d b =>
let invs' ← invs.mapM (recurse ·)
return ⟨ .While (← recurse c) invs' d (← recurse b false), source ⟩
let invs' ← invs.mapM (recurseOne ·)
return [⟨ .While (← recurseOne c) invs' d (← recurseOne b false), source ⟩]
| .Return v =>
let v' ← match v with | some x => some <$> recurse x | none => pure none
return ⟨ .Return v', source ⟩
let v' ← match v with | some x => some <$> recurseOne x | none => pure none
return [⟨ .Return v', source ⟩]
| .Assign targets v =>

-- Process field targets
Expand All @@ -338,7 +342,7 @@ where
let valTy := (model.get fieldName).getType
recordBoxConstructor model valTy.val
let freshVar ← freshVarName
let target' ← recurse target
let target' ← recurseOne target
let boxedVal := mkMd <| .StaticCall (boxConstructorName model valTy.val) [mkMd (.Var (.Local freshVar))]
let updateStmt : StmtExprMd := ⟨ .Assign [mkVarMd (.Local heapVar)]
(mkMd (.StaticCall "updateField" [mkMd (.Var (.Local heapVar)), target', mkMd (.StaticCall qualifiedName []), boxedVal])), source ⟩
Expand All @@ -350,7 +354,7 @@ where
-- Detect calls and add a heap argument if needed
let (v', addedHeap) <- match _hv : v.val with
| .StaticCall callee args => do
let args' <- args.mapM recurse
let args' <- args.mapM recurseOne
let calleeWritesHeap ← writesHeap callee
let calleeReadsHeap ← readsHeap callee
if calleeWritesHeap then
Expand All @@ -360,11 +364,11 @@ where
else
pure (⟨ .StaticCall callee args', v.source ⟩, false)
| .InstanceCall callTarget _callee args => do
let _callTarget' ← recurse callTarget
let _args' <- args.mapM recurse
let _callTarget' ← recurseOne callTarget
let _args' <- args.mapM recurseOne
pure (⟨ .InstanceCall _callTarget' _callee _args', v.source ⟩, false)
| _ =>
pure (<- recurse v, false)
pure (<- recurseOne v, false)
let allTargets := if addedHeap
then ⟨ Variable.Local heapVar, v.source ⟩ :: processedTargets
else processedTargets
Expand All @@ -387,15 +391,12 @@ where
else updateStatements
pure (newAssign, suffixes)

-- Create a block if necessary
if suffixes.length > 0 then
return ⟨ StmtExpr.Block (newAssign :: suffixes) none, source ⟩
else
return newAssign
-- Return the list of statements directly (flattened into enclosing block)
return newAssign :: suffixes

| .PureFieldUpdate t f v => return ⟨ .PureFieldUpdate (← recurse t) f (← recurse v), source ⟩
| .PureFieldUpdate t f v => return [⟨ .PureFieldUpdate (← recurseOne t) f (← recurseOne v), source ⟩]
| .PrimitiveOp op args =>
let args' ← args.mapM (recurse ·)
let args' ← args.mapM (recurseOne ·)
-- For == and != on Composite types, compare refs instead
match op, args with
| .Eq, [e1, _e2] =>
Expand All @@ -404,54 +405,54 @@ where
| .UserDefined _ =>
let ref1 := mkMd (.StaticCall "Composite..ref!" [args'[0]!])
let ref2 := mkMd (.StaticCall "Composite..ref!" [args'[1]!])
return ⟨ .PrimitiveOp .Eq [ref1, ref2], source ⟩
| _ => return ⟨ .PrimitiveOp op args', source ⟩
return [⟨ .PrimitiveOp .Eq [ref1, ref2], source ⟩]
| _ => return [⟨ .PrimitiveOp op args', source ⟩]
| .Neq, [e1, _e2] =>
let ty := (computeExprType model e1).val
match ty with
| .UserDefined _ =>
let ref1 := mkMd (.StaticCall "Composite..ref!" [args'[0]!])
let ref2 := mkMd (.StaticCall "Composite..ref!" [args'[1]!])
return ⟨ .PrimitiveOp .Neq [ref1, ref2], source ⟩
| _ => return ⟨ .PrimitiveOp op args', source ⟩
| _, _ => return ⟨ .PrimitiveOp op args', source ⟩
| .New _ => return exprMd
| .ReferenceEquals l r => return ⟨ .ReferenceEquals (← recurse l) (← recurse r), source ⟩
return [⟨ .PrimitiveOp .Neq [ref1, ref2], source ⟩]
| _ => return [⟨ .PrimitiveOp op args', source ⟩]
| _, _ => return [⟨ .PrimitiveOp op args', source ⟩]
| .New _ => return [exprMd]
| .ReferenceEquals l r => return [⟨ .ReferenceEquals (← recurseOne l) (← recurseOne r), source ⟩]
| .AsType t ty =>
let t' ← recurse t valueUsed
let t' ← recurseOne t valueUsed
let isCheck := ⟨ .IsType t' ty, source ⟩
let assertStmt := ⟨ .Assert { condition := isCheck }, source ⟩
return ⟨ .Block [assertStmt, t'] none, source ⟩
| .IsType t ty => return ⟨ .IsType (← recurse t) ty, source ⟩
return [⟨ .Block [assertStmt, t'] none, source ⟩]
| .IsType t ty => return [⟨ .IsType (← recurseOne t) ty, source ⟩]
| .Quantifier mode p trigger b =>
let trigger' ← trigger.attach.mapM fun ⟨t, _⟩ => recurse t
return ⟨.Quantifier mode p trigger' (← recurse b), source⟩
| .Assigned n => return ⟨ .Assigned (← recurse n), source ⟩
| .Old v => return ⟨ .Old (← recurse v), source ⟩
| .Fresh v => return ⟨ .Fresh (← recurse v), source ⟩
let trigger' ← trigger.attach.mapM fun ⟨t, _⟩ => recurseOne t
return [⟨.Quantifier mode p trigger' (← recurseOne b), source⟩]
| .Assigned n => return [⟨ .Assigned (← recurseOne n), source ⟩]
| .Old v => return [⟨ .Old (← recurseOne v), source ⟩]
| .Fresh v => return [⟨ .Fresh (← recurseOne v), source ⟩]
| .Assert ⟨condExpr, summary⟩ =>
return ⟨ .Assert { condition := ← recurse condExpr, summary }, source ⟩
| .Assume c => return ⟨ .Assume (← recurse c), source ⟩
| .ProveBy v p => return ⟨ .ProveBy (← recurse v) (← recurse p), source ⟩
| .ContractOf ty f => return ⟨ .ContractOf ty (← recurse f), source ⟩
| _ => return exprMd
termination_by sizeOf exprMd
decreasing_by
all_goals simp_wf
all_goals (try have := AstNode.sizeOf_val_lt exprMd)
all_goals (try have := AstNode.sizeOf_val_lt v)
all_goals (try term_by_mem)
all_goals (try (cases exprMd; simp_all; omega))
-- For field inner expressions in attach-based:
all_goals (try (
have := List.sizeOf_lt_of_mem ‹_›
have := Variable.sizeOf_field_target_lt_of_eq _htv
omega))
-- Remaining goals
all_goals (
cases exprMd with | mk val src mmd =>
simp_all
omega)
return [⟨ .Assert { condition := ← recurseOne condExpr, summary }, source ⟩]
| .Assume c => return [⟨ .Assume (← recurseOne c), source ⟩]
| .ProveBy v p => return [⟨ .ProveBy (← recurseOne v) (← recurseOne p), source ⟩]
| .ContractOf ty f => return [⟨ .ContractOf ty (← recurseOne f), source ⟩]
| _ => return [exprMd]
termination_by (sizeOf exprMd, 0)
decreasing_by
all_goals simp_wf
all_goals (try have := AstNode.sizeOf_val_lt exprMd)
all_goals (try have := AstNode.sizeOf_val_lt v)
all_goals (try term_by_mem)
all_goals (try (cases exprMd; simp_all; omega))
-- For field inner expressions in attach-based:
all_goals (try (
have := List.sizeOf_lt_of_mem ‹_›
have := Variable.sizeOf_field_target_lt_of_eq _htv
omega))
-- Remaining goals
all_goals (
cases exprMd with | mk val src =>
simp_all
omega)

def heapTransformProcedure (model: SemanticModel) (proc : Procedure) : TransformM Procedure := do
let heapName : Identifier := "$heap"
Expand Down
9 changes: 9 additions & 0 deletions Strata/Languages/Laurel/LaurelCompilationPipeline.lean
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ private def runLaurelPasses (options : LaurelTranslateOptions) (program : Progra
-- Run resolve after the pass if needed
if pass.needsResolves then
let result := resolve program (some model)
let newErrors := result.errors.filter fun e => !resolutionErrors.contains e
if !newErrors.isEmpty then
let newDiags := newErrors.toList.map fun d =>
{ d with
message :=
s!"Internal error: resolution after '{pass.name}' introduced this diagnostic: {d.message}"
type := .StrataBug }
emit pass.name "laurel.st" program
return (program, model, allDiags ++ newDiags, allStats)
program := result.program
model := result.model
emit pass.name "laurel.st" program
Comment thread
MikaelMayer marked this conversation as resolved.
Expand Down
Loading
Loading