Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions Examples/expected/HeapReasoning.core.expected
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ Successfully parsed.
HeapReasoning.core.st(98, 2) [modifiesFrameRef1]: ✅ pass
HeapReasoning.core.st(103, 2) [modifiesFrameRef1]: ✅ pass
HeapReasoning.core.st(108, 2) [modifiesFrameRef1]: ✅ pass
[Container_ctor_ensures_4]: ✅ pass
HeapReasoning.core.st(86, 2) [Container_ctor_ensures_7]: ✅ pass
HeapReasoning.core.st(87, 2) [Container_ctor_ensures_8]: ✅ pass
HeapReasoning.core.st(88, 2) [Container_ctor_ensures_9]: ✅ pass
Expand All @@ -12,7 +11,6 @@ HeapReasoning.core.st(169, 2) [modifiesFrameRef2]: ✅ pass
HeapReasoning.core.st(172, 2) [modifiesFrameRef1Next]: ✅ pass
HeapReasoning.core.st(177, 2) [modifiesFrameRef2Next]: ✅ pass
HeapReasoning.core.st(132, 2) [UpdateContainers_ensures_5]: ✅ pass
[UpdateContainers_ensures_6]: ✅ pass
HeapReasoning.core.st(150, 2) [UpdateContainers_ensures_14]: ✅ pass
HeapReasoning.core.st(151, 2) [UpdateContainers_ensures_15]: ✅ pass
HeapReasoning.core.st(152, 2) [UpdateContainers_ensures_16]: ✅ pass
Expand Down Expand Up @@ -51,5 +49,4 @@ HeapReasoning.core.st(238, 2) [c2Pineapple0]: ✅ pass
HeapReasoning.core.st(240, 2) [c1NextEqC2]: ✅ pass
HeapReasoning.core.st(241, 2) [c2NextEqC1]: ✅ pass
HeapReasoning.core.st(195, 2) [Main_ensures_1]: ✅ pass
[Main_ensures_2]: ✅ pass
All 53 goals passed.
All 50 goals passed.
4 changes: 2 additions & 2 deletions Strata/Languages/Core/ProcedureEval.lean
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ def eval (E : Env) (p : Procedure) : Env × Statistics :=
match check.attr with
| .Free =>
-- NOTE: A free postcondition is not checked.
-- We simply change a free-postcondition to "true", but
-- We simply change a free-postcondition to "assume true", but
-- keep a record in the metadata field.
-- TODO: Perhaps introduce an "opaque" expression construct
-- that hides the expression from the evaluator, allowing us
-- to retain the postcondition body instead of replacing it
-- with "true".
(.assert label (.true ())
(.assume label (.true ())
((Imperative.MetaData.pushElem
#[]
(.label label)
Expand Down
135 changes: 78 additions & 57 deletions Strata/Languages/Laurel/CoreGroupingAndOrdering.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@
-/

module
public import Strata.Languages.Laurel.Laurel
public import Strata.Languages.Laurel.TransparencyPass
import Strata.DL.Lambda.LExpr
import Strata.DDM.Util.Graph.Tarjan
import Strata.Languages.Laurel.Grammar.AbstractToConcreteTreeTranslator

/-!
## Grouping and Ordering for Core Translation

Utilities for computing the grouping and topological ordering of Laurel
declarations before they are emitted as Strata Core declarations.

- `groupDatatypesByScc` — groups mutually recursive datatypes into SCC groups
using Tarjan's SCC algorithm.
- `computeSccDecls` — builds the procedure call graph, runs Tarjan's SCC
algorithm, and returns each SCC as a list of procedures paired with a flag
indicating whether the SCC is recursive. The result is in reverse topological
Expand Down Expand Up @@ -90,7 +89,7 @@ def collectStaticCallNames (expr : StmtExprMd) : List String :=
| .InstanceCall t _ args =>
collectStaticCallNames t ++ args.flatMap (fun a => collectStaticCallNames a)
| .Old v | .Fresh v | .Assume v => collectStaticCallNames v
| .Assert ⟨cond, _summary⟩ => collectStaticCallNames cond
| .Assert ⟨cond, _summary, _⟩ => collectStaticCallNames cond
| .ProveBy v p => collectStaticCallNames v ++ collectStaticCallNames p
| .ReferenceEquals l r => collectStaticCallNames l ++ collectStaticCallNames r
| .AsType t _ | .IsType t _ => collectStaticCallNames t
Expand All @@ -113,27 +112,24 @@ Build the procedure call graph, run Tarjan's SCC algorithm, and return each SCC
as a list of procedures paired with a flag indicating whether the SCC is recursive.
Results are in reverse topological order: dependencies before dependents.

Procedures with an `invokeOn` trigger are placed as early as possible — before
unrelated procedures without one — by stably partitioning them first before building
Procedures with `invokeOn` are placed as early as possible — before
unrelated procedures without them — by stably partitioning them first before building
the graph. Tarjan then naturally assigns them lower indices, causing them to appear
earlier in the output.

External procedures are excluded.
-/
public def computeSccDecls (program : Program) : List (List Procedure × Bool) :=
-- External procedures are completely ignored (not translated to Core).
public def computeSccDecls (program : UnorderedCoreWithLaurelTypes) : List (List Procedure × Bool) :=
-- Stable partition: procedures with invokeOn come first, preserving relative
-- order within each group. Tarjan then places them earlier in the topological output.
let allProcs := program.functions ++ program.coreProcedures
let (withInvokeOn, withoutInvokeOn) :=
(program.staticProcedures.filter (fun p => !p.body.isExternal))
|>.partition (fun p => p.invokeOn.isSome)
let nonExternal : List Procedure := withInvokeOn ++ withoutInvokeOn
allProcs.partition (fun p => p.invokeOn.isSome)
let orderedProcs : List Procedure := withInvokeOn ++ withoutInvokeOn

-- Build a call-graph over all non-external procedures.
-- Build a call-graph over all procedures.
-- An edge proc → callee means proc's body/contracts contain a StaticCall to callee.
let nonExternalArr : Array Procedure := nonExternal.toArray
let procsArr : Array Procedure := orderedProcs.toArray
let nameToIdx : Std.HashMap String Nat :=
nonExternalArr.foldl (fun (acc : Std.HashMap String Nat × Nat) proc =>
procsArr.foldl (fun (acc : Std.HashMap String Nat × Nat) proc =>
(acc.1.insert proc.name.text acc.2, acc.2 + 1)) ({}, 0) |>.1

-- Collect all callee names from a procedure's body and contracts.
Expand All @@ -149,9 +145,9 @@ public def computeSccDecls (program : Program) : List (List Procedure × Bool) :
(bodyExprs ++ contractExprs).flatMap collectStaticCallNames

-- Build the OutGraph for Tarjan.
let n := nonExternalArr.size
let n := procsArr.size
let graph : Strata.OutGraph n :=
nonExternalArr.foldl (fun (acc : Strata.OutGraph n × Nat) proc =>
procsArr.foldl (fun (acc : Strata.OutGraph n × Nat) proc =>
let callerIdx := acc.2
let g := acc.1
let callees := procCallees proc
Expand All @@ -167,7 +163,7 @@ public def computeSccDecls (program : Program) : List (List Procedure × Bool) :

sccs.toList.filterMap fun scc =>
let procs := scc.toList.filterMap fun idx =>
nonExternalArr[idx.val]?
procsArr[idx.val]?
if procs.isEmpty then none else
let isRecursive := procs.length > 1 ||
(match scc.toList.head? with
Expand All @@ -176,60 +172,85 @@ public def computeSccDecls (program : Program) : List (List Procedure × Bool) :
some (procs, isRecursive)

/--
A single declaration in an ordered Laurel program. Declarations are in
A single declaration in a CoreWithLaurelTypes program. Declarations are in
dependency order (dependencies before dependents).
-/
public inductive OrderedDecl where
/-- A group of functions (single non-recursive, or mutually recursive). -/
| procs (procs : List Procedure) (isRecursive : Bool)
/-- A group of functions (single non-recursive, or mutually recursive).
Invariant: `funcs.length > 1 → isRecursive = true`. -/
| funcs (funcs : List Procedure) (isRecursive : Bool)
/-- A single (non-functional) procedure. -/
| procedure (procedure : Procedure)
/-- A group of (possibly mutually recursive) datatypes. -/
| datatypes (dts : List DatatypeDefinition)
/-- A named constant. -/
| constant (c : Constant)

/--
A Laurel program whose declarations have been grouped and topologically ordered.
Produced by `orderProgram` from a `Program`.
A program whose declarations have been grouped and topologically ordered,
using Laurel types. Produced by `orderFunctionsAndProcedures` from a
`UnorderedCoreWithLaurelTypes`.
-/
public structure OrderedLaurel where
public structure CoreWithLaurelTypes where
decls : List OrderedDecl

/--
Group mutually recursive datatypes into SCC groups using Tarjan's SCC algorithm.
Returns groups in topological order (dependencies before dependents).
-/
public def groupDatatypesByScc (program : Program) : List (List DatatypeDefinition) :=
let laurelDatatypes := program.types.filterMap fun td => match td with
| .Datatype dt => some dt
| _ => none
let n := laurelDatatypes.length
if n == 0 then [] else
let nameToIdx : Std.HashMap String Nat :=
laurelDatatypes.foldlIdx (fun m i dt => m.insert dt.name.text i) {}
let edges : List (Nat × Nat) :=
laurelDatatypes.foldlIdx (fun acc i dt =>
(datatypeRefs dt).filterMap nameToIdx.get? |>.foldl (fun acc j => (j, i) :: acc) acc) []
let g := OutGraph.ofEdges! n edges
let dtsArr := laurelDatatypes.toArray
OutGraph.tarjan g |>.toList.filterMap fun comp =>
let members := comp.toList.filterMap fun idx => dtsArr[idx]?
if members.isEmpty then none else some members
open Std (Format ToFormat)

/--
Group procedures into SCC groups and wrap them as `OrderedDecl.procs`.
-/
public def groupProcsByScc (program : Program) : List OrderedDecl :=
(computeSccDecls program).map fun (procs, isRecursive) =>
OrderedDecl.procs procs isRecursive
public section

def formatOrderedDecl : OrderedDecl → Format
| .funcs funcs _ => Format.joinSep (funcs.map ToFormat.format) "\n\n"
| .procedure proc => ToFormat.format proc
| .datatypes dts => Format.joinSep (dts.map ToFormat.format) "\n\n"
| .constant c => ToFormat.format c

instance : ToFormat OrderedDecl where
format := formatOrderedDecl

def formatCoreWithLaurelTypes (p : CoreWithLaurelTypes) : Format :=
Format.joinSep (p.decls.map formatOrderedDecl) "\n\n"

instance : ToFormat CoreWithLaurelTypes where
format := formatCoreWithLaurelTypes

end -- public section

/--
Produce an `OrderedLaurel` from a `Program` by grouping and ordering
procedures via SCC, collecting datatypes, and constants.
Produce a `CoreWithLaurelTypes` from a `UnorderedCoreWithLaurelTypes` by
computing a combined ordering of functions and proofs using the call graph,
then collecting datatypes and constants.

Functions are grouped into SCCs (for mutual recursion). Proofs are emitted
as individual `procedure` decls. Both participate in the topological ordering
so that axioms are available to functions that need them.
-/
public def orderProgram (program : Program) : OrderedLaurel :=
let datatypeDecls := (groupDatatypesByScc program).map OrderedDecl.datatypes
public def orderFunctionsAndProcedures (program : UnorderedCoreWithLaurelTypes) : CoreWithLaurelTypes :=
let datatypeDecls := (groupDatatypesByScc' program).map OrderedDecl.datatypes
let constantDecls := program.constants.map OrderedDecl.constant
let procDecls := groupProcsByScc program
{ decls := datatypeDecls ++ constantDecls ++ procDecls }
let funcNames : Std.HashSet String :=
program.functions.foldl (fun s p => s.insert p.name.text) {}
let orderedDecls := (computeSccDecls program).flatMap fun (procs, isRecursive) =>
-- Split the SCC into functions and proofs
let (funcs, proofs) := procs.partition (fun p => funcNames.contains p.name.text)
let funcDecl := if funcs.isEmpty then [] else [OrderedDecl.funcs funcs isRecursive]
let proofDecls := proofs.map OrderedDecl.procedure
funcDecl ++ proofDecls
{ decls := datatypeDecls ++ constantDecls ++ orderedDecls }
where
/-- Group datatypes from a UnorderedCoreWithLaurelTypes by SCC. -/
groupDatatypesByScc' (program : UnorderedCoreWithLaurelTypes) : List (List DatatypeDefinition) :=
let laurelDatatypes := program.datatypes
let n := laurelDatatypes.length
if n == 0 then [] else
let nameToIdx : Std.HashMap String Nat :=
laurelDatatypes.foldlIdx (fun m i dt => m.insert dt.name.text i) {}
let edges : List (Nat × Nat) :=
laurelDatatypes.foldlIdx (fun acc i dt =>
(datatypeRefs dt).filterMap nameToIdx.get? |>.foldl (fun acc j => (j, i) :: acc) acc) []
let g := OutGraph.ofEdges! n edges
let dtsArr := laurelDatatypes.toArray
OutGraph.tarjan g |>.toList.filterMap fun comp =>
let members := comp.toList.filterMap fun idx => dtsArr[idx]?
if members.isEmpty then none else some members

end Strata.Laurel
5 changes: 3 additions & 2 deletions Strata/Languages/Laurel/DesugarShortCircuit.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
-/
module

public import Strata.Languages.Laurel.MapStmtExpr
public import Strata.Languages.Laurel.LiftImperativeExpressions
public import Strata.Languages.Laurel.Resolution
import Strata.Languages.Laurel.LiftImperativeExpressions
import Strata.Languages.Laurel.MapStmtExpr

/-!
# Desugar Short-Circuit Operators
Expand Down
11 changes: 11 additions & 0 deletions Strata/Languages/Laurel/EliminateValueReturns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
module

public import Strata.Languages.Laurel.MapStmtExpr
public import Strata.Languages.Laurel.TransparencyPass

/-!
# Eliminate Value Returns
Expand Down Expand Up @@ -87,6 +88,16 @@ def eliminateValueReturnsTransform (program : Program) : Program × Array Diagno
) ([], #[])
({ program with staticProcedures := procs.reverse }, diags)

/-- Transform an `UnorderedCoreWithLaurelTypes` by eliminating value returns
in all core (non-functional) procedures. -/
def eliminateValueReturnsTransformUnordered (uc : UnorderedCoreWithLaurelTypes)
: UnorderedCoreWithLaurelTypes × Array DiagnosticModel :=
let (procs, diags) := uc.coreProcedures.foldl (fun (ps, ds) proc =>
let (proc', procDiags) := eliminateValueReturnsInProc proc
(proc' :: ps, ds ++ procDiags)
) ([], #[])
({ uc with coreProcedures := procs.reverse }, diags)

end -- public section

end Laurel
6 changes: 3 additions & 3 deletions Strata/Languages/Laurel/HeapParameterization.lean
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def collectExpr (expr : StmtExpr) : StateM AnalysisResult Unit := do
| .Assigned n => collectExprMd n
| .Old v => collectExprMd v
| .Fresh v => collectExprMd v
| .Assert ⟨c, _⟩ => collectExprMd c
| .Assert ⟨c, _, _⟩ => collectExprMd c
| .Assume c => collectExprMd c
| .ProveBy v p => collectExprMd v; collectExprMd p
| .ContractOf _ f => collectExprMd f
Expand Down Expand Up @@ -434,8 +434,8 @@ where
| .Assigned n => return [⟨ .Assigned (← recurseOne n), source ⟩]
| .Old v => return [⟨ .Old (← recurseOne v), source ⟩]
| .Fresh v => return [⟨ .Fresh (← recurseOne v), source ⟩]
| .Assert ⟨condExpr, summary⟩ =>
return [⟨ .Assert { condition := ← recurseOne condExpr, summary }, source ⟩]
| .Assert ⟨condExpr, summary, free⟩ =>
return [⟨ .Assert { condition := ← recurseOne condExpr, summary, free }, source ⟩]
| .Assume c => return [⟨ .Assume (← recurseOne c), source ⟩]
| .ProveBy v p => return [⟨ .ProveBy (← recurseOne v) (← recurseOne p), source ⟩]
| .ContractOf ty f => return [⟨ .ContractOf ty (← recurseOne f), source ⟩]
Expand Down
8 changes: 5 additions & 3 deletions Strata/Languages/Laurel/InferHoleTypes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,10 @@ private def inferExpr (expr : StmtExprMd) (expectedType : HighTypeMd) : InferHol
| some d => pure (some (← inferExpr d (⟨ .TInt, source ⟩)))
| none => pure none
return ⟨.While (← inferExpr cond ⟨ .TBool, source ⟩) (← invs.mapM (inferExpr · ⟨ .TBool, source ⟩)) dec' (← inferExpr body ⟨ .TVoid, source⟩), source⟩
| .Assert ⟨condExpr, summary⟩ =>
return ⟨.Assert { condition := ← inferExpr condExpr ⟨ .TBool, source ⟩, summary }, source⟩
| .Assume cond => return ⟨.Assume (← inferExpr cond ⟨ .TBool, source ⟩), source⟩
| .Assert ⟨condExpr, summary, free⟩ =>
return ⟨.Assert { condition := ← inferExpr condExpr ⟨ .TBool, source ⟩, summary, free }, source⟩
| .Assume cond =>
return ⟨.Assume (← inferExpr cond ⟨ .TBool, source ⟩), source⟩
| .Return (some retExpr) =>
return ⟨.Return (some (← inferExpr retExpr (← get).currentOutputType)), source⟩
| .Old v => return ⟨.Old (← inferExpr v expectedType), source⟩
Expand Down Expand Up @@ -180,6 +181,7 @@ private def inferProcedure (proc : Procedure) : InferHoleM Procedure := do

/--
Annotate every `.Hole` in the program with a type inferred from context.
Returns the updated program and any diagnostics (e.g. holes whose type could not be inferred).
-/
def inferHoleTypes (model : SemanticModel) (program : Program) : Program × List DiagnosticModel × Statistics :=
let initState : InferHoleState := { model := model, currentOutputType := { val := .Unknown, source := none }}
Expand Down
40 changes: 40 additions & 0 deletions Strata/Languages/Laurel/Laurel.lean
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ structure Condition where
condition : AstNode StmtExpr
/-- Optional human-readable summary describing the property being checked. -/
summary : Option String := none
/-- When `true`, this condition is *free*: assumed but not checked.
A free precondition is assumed by the implementation but not asserted at
call sites. A free postcondition is assumed upon return from calls but
not checked on exit from implementations. -/
free : Bool := false

/--
The body of a procedure. A body can be transparent (with a visible
Expand Down Expand Up @@ -445,6 +450,41 @@ def HighType.isBool : HighType → Bool
| TBool => true
| _ => false

/-- Return the constructor name of a `StmtExprMd` as a `String`. -/
def StmtExpr.constructorName (e : StmtExpr) : String :=
match e with
| .IfThenElse .. => "IfThenElse"
| .Block .. => "Block"
| .While .. => "While"
| .Exit .. => "Exit"
| .Return .. => "Return"
| .LiteralInt .. => "LiteralInt"
| .LiteralBool .. => "LiteralBool"
| .LiteralString .. => "LiteralString"
| .LiteralDecimal .. => "LiteralDecimal"
| .Var .. => "Var"
| .Assign .. => "Assign"
| .PureFieldUpdate .. => "PureFieldUpdate"
| .StaticCall .. => "StaticCall"
| .PrimitiveOp .. => "PrimitiveOp"
| .New .. => "New"
| .This => "This"
| .ReferenceEquals .. => "ReferenceEquals"
| .AsType .. => "AsType"
| .IsType .. => "IsType"
| .InstanceCall .. => "InstanceCall"
| .Quantifier .. => "Quantifier"
| .Assigned .. => "Assigned"
| .Old .. => "Old"
| .Fresh .. => "Fresh"
| .Assert .. => "Assert"
| .Assume .. => "Assume"
| .ProveBy .. => "ProveBy"
| .ContractOf .. => "ContractOf"
| .Abstract => "Abstract"
| .All => "All"
| .Hole .. => "Hole"

/-- Check whether a single modifies entry is the wildcard (`*`). -/
def StmtExprMd.isWildcard (m : StmtExprMd) : Bool := match m.val with | .All => true | _ => false

Expand Down
Loading