From 7ad172e9ce72f7f9699eef2a24736dc611e5096a Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 20 Jun 2025 16:21:21 +0100 Subject: [PATCH 01/31] Move the simp wrappers --- backends/lean/Aeneas.lean | 1 + backends/lean/Aeneas/BvTac/BvTac.lean | 2 +- backends/lean/Aeneas/Progress/Progress.lean | 16 +-- .../lean/Aeneas/Progress/ProgressStar.lean | 2 +- .../lean/Aeneas/ScalarTac/CondSimpTac.lean | 8 +- backends/lean/Aeneas/ScalarTac/ScalarTac.lean | 21 +-- backends/lean/Aeneas/Utils.lean | 120 ++---------------- 7 files changed, 36 insertions(+), 134 deletions(-) diff --git a/backends/lean/Aeneas.lean b/backends/lean/Aeneas.lean index 1235dc570..e094e663c 100644 --- a/backends/lean/Aeneas.lean +++ b/backends/lean/Aeneas.lean @@ -19,6 +19,7 @@ import Aeneas.Saturate import Aeneas.ScalarDecrTac import Aeneas.ScalarNF import Aeneas.ScalarTac +import Aeneas.Simp import Aeneas.SimpIfs import Aeneas.SimpLemmas import Aeneas.SimpLists diff --git a/backends/lean/Aeneas/BvTac/BvTac.lean b/backends/lean/Aeneas/BvTac/BvTac.lean index 381042e21..f7a2e4113 100644 --- a/backends/lean/Aeneas/BvTac/BvTac.lean +++ b/backends/lean/Aeneas/BvTac/BvTac.lean @@ -71,7 +71,7 @@ partial def bvTacPreprocess (config : Config) (n : Option Expr): TacticM Unit := /- Call `simp_all ` to normalize the goal a bit -/ let simpLemmas ← bvifySimpExt.getTheorems let simprocs ← bvifySimprocExt.getSimprocs - Utils.simpAll {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 0} true + Simp.simpAll {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 0} true {simprocs := #[simprocs], simpThms := #[simpLemmas]} allGoals do trace[BvTac] "Goal after `simp_all`: {← getMainGoal}" diff --git a/backends/lean/Aeneas/Progress/Progress.lean b/backends/lean/Aeneas/Progress/Progress.lean index ef4fa9694..795217e44 100644 --- a/backends/lean/Aeneas/Progress/Progress.lean +++ b/backends/lean/Aeneas/Progress/Progress.lean @@ -235,7 +235,7 @@ def progressWith (fExpr : Expr) (th : Expr) splitEqAndPost fun hEq hPost ids => do trace[Progress] "eq and post:\n{hEq} : {← inferType hEq}\n{hPost}" trace[Progress] "current goal: {← getMainGoal}" - simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} + Simp.simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} {simpThms := #[← progressSimpExt.getTheorems], hypsToUse := #[hEq.fvarId!]} (.targets #[] true) /- It may happen that at this point the goal is already solved (though this is rare) TODO: not sure this is the best way of checking it -/ @@ -247,7 +247,7 @@ def progressWith (fExpr : Expr) (th : Expr) else trace[Progress] "goal after applying the eq and simplifying the binds: {← getMainGoal}" -- TODO: remove this? (some types get unfolded too much: we "fold" them back) - let _ ← tryTac (simpAt true {} {addSimpThms := scalar_eqs} .wildcard_dep) + let _ ← tryTac (Simp.simpAt true {} {addSimpThms := scalar_eqs} .wildcard_dep) trace[Progress] "goal after folding back scalar types: {← getMainGoal}" -- Clear the equality, unless the user requests not to do so if keep.isSome then pure () @@ -339,16 +339,16 @@ def progressWith (fExpr : Expr) (th : Expr) let nonHPosts ← Utils.refreshFVarIds Std.HashSet.emptyWithCapacity hPostsSet let nonHPosts := Std.HashSet.ofArray nonHPosts -- Simplify the post-conditions - let args : SimpArgs := + let args : Simp.SimpArgs := {simpThms := #[← progressPostSimpExt.getTheorems], simprocs := #[← ScalarTac.scalarTacSimprocExt.getSimprocs]} - simpAt true { maxDischargeDepth := 0, failIfUnchanged := false } + Simp.simpAt true { maxDischargeDepth := 0, failIfUnchanged := false } args (.targets hPosts false) -- The introduced post-conditions may have been modified, so we need to recompute their fvar ids let hPosts ← Utils.refreshFVarIds Std.HashSet.emptyWithCapacity nonHPosts -- Simplify the goal again tryTac do - simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} + Simp.simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} {simpThms := #[← progressSimpExt.getTheorems], declsToUnfold := #[``pure]} (.targets #[] true) -- pure (hPosts) @@ -535,7 +535,7 @@ def evalProgress (keep keepPretty : Option Name) (withArg: Option Expr) (ids: Ar : TacticM Stats := do /- Simplify the goal -- TODO: this might close it: we need to check that and abort if necessary, and properly track that in the `Stats` -/ - simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} + Simp.simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} {simpThms := #[← progressSimpExt.getTheorems]} (.targets #[] true) withMainContext do let splitPost := true @@ -557,11 +557,11 @@ def evalProgress (keep keepPretty : Option Name) (withArg: Option Expr) (ids: Ar if ← isProp decl.type then pure (some decl.fvarId) else pure none - let simpArgs : SimpArgs := {simpThms := #[simpLemmas], hypsToUse := localAsms.toArray} + let simpArgs : Simp.SimpArgs := {simpThms := #[simpLemmas], hypsToUse := localAsms.toArray} let simpTac : TacticM Unit := do trace[Progress] "Attempting to solve with `simp [*]`" -- Simplify the goal - Utils.simpAt false { maxDischargeDepth := 1 } simpArgs (.targets #[] true) + Simp.simpAt false { maxDischargeDepth := 1 } simpArgs (.targets #[] true) -- Raise an error if the goal is not proved allGoalsNoRecover (throwError "Goal not proved") /- We use our custom assumption tactic, which instantiates meta-variables only if there is a single diff --git a/backends/lean/Aeneas/Progress/ProgressStar.lean b/backends/lean/Aeneas/Progress/ProgressStar.lean index b468a734e..53df8679c 100644 --- a/backends/lean/Aeneas/Progress/ProgressStar.lean +++ b/backends/lean/Aeneas/Progress/ProgressStar.lean @@ -152,7 +152,7 @@ attribute [progress_simps] Aeneas.Std.bind_assoc_eq partial def evalProgressStar(cfg: Config): TacticM Info := withMainContext do focus do trace[ProgressStar] "Simplifying the goal: {←(getMainTarget >>= (liftM ∘ ppExpr))}" - Utils.simpAt (simpOnly := true) + Simp.simpAt (simpOnly := true) { maxDischargeDepth := 1, failIfUnchanged := false} {simpThms := #[← Progress.progressSimpExt.getTheorems]} (.targets #[] true) diff --git a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean index 7d93d5f07..f03e7079f 100644 --- a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean +++ b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean @@ -74,7 +74,7 @@ structure CondSimpArgs where def condSimpTacSimp (config : Simp.Config) (args : CondSimpArgs) (loc : Utils.Location) (additionalAsms : Array FVarId := #[]) (dischWithScalarTac : Bool) : TacticM Unit := do withMainContext do - let simpArgs : Utils.SimpArgs := + let simpArgs : Simp.SimpArgs := {simpThms := args.simpThms, simprocs := args.simprocs, declsToUnfold := args.declsToUnfold, @@ -87,11 +87,11 @@ def condSimpTacSimp (config : Simp.Config) (args : CondSimpArgs) (loc : Utils.Lo let dischargeWrapper := Lean.Elab.Tactic.Simp.DischargeWrapper.custom ref d let _ ← dischargeWrapper.with fun discharge? => do -- Initialize the simp context - let (ctx, simprocs) ← Utils.mkSimpCtx true config .simp simpArgs + let (ctx, simprocs) ← Simp.mkSimpCtx true config .simp simpArgs -- Apply the simplifier - let _ ← Utils.customSimpLocation ctx simprocs discharge? loc + let _ ← Simp.customSimpLocation ctx simprocs discharge? loc else - Utils.simpAt true config simpArgs loc + Simp.simpAt true config simpArgs loc /-- A helper to define tactics which perform conditional simplifications with `scalar_tac` as a discharger. -/ def condSimpTac diff --git a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean index 37e3885f2..f459e57ce 100644 --- a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean @@ -8,6 +8,7 @@ import Aeneas.ScalarTac.Init import Aeneas.Saturate import Aeneas.SimpScalar.Init import Aeneas.SimpBoolProp.SimpBoolProp +import Aeneas.Simp namespace Aeneas @@ -169,7 +170,7 @@ elab "scalar_tac_saturate" config:Parser.Tactic.optConfig : tactic => do let config ← elabConfig config let _ ← scalarTacSaturateForward config.toSaturateConfig (fun _ => pure ()) -def getSimpArgs : CoreM SimpArgs := do +def getSimpArgs : CoreM Simp.SimpArgs := do pure { simprocs := #[ ← SimpBoolProp.simpBoolPropSimprocExt.getSimprocs, @@ -191,13 +192,13 @@ def getSimpThmNames : CoreM (Array Name) := do /- Sometimes `simp at *` doesn't work in the presence of dependent types. However, simplifying the assumptions *does* work, hence this peculiar way of simplifying the context. -/ -def simpAsmsTarget (simpOnly : Bool) (config : Simp.Config) (args : SimpArgs) : TacticM Unit := +def simpAsmsTarget (simpOnly : Bool) (config : Simp.Config) (args : Simp.SimpArgs) : TacticM Unit := withMainContext do let lctx ← getLCtx let decls ← lctx.getDecls let props ← decls.filterM (fun d => do pure (← inferType d.type).isProp) let props := (props.map fun d => d.fvarId).toArray - Aeneas.Utils.simpAt simpOnly config args (.targets props true) + Aeneas.Simp.simpAt simpOnly config args (.targets props true) /- Boosting a bit the `omega` tac. -/ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do @@ -209,7 +210,7 @@ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do typed expressions such as `UScalar.ofNat`. -/ trace[ScalarTac] "Original goal before preprocessing: {← getMainGoal}" - let simpArgs : SimpArgs ← getSimpArgs + let simpArgs : Simp.SimpArgs ← getSimpArgs simpAsmsTarget true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} -- Remove the forall quantifiers to prepare for the call of `simp_all` (we -- don't want `simp_all` to use assumptions of the shape `∀ x, P x`)) @@ -223,7 +224,7 @@ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do if config.saturate then allGoalsNoRecover (scalarTacSaturateForward config.toSaturateConfig (fun _ => pure ())) trace[ScalarTac] "Goal after saturation: {← getMainGoal}" - let simpArgs : SimpArgs ← getSimpArgs + let simpArgs : Simp.SimpArgs ← getSimpArgs -- Apply `simpAll` if config.simpAllMaxSteps ≠ 0 then tryTac do @@ -231,7 +232,7 @@ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do will not have any effect. This is important because it often happens that the user instantiates one such assumptions with specific arguments, meaning that if we call `simpAll` naively, those instantiations will get simplified to `True` and thus eliminated. -/ - Utils.simpAll + Simp.simpAll {failIfUnchanged := false, maxSteps := config.simpAllMaxSteps, maxDischargeDepth := 0} true simpArgs -- We might have proven the goal @@ -240,7 +241,7 @@ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do return trace[ScalarTac] "Goal after simpAll: {← getMainGoal}" -- Call `simp` again, this time to inline the let-bindings (otherwise, omega doesn't always manage to deal with them) - Utils.simpAt true {zetaDelta := true, failIfUnchanged := false, maxDischargeDepth := 1} simpArgs .wildcard + Simp.simpAt true {zetaDelta := true, failIfUnchanged := false, maxDischargeDepth := 1} simpArgs .wildcard -- We might have proven the goal if (← getGoals).isEmpty then trace[ScalarTac] "Goal proven by preprocessing!" @@ -254,7 +255,7 @@ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do return trace[ScalarTac] "Goal after normCast: {← getMainGoal}" -- Call `simp` again because `normCast` sometimes does weird things - Utils.simpAt true {failIfUnchanged := false, maxDischargeDepth := 1} simpArgs .wildcard + Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 1} simpArgs .wildcard -- We might have proven the goal if (← getGoals).isEmpty then trace[ScalarTac] "Goal proven by preprocessing!" @@ -268,7 +269,7 @@ elab "scalar_tac_preprocess" config:Parser.Tactic.optConfig : tactic => do def scalarTacCore (config : Config) : Tactic.TacticM Unit := do Tactic.withMainContext do Tactic.focus do - let simpArgs : SimpArgs ← getSimpArgs + let simpArgs : Simp.SimpArgs ← getSimpArgs let g ← Tactic.getMainGoal trace[ScalarTac] "Original goal: {g}" -- Introduce all the universally quantified variables (includes the assumptions) @@ -284,7 +285,7 @@ def scalarTacCore (config : Config) : Tactic.TacticM Unit := do splitAll do allGoalsNoRecover (tryTac do - Utils.simpAll {failIfUnchanged := false, maxSteps := config.simpAllMaxSteps, maxDischargeDepth := 0} + Simp.simpAll {failIfUnchanged := false, maxSteps := config.simpAllMaxSteps, maxDischargeDepth := 0} true simpArgs) trace[ScalarTac] "Calling omega" allGoalsNoRecover (Tactic.Omega.omegaTactic {}) diff --git a/backends/lean/Aeneas/Utils.lean b/backends/lean/Aeneas/Utils.lean index 89c5b0ba0..7f4f5a439 100644 --- a/backends/lean/Aeneas/Utils.lean +++ b/backends/lean/Aeneas/Utils.lean @@ -769,116 +769,6 @@ example (h : ∃ x y z, x + y + z ≥ 0) : ∃ x, x ≥ 0 := by rename_i x y z exists x + y + z -structure SimpArgs where - simprocs : Simp.SimprocsArray := #[] - simpThms : Array SimpTheorems := #[] - addSimprocs : Array Name := #[] - declsToUnfold : Array Name := #[] - addSimpThms : Array Name := #[] - hypsToUse : Array FVarId := #[] - -/- Initialize a context for the `simp` function. - - The initialization of the context is adapted from `Tactic.elabSimpArgs`. - Something very annoying is that there is no function which allows to - initialize a simp context without doing an elaboration - as a consequence - we write our own here. -/ -def mkSimpCtx (simpOnly : Bool) (config : Simp.Config) (kind : SimpKind) (args : SimpArgs) : - Tactic.TacticM (Simp.Context × Simp.SimprocsArray) := do - -- Initialize either with the builtin simp theorems or with all the simp theorems - let simpThms ← - if simpOnly then Tactic.simpOnlyBuiltins.foldlM (·.addConst ·) ({} : SimpTheorems) - else getSimpTheorems - -- Add the equational theorems for the declarations to unfold - let addDeclToUnfold (thms : SimpTheorems) (decl : Name) : Tactic.TacticM SimpTheorems := - if kind == .dsimp then pure (thms.addDeclToUnfoldCore decl) - else thms.addDeclToUnfold decl - let simpThms ← - args.declsToUnfold.foldlM addDeclToUnfold simpThms - -- Add the hypotheses and the rewriting theorems - let simpThms ← - args.hypsToUse.foldlM (fun thms fvarId => - -- post: TODO: don't know what that is. It seems to be true by default. - -- inv: invert the equality - thms.add (.fvar fvarId) #[] (mkFVar fvarId) (post := true) (inv := false) - -- thms.eraseCore (.fvar fvar) - ) simpThms - -- Add the rewriting theorems to use - let simpThms ← - args.addSimpThms.foldlM (fun thms thmName => do - let info ← getConstInfo thmName - if (← isProp info.type) then - -- post: TODO: don't know what that is - -- inv: invert the equality - thms.addConst thmName (post := false) (inv := false) - else - throwError "Not a proposition: {thmName}" - ) simpThms - let congrTheorems ← getSimpCongrTheorems - let defaultSimprocs ← if simpOnly then pure {} else Simp.getSimprocs - let addSimprocs ← args.addSimprocs.foldlM (fun simprocs name => simprocs.add name true) defaultSimprocs - let ctx ← Simp.mkContext config (simpTheorems := #[simpThms] ++ args.simpThms) congrTheorems - pure (ctx, #[addSimprocs] ++ args.simprocs) - -inductive Location where - /-- Apply the tactic everywhere. Same as `Tactic.Location.wildcard` -/ - | wildcard - /-- Apply the tactic everywhere, including in the variable types (i.e., in - assumptions which are not propositions). - --/ - | wildcard_dep - /-- Same as Tactic.Location -/ - | targets (hypotheses : Array FVarId) (type : Bool) - --- Adapted from Tactic.simpLocation -def customSimpLocation (ctx : Simp.Context) (simprocs : Simp.SimprocsArray) - (discharge? : Option Simp.Discharge := none) - (loc : Location) : TacticM Simp.Stats := do - match loc with - | Location.targets hyps simplifyTarget => - -- Custom behavior: we directly provide the fvar ideas of the assumption rather than syntax - withMainContext do - simpLocation.go ctx simprocs discharge? hyps simplifyTarget - | Location.wildcard => - -- Simply call the regular simpLocation - simpLocation ctx simprocs discharge? Tactic.Location.wildcard - | Location.wildcard_dep => - -- Custom behavior - withMainContext do - -- Lookup *all* the declarations - let lctx ← Lean.MonadLCtx.getLCtx - let decls ← lctx.getDecls - let tgts := (decls.map (fun d => d.fvarId)).toArray - -- Call the regular simpLocation.go - simpLocation.go ctx simprocs discharge? tgts (simplifyTarget := true) - -/- Call the simp tactic. -/ -def simpAt (simpOnly : Bool) (config : Simp.Config) (args : SimpArgs) (loc : Location) : - Tactic.TacticM Unit := do - -- Initialize the simp context - let (ctx, simprocs) ← mkSimpCtx simpOnly config .simp args - -- Apply the simplifier - let _ ← customSimpLocation ctx simprocs (discharge? := .none) loc - -/- Call the dsimp tactic. -/ -def dsimpAt (simpOnly : Bool) (config : Simp.Config) (args : SimpArgs) (loc : Tactic.Location) : - Tactic.TacticM Unit := do - -- Initialize the simp context - let (ctx, simprocs) ← mkSimpCtx simpOnly config .dsimp args - -- Apply the simplifier - dsimpLocation ctx simprocs loc - --- Call the simpAll tactic -def simpAll (config : Simp.Config) (simpOnly : Bool) (args : SimpArgs) : - Tactic.TacticM Unit := do - -- Initialize the simp context - let (ctx, simprocs) ← mkSimpCtx simpOnly config .simpAll args - -- Apply the simplifier - let (result?, _) ← Lean.Meta.simpAll (← getMainGoal) ctx (simprocs := simprocs) - match result? with - | none => replaceMainGoal [] - | some mvarId => replaceMainGoal [mvarId] - /- Adapted from Elab.Tactic.Rewrite -/ def rewriteTarget (eqThm : Expr) (symm : Bool) (config : Rewrite.Config := {}) : TacticM Unit := do Term.withSynthesize <| withMainContext do @@ -1693,6 +1583,16 @@ def duplicateAssumptions (toDuplicate : Option (Array FVarId) := none) : elab "duplicate_assumptions" : tactic => do let _ ← duplicateAssumptions +inductive Location where + /-- Apply the tactic everywhere. Same as `Location.wildcard` -/ + | wildcard + /-- Apply the tactic everywhere, including in the variable types (i.e., in + assumptions which are not propositions). + --/ + | wildcard_dep + /-- Same as Location -/ + | targets (hypotheses : Array FVarId) (type : Bool) + /-- info: example (a : Nat) From 5561355eff64a97665e9786a0af84b8f6a4b3235 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 20 Jun 2025 17:52:51 +0100 Subject: [PATCH 02/31] Update the simp wrappers so that they output the updated fvar ids --- backends/lean/Aeneas/BvTac/BvTac.lean | 36 ++--- backends/lean/Aeneas/Bvify/Bvify.lean | 2 +- backends/lean/Aeneas/Progress/Progress.lean | 43 +++--- .../lean/Aeneas/Progress/ProgressStar.lean | 5 +- .../lean/Aeneas/ScalarTac/CondSimpTac.lean | 41 +++--- backends/lean/Aeneas/ScalarTac/ScalarTac.lean | 6 +- backends/lean/Aeneas/Simp.lean | 1 + backends/lean/Aeneas/Simp/Simp.lean | 133 ++++++++++++++++++ backends/lean/Aeneas/Utils.lean | 16 --- 9 files changed, 199 insertions(+), 84 deletions(-) create mode 100644 backends/lean/Aeneas/Simp.lean create mode 100644 backends/lean/Aeneas/Simp/Simp.lean diff --git a/backends/lean/Aeneas/BvTac/BvTac.lean b/backends/lean/Aeneas/BvTac/BvTac.lean index f7a2e4113..9612ebfd9 100644 --- a/backends/lean/Aeneas/BvTac/BvTac.lean +++ b/backends/lean/Aeneas/BvTac/BvTac.lean @@ -55,26 +55,26 @@ partial def bvTacPreprocess (config : Config) (n : Option Expr): TacticM Unit := /- First try simplifying the goal - if it is an (in-)equality between scalars, it may get the bitwidth to use for the bit-vectors might be obvious from the goal: we marked some theorems wiht `bvify_simps` for this reason. -/ - Bvify.bvifyTacSimp (Utils.Location.targets #[] true) + let r ← Bvify.bvifyTacSimp (Utils.Location.targets #[] true) + if r.isNone then return /- The simp call above may have proven the goal (unlikely, but we have to take this into account) -/ - allGoals do - trace[BvTac] "Goal after `bvifyTacSimp`: {← getMainGoal}" - /- Figure out the bitwidth if the user didn't provide it -/ - let n ← do - match n with - | some n => pure n - | none => getn - /- Then apply bvify -/ - bvifyTac config.toConfig n Utils.Location.wildcard - trace[BvTac] "Goal after `bvifyTac`: {← getMainGoal}" - /- Call `simp_all ` to normalize the goal a bit -/ - let simpLemmas ← bvifySimpExt.getTheorems - let simprocs ← bvifySimprocExt.getSimprocs - Simp.simpAll {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 0} true - {simprocs := #[simprocs], simpThms := #[simpLemmas]} - allGoals do - trace[BvTac] "Goal after `simp_all`: {← getMainGoal}" + trace[BvTac] "Goal after `bvifyTacSimp`: {← getMainGoal}" + /- Figure out the bitwidth if the user didn't provide it -/ + let n ← do + match n with + | some n => pure n + | none => getn + /- Then apply bvify -/ + bvifyTac config.toConfig n Utils.Location.wildcard + trace[BvTac] "Goal after `bvifyTac`: {← getMainGoal}" + /- Call `simp_all ` to normalize the goal a bit -/ + let simpLemmas ← bvifySimpExt.getTheorems + let simprocs ← bvifySimprocExt.getSimprocs + Simp.simpAll {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 0} true + {simprocs := #[simprocs], simpThms := #[simpLemmas]} + -- The simpAll may have solved the goal, so we need to be careful + allGoals do trace[BvTac] "Goal after `simp_all`: {← getMainGoal}" elab "bv_tac_preprocess" config:Parser.Tactic.optConfig n:(colGt term)? : tactic => do bvTacPreprocess (← elabConfig config) (← optElabTerm n) diff --git a/backends/lean/Aeneas/Bvify/Bvify.lean b/backends/lean/Aeneas/Bvify/Bvify.lean index 261b610c8..dce465f91 100644 --- a/backends/lean/Aeneas/Bvify/Bvify.lean +++ b/backends/lean/Aeneas/Bvify/Bvify.lean @@ -280,7 +280,7 @@ def bvifyAddSimpThms (n : Expr) : TacticM (Array FVarId) := do def bvifySimpConfig : Simp.Config := {maxDischargeDepth := 2, failIfUnchanged := false} -def bvifyTacSimp (loc : Utils.Location) : TacticM Unit := do +def bvifyTacSimp (loc : Utils.Location) : TacticM (Option (Array FVarId)) := do let args : ScalarTac.CondSimpArgs := { simpThms := #[← bvifySimpExt.getTheorems, ← SimpBoolProp.simpBoolPropSimpExt.getTheorems] simprocs := #[← bvifySimprocExt.getSimprocs, ← SimpBoolProp.simpBoolPropSimprocExt.getSimprocs] diff --git a/backends/lean/Aeneas/Progress/Progress.lean b/backends/lean/Aeneas/Progress/Progress.lean index 795217e44..f8c2de626 100644 --- a/backends/lean/Aeneas/Progress/Progress.lean +++ b/backends/lean/Aeneas/Progress/Progress.lean @@ -235,19 +235,17 @@ def progressWith (fExpr : Expr) (th : Expr) splitEqAndPost fun hEq hPost ids => do trace[Progress] "eq and post:\n{hEq} : {← inferType hEq}\n{hPost}" trace[Progress] "current goal: {← getMainGoal}" - Simp.simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} + let r ← Simp.simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} {simpThms := #[← progressSimpExt.getTheorems], hypsToUse := #[hEq.fvarId!]} (.targets #[] true) /- It may happen that at this point the goal is already solved (though this is rare) TODO: not sure this is the best way of checking it -/ - let goals ← getUnsolvedGoals - assert! (goals.length ≤ 1) -- We focused on the main goal so there should be at most one goal - if goals == [] then + if r.isNone then trace[Progress] "The main goal was solved!" next (Ok #[]) else trace[Progress] "goal after applying the eq and simplifying the binds: {← getMainGoal}" -- TODO: remove this? (some types get unfolded too much: we "fold" them back) - let _ ← tryTac (Simp.simpAt true {} {addSimpThms := scalar_eqs} .wildcard_dep) + tryTac (do let _ ← Simp.simpAt true {} {addSimpThms := scalar_eqs} .wildcard_dep) trace[Progress] "goal after folding back scalar types: {← getMainGoal}" -- Clear the equality, unless the user requests not to do so if keep.isSome then pure () @@ -331,27 +329,26 @@ def progressWith (fExpr : Expr) (th : Expr) setGoals curGoals let hPosts ← match curGoals with - | [] => pure hPosts + | [] => pure (hPosts) | [ _ ] => - /- Compute the list of assumptions which are not post-conditions (we need this below to re-compute - the list of introduced post-conditions after the simplification) -/ - let hPostsSet := Std.HashSet.ofArray hPosts - let nonHPosts ← Utils.refreshFVarIds Std.HashSet.emptyWithCapacity hPostsSet - let nonHPosts := Std.HashSet.ofArray nonHPosts -- Simplify the post-conditions let args : Simp.SimpArgs := {simpThms := #[← progressPostSimpExt.getTheorems], simprocs := #[← ScalarTac.scalarTacSimprocExt.getSimprocs]} - Simp.simpAt true { maxDischargeDepth := 0, failIfUnchanged := false } + let hPosts ← Simp.simpAt true { maxDischargeDepth := 0, failIfUnchanged := false } args (.targets hPosts false) - -- The introduced post-conditions may have been modified, so we need to recompute their fvar ids - let hPosts ← Utils.refreshFVarIds Std.HashSet.emptyWithCapacity nonHPosts - -- Simplify the goal again - tryTac do - Simp.simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} - {simpThms := #[← progressSimpExt.getTheorems], declsToUnfold := #[``pure]} (.targets #[] true) - -- - pure (hPosts) + match hPosts with + | none => + -- We actually closed the goal: we shouldn't get there + -- TODO: make this more robust + throwError "Unexpected: goal closed by simplifying the introduced post-conditions" + | some hPosts => + -- Simplify the goal again + tryTac do + let _ ← Simp.simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} + {simpThms := #[← progressSimpExt.getTheorems], declsToUnfold := #[``pure]} (.targets #[] true) + -- + pure (hPosts) | _ => throwError "Unexpected number of goals" let curGoals ← getUnsolvedGoals trace[Progress] "Main goal after simplifying the post-conditions and the target: {curGoals}" @@ -535,7 +532,7 @@ def evalProgress (keep keepPretty : Option Name) (withArg: Option Expr) (ids: Ar : TacticM Stats := do /- Simplify the goal -- TODO: this might close it: we need to check that and abort if necessary, and properly track that in the `Stats` -/ - Simp.simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} + let _ ← Simp.simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} {simpThms := #[← progressSimpExt.getTheorems]} (.targets #[] true) withMainContext do let splitPost := true @@ -561,9 +558,9 @@ def evalProgress (keep keepPretty : Option Name) (withArg: Option Expr) (ids: Ar let simpTac : TacticM Unit := do trace[Progress] "Attempting to solve with `simp [*]`" -- Simplify the goal - Simp.simpAt false { maxDischargeDepth := 1 } simpArgs (.targets #[] true) + let r ← Simp.simpAt false { maxDischargeDepth := 1 } simpArgs (.targets #[] true) -- Raise an error if the goal is not proved - allGoalsNoRecover (throwError "Goal not proved") + if r.isSome then throwError "Goal not proved" /- We use our custom assumption tactic, which instantiates meta-variables only if there is a single assumption matching the goal. -/ let customAssumTac : TacticM Unit := do diff --git a/backends/lean/Aeneas/Progress/ProgressStar.lean b/backends/lean/Aeneas/Progress/ProgressStar.lean index 53df8679c..ae6a01a10 100644 --- a/backends/lean/Aeneas/Progress/ProgressStar.lean +++ b/backends/lean/Aeneas/Progress/ProgressStar.lean @@ -152,13 +152,12 @@ attribute [progress_simps] Aeneas.Std.bind_assoc_eq partial def evalProgressStar(cfg: Config): TacticM Info := withMainContext do focus do trace[ProgressStar] "Simplifying the goal: {←(getMainTarget >>= (liftM ∘ ppExpr))}" - Simp.simpAt (simpOnly := true) + let r ← Simp.simpAt (simpOnly := true) { maxDischargeDepth := 1, failIfUnchanged := false} {simpThms := #[← Progress.progressSimpExt.getTheorems]} (.targets #[] true) /- We may have proven the goal already -/ - let goals ← getUnsolvedGoals - if goals == [] then + if r.isNone then let progress_simps ← `(Parser.Tactic.simpLemma| $(mkIdent `progress_simps):term) return ⟨ #[← `(tactic|simp [$progress_simps])], [] ⟩ /- Continue -/ diff --git a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean index f03e7079f..9c94fd079 100644 --- a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean +++ b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean @@ -72,7 +72,7 @@ structure CondSimpArgs where hypsToUse : Array FVarId := #[] def condSimpTacSimp (config : Simp.Config) (args : CondSimpArgs) (loc : Utils.Location) - (additionalAsms : Array FVarId := #[]) (dischWithScalarTac : Bool) : TacticM Unit := do + (additionalAsms : Array FVarId := #[]) (dischWithScalarTac : Bool) : TacticM (Option (Array FVarId)) := do withMainContext do let simpArgs : Simp.SimpArgs := {simpThms := args.simpThms, @@ -85,11 +85,11 @@ def condSimpTacSimp (config : Simp.Config) (args : CondSimpArgs) (loc : Utils.Lo already saturated by looking at the assumptions (we do this once and for all beforehand) -/ let (ref, d) ← tacticToDischarge (← `(tactic|scalar_tac -saturateAssumptions)) let dischargeWrapper := Lean.Elab.Tactic.Simp.DischargeWrapper.custom ref d - let _ ← dischargeWrapper.with fun discharge? => do + dischargeWrapper.with fun discharge? => do -- Initialize the simp context let (ctx, simprocs) ← Simp.mkSimpCtx true config .simp simpArgs -- Apply the simplifier - let _ ← Simp.customSimpLocation ctx simprocs discharge? loc + pure ((← Simp.customSimpLocation ctx simprocs discharge? loc).fst) else Simp.simpAt true config simpArgs loc @@ -108,12 +108,9 @@ def condSimpTac | .wildcard => pure none | .wildcard_dep => throwError "{tacName} does not support using location `Utils.Location.wildcard_dep`" | .targets hyps _ => pure (some hyps) - let getOtherAsms (ignore : Std.HashSet FVarId) : TacticM (Array FVarId) := - Utils.refreshFVarIds Std.HashSet.emptyWithCapacity ignore /- First duplicate the propositions in the context: we will need the original ones (which mention integers rather than bit-vectors) for `scalar_tac` to succeed when doing the conditional rewritings. -/ let (oldAsms, newAsms) ← Utils.duplicateAssumptions toDuplicate - let oldAsmsSet := Std.HashSet.ofArray oldAsms trace[CondSimpTac] "Goal after duplicating the assumptions: {← getMainGoal}" /- Introduce the scalar_tac assumptions - by doing this beforehand we don't have to redo it every time we call `scalar_tac`. TODO: also do the `simp_all`. -/ @@ -124,23 +121,27 @@ def condSimpTac let additionalSimpThms ← addSimpThms trace[CondSimpTac] "Goal after adding the additional simp assumptions: {← getMainGoal}" /- Simplify the targets (note that we preserve the new assumptions for `scalar_tac`) -/ - let (loc, notLocAsms) ← do + let loc ← do match loc with - | .wildcard => pure (Utils.Location.targets oldAsms true, ← getOtherAsms oldAsmsSet) + | .wildcard => pure (Utils.Location.targets oldAsms true) | .wildcard_dep => throwError "Unreachable" - | .targets hyps type => pure (Utils.Location.targets hyps type, ← getOtherAsms (Std.HashSet.ofArray hyps)) - if doFirstSimp then - condSimpTacSimp simpConfig args loc additionalSimpThms false - if (← getUnsolvedGoals) == [] then return + | .targets hyps type => pure (Utils.Location.targets hyps type) + let nloc ← + if doFirstSimp then + match ← condSimpTacSimp simpConfig args loc additionalSimpThms false with + | none => return + | some freshFvarIds => + match loc with + | .wildcard => pure (Utils.Location.targets freshFvarIds true) + | .wildcard_dep => throwError "Unreachable" + | .targets _ type => pure (Utils.Location.targets freshFvarIds type) + else pure loc trace[CondSimpTac] "Goal after simplifying: {← getMainGoal}" - /- Simplify the targets by using `scalar_tac` as a discharger -/ - let notLocAsmsSet := Std.HashSet.ofArray notLocAsms - let nloc ← do - match loc with - | .wildcard => pure (Utils.Location.targets (← Utils.refreshFVarIds oldAsmsSet notLocAsmsSet) true) - | .wildcard_dep => throwError "Unreachable" - | .targets hyps type => pure (Utils.Location.targets (← Utils.refreshFVarIds (Std.HashSet.ofArray hyps) notLocAsmsSet) type) - condSimpTacSimp simpConfig args nloc additionalSimpThms true + /- Simplify the targets by using `scalar_tac` as a discharger. + TODO: scalar_tac should only be allowed to preprocess `scalarTacAsms`. + TODO: we should preprocess those. + -/ + let _ ← condSimpTacSimp simpConfig args nloc additionalSimpThms true if (← getUnsolvedGoals) == [] then return /- Clear the additional assumptions -/ Utils.clearFVarIds scalarTacAsms diff --git a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean index f459e57ce..56e6927c6 100644 --- a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean @@ -198,7 +198,7 @@ def simpAsmsTarget (simpOnly : Bool) (config : Simp.Config) (args : Simp.SimpArg let decls ← lctx.getDecls let props ← decls.filterM (fun d => do pure (← inferType d.type).isProp) let props := (props.map fun d => d.fvarId).toArray - Aeneas.Simp.simpAt simpOnly config args (.targets props true) + let _ ← Aeneas.Simp.simpAt simpOnly config args (.targets props true) /- Boosting a bit the `omega` tac. -/ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do @@ -241,7 +241,7 @@ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do return trace[ScalarTac] "Goal after simpAll: {← getMainGoal}" -- Call `simp` again, this time to inline the let-bindings (otherwise, omega doesn't always manage to deal with them) - Simp.simpAt true {zetaDelta := true, failIfUnchanged := false, maxDischargeDepth := 1} simpArgs .wildcard + let _ ← Simp.simpAt true {zetaDelta := true, failIfUnchanged := false, maxDischargeDepth := 1} simpArgs .wildcard -- We might have proven the goal if (← getGoals).isEmpty then trace[ScalarTac] "Goal proven by preprocessing!" @@ -255,7 +255,7 @@ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do return trace[ScalarTac] "Goal after normCast: {← getMainGoal}" -- Call `simp` again because `normCast` sometimes does weird things - Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 1} simpArgs .wildcard + let _ ← Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 1} simpArgs .wildcard -- We might have proven the goal if (← getGoals).isEmpty then trace[ScalarTac] "Goal proven by preprocessing!" diff --git a/backends/lean/Aeneas/Simp.lean b/backends/lean/Aeneas/Simp.lean new file mode 100644 index 000000000..759bca9c5 --- /dev/null +++ b/backends/lean/Aeneas/Simp.lean @@ -0,0 +1 @@ +import Aeneas.Simp.Simp diff --git a/backends/lean/Aeneas/Simp/Simp.lean b/backends/lean/Aeneas/Simp/Simp.lean new file mode 100644 index 000000000..ed1f7c721 --- /dev/null +++ b/backends/lean/Aeneas/Simp/Simp.lean @@ -0,0 +1,133 @@ +import Aeneas.Utils + +namespace Aeneas.Simp + +open Lean Meta Elab Tactic + +structure SimpArgs where + simprocs : Simp.SimprocsArray := #[] + simpThms : Array SimpTheorems := #[] + addSimprocs : Array Name := #[] + declsToUnfold : Array Name := #[] + addSimpThms : Array Name := #[] + hypsToUse : Array FVarId := #[] + +/- Initialize a context for the `simp` function. + + The initialization of the context is adapted from `elabSimpArgs`. + Something very annoying is that there is no function which allows to + initialize a simp context without doing an elaboration - as a consequence + we write our own here. -/ +def mkSimpCtx (simpOnly : Bool) (config : Simp.Config) (kind : SimpKind) (args : SimpArgs) : + MetaM (Simp.Context × Simp.SimprocsArray) := do + -- Initialize either with the builtin simp theorems or with all the simp theorems + let simpThms ← + if simpOnly then simpOnlyBuiltins.foldlM (·.addConst ·) ({} : SimpTheorems) + else getSimpTheorems + -- Add the equational theorems for the declarations to unfold + let addDeclToUnfold (thms : SimpTheorems) (decl : Name) : MetaM SimpTheorems := + if kind == .dsimp then pure (thms.addDeclToUnfoldCore decl) + else thms.addDeclToUnfold decl + let simpThms ← + args.declsToUnfold.foldlM addDeclToUnfold simpThms + -- Add the hypotheses and the rewriting theorems + let simpThms ← + args.hypsToUse.foldlM (fun thms fvarId => + -- post: TODO: don't know what that is. It seems to be true by default. + -- inv: invert the equality + thms.add (.fvar fvarId) #[] (mkFVar fvarId) (post := true) (inv := false) + -- thms.eraseCore (.fvar fvar) + ) simpThms + -- Add the rewriting theorems to use + let simpThms ← + args.addSimpThms.foldlM (fun thms thmName => do + let info ← getConstInfo thmName + if (← isProp info.type) then + -- post: TODO: don't know what that is + -- inv: invert the equality + thms.addConst thmName (post := false) (inv := false) + else + throwError "Not a proposition: {thmName}" + ) simpThms + let congrTheorems ← getSimpCongrTheorems + let defaultSimprocs ← if simpOnly then pure {} else Simp.getSimprocs + let addSimprocs ← args.addSimprocs.foldlM (fun simprocs name => simprocs.add name true) defaultSimprocs + let ctx ← Simp.mkContext config (simpTheorems := #[simpThms] ++ args.simpThms) congrTheorems + pure (ctx, #[addSimprocs] ++ args.simprocs) + +/- Adapted from `Lean.Elab.Tactic.simpLocation` so that: + - we can use our own `Location` + - we return the fvars which have been simplified (that is, the unmodified fvars, + and the fresh fvars introduced because of the simplification). + Note that we return an option: if `none`, it means that the goal was closed. +-/ +def customSimpLocation (ctx : Simp.Context) (simprocs : Simp.SimprocsArray) + (discharge? : Option Simp.Discharge := none) + (loc : Utils.Location) : TacticM (Option (Array FVarId) × Simp.Stats) := do + match loc with + | .targets hyps simplifyTarget => + -- Custom behavior: we directly provide the fvar ids of the assumption rather than syntax + withMainContext do + go hyps simplifyTarget + | .wildcard => + -- Simply call the regular simpLocation + withMainContext do + go (← (← getMainGoal).getNondepPropHyps) (simplifyTarget := true) + | .wildcard_dep => + -- Custom behavior + withMainContext do + -- Lookup *all* the declarations + let lctx ← Lean.MonadLCtx.getLCtx + let decls ← lctx.getDecls + let tgts := (decls.map (fun d => d.fvarId)).toArray + -- Call the regular simpLocation.go + go tgts (simplifyTarget := true) +where + go (fvarIdsToSimp : Array FVarId) (simplifyTarget : Bool) : TacticM (Option (Array FVarId) × Simp.Stats) := do + let mvarId ← getMainGoal + let (result?, stats) ← simpGoal mvarId ctx (simprocs := simprocs) (simplifyTarget := simplifyTarget) (discharge? := discharge?) (fvarIdsToSimp := fvarIdsToSimp) + let freshFVarIds ← + match result? with + | none => replaceMainGoal []; pure none + | some (fvars, mvarId) => replaceMainGoal [mvarId]; pure fvars + /- We need to filter the `fvarIdsToSimp` to remove those which have been replaced with fresh fvars -/ + let fvars ← do + match freshFVarIds with + | none => pure none + | some freshFVarIds => + withMainContext do + let ctx ← getLCtx + let ldecls := ctx.foldl (fun set decl => set.insert decl.fvarId) Std.HashSet.emptyWithCapacity + pure (fvarIdsToSimp.filter ldecls.contains ++ freshFVarIds) + return (fvars, stats) + +/- Call the simp tactic. -/ +def simpAt (simpOnly : Bool) (config : Simp.Config) (args : SimpArgs) (loc : Utils.Location) : + TacticM (Option (Array FVarId)) := do + -- Initialize the simp context + let (ctx, simprocs) ← mkSimpCtx simpOnly config .simp args + -- Apply the simplifier + pure ((← customSimpLocation ctx simprocs (discharge? := .none) loc).fst) + +/- Call the dsimp tactic. + TODO: update to return the fresh fvar ids. +-/ +def dsimpAt (simpOnly : Bool) (config : Simp.Config) (args : SimpArgs) (loc : Location) : + TacticM Unit := do + -- Initialize the simp context + let (ctx, simprocs) ← mkSimpCtx simpOnly config .dsimp args + -- Apply the simplifier + dsimpLocation ctx simprocs loc + +-- Call the simpAll tactic +def simpAll (config : Simp.Config) (simpOnly : Bool) (args : SimpArgs) : + TacticM Unit := do + -- Initialize the simp context + let (ctx, simprocs) ← mkSimpCtx simpOnly config .simpAll args + -- Apply the simplifier + let (result?, _) ← Lean.Meta.simpAll (← getMainGoal) ctx (simprocs := simprocs) + match result? with + | none => replaceMainGoal [] + | some mvarId => replaceMainGoal [mvarId] + +end Aeneas.Simp diff --git a/backends/lean/Aeneas/Utils.lean b/backends/lean/Aeneas/Utils.lean index 7f4f5a439..edc59b74d 100644 --- a/backends/lean/Aeneas/Utils.lean +++ b/backends/lean/Aeneas/Utils.lean @@ -1637,22 +1637,6 @@ def optElabTerm (e : Option (TSyntax `term)) : TacticM (Option Expr) := do | none => pure none | some e => pure (some (← Lean.Elab.Tactic.elabTerm e none)) -/-- Compute the list of assumptions which: - - either belong to `keep` - - or do not belong to `ignore` - - This is useful to refresh a list of fvar ids after applying a tactic such as `simp` - to them. Indeed, whenever we apply a simplification to some assumptions, the only way - to retrieve their new ids is to go through the context and filter the ids which we know - do not come from the duplicated assumptions. --/ -def refreshFVarIds (keep ignore : Std.HashSet FVarId) : TacticM (Array FVarId) := do - withMainContext do - let decls := (← (← getLCtx).getDecls).toArray - decls.filterMapM fun d => do - if (← inferType d.type).isProp ∧ (d.fvarId ∈ keep ∨ d.fvarId ∉ ignore) - then pure (some d.fvarId) else pure none - end Utils end Aeneas From 43ae00194836588901c9ff7c4c91fde96f1e1b95 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 23 Jun 2025 11:05:31 +0100 Subject: [PATCH 03/31] Implement simpAllAssumptions --- backends/lean/Aeneas/Simp.lean | 1 + .../lean/Aeneas/Simp/SimpAllAssumptions.lean | 171 ++++++++++++++++++ 2 files changed, 172 insertions(+) create mode 100644 backends/lean/Aeneas/Simp/SimpAllAssumptions.lean diff --git a/backends/lean/Aeneas/Simp.lean b/backends/lean/Aeneas/Simp.lean index 759bca9c5..25ed8bac9 100644 --- a/backends/lean/Aeneas/Simp.lean +++ b/backends/lean/Aeneas/Simp.lean @@ -1 +1,2 @@ import Aeneas.Simp.Simp +import Aeneas.Simp.SimpAllAssumptions diff --git a/backends/lean/Aeneas/Simp/SimpAllAssumptions.lean b/backends/lean/Aeneas/Simp/SimpAllAssumptions.lean new file mode 100644 index 000000000..bebbc06db --- /dev/null +++ b/backends/lean/Aeneas/Simp/SimpAllAssumptions.lean @@ -0,0 +1,171 @@ +import Lean.Meta.Tactic.Clear +import Lean.Meta.Tactic.Util +import Lean.Meta.Tactic.Simp.Main +import Aeneas.Simp.Simp + +/-! +This file is adapted from Lean.Meta.Tactic.Simp.SimpAll.lean +We slightly modify `simpAll` to be able to simplify only a *subset* of the assumptions. +-/ + +namespace Aeneas.Simp + +open Lean Meta + +open Simp (Stats SimprocsArray) + +namespace SimpAll + +structure Entry where + fvarId : FVarId -- original fvarId + userName : Name + id : Origin -- id of the theorem at `SimpTheorems` + origType : Expr + type : Expr + proof : Expr + deriving Inhabited + +structure State where + modified : Bool := false + mvarId : MVarId + entries : Array Entry := #[] + ctx : Simp.Context + simprocs : SimprocsArray + usedTheorems : Simp.UsedSimps := {} + diag : Simp.Diagnostics := {} + +abbrev M := StateRefT State MetaM + +private def initEntries (fvars : Array FVarId) : M Unit := do + let hs := fvars + let hsNonDeps ← (← get).mvarId.getNondepPropHyps + let mut simpThms := (← get).ctx.simpTheorems + for h in hs do + unless simpThms.isErased (.fvar h) do + let localDecl ← h.getDecl + let proof := localDecl.toExpr + let ctx := (← get).ctx + simpThms ← simpThms.addTheorem (.fvar h) proof (config := ctx.indexConfig) + modify fun s => { s with ctx := s.ctx.setSimpTheorems simpThms } + if hsNonDeps.contains h then + -- We only simplify nondependent hypotheses + let type ← instantiateMVars localDecl.type + let entry : Entry := { fvarId := h, userName := localDecl.userName, id := .fvar h, origType := type, type, proof } + modify fun s => { s with entries := s.entries.push entry } + +private abbrev getSimpTheorems : M SimpTheoremsArray := + return (← get).ctx.simpTheorems + +private partial def loop : M Bool := do + modify fun s => { s with modified := false } + let simprocs := (← get).simprocs + -- simplify entries + let entries := (← get).entries + for h : i in [:entries.size] do + let entry := entries[i] + let ctx := (← get).ctx + -- We disable the current entry to prevent it to be simplified to `True` + let simpThmsWithoutEntry := (← getSimpTheorems).eraseTheorem entry.id + let ctx := ctx.setSimpTheorems simpThmsWithoutEntry + let (r, stats) ← simpStep (← get).mvarId entry.proof entry.type ctx simprocs (stats := { (← get) with }) + modify fun s => { s with usedTheorems := stats.usedTheorems, diag := stats.diag } + match r with + | none => return true -- closed the goal + | some (proofNew, typeNew) => + unless typeNew == entry.type do + /- We must erase the `id` for the simplified theorem. Otherwise, + the previous versions can be used to self-simplify the new version. For example, suppose we have + ``` + x : Nat + h : x ≠ 0 + ⊢ Unit + ``` + In the first round, `h : x ≠ 0` is simplified to `h : ¬ x = 0`. + + It is also important for avoiding identical hypotheses to simplify each other to `True`. + Example + ``` + ... + h₁ : p a + h₂ : p a + ⊢ q a + ``` + `h₁` is first simplified to `True`. If we don't remove `h₁` from the set of simp theorems, it will + be used to simplify `h₂` to `True` and information is lost. + + We must use `mkExpectedTypeHint` because `inferType proofNew` may not be equal to `typeNew` when + we have theorems marked with `rfl`. + -/ + trace[Aeneas.Meta.Tactic.simp.all] "entry.id: {← ppOrigin entry.id}, {entry.type} => {typeNew}" + let mut simpThmsNew := (← getSimpTheorems).eraseTheorem (.fvar entry.fvarId) + let idNew ← mkFreshId + simpThmsNew ← simpThmsNew.addTheorem (.other idNew) (← mkExpectedTypeHint proofNew typeNew) (config := ctx.indexConfig) + modify fun s => { s with + modified := true + ctx := ctx.setSimpTheorems simpThmsNew + entries[i] := { entry with type := typeNew, proof := proofNew, id := .other idNew } + } + if (← get).modified then + loop + else + return false + +def main (fvars : Array FVarId) : M (Option (Array FVarId × MVarId)) := do + initEntries fvars + if (← loop) then + return none -- close the goal + else + let mvarId := (← get).mvarId + -- Prior to #2334, the logic here was to re-assert all hypotheses and call `tryClearMany` on them all. + -- This had the effect that the order of hypotheses was sometimes modified, whether or not any where simplified. + -- Now we only re-assert the first modified hypothesis, + -- along with all subsequent hypotheses, so as to preserve the order of hypotheses. + let mut toAssert := #[] + let mut toClear := #[] + let mut preserved := #[] + let mut modified := false + for e in (← get).entries do + if e.type.isTrue then + -- Do not assert `True` hypotheses + toClear := toClear.push e.fvarId + else if modified || e.type != e.origType then + toClear := toClear.push e.fvarId + toAssert := toAssert.push { userName := e.userName, type := e.type, value := e.proof } + modified := true + else + preserved := preserved.push e.fvarId + let (freshIds, mvarId) ← mvarId.assertHypotheses toAssert + let mvarId ← mvarId.tryClearMany toClear + pure (some (preserved ++ freshIds, mvarId)) + +end SimpAll + +def simpAllAssumptionsAux (mvarId : MVarId) (ctx : Simp.Context) (fvars : Array FVarId) (simprocs : SimprocsArray := #[]) (stats : Stats := {}) : + MetaM (Option (Array FVarId × MVarId) × Stats) := do + mvarId.withContext do + let (r, s) ← (SimpAll.main fvars).run { stats with mvarId, ctx, simprocs } + if let .some (_, mvarIdNew) := r then + if ctx.config.failIfUnchanged && mvarId == mvarIdNew then + throwError "simp_all made no progress" + return (r, { s with }) + +open Utils Simp Elab Tactic in +def simpAllAssumptions (config : Simp.Config) (simpOnly : Bool) (args : SimpArgs) (mvarId : MVarId) (fvars : Array FVarId) : + MetaM (Option (Array FVarId × MVarId)) := do + -- Initialize the simp context + let (ctx, simprocs) ← mkSimpCtx simpOnly config .simpAll args + -- Apply the simplifier + let (result?, _) ← simpAllAssumptionsAux mvarId ctx fvars (simprocs := simprocs) + pure result? + +open Utils Simp Elab Tactic in +def evalSimpAllAssumptions (config : Simp.Config) (simpOnly : Bool) (args : SimpArgs) (mvarId : MVarId) (fvars : Array FVarId) : + TacticM (Option (Array FVarId)) := do + match ← simpAllAssumptions config simpOnly args mvarId fvars with + | none => replaceMainGoal []; pure none + | some (fvarIds, mvarId) => replaceMainGoal [mvarId]; pure (some fvarIds) + +initialize + registerTraceClass `Aeneas.Meta.Tactic.simp.all + +end Aeneas.Simp From b41ea322025634d25db584de313af9cfdbefe929 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 23 Jun 2025 11:09:32 +0100 Subject: [PATCH 04/31] Start working on an incremental version of scalarTacPreprocess --- backends/lean/Aeneas/Saturate/Tactic.lean | 167 ++++++++++++------ .../lean/Aeneas/ScalarTac/CondSimpTac.lean | 4 +- backends/lean/Aeneas/ScalarTac/ScalarTac.lean | 157 +++++++++++++--- .../lean/Aeneas/Simp/SimpAllAssumptions.lean | 43 +++-- backends/lean/Aeneas/SimpScalar.lean | 1 + backends/lean/Aeneas/Utils.lean | 71 ++++++-- 6 files changed, 335 insertions(+), 108 deletions(-) diff --git a/backends/lean/Aeneas/Saturate/Tactic.lean b/backends/lean/Aeneas/Saturate/Tactic.lean index c1d7fb66a..5d8b71ad0 100644 --- a/backends/lean/Aeneas/Saturate/Tactic.lean +++ b/backends/lean/Aeneas/Saturate/Tactic.lean @@ -106,6 +106,8 @@ structure State where assumptions to instantiate some other theorems. -/ assumptions : Std.HashMap Expr AsmPath + /- Do not introduce a theorem if its conclusion is already in the set -/ + ignore : Std.HashSet Expr def State.insertHitRules (s : State) (rules : Array Rule) : State := { s with diagnostics := s.diagnostics.insertHitRules rules } @@ -131,8 +133,9 @@ def State.new (pmatches : DiscrTree PartialMatch := DiscrTree.empty) (diagnostics : Diagnostics := Diagnostics.empty) (matched : Std.HashSet Expr := Std.HashSet.emptyWithCapacity) - (assumptions : Std.HashMap Expr AsmPath := Std.HashMap.emptyWithCapacity) : State := - { rules, pmatches, diagnostics, matched, assumptions } + (assumptions : Std.HashMap Expr AsmPath := Std.HashMap.emptyWithCapacity) + (ignore : Std.HashSet Expr := Std.HashSet.emptyWithCapacity) : State := + { rules, pmatches, diagnostics, matched, assumptions, ignore } def mkExprFromPath (path : AsmPath) : MetaM Expr := do match path with @@ -455,6 +458,51 @@ private partial def visit trace[Saturate.explore] ".proj" visit config none (depth + 1) exploreSubterms preprocessThm boundVars state b +/-- Recompute the set of assumptions. + + This is necessary if we want to saturate the goal in several steps and modify + the assumptions in between (with a call to simp for example). +-/ +private partial def visitRecomputeAssumptions + (path : Option AsmPath) + (depth : Nat) + (exploreSubterms : Expr → Array Expr → MetaM (Array Expr)) + (state : State) (e : Expr) + : MetaM State := do + trace[Saturate.explore] "Visiting {e}" + -- Register the current assumption, if it is a conjunct inside an assumption + let state := state.insertAssumption path e + let e := e.consumeMData + match e with + | .bvar _ + | .fvar _ + | .mvar _ + | .sort _ + | .lit _ + | .const _ _ => + trace[Saturate.explore] "Stop: bvar, fvar, etc." + pure state + | .app .. => do e.withApp fun f args => do + trace[Saturate.explore] ".app" + /- Check if this is a conjunction and we know the path to the current sub-expression (because it is a conjunct + in an assumption) -/ + if path.isSome ∧ f.isConst ∧ f.constName! == ``And ∧ args.size == 2 then do + -- This is a conjunction + let some path := path + | throwError "Unreachable" + let state ← visitRecomputeAssumptions (some (.conj .left path)) (depth + 1) exploreSubterms state (args[0]!) + let state ← visitRecomputeAssumptions (some (.conj .right path)) (depth + 1) exploreSubterms state (args[1]!) + pure state + else + pure state + | .lam .. + | .forallE .. + | .letE .. => do pure state + | .mdata _ b => do + trace[Saturate.explore] ".mdata" + visitRecomputeAssumptions path (depth + 1) exploreSubterms state b + | .proj _ _ _ => do pure state + def arithOpArity3 : Std.HashSet Name := Std.HashSet.ofList [ ``Nat.cast, ``Int.cast ] @@ -494,24 +542,17 @@ def exploreArithSubterms (f : Expr) (args : Array Expr) : MetaM (Array Expr) := pure #[] /- The saturation tactic itself -/ -partial def evalSaturate {α} +partial def evalSaturateCore (config : Config) - (satAttr : Array SaturateAttribute) + (state : State) (exploreSubterms : Option (Expr → Array Expr → MetaM (Array Expr)) := none) (preprocessThm : Option (Array Expr → Expr → MetaM Unit)) - (declsToExplore : Option (Array FVarId) := none) - (exploreAssumptions : Bool := true) + (declsToExplore : Array FVarId) (exploreTarget : Bool := true) - (next : Array FVarId → TacticM α) - : TacticM α + : TacticM (State × Array FVarId) := do withMainContext do trace[Saturate] "Exploring goal: {← getMainGoal}" - -- Retrieve the rule sets - let env ← getEnv - let s := satAttr.map fun s => s.ext.getState env - -- Get the local context - let ctx ← Lean.MonadLCtx.getLCtx -- Explore let exploreSubterms := match exploreSubterms with @@ -527,11 +568,10 @@ partial def evalSaturate {α} -- Explore the assumptions trace[Saturate] "Exploring the assumptions" - let state := State.new s let visitLocalDecl (state : State) (decl : LocalDecl) : TacticM State := do trace[Saturate] "Exploring local decl: {decl.userName}" - /- We explore both the type, the expresion and the body (if there is) -/ + /- We explore both the type, the expression and the body (if there is) -/ /- Note that the path is used only when exploring the type of assumptions -/ let path ← if (← inferType decl.type).isProp then pure (some (.asm decl.fvarId)) @@ -543,24 +583,9 @@ partial def evalSaturate {α} | some value => visit config none state value let state ← do - if exploreAssumptions then do - let decls ← - match declsToExplore with - | none => do pure (← ctx.getDecls).toArray - | some decls => decls.mapM fun d => d.getDecl - decls.foldlM (fun state (decl : LocalDecl) => do - trace[Saturate] "Exploring local decl: {decl.userName}" - /- We explore both the type, the expresion and the body (if there is) -/ - /- Note that the path is used only when exploring the type of assumptions -/ - let path ← - if (← inferType decl.type).isProp then pure (some (.asm decl.fvarId)) - else pure none - let state ← visit config path state decl.type - let state ← visit config none state decl.toExpr - match decl.value? with - | none => pure state - | some value => visit config none state value) state - else pure state + trace[Saturate] "declsToExplore: {declsToExplore.map Expr.fvar}" + let decls ← declsToExplore.mapM fun d => d.getDecl + decls.foldlM visitLocalDecl state -- Explore the target trace[Saturate] "Exploring the target" @@ -583,7 +608,7 @@ partial def evalSaturate {α} -- The application worked: introduce the assumption in the context let thTy ← withMainContext do inferType th -- Check that we don't add the same assumption twice - if assumptions.contains thTy then + if assumptions.contains thTy || state.ignore.contains thTy then continue else let x ← Utils.addDeclTac (.num (.str .anonymous "_h") i) th thTy (asLet := false) fun x => pure x @@ -636,21 +661,61 @@ partial def evalSaturate {α} -- Display the diagnostics information trace[Saturate.diagnostics] "Saturate diagnostics info: {state.diagnostics.toArray}" -- Continue - next allFVars + pure (state, allFVars) + +/- Reexplore the context to recompute the set of assumptions -/ +def recomputeAssumptions + (state : State) + (exploreSubterms : Option (Expr → Array Expr → MetaM (Array Expr)) := none) + (declsToExplore : Array FVarId) + : TacticM State + := do + withMainContext do + trace[Saturate] "Exploring goal: {← getMainGoal}" + let ignore := state.assumptions.fold (fun ignore asm _ => ignore.insert asm) state.ignore + let state : State := { state with ignore, assumptions := Std.HashMap.emptyWithCapacity } + -- Explore + let exploreSubterms := + match exploreSubterms with + | none => fun f args => pure (#[f] ++ args) + | some explore => explore + let visit path (state : State) expr : MetaM State := + visitRecomputeAssumptions path 0 exploreSubterms state expr + + -- Explore the assumptions + trace[Saturate] "Exploring the assumptions" + + let decls ← declsToExplore.mapM fun d => d.getDecl + decls.foldlM (fun (state : State) (decl : LocalDecl) => do + trace[Saturate] "Exploring local decl: {decl.userName}" + if (← inferType decl.type).isProp then + let path := some (.asm decl.fvarId) + visit path state decl.type + else pure state) state + +partial def evalSaturate {α} + (config : Config) + (satAttr : Array SaturateAttribute) + (exploreSubterms : Option (Expr → Array Expr → MetaM (Array Expr)) := none) + (preprocessThm : Option (Array Expr → Expr → MetaM Unit)) + (declsToExplore : Array FVarId) + (exploreTarget : Bool := true) + (next : Array FVarId → TacticM α) + : TacticM α + := do + -- Retrieve the rule sets + let env ← getEnv + let s := satAttr.map fun s => s.ext.getState env + let state := State.new s + let (_, fvarIds) ← evalSaturateCore config state exploreSubterms preprocessThm declsToExplore exploreTarget + withMainContext do next fvarIds elab "aeneas_saturate" : tactic => do let _ ← evalSaturate {} #[saturateAttr] none none - (declsToExplore := none) - (exploreAssumptions := true) + (declsToExplore := ((← (← getLCtx).getDecls).map fun d => d.fvarId).toArray) (exploreTarget := true) (fun _ => pure ()) namespace Test - local elab "aeneas_saturate_test" : tactic => do - let _ ← evalSaturate {} #[saturateAttr] none none - (declsToExplore := none) - (exploreAssumptions := true) - (exploreTarget := true) (fun _ => pure ()) - set_option trace.Saturate.attribute false in @[aeneas_saturate l.length] -- TODO: local doesn't work here theorem rule1 (α : Type u) (l : List α) : l.length ≥ 0 := by simp @@ -685,7 +750,7 @@ let _g := fun l => l.length + l5.length; let _k := l4.length let _g := fun (l : List α) => l.length + l5.length (l0 ++ l1 ++ l2).length = l0.length + l1.length + l2.length := by - aeneas_saturate_test + aeneas_saturate extract_goal1 simp [Nat.add_assoc] @@ -730,7 +795,7 @@ let _g := fun l => l.length + l5.length; let _k := l4.length let _g := fun (l : List α) => l.length + l5.length (l0 ++ l1 ++ l2).length = l0.length + l1.length + l2.length := by - aeneas_saturate_test + aeneas_saturate extract_goal1 simp [Nat.add_assoc] @@ -765,7 +830,7 @@ let _g := fun l => l.length + l5.length; let _k := l4.length let _g := fun (l : List α) => l.length + l5.length (l0 ++ l1 ++ l2).length = l0.length + l1.length + l2.length := by - aeneas_saturate_test + aeneas_saturate extract_goal1 simp [Nat.add_assoc] end Test @@ -799,7 +864,7 @@ let _g := fun l => l.length + l5.length; let _k := l4.length let _g := fun (l : List α) => l.length + l5.length (l0 ++ l1 ++ l2).length = l0.length + l1.length + l2.length := by - aeneas_saturate_test + aeneas_saturate extract_goal1 simp [Nat.add_assoc] @@ -812,11 +877,11 @@ let _g := fun l => l.length + l5.length; theorem rule2 (x : Nat) (h : 1 ≤ x) : 2 ≤ x + x := by omega example (x : Nat) (_h : 1 ≤ x) : 2 ≤ x + x := by - aeneas_saturate_test + aeneas_saturate assumption example (x : Nat) (_h : True ∧ 1 ≤ x ∧ True) : 2 ≤ x + x := by - aeneas_saturate_test + aeneas_saturate assumption set_option trace.Saturate.attribute true in @@ -828,12 +893,12 @@ let _g := fun l => l.length + l5.length; *after one exploration pass* (in particular, the assumption `0 ≤ x * y`is useless). -/ set_option trace.Saturate.insertPartialMatch true in example (x y : Nat) (_ : 0 ≤ x * y) (h : 3 ≤ y ∧ 2 ≤ x) : 6 ≤ x * y := by - aeneas_saturate_test + aeneas_saturate omega set_option trace.Saturate.insertPartialMatch true in example (x y z : Nat) (h : 2 ≤ x ∧ 3 ≤ y) (h1 : 8 ≤ z) : 32 ≤ x * y * z := by - aeneas_saturate_test + aeneas_saturate omega -/ end Test diff --git a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean index 9c94fd079..72f5f32b3 100644 --- a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean +++ b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean @@ -115,8 +115,8 @@ def condSimpTac /- Introduce the scalar_tac assumptions - by doing this beforehand we don't have to redo it every time we call `scalar_tac`. TODO: also do the `simp_all`. -/ withMainContext do - ScalarTac.scalarTacSaturateForward { nonLin := config.nonLin, saturationPasses := config.saturationPasses } - fun scalarTacAsms => do + ScalarTac.scalarTacSaturateForward { nonLin := config.nonLin, saturationPasses := config.saturationPasses } none none + fun _ scalarTacAsms => do trace[CondSimpTac] "Goal after saturating the context: {← getMainGoal}" let additionalSimpThms ← addSimpThms trace[CondSimpTac] "Goal after adding the additional simp assumptions: {← getMainGoal}" diff --git a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean index 56e6927c6..df5b9c10f 100644 --- a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean @@ -143,32 +143,46 @@ structure Config extends SaturateConfig where declare_config_elab elabConfig Config /-- Apply the scalar_tac forward rules -/ -def scalarTacSaturateForward {α} (config : SaturateConfig) (f : Array FVarId → TacticM α) : TacticM α := do - /- - let options : Aesop.Options := {} - -- Use a forward max depth of 0 to prevent recursively applying forward rules on the assumptions - -- introduced by the forward rules themselves. - let options ← options.toOptions' (some 0)-/ - -- We always use the rule set `Aeneas.ScalarTac`, but also need to add other rule sets locally - -- activated by the user. The `Aeneas.ScalarTacNonLin` rule set has a special treatment as - -- it is activated through an option. +def scalarTacSaturateForward {α} + (config : SaturateConfig) + (satState : Option Saturate.State) + (declsToExplore : Option (Array FVarId)) + (f : Saturate.State → Array FVarId → TacticM α) : TacticM α := do + withMainContext do + /- We always use the rule set `Aeneas.ScalarTac`, but also need to add other rule sets locally + activated by the user. The `Aeneas.ScalarTacNonLin` rule set has a special treatment as + it is activated through an option. -/ let rules := if config.nonLin then #[scalarTacAttribute, scalarTacNonLinAttribute] else #[scalarTacAttribute] - Saturate.evalSaturate - { visitProofTerms := false, - visitBoundExpressions := config.saturateVisitBoundExpressions, - saturationPasses := config.saturationPasses } - rules none none - (declsToExplore := none) - (exploreAssumptions := config.saturateAssumptions) - (exploreTarget := config.saturateTarget) f - + let declsToExplore ← do + match declsToExplore with + | none => + if config.saturateAssumptions + then pure ((← (← getLCtx).getDecls).map fun d => d.fvarId).toArray + else pure #[] + | some decls => pure decls + let state ← do + match satState with + | none => + let env ← getEnv + let rules := rules.map fun s => s.ext.getState env + pure (Saturate.State.new rules) + | some state => pure state + let (state, fvarIds) ← + Saturate.evalSaturateCore + { visitProofTerms := false, + visitBoundExpressions := config.saturateVisitBoundExpressions, + saturationPasses := config.saturationPasses } + state none none + declsToExplore + (exploreTarget := config.saturateTarget) + withMainContext do f state fvarIds -- For debugging elab "scalar_tac_saturate" config:Parser.Tactic.optConfig : tactic => do let config ← elabConfig config - let _ ← scalarTacSaturateForward config.toSaturateConfig (fun _ => pure ()) + let _ ← scalarTacSaturateForward config.toSaturateConfig none none (fun _ _ => pure ()) def getSimpArgs : CoreM Simp.SimpArgs := do pure { @@ -192,13 +206,13 @@ def getSimpThmNames : CoreM (Array Name) := do /- Sometimes `simp at *` doesn't work in the presence of dependent types. However, simplifying the assumptions *does* work, hence this peculiar way of simplifying the context. -/ -def simpAsmsTarget (simpOnly : Bool) (config : Simp.Config) (args : Simp.SimpArgs) : TacticM Unit := +def simpAsmsTarget (simpOnly : Bool) (config : Simp.Config) (args : Simp.SimpArgs) : TacticM (Option (Array FVarId)) := withMainContext do let lctx ← getLCtx let decls ← lctx.getDecls let props ← decls.filterM (fun d => do pure (← inferType d.type).isProp) let props := (props.map fun d => d.fvarId).toArray - let _ ← Aeneas.Simp.simpAt simpOnly config args (.targets props true) + Aeneas.Simp.simpAt simpOnly config args (.targets props true) /- Boosting a bit the `omega` tac. -/ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do @@ -211,20 +225,19 @@ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do -/ trace[ScalarTac] "Original goal before preprocessing: {← getMainGoal}" let simpArgs : Simp.SimpArgs ← getSimpArgs - simpAsmsTarget true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} + let r ← simpAsmsTarget true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} -- Remove the forall quantifiers to prepare for the call of `simp_all` (we -- don't want `simp_all` to use assumptions of the shape `∀ x, P x`)) {simpArgs with addSimpThms := #[``forall_eq_forall']} -- We might have proven the goal - if (← getGoals).isEmpty then + if r.isNone then trace[ScalarTac] "Goal proven by preprocessing!" return trace[ScalarTac] "Goal after first simplification: {← getMainGoal}" -- Apply the forward rules if config.saturate then - allGoalsNoRecover (scalarTacSaturateForward config.toSaturateConfig (fun _ => pure ())) + scalarTacSaturateForward config.toSaturateConfig none none (fun _ _ => pure ()) trace[ScalarTac] "Goal after saturation: {← getMainGoal}" - let simpArgs : Simp.SimpArgs ← getSimpArgs -- Apply `simpAll` if config.simpAllMaxSteps ≠ 0 then tryTac do @@ -248,7 +261,7 @@ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do return trace[ScalarTac] "Goal after 2nd simp (with zetaDelta): {← getMainGoal}" -- Apply normCast - Utils.normCastAtAll + let _ ← Utils.normCastAt .wildcard -- We might have proven the goal if (← getGoals).isEmpty then trace[ScalarTac] "Goal proven by preprocessing!" @@ -262,6 +275,98 @@ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do return trace[ScalarTac] "Goal after 2nd call to simpAt: {← getMainGoal}" +structure State where + saturateState : Saturate.State + +def scalarTacPartialPreprocess (config : Config) (state : State) + (oldAssumptions assumptions : Array FVarId) (simpTarget : Bool) : + Tactic.TacticM (Option (State × Array FVarId)) := do + Tactic.focus do + Tactic.withMainContext do + -- Pre-preprocessing + /- We simplify a first time before saturating the context. + This is useful because simplifying often introduces expressions which are useful + for the saturation phase, and it also often allows to get rid of some dependently + typed expressions such as `UScalar.ofNat`. + -/ + trace[ScalarTac] "Original goal before preprocessing: {← getMainGoal}" + let simpArgs : Simp.SimpArgs ← getSimpArgs + let r ← Simp.simpAt true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} + /- Remove the forall quantifiers to prepare for the call of `simp_all` (we + don't want `simp_all` to use assumptions of the shape `∀ x, P x`)) -/ + {simpArgs with addSimpThms := #[``forall_eq_forall']} + -- TODO: it would be good to always simplify the target before exploring it to saturate + (.targets assumptions simpTarget) + -- We might have proven the goal + let some assumptions := r + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + trace[ScalarTac] "Goal after first simplification: {← getMainGoal}" + -- Apply the forward rules + let satConfig := config.toSaturateConfig + let satConfig := { satConfig with + saturateAssumptions := satConfig.saturateAssumptions && config.saturate, + saturateTarget := satConfig.saturateTarget && config.saturate, + } + scalarTacSaturateForward satConfig state.saturateState (some assumptions) fun saturateState nassumptions => do + let state := { state with saturateState } + let assumptions := assumptions ++ nassumptions + + let finish : Tactic.TacticM (Option (State × Array FVarId)) := do + trace[ScalarTac] "Goal after saturation: {← getMainGoal}" + -- Apply `simpAll` to the *new assumptions* + let applySimpAll assumptions simpTarget := do + match + ← tryTactic? do + /- By setting the maxDischargeDepth at 0, we make sure that assumptions of the shape `∀ x, P x → ...` + will not have any effect. This is important because it often happens that the user instantiates + one such assumptions with specific arguments, meaning that if we call `simpAll` naively, those + instantiations will get simplified to `True` and thus eliminated. -/ + Simp.simpAllAssumptions + {failIfUnchanged := false, maxSteps := config.simpAllMaxSteps, maxDischargeDepth := 0} true + simpArgs (← getMainGoal) assumptions simpTarget + with + | some r => pure r + | none => pure (some (assumptions, ← getMainGoal)) + let r ← do + if config.simpAllMaxSteps ≠ 0 then + /- Simplify the new assumptions with the old assumptions (it's often enough to go in that direction) -/ + let some assumptions ← Simp.simpAt true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} + {simpArgs with hypsToUse := oldAssumptions} (.targets assumptions false) + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + /- Apply `simpAll` on the new assumptions -/ + let some (assumptions, mvarId) ← applySimpAll assumptions false + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + /- We apply `simpAll` on the old and new assumptions -/ + setGoals [mvarId] + applySimpAll (oldAssumptions ++ assumptions) simpTarget + else + pure (some (assumptions, ← getMainGoal)) + let some (assumptions, mvarId) := r + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + setGoals [mvarId] + trace[ScalarTac] "Goal after simpAll: {← getMainGoal}" + + -- Call `simp` again, this time to inline the let-bindings (otherwise, omega doesn't always manage to deal with them) + let r ← Simp.simpAt true {zetaDelta := true, failIfUnchanged := false, maxDischargeDepth := 1} simpArgs (.targets assumptions simpTarget) + -- We might have proven the goal + let some assumptions := r + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + trace[ScalarTac] "Goal after 2nd simp (with zetaDelta): {← getMainGoal}" + -- Apply normCast + let some assumptions ← Utils.normCastAt (.targets assumptions simpTarget) + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + trace[ScalarTac] "Goal after normCast: {← getMainGoal}" + -- Call `simp` again because `normCast` sometimes does weird things + let some assumptions ← Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 1} simpArgs (.targets assumptions simpTarget) + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + trace[ScalarTac] "Goal after 2nd call to simpAt: {← getMainGoal}" + -- We modified the assumptions in the context so we need to update the state accordingly + let saturateState ← Saturate.recomputeAssumptions state.saturateState none assumptions + let state := { state with saturateState } + -- We're done + return some (state, assumptions) + finish + elab "scalar_tac_preprocess" config:Parser.Tactic.optConfig : tactic => do let config ← elabConfig config scalarTacPreprocess config diff --git a/backends/lean/Aeneas/Simp/SimpAllAssumptions.lean b/backends/lean/Aeneas/Simp/SimpAllAssumptions.lean index bebbc06db..81c9aee82 100644 --- a/backends/lean/Aeneas/Simp/SimpAllAssumptions.lean +++ b/backends/lean/Aeneas/Simp/SimpAllAssumptions.lean @@ -56,7 +56,7 @@ private def initEntries (fvars : Array FVarId) : M Unit := do private abbrev getSimpTheorems : M SimpTheoremsArray := return (← get).ctx.simpTheorems -private partial def loop : M Bool := do +private partial def loop (target : Bool) : M Bool := do modify fun s => { s with modified := false } let simprocs := (← get).simprocs -- simplify entries @@ -105,14 +105,30 @@ private partial def loop : M Bool := do ctx := ctx.setSimpTheorems simpThmsNew entries[i] := { entry with type := typeNew, proof := proofNew, id := .other idNew } } + if (← get).modified then - loop + loop target else - return false - -def main (fvars : Array FVarId) : M (Option (Array FVarId × MVarId)) := do + -- We're done with the loop: simplify the target + if target then + let mvarId := (← get).mvarId + let (r, stats) ← simpTarget mvarId (← get).ctx simprocs (stats := { (← get) with }) + modify fun s => { s with usedTheorems := stats.usedTheorems, diag := stats.diag } + match r with + | none => return true + | some mvarIdNew => + unless mvarId == mvarIdNew do + modify fun s => { s with + modified := true + mvarId := mvarIdNew + } + return false + else + return false + +def main (fvars : Array FVarId) (target : Bool) : M (Option (Array FVarId × MVarId)) := do initEntries fvars - if (← loop) then + if (← loop target) then return none -- close the goal else let mvarId := (← get).mvarId @@ -140,28 +156,31 @@ def main (fvars : Array FVarId) : M (Option (Array FVarId × MVarId)) := do end SimpAll -def simpAllAssumptionsAux (mvarId : MVarId) (ctx : Simp.Context) (fvars : Array FVarId) (simprocs : SimprocsArray := #[]) (stats : Stats := {}) : +def simpAllAssumptionsAux (mvarId : MVarId) (ctx : Simp.Context) (fvars : Array FVarId) (target : Bool) + (simprocs : SimprocsArray := #[]) (stats : Stats := {}) : MetaM (Option (Array FVarId × MVarId) × Stats) := do mvarId.withContext do - let (r, s) ← (SimpAll.main fvars).run { stats with mvarId, ctx, simprocs } + let (r, s) ← (SimpAll.main fvars target).run { stats with mvarId, ctx, simprocs } if let .some (_, mvarIdNew) := r then if ctx.config.failIfUnchanged && mvarId == mvarIdNew then throwError "simp_all made no progress" return (r, { s with }) open Utils Simp Elab Tactic in -def simpAllAssumptions (config : Simp.Config) (simpOnly : Bool) (args : SimpArgs) (mvarId : MVarId) (fvars : Array FVarId) : +def simpAllAssumptions (config : Simp.Config) (simpOnly : Bool) (args : SimpArgs) + (mvarId : MVarId) (fvars : Array FVarId) (target : Bool) : MetaM (Option (Array FVarId × MVarId)) := do -- Initialize the simp context let (ctx, simprocs) ← mkSimpCtx simpOnly config .simpAll args -- Apply the simplifier - let (result?, _) ← simpAllAssumptionsAux mvarId ctx fvars (simprocs := simprocs) + let (result?, _) ← simpAllAssumptionsAux mvarId ctx fvars target (simprocs := simprocs) pure result? open Utils Simp Elab Tactic in -def evalSimpAllAssumptions (config : Simp.Config) (simpOnly : Bool) (args : SimpArgs) (mvarId : MVarId) (fvars : Array FVarId) : +def evalSimpAllAssumptions (config : Simp.Config) (simpOnly : Bool) (args : SimpArgs) + (mvarId : MVarId) (fvars : Array FVarId) (target : Bool) : TacticM (Option (Array FVarId)) := do - match ← simpAllAssumptions config simpOnly args mvarId fvars with + match ← simpAllAssumptions config simpOnly args mvarId fvars target with | none => replaceMainGoal []; pure none | some (fvarIds, mvarId) => replaceMainGoal [mvarId]; pure (some fvarIds) diff --git a/backends/lean/Aeneas/SimpScalar.lean b/backends/lean/Aeneas/SimpScalar.lean index 8c551f076..639aad374 100644 --- a/backends/lean/Aeneas/SimpScalar.lean +++ b/backends/lean/Aeneas/SimpScalar.lean @@ -1 +1,2 @@ import Aeneas.SimpScalar.SimpScalar +import Aeneas.SimpScalar.Tests diff --git a/backends/lean/Aeneas/Utils.lean b/backends/lean/Aeneas/Utils.lean index edc59b74d..a92bd6418 100644 --- a/backends/lean/Aeneas/Utils.lean +++ b/backends/lean/Aeneas/Utils.lean @@ -20,6 +20,10 @@ namespace LocalContext let ls ← lctx.getAllDecls pure (ls.filter (fun d => not d.isImplementationDetail)) + def getAssumptions (lctx : Lean.LocalContext) : MetaM (List Lean.LocalDecl) := do + let ls ← lctx.getAllDecls + ls.filterM (fun d => do pure (¬ d.isImplementationDetail && (← isProp d.type))) + end LocalContext end Lean @@ -246,6 +250,16 @@ section Methods end Methods +inductive Location where + /-- Apply the tactic everywhere. Same as `Location.wildcard` -/ + | wildcard + /-- Apply the tactic everywhere, including in the variable types (i.e., in + assumptions which are not propositions). + --/ + | wildcard_dep + /-- Same as Location -/ + | targets (hypotheses : Array FVarId) (type : Bool) + def addDeclTac {α} (name : Name) (val : Expr) (type : Expr) (asLet : Bool) (m : Expr → TacticM α) : TacticM α := -- I don't think we need that @@ -830,13 +844,46 @@ def rewriteAt (cfg : Rewrite.Config) (rpt : Bool) else evalRewriteSeqAux cfg thms loc -/-- Apply norm_cast to the whole context -/ -def normCastAtAll : TacticM Unit := do +@[inline] def liftMetaTactic2 (tactic : MVarId → MetaM (Option (α × MVarId))) : TacticM (Option α) := withMainContext do - let ctx ← Lean.MonadLCtx.getLCtx - let decls ← ctx.getDecls - NormCast.normCastTarget {} - decls.forM (fun d => NormCast.normCastHyp {} d.fvarId) + if let some (res, mvarId) ← tactic (← getMainGoal) then + replaceMainGoal [mvarId] + pure (some res) + else + replaceMainGoal [] + pure none + +/-- Copy/paster from `norm_cast`: we need to retrieve the new fvar id -/ +def normCastHyp (cfg : Simp.NormCastConfig) (fvarId : FVarId) : TacticM (Option FVarId) := + liftMetaTactic2 fun goal => do + let hyp ← instantiateMVars (← fvarId.getDecl).type + let prf ← NormCast.derive hyp cfg + return (← applySimpResultToLocalDecl goal fvarId prf false) + +/- Return `none` if the goal was closed, otherwise return the modified fvar ids -/ +def normCastAt (loc : Location) : TacticM (Option (Array (FVarId))) := do + withMainContext do + match loc with + | .targets asms tgt => + let mut fvarIds := #[] + if tgt then + NormCast.normCastTarget {} + for fvarId in asms do + match ← normCastHyp {} fvarId with + | some id => fvarIds := fvarIds.push id + | none => return none + return some fvarIds + | .wildcard => + NormCast.normCastTarget {} + let ctx ← Lean.MonadLCtx.getLCtx + let decls ← ctx.getDecls + let mut fvarIds := #[] + for d in decls do + match ← normCastHyp {} d.fvarId with + | some id => fvarIds := fvarIds.push id + | none => return none + return some fvarIds + | .wildcard_dep => throwError "Unimplemented: normCastAt wildcarp_dep" @[inline] def tryLiftMetaTactic1 (tactic : MVarId → MetaM (Option MVarId)) (msg : String) : TacticM Unit := withMainContext do @@ -1560,7 +1607,7 @@ def duplicateAssumptions (toDuplicate : Option (Array FVarId) := none) : match toDuplicate with | none => pure (← (← getLCtx).getDecls).toArray | some decls => do decls.mapM fun d => d.getDecl - let props ← decls.filterM (fun d => do pure (← inferType d.type).isProp) + let props ← decls.filterM (fun d => isProp d.type) trace[Utils] "Current assumptions: {props.map LocalDecl.type}" let goal ← getMainGoal let goalType ← instantiateMVars (← goal.getType) @@ -1583,16 +1630,6 @@ def duplicateAssumptions (toDuplicate : Option (Array FVarId) := none) : elab "duplicate_assumptions" : tactic => do let _ ← duplicateAssumptions -inductive Location where - /-- Apply the tactic everywhere. Same as `Location.wildcard` -/ - | wildcard - /-- Apply the tactic everywhere, including in the variable types (i.e., in - assumptions which are not propositions). - --/ - | wildcard_dep - /-- Same as Location -/ - | targets (hypotheses : Array FVarId) (type : Bool) - /-- info: example (a : Nat) From 6c1bdebbddeaebf66fd5fbd7aa00021def1feece Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 23 Jun 2025 12:01:21 +0100 Subject: [PATCH 05/31] Implement Simp.tacticToDischarge --- .../lean/Aeneas/ScalarTac/CondSimpTac.lean | 3 +- backends/lean/Aeneas/Simp/Simp.lean | 32 +++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean index 72f5f32b3..12757413f 100644 --- a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean +++ b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean @@ -83,8 +83,7 @@ def condSimpTacSimp (config : Simp.Config) (args : CondSimpArgs) (loc : Utils.Lo if dischWithScalarTac then /- Note that when calling `scalar_tac` we saturate only by looking at the target: we have already saturated by looking at the assumptions (we do this once and for all beforehand) -/ - let (ref, d) ← tacticToDischarge (← `(tactic|scalar_tac -saturateAssumptions)) - let dischargeWrapper := Lean.Elab.Tactic.Simp.DischargeWrapper.custom ref d + let dischargeWrapper ← Simp.tacticToDischarge (scalarTac {saturateAssumptions := false}) dischargeWrapper.with fun discharge? => do -- Initialize the simp context let (ctx, simprocs) ← Simp.mkSimpCtx true config .simp simpArgs diff --git a/backends/lean/Aeneas/Simp/Simp.lean b/backends/lean/Aeneas/Simp/Simp.lean index ed1f7c721..a74f6ad5a 100644 --- a/backends/lean/Aeneas/Simp/Simp.lean +++ b/backends/lean/Aeneas/Simp/Simp.lean @@ -12,6 +12,38 @@ structure SimpArgs where addSimpThms : Array Name := #[] hypsToUse : Array FVarId := #[] +/- This is adapted from `Lean.Elab.Tactic.tacticToDischarge` -/ +def tacticToDischargeAux (tactic : TacticM Unit) : TacticM (IO.Ref Term.State × Simp.Discharge) := do + let ref ← IO.mkRef (← getThe Term.State) + let ctx ← readThe Term.Context + let disch : Simp.Discharge := fun e => do + let mvar ← mkFreshExprSyntheticOpaqueMVar e `simp.discharger + let s ← ref.get + let runTac? : TermElabM (Option Expr) := + try + /- SH: I don't understand why we need this even though we do not do any elaboration + (removing those wrappers leads to failures in proofs). -/ + Term.withoutModifyingElabMetaStateWithInfo do + let ngoals ← Term.withSynthesize (postpone := .no) do + Tactic.run mvar.mvarId! tactic + if ngoals.isEmpty then + let result ← instantiateMVars mvar + if result.hasExprMVar then + return none + else + return some result + else return none + catch _ => + return none + let (result?, s) ← liftM (m := MetaM) <| Term.TermElabM.run runTac? ctx s + ref.set s + return result? + return (ref, disch) + +def tacticToDischarge (tactic : TacticM Unit) : TacticM Simp.DischargeWrapper := do + let (ref, d) ← tacticToDischargeAux tactic + pure (Lean.Elab.Tactic.Simp.DischargeWrapper.custom ref d) + /- Initialize a context for the `simp` function. The initialization of the context is adapted from `elabSimpArgs`. From 14b418803c29e8ad606835ef6e3bf93c52bb2ee9 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 23 Jun 2025 22:44:46 +0100 Subject: [PATCH 06/31] Update condSimpTac to incrementally preprocess the goal --- backends/lean/Aeneas/Array.lean | 2 + backends/lean/Aeneas/BitVec.lean | 3 + backends/lean/Aeneas/Bvify/Bvify.lean | 2 +- .../lean/Aeneas/ScalarTac/CondSimpTac.lean | 89 ++++++++++++------- backends/lean/Aeneas/ScalarTac/ScalarTac.lean | 77 +++++++++++----- .../lean/Aeneas/Simp/SimpAllAssumptions.lean | 40 ++++++--- backends/lean/Aeneas/SimpIfs/SimpIfs.lean | 2 +- backends/lean/Aeneas/SimpLists/SimpLists.lean | 1 + backends/lean/Aeneas/Std/Array/Array.lean | 4 +- backends/lean/Aeneas/Std/Scalar/Core.lean | 8 ++ backends/lean/Aeneas/Utils.lean | 10 +-- 11 files changed, 166 insertions(+), 72 deletions(-) diff --git a/backends/lean/Aeneas/Array.lean b/backends/lean/Aeneas/Array.lean index 15b5ed5b8..43fe5fd44 100644 --- a/backends/lean/Aeneas/Array.lean +++ b/backends/lean/Aeneas/Array.lean @@ -5,6 +5,8 @@ namespace Array attribute [-simp] List.getElem!_eq_getElem?_getD +attribute [scalar_tac_simps, simp_lists_simps] Array.size + def setSlice! {α} (a : Array α) (i : ℕ) (s : List α) : Array α := ⟨ a.toList.setSlice! i s⟩ diff --git a/backends/lean/Aeneas/BitVec.lean b/backends/lean/Aeneas/BitVec.lean index 6e122005c..4d4b77f12 100644 --- a/backends/lean/Aeneas/BitVec.lean +++ b/backends/lean/Aeneas/BitVec.lean @@ -17,6 +17,9 @@ open Lean attribute [-simp] List.getElem!_eq_getElem?_getD +attribute [bvify_simps, simp_scalar_simps] BitVec.zero_eq +attribute [bvify_simps, simp_scalar_simps] BitVec.instInhabited + def BitVec.toArray {n} (bv: BitVec n) : Array Bool := Array.finRange n |>.map (bv[·]) def BitVec.ofFn {n} (f: Fin n → Bool) : BitVec n := (BitVec.ofBoolListLE <| List.ofFn f).cast (by simp) def BitVec.set {n} (i: Fin n) (b: Bool) (bv: BitVec n) : BitVec n := bv ^^^ (((bv[i] ^^ b).toNat : BitVec n) <<< i.val) diff --git a/backends/lean/Aeneas/Bvify/Bvify.lean b/backends/lean/Aeneas/Bvify/Bvify.lean index dce465f91..d05eceadf 100644 --- a/backends/lean/Aeneas/Bvify/Bvify.lean +++ b/backends/lean/Aeneas/Bvify/Bvify.lean @@ -285,7 +285,7 @@ def bvifyTacSimp (loc : Utils.Location) : TacticM (Option (Array FVarId)) := do simpThms := #[← bvifySimpExt.getTheorems, ← SimpBoolProp.simpBoolPropSimpExt.getTheorems] simprocs := #[← bvifySimprocExt.getSimprocs, ← SimpBoolProp.simpBoolPropSimprocExt.getSimprocs] } - ScalarTac.condSimpTacSimp bvifySimpConfig args loc #[] false + ScalarTac.condSimpTacSimp bvifySimpConfig args loc #[] #[] none def bvifyTac (config : Config) (n : Expr) (loc : Utils.Location) : TacticM Unit := do let args : ScalarTac.CondSimpArgs := { diff --git a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean index 12757413f..896839452 100644 --- a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean +++ b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean @@ -71,25 +71,31 @@ structure CondSimpArgs where addSimpThms : Array Name := #[] hypsToUse : Array FVarId := #[] +def CondSimpArgs.toSimpArgs (args : CondSimpArgs) : Simp.SimpArgs := { + simpThms := args.simpThms, + simprocs := args.simprocs, + declsToUnfold := args.declsToUnfold, + addSimpThms := args.addSimpThms, + hypsToUse := args.hypsToUse } + def condSimpTacSimp (config : Simp.Config) (args : CondSimpArgs) (loc : Utils.Location) - (additionalAsms : Array FVarId := #[]) (dischWithScalarTac : Bool) : TacticM (Option (Array FVarId)) := do + (toClear : Array FVarId := #[]) + (additionalAsms : Array FVarId := #[]) (state : Option (ScalarTac.State × Array FVarId)) : + TacticM (Option (Array FVarId)) := do withMainContext do - let simpArgs : Simp.SimpArgs := - {simpThms := args.simpThms, - simprocs := args.simprocs, - declsToUnfold := args.declsToUnfold, - addSimpThms := args.addSimpThms, - hypsToUse := args.hypsToUse ++ additionalAsms} - if dischWithScalarTac then + let simpArgs := args.toSimpArgs + let simpArgs := { simpArgs with hypsToUse := simpArgs.hypsToUse ++ additionalAsms } + match state with + | some (state, asms) => /- Note that when calling `scalar_tac` we saturate only by looking at the target: we have already saturated by looking at the assumptions (we do this once and for all beforehand) -/ - let dischargeWrapper ← Simp.tacticToDischarge (scalarTac {saturateAssumptions := false}) + let dischargeWrapper ← Simp.tacticToDischarge (incrScalarTac {saturateAssumptions := false} state toClear asms) dischargeWrapper.with fun discharge? => do -- Initialize the simp context let (ctx, simprocs) ← Simp.mkSimpCtx true config .simp simpArgs -- Apply the simplifier pure ((← Simp.customSimpLocation ctx simprocs discharge? loc).fst) - else + | none => Simp.simpAt true config simpArgs loc /-- A helper to define tactics which perform conditional simplifications with `scalar_tac` as a discharger. -/ @@ -101,22 +107,45 @@ def condSimpTac Elab.Tactic.focus do withMainContext do trace[CondSimpTac] "Initial goal: {← getMainGoal}" + /- First duplicate the propositions in the context: we need the preprocessing of `scalar_tac` to modify + the assumptions, but we need to preserve a copy so that we can present a clean state to the user later + (and pretend nothing happened). Note that we do this in two times: we want to treat the simp + theorems provided by the user in `args` separately from the other assumptions. -/ + let allAssumptions ← pure (← (← getLCtx).getAssumptions).toArray + trace[CondSimpTac] "allAssumptions: {allAssumptions.map fun d => Expr.fvar d.fvarId}" + let (_, hypsToUse) ← Utils.duplicateAssumptions (some args.hypsToUse) + withMainContext do + trace[CondSimpTac] "Goal after duplicating the hyps to use: {← getMainGoal}" + trace[CondSimpTac] "hypsToUse: {hypsToUse.map Expr.fvar}" /- -/ - let toDuplicate ← do - match loc with - | .wildcard => pure none - | .wildcard_dep => throwError "{tacName} does not support using location `Utils.Location.wildcard_dep`" - | .targets hyps _ => pure (some hyps) - /- First duplicate the propositions in the context: we will need the original ones (which mention - integers rather than bit-vectors) for `scalar_tac` to succeed when doing the conditional rewritings. -/ - let (oldAsms, newAsms) ← Utils.duplicateAssumptions toDuplicate + let (oldAsms, newAsms) ← Utils.duplicateAssumptions (some (allAssumptions.map LocalDecl.fvarId)) + let toClear := oldAsms + withMainContext do trace[CondSimpTac] "Goal after duplicating the assumptions: {← getMainGoal}" - /- Introduce the scalar_tac assumptions - by doing this beforehand we don't have to - redo it every time we call `scalar_tac`. TODO: also do the `simp_all`. -/ + trace[CondSimpTac] "newAsms: {newAsms.map Expr.fvar}" + /- Preprocess the assumptions -/ + let scalarConfig : ScalarTac.Config := { nonLin := config.nonLin, saturationPasses := config.saturationPasses } + let state ← State.new scalarConfig + /- First the hyps to use -/ + let some (_, hypsToUse) ← scalarTacPartialPreprocess scalarConfig state #[] hypsToUse false + | trace[CondSimpTac] "Goal proven through preprocessing!"; return + withMainContext do + trace[CondSimpTac] "Goal after preprocessing the hyps to use ({hypsToUse.map Expr.fvar}): {← getMainGoal}" + /- Remove the `forall'` and simplify the hyps to use -/ + let simpHypsToUseArgs := { args.toSimpArgs with hypsToUse := #[], declsToUnfold := #[``forall'] } + let some hypsToUse ← Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 0} + simpHypsToUseArgs (.targets hypsToUse false) + | trace[ScalarTac] "Goal proven by preprocessing!"; return + let args := { args with hypsToUse } + withMainContext do + trace[CondSimpTac] "Goal after simplifying the preprocessed hyps to use ({hypsToUse.map Expr.fvar}): {← getMainGoal}" + /- Preprocess the "regular" assumptions -/ + let some (state, newAsms) ← scalarTacPartialPreprocess scalarConfig state #[] newAsms false + | trace[CondSimpTac] "Goal proven through preprocessing!"; return withMainContext do - ScalarTac.scalarTacSaturateForward { nonLin := config.nonLin, saturationPasses := config.saturationPasses } none none - fun _ scalarTacAsms => do - trace[CondSimpTac] "Goal after saturating the context: {← getMainGoal}" + trace[CondSimpTac] "Goal after the initial preprocessing: {← getMainGoal}" + trace[CondSimpTac] "newAsms: {newAsms.map Expr.fvar}" + /- Introduce the additional simp theorems -/ let additionalSimpThms ← addSimpThms trace[CondSimpTac] "Goal after adding the additional simp assumptions: {← getMainGoal}" /- Simplify the targets (note that we preserve the new assumptions for `scalar_tac`) -/ @@ -127,12 +156,12 @@ def condSimpTac | .targets hyps type => pure (Utils.Location.targets hyps type) let nloc ← if doFirstSimp then - match ← condSimpTacSimp simpConfig args loc additionalSimpThms false with + match ← condSimpTacSimp simpConfig args loc toClear additionalSimpThms none with | none => return | some freshFvarIds => match loc with | .wildcard => pure (Utils.Location.targets freshFvarIds true) - | .wildcard_dep => throwError "Unreachable" + | .wildcard_dep => throwError "{tacName} does not support using location `Utils.Location.wildcard_dep`" | .targets _ type => pure (Utils.Location.targets freshFvarIds type) else pure loc trace[CondSimpTac] "Goal after simplifying: {← getMainGoal}" @@ -140,14 +169,14 @@ def condSimpTac TODO: scalar_tac should only be allowed to preprocess `scalarTacAsms`. TODO: we should preprocess those. -/ - let _ ← condSimpTacSimp simpConfig args nloc additionalSimpThms true + let _ ← condSimpTacSimp simpConfig args nloc toClear additionalSimpThms (some (state, newAsms)) if (← getUnsolvedGoals) == [] then return /- Clear the additional assumptions -/ - Utils.clearFVarIds scalarTacAsms - trace[CondSimpTac] "Goal after clearing the scalar_tac assumptions: {← getMainGoal}" - Utils.clearFVarIds newAsms + setGoals [← (← getMainGoal).tryClearMany hypsToUse] + trace[CondSimpTac] "Goal after clearing the duplicated hypotheses to use: {← getMainGoal}" + setGoals [← (← getMainGoal).tryClearMany newAsms] trace[CondSimpTac] "Goal after clearing the duplicated assumptions: {← getMainGoal}" - Utils.clearFVarIds additionalSimpThms + setGoals [← (← getMainGoal).tryClearMany additionalSimpThms] trace[CondSimpTac] "Goal after clearing the additional theorems: {← getMainGoal}" end Aeneas.ScalarTac diff --git a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean index df5b9c10f..b3b513069 100644 --- a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean @@ -278,8 +278,17 @@ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do structure State where saturateState : Saturate.State +def State.new (config : Config) : MetaM State := do + let env ← getEnv + let rules := + if config.nonLin then #[scalarTacAttribute, scalarTacNonLinAttribute] + else #[scalarTacAttribute] + let rules := rules.map fun s => s.ext.getState env + let saturateState := Saturate.State.new rules + pure { saturateState } + def scalarTacPartialPreprocess (config : Config) (state : State) - (oldAssumptions assumptions : Array FVarId) (simpTarget : Bool) : + (hypsToUseForSimp assumptionsToPreprocess : Array FVarId) (simpTarget : Bool) : Tactic.TacticM (Option (State × Array FVarId)) := do Tactic.focus do Tactic.withMainContext do @@ -296,7 +305,7 @@ def scalarTacPartialPreprocess (config : Config) (state : State) don't want `simp_all` to use assumptions of the shape `∀ x, P x`)) -/ {simpArgs with addSimpThms := #[``forall_eq_forall']} -- TODO: it would be good to always simplify the target before exploring it to saturate - (.targets assumptions simpTarget) + (.targets assumptionsToPreprocess simpTarget) -- We might have proven the goal let some assumptions := r | trace[ScalarTac] "Goal proven by preprocessing!"; return none @@ -309,41 +318,41 @@ def scalarTacPartialPreprocess (config : Config) (state : State) } scalarTacSaturateForward satConfig state.saturateState (some assumptions) fun saturateState nassumptions => do let state := { state with saturateState } - let assumptions := assumptions ++ nassumptions let finish : Tactic.TacticM (Option (State × Array FVarId)) := do trace[ScalarTac] "Goal after saturation: {← getMainGoal}" -- Apply `simpAll` to the *new assumptions* let applySimpAll assumptions simpTarget := do - match - ← tryTactic? do + let r ← tryTactic? do /- By setting the maxDischargeDepth at 0, we make sure that assumptions of the shape `∀ x, P x → ...` will not have any effect. This is important because it often happens that the user instantiates one such assumptions with specific arguments, meaning that if we call `simpAll` naively, those instantiations will get simplified to `True` and thus eliminated. -/ - Simp.simpAllAssumptions + Simp.evalSimpAllAssumptions {failIfUnchanged := false, maxSteps := config.simpAllMaxSteps, maxDischargeDepth := 0} true - simpArgs (← getMainGoal) assumptions simpTarget - with - | some r => pure r - | none => pure (some (assumptions, ← getMainGoal)) + simpArgs assumptions simpTarget + if r.isSome then trace[ScalarTac] "applySimpAll succeeded" + else trace[ScalarTac] "applySimpAll failed" + match r with + | some (some (fvars, _)) => pure (some fvars) + | some none | none => pure (some assumptions) let r ← do - if config.simpAllMaxSteps ≠ 0 then + if config.simpAllMaxSteps ≠ 0 ∧ assumptionsToPreprocess.size + nassumptions.size > 0 then /- Simplify the new assumptions with the old assumptions (it's often enough to go in that direction) -/ - let some assumptions ← Simp.simpAt true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} - {simpArgs with hypsToUse := oldAssumptions} (.targets assumptions false) + let some nassumptions ← Simp.simpAt true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} + {simpArgs with hypsToUse := hypsToUseForSimp ++ assumptions} (.targets nassumptions false) | trace[ScalarTac] "Goal proven by preprocessing!"; return none /- Apply `simpAll` on the new assumptions -/ - let some (assumptions, mvarId) ← applySimpAll assumptions false + let some nassumptions ← applySimpAll nassumptions false | trace[ScalarTac] "Goal proven by preprocessing!"; return none - /- We apply `simpAll` on the old and new assumptions -/ - setGoals [mvarId] - applySimpAll (oldAssumptions ++ assumptions) simpTarget + /- We apply `simpAll` to all the assumptions -/ + applySimpAll (hypsToUseForSimp ++ assumptions ++ nassumptions) simpTarget + let some _ ← Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 0} + { hypsToUse := hypsToUseForSimp } (.targets #[] simpTarget) else - pure (some (assumptions, ← getMainGoal)) - let some (assumptions, mvarId) := r + pure (some assumptions) + let some assumptions := r | trace[ScalarTac] "Goal proven by preprocessing!"; return none - setGoals [mvarId] trace[ScalarTac] "Goal after simpAll: {← getMainGoal}" -- Call `simp` again, this time to inline the let-bindings (otherwise, omega doesn't always manage to deal with them) @@ -357,9 +366,16 @@ def scalarTacPartialPreprocess (config : Config) (state : State) | trace[ScalarTac] "Goal proven by preprocessing!"; return none trace[ScalarTac] "Goal after normCast: {← getMainGoal}" -- Call `simp` again because `normCast` sometimes does weird things - let some assumptions ← Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 1} simpArgs (.targets assumptions simpTarget) + let some assumptions ← Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 1} simpArgs + (.targets assumptions simpTarget) | trace[ScalarTac] "Goal proven by preprocessing!"; return none trace[ScalarTac] "Goal after 2nd call to simpAt: {← getMainGoal}" + /- Remove the occurrences of `forall'` in the target -/ + if simpTarget then + let some _ ← Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 0} + { declsToUnfold := #[``forall'] } (.targets #[] simpTarget) + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + trace[ScalarTac] "Goal after eliminating `forall'` in the target: {← getMainGoal}" -- We modified the assumptions in the context so we need to update the state accordingly let saturateState ← Saturate.recomputeAssumptions state.saturateState none assumptions let state := { state with saturateState } @@ -446,6 +462,25 @@ elab "scalar_tac" config:Parser.Tactic.optConfig : tactic => do let config ← elabConfig config scalarTac config +/-- Incremental version of `scalar_tac`, where we call preprocessing several times to incrementally + saturate the context. + TODO: do we really need the config? + -/ +def incrScalarTac (config : Config) (state : State) (toClear : Array FVarId) (assumptions : Array FVarId) : TacticM Unit := do + Tactic.focus do + Tactic.withMainContext do + /- Clear the useless assumptions -/ + let mvarId ← (← getMainGoal).tryClearMany toClear + setGoals [mvarId] + /- Saturate by exploring only the goal -/ + let some (_, _) ← scalarTacPartialPreprocess config state assumptions #[] true + | trace[ScalarTac] "incrScalarTac: goal proven by preprocessing" + trace[ScalarTac] "Goal after final preprocessing: {← getMainGoal}" + /- Call omega -/ + trace[ScalarTac] "Calling omega" + Tactic.Omega.omegaTactic {} + trace[ScalarTac] "Goal proved!" + -- For termination proofs syntax "int_decr_tac" : tactic macro_rules diff --git a/backends/lean/Aeneas/Simp/SimpAllAssumptions.lean b/backends/lean/Aeneas/Simp/SimpAllAssumptions.lean index 81c9aee82..c56789564 100644 --- a/backends/lean/Aeneas/Simp/SimpAllAssumptions.lean +++ b/backends/lean/Aeneas/Simp/SimpAllAssumptions.lean @@ -12,6 +12,8 @@ namespace Aeneas.Simp open Lean Meta +initialize registerTraceClass `Aeneas.Meta.Tactic.simp.all + open Simp (Stats SimprocsArray) namespace SimpAll @@ -126,7 +128,8 @@ private partial def loop (target : Bool) : M Bool := do else return false -def main (fvars : Array FVarId) (target : Bool) : M (Option (Array FVarId × MVarId)) := do +def main (fvars : Array FVarId) (target : Bool) : + M (Option (Array FVarId × Std.HashMap FVarId FVarId × MVarId)) := do initEntries fvars if (← loop target) then return none -- close the goal @@ -140,6 +143,9 @@ def main (fvars : Array FVarId) (target : Bool) : M (Option (Array FVarId × MVa let mut toClear := #[] let mut preserved := #[] let mut modified := false + let mut fvarIdMap := Std.HashMap.emptyWithCapacity + let mut toRemap := #[] + let mut i := 0 for e in (← get).entries do if e.type.isTrue then -- Do not assert `True` hypotheses @@ -148,20 +154,33 @@ def main (fvars : Array FVarId) (target : Bool) : M (Option (Array FVarId × MVa toClear := toClear.push e.fvarId toAssert := toAssert.push { userName := e.userName, type := e.type, value := e.proof } modified := true + toRemap := toRemap.push (e.fvarId, i) + i := i + 1 else preserved := preserved.push e.fvarId + fvarIdMap := fvarIdMap.insert e.fvarId e.fvarId + let _ ← mvarId.withContext do trace[Aeneas.Meta.Tactic.simp.all] "preserved: {preserved.map Expr.fvar}" + let _ ← mvarId.withContext do + trace[Aeneas.Meta.Tactic.simp.all] "toClear: {← toClear.mapM fun fid => do pure (Expr.fvar fid, ← inferType (Expr.fvar fid))}" + let _ ← mvarId.withContext do trace[Aeneas.Meta.Tactic.simp.all] "toRemap: {toRemap.map (Expr.fvar ∘ Prod.fst)}" let (freshIds, mvarId) ← mvarId.assertHypotheses toAssert + let _ ← mvarId.withContext do trace[Aeneas.Meta.Tactic.simp.all] "freshIds: {freshIds.map Expr.fvar}" + trace[Aeneas.Meta.Tactic.simp.all] "Before clearing: {mvarId}" let mvarId ← mvarId.tryClearMany toClear - pure (some (preserved ++ freshIds, mvarId)) + trace[Aeneas.Meta.Tactic.simp.all] "After clearing: {mvarId}" + for (fvarId, id) in toRemap do + fvarIdMap := fvarIdMap.insert fvarId freshIds[id]! + pure (some (preserved ++ freshIds, fvarIdMap, mvarId)) end SimpAll def simpAllAssumptionsAux (mvarId : MVarId) (ctx : Simp.Context) (fvars : Array FVarId) (target : Bool) (simprocs : SimprocsArray := #[]) (stats : Stats := {}) : - MetaM (Option (Array FVarId × MVarId) × Stats) := do + MetaM (Option (Array FVarId × Std.HashMap FVarId FVarId × MVarId) × Stats) := do mvarId.withContext do + trace[Aeneas.Meta.Tactic.simp.all] "Initial mvar: {mvarId}" let (r, s) ← (SimpAll.main fvars target).run { stats with mvarId, ctx, simprocs } - if let .some (_, mvarIdNew) := r then + if let .some (_, _, mvarIdNew) := r then if ctx.config.failIfUnchanged && mvarId == mvarIdNew then throwError "simp_all made no progress" return (r, { s with }) @@ -169,7 +188,7 @@ def simpAllAssumptionsAux (mvarId : MVarId) (ctx : Simp.Context) (fvars : Array open Utils Simp Elab Tactic in def simpAllAssumptions (config : Simp.Config) (simpOnly : Bool) (args : SimpArgs) (mvarId : MVarId) (fvars : Array FVarId) (target : Bool) : - MetaM (Option (Array FVarId × MVarId)) := do + MetaM (Option (Array FVarId × Std.HashMap FVarId FVarId × MVarId)) := do -- Initialize the simp context let (ctx, simprocs) ← mkSimpCtx simpOnly config .simpAll args -- Apply the simplifier @@ -178,13 +197,10 @@ def simpAllAssumptions (config : Simp.Config) (simpOnly : Bool) (args : SimpArgs open Utils Simp Elab Tactic in def evalSimpAllAssumptions (config : Simp.Config) (simpOnly : Bool) (args : SimpArgs) - (mvarId : MVarId) (fvars : Array FVarId) (target : Bool) : - TacticM (Option (Array FVarId)) := do - match ← simpAllAssumptions config simpOnly args mvarId fvars target with + (fvars : Array FVarId) (target : Bool) : + TacticM (Option (Array FVarId × Std.HashMap FVarId FVarId)) := do + match ← simpAllAssumptions config simpOnly args (← getMainGoal) fvars target with | none => replaceMainGoal []; pure none - | some (fvarIds, mvarId) => replaceMainGoal [mvarId]; pure (some fvarIds) - -initialize - registerTraceClass `Aeneas.Meta.Tactic.simp.all + | some (fvarIds, idsMap, mvarId) => replaceMainGoal [mvarId]; pure (some (fvarIds, idsMap)) end Aeneas.Simp diff --git a/backends/lean/Aeneas/SimpIfs/SimpIfs.lean b/backends/lean/Aeneas/SimpIfs/SimpIfs.lean index 370ff36df..d5ed1180f 100644 --- a/backends/lean/Aeneas/SimpIfs/SimpIfs.lean +++ b/backends/lean/Aeneas/SimpIfs/SimpIfs.lean @@ -57,7 +57,7 @@ elab stx:simp_ifs : tactic => let (config, args, loc) ← parseSimpIfs stx simpIfsTac config args loc -example [Inhabited α] (i j : Nat) (h :i ≥ j ∧ i < j + 1) : (if i = j then 0 else 1) = 0 := by +example {α} [Inhabited α] (i j : Nat) (h :i ≥ j ∧ i < j + 1) : (if i = j then 0 else 1) = 0 := by simp_ifs end Aeneas.SimpIfs diff --git a/backends/lean/Aeneas/SimpLists/SimpLists.lean b/backends/lean/Aeneas/SimpLists/SimpLists.lean index 428a7d4ee..576b3166d 100644 --- a/backends/lean/Aeneas/SimpLists/SimpLists.lean +++ b/backends/lean/Aeneas/SimpLists/SimpLists.lean @@ -35,6 +35,7 @@ attribute [simp_lists_simps] Nat.testBit_two_pow_add_eq attribute [simp_lists_simps] List.map_map List.map_id_fun List.map_id_fun' id_eq +attribute [simp_lists_simps] Fin.getElem!_fin def simpListsTac (config : ScalarTac.CondSimpTacConfig) (args : ScalarTac.CondSimpPartialArgs) (loc : Utils.Location) : TacticM Unit := do diff --git a/backends/lean/Aeneas/Std/Array/Array.lean b/backends/lean/Aeneas/Std/Array/Array.lean index fb39463fa..732430d94 100644 --- a/backends/lean/Aeneas/Std/Array/Array.lean +++ b/backends/lean/Aeneas/Std/Array/Array.lean @@ -119,12 +119,12 @@ def Array.set {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: α) : Arr def Array.set_opt {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: Option α) : Array α n := ⟨ v.val.set_opt i.val x, by have := v.property; simp [*] ⟩ -@[simp] +@[simp, simp_lists_simps] theorem Array.set_val_eq {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: α) : (v.set i x).val = v.val.set i.val x := by simp [set] -@[simp] +@[simp, simp_lists_simps] theorem Array.set_opt_val_eq {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: Option α) : (v.set_opt i x).val = v.val.set_opt i.val x := by simp [set_opt] diff --git a/backends/lean/Aeneas/Std/Scalar/Core.lean b/backends/lean/Aeneas/Std/Scalar/Core.lean index f815654f6..81d821aba 100644 --- a/backends/lean/Aeneas/Std/Scalar/Core.lean +++ b/backends/lean/Aeneas/Std/Scalar/Core.lean @@ -751,6 +751,14 @@ instance (ty : UScalarTy) : Inhabited (UScalar ty) := by instance (ty : IScalarTy) : Inhabited (IScalar ty) := by constructor; cases ty <;> apply (IScalar.ofInt 0 (by simp [IScalar.cMin, IScalar.cMax, IScalar.rMin, IScalar.rMax]; simp_bounds)) +@[simp, simp_scalar_simps] +theorem UScalar.default_val {ty} : (default : UScalar ty).val = 0 := by + simp only [default]; cases ty <;> simp + +@[simp, simp_scalar_simps] +theorem UScalar.default_bv {ty} : (default : UScalar ty).bv = 0 := by + simp only [default]; cases ty <;> simp + theorem IScalar.min_lt_max (ty : IScalarTy) : IScalar.min ty < IScalar.max ty := by cases ty <;> simp [IScalar.min, IScalar.max] <;> (try simp_bounds) have : (0 : Int) < 2 ^ (System.Platform.numBits - 1) := by simp diff --git a/backends/lean/Aeneas/Utils.lean b/backends/lean/Aeneas/Utils.lean index a92bd6418..5fbbd8c93 100644 --- a/backends/lean/Aeneas/Utils.lean +++ b/backends/lean/Aeneas/Utils.lean @@ -21,8 +21,8 @@ namespace LocalContext pure (ls.filter (fun d => not d.isImplementationDetail)) def getAssumptions (lctx : Lean.LocalContext) : MetaM (List Lean.LocalDecl) := do - let ls ← lctx.getAllDecls - ls.filterM (fun d => do pure (¬ d.isImplementationDetail && (← isProp d.type))) + let ls ← lctx.getDecls + ls.filterM (fun d => isProp d.type) end LocalContext @@ -1606,8 +1606,9 @@ def duplicateAssumptions (toDuplicate : Option (Array FVarId) := none) : let decls ← do match toDuplicate with | none => pure (← (← getLCtx).getDecls).toArray - | some decls => do decls.mapM fun d => d.getDecl - let props ← decls.filterM (fun d => isProp d.type) + | some decls => decls.mapM fun d => d.getDecl + trace[Utils] "All declarations: {decls.map fun d => Expr.fvar d.fvarId}" + let props ← decls.filterM (fun d => do isProp d.type) trace[Utils] "Current assumptions: {props.map LocalDecl.type}" let goal ← getMainGoal let goalType ← instantiateMVars (← goal.getType) @@ -1622,7 +1623,6 @@ def duplicateAssumptions (toDuplicate : Option (Array FVarId) := none) : let e ← mkLambdaFVars #[var] mgoal let e := mkApp e (.fvar d.fvarId) currentGoal.assign e - --currentGoal.assign (.letE d.userName d.type (.fvar d.fvarId) mgoal false) replaceMainGoal [mgoal.mvarId!] pure var.fvarId! pure (props.map LocalDecl.fvarId, newProps) From 675bb5a3b61c62867afafa7978410ebee5c09a4b Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 23 Jun 2025 22:51:52 +0100 Subject: [PATCH 07/31] Fix an issue in ScalarTac.scalarTacPartialPreprocess --- backends/lean/Aeneas/ScalarTac/ScalarTac.lean | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean index b3b513069..c068d9b71 100644 --- a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean @@ -347,8 +347,13 @@ def scalarTacPartialPreprocess (config : Config) (state : State) | trace[ScalarTac] "Goal proven by preprocessing!"; return none /- We apply `simpAll` to all the assumptions -/ applySimpAll (hypsToUseForSimp ++ assumptions ++ nassumptions) simpTarget + else if hypsToUseForSimp.size > 0 && simpTarget then + /- Even though there is nothing to preprocess, we want to simplify the goal by using the hypotheses to use, + to make sure we propagate equalities for instance -/ let some _ ← Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 0} { hypsToUse := hypsToUseForSimp } (.targets #[] simpTarget) + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + pure (some assumptions) else pure (some assumptions) let some assumptions := r From 90b57297b7a05f624ed4f6d2784df16969da2fda Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 23 Jun 2025 22:59:51 +0100 Subject: [PATCH 08/31] Use the isProp from the standard library --- backends/lean/Aeneas/Let/Let.lean | 4 ++-- backends/lean/Aeneas/Saturate/Attribute.lean | 2 +- backends/lean/Aeneas/Saturate/Tactic.lean | 8 ++++---- backends/lean/Aeneas/ScalarTac/CondSimpTac.lean | 2 +- backends/lean/Aeneas/ScalarTac/ScalarTac.lean | 4 +--- 5 files changed, 9 insertions(+), 11 deletions(-) diff --git a/backends/lean/Aeneas/Let/Let.lean b/backends/lean/Aeneas/Let/Let.lean index ad766f1bc..f7889ea18 100644 --- a/backends/lean/Aeneas/Let/Let.lean +++ b/backends/lean/Aeneas/Let/Let.lean @@ -15,7 +15,7 @@ def opaque_refold (h x : Name) (e : Expr) : TacticM Unit := withMainContext do /- Retrieve the list of propositions in the context -/ let ctx ← getLCtx - let props ← (← ctx.getDecls).filterM fun x => do pure (← inferType x.type).isProp + let props ← (← ctx.getDecls).filterM fun x => isProp x.type /- Generalize -/ let goal ← getMainGoal let (_, _, ngoal) ← goal.generalizeHyp #[{expr := e, xName? := x, hName? := h}] (props.map LocalDecl.fvarId).toArray @@ -66,7 +66,7 @@ def transparent_refold (x : Name) (e : Expr) : TacticM Unit := withMainContext do /- Retrieve the list of propositions in the context -/ let ctx ← getLCtx - let props ← (← ctx.getDecls).filterM fun x => do pure (← inferType x.type).isProp + let props ← (← ctx.getDecls).filterM fun x => isProp x.type /- List the assumptions which contain the declaration that we want to refold -/ let mut toRevert := #[] for decl in props.reverse do diff --git a/backends/lean/Aeneas/Saturate/Attribute.lean b/backends/lean/Aeneas/Saturate/Attribute.lean index c58b470a8..cd1f7f5cc 100644 --- a/backends/lean/Aeneas/Saturate/Attribute.lean +++ b/backends/lean/Aeneas/Saturate/Attribute.lean @@ -245,7 +245,7 @@ def makeAttribute (mapName attributeName : Name) (elabAttribute : Syntax → Met else -- Create a pattern let pat ← inferType fv - unless (← inferType pat).isProp do + unless ← isProp pat do throwError "Found a free variable not bound in the (optional) user provided pattern or in a precondition: {fv}" let curPatFVars ← getFVarIds pat (Std.HashSet.emptyWithCapacity) patFVars := patFVars.union (curPatFVars.insert fv.fvarId!) diff --git a/backends/lean/Aeneas/Saturate/Tactic.lean b/backends/lean/Aeneas/Saturate/Tactic.lean index 5d8b71ad0..994621aea 100644 --- a/backends/lean/Aeneas/Saturate/Tactic.lean +++ b/backends/lean/Aeneas/Saturate/Tactic.lean @@ -367,8 +367,8 @@ def matchExpr def filterProofTerms (config : Config) (exprs : Array Expr) : MetaM (Array Expr) := if ¬ config.visitProofTerms then exprs.filterM fun arg => do - let ty ← inferType (← inferType arg) - if ty.isProp then pure false + let ty ← inferType arg + if ← isProp ty then pure false else pure true else pure exprs @@ -574,7 +574,7 @@ partial def evalSaturateCore /- We explore both the type, the expression and the body (if there is) -/ /- Note that the path is used only when exploring the type of assumptions -/ let path ← - if (← inferType decl.type).isProp then pure (some (.asm decl.fvarId)) + if ← isProp decl.type then pure (some (.asm decl.fvarId)) else pure none let state ← visit config path state decl.type let state ← visit config none state decl.toExpr @@ -688,7 +688,7 @@ def recomputeAssumptions let decls ← declsToExplore.mapM fun d => d.getDecl decls.foldlM (fun (state : State) (decl : LocalDecl) => do trace[Saturate] "Exploring local decl: {decl.userName}" - if (← inferType decl.type).isProp then + if ← isProp decl.type then let path := some (.asm decl.fvarId) visit path state decl.type else pure state) state diff --git a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean index 896839452..14d90e9fc 100644 --- a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean +++ b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean @@ -55,7 +55,7 @@ def condSimpParseArgs (tacName : String) (args : TSyntaxArray [`term, `token.«* trace[CondSimpTac] "found token: *" let decls ← (← getLCtx).getDecls let decls ← decls.filterMapM ( - fun d => do if (← inferType d.type).isProp then pure (some d.fvarId) else pure none) + fun d => do if ← isProp d.type then pure (some d.fvarId) else pure none) trace[CondSimpTac] "filtered decls: {decls.map Expr.fvar}" hypsToUse := hypsToUse.append decls.toArray else diff --git a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean index c068d9b71..7d4163d06 100644 --- a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean @@ -208,9 +208,7 @@ def getSimpThmNames : CoreM (Array Name) := do the assumptions *does* work, hence this peculiar way of simplifying the context. -/ def simpAsmsTarget (simpOnly : Bool) (config : Simp.Config) (args : Simp.SimpArgs) : TacticM (Option (Array FVarId)) := withMainContext do - let lctx ← getLCtx - let decls ← lctx.getDecls - let props ← decls.filterM (fun d => do pure (← inferType d.type).isProp) + let props ← (← getLCtx).getAssumptions let props := (props.map fun d => d.fvarId).toArray Aeneas.Simp.simpAt simpOnly config args (.targets props true) From 50ae531ba930eb7f411f109150c8af456c6f36e1 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 23 Jun 2025 23:30:02 +0100 Subject: [PATCH 09/31] Do not unfold the local decls in the hyps to use in condSimpTac --- backends/lean/Aeneas/ScalarTac/CondSimpTac.lean | 8 +++++--- backends/lean/Aeneas/ScalarTac/ScalarTac.lean | 9 ++++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean index 14d90e9fc..5541cb2f9 100644 --- a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean +++ b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean @@ -126,8 +126,10 @@ def condSimpTac /- Preprocess the assumptions -/ let scalarConfig : ScalarTac.Config := { nonLin := config.nonLin, saturationPasses := config.saturationPasses } let state ← State.new scalarConfig - /- First the hyps to use -/ - let some (_, hypsToUse) ← scalarTacPartialPreprocess scalarConfig state #[] hypsToUse false + /- First the hyps to use. + Note that we do not inline the local let-declarations: we will do this only for the "regular" assumptions + and the target. -/ + let some (_, hypsToUse) ← scalarTacPartialPreprocess scalarConfig state (zetaDelta := false) #[] hypsToUse false | trace[CondSimpTac] "Goal proven through preprocessing!"; return withMainContext do trace[CondSimpTac] "Goal after preprocessing the hyps to use ({hypsToUse.map Expr.fvar}): {← getMainGoal}" @@ -140,7 +142,7 @@ def condSimpTac withMainContext do trace[CondSimpTac] "Goal after simplifying the preprocessed hyps to use ({hypsToUse.map Expr.fvar}): {← getMainGoal}" /- Preprocess the "regular" assumptions -/ - let some (state, newAsms) ← scalarTacPartialPreprocess scalarConfig state #[] newAsms false + let some (state, newAsms) ← scalarTacPartialPreprocess scalarConfig state (zetaDelta := true) #[] newAsms false | trace[CondSimpTac] "Goal proven through preprocessing!"; return withMainContext do trace[CondSimpTac] "Goal after the initial preprocessing: {← getMainGoal}" diff --git a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean index 7d4163d06..3deceaacb 100644 --- a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean @@ -286,7 +286,7 @@ def State.new (config : Config) : MetaM State := do pure { saturateState } def scalarTacPartialPreprocess (config : Config) (state : State) - (hypsToUseForSimp assumptionsToPreprocess : Array FVarId) (simpTarget : Bool) : + (zetaDelta : Bool) (hypsToUseForSimp assumptionsToPreprocess : Array FVarId) (simpTarget : Bool) : Tactic.TacticM (Option (State × Array FVarId)) := do Tactic.focus do Tactic.withMainContext do @@ -359,7 +359,10 @@ def scalarTacPartialPreprocess (config : Config) (state : State) trace[ScalarTac] "Goal after simpAll: {← getMainGoal}" -- Call `simp` again, this time to inline the let-bindings (otherwise, omega doesn't always manage to deal with them) - let r ← Simp.simpAt true {zetaDelta := true, failIfUnchanged := false, maxDischargeDepth := 1} simpArgs (.targets assumptions simpTarget) + let r ← do + if zetaDelta then + Simp.simpAt true {zetaDelta := true, failIfUnchanged := false, maxDischargeDepth := 1} simpArgs (.targets assumptions simpTarget) + else pure assumptions -- We might have proven the goal let some assumptions := r | trace[ScalarTac] "Goal proven by preprocessing!"; return none @@ -476,7 +479,7 @@ def incrScalarTac (config : Config) (state : State) (toClear : Array FVarId) (as let mvarId ← (← getMainGoal).tryClearMany toClear setGoals [mvarId] /- Saturate by exploring only the goal -/ - let some (_, _) ← scalarTacPartialPreprocess config state assumptions #[] true + let some (_, _) ← scalarTacPartialPreprocess config state (zetaDelta := true) assumptions #[] true | trace[ScalarTac] "incrScalarTac: goal proven by preprocessing" trace[ScalarTac] "Goal after final preprocessing: {← getMainGoal}" /- Call omega -/ From 85606ac58158f30c58be72b769bc1ae9d01388c7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 24 Jun 2025 10:10:08 +0100 Subject: [PATCH 10/31] Make a minor modification to scalar_decr_tac --- backends/lean/Aeneas/ScalarDecrTac.lean | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backends/lean/Aeneas/ScalarDecrTac.lean b/backends/lean/Aeneas/ScalarDecrTac.lean index f2e67c343..c96413a46 100644 --- a/backends/lean/Aeneas/ScalarDecrTac.lean +++ b/backends/lean/Aeneas/ScalarDecrTac.lean @@ -46,9 +46,10 @@ def removeInvImageAssumptions : TacticM Unit := do let filtDecls ← liftM (decls.filterM fun decl => do if ← isProp decl.type then containsInvertImage decl else pure false) + let filtDecls := filtDecls.toArray.map LocalDecl.fvarId /- Attempt to clear those assumptions - note that it may not always succeed as some other assumptions might depend on them -/ - tryClearFVarIds ⟨ filtDecls.map fun d => d.fvarId ⟩ + setGoals [← (← getMainGoal).tryClearMany filtDecls] elab "remove_invImage_assumptions" : tactic => removeInvImageAssumptions From 509649df6b397e75e202af52c1ce091c1991ff45 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 24 Jun 2025 10:10:29 +0100 Subject: [PATCH 11/31] Move some scalar_tac tests --- backends/lean/Aeneas/ScalarTac/ScalarTac.lean | 50 ------------------ backends/lean/Aeneas/ScalarTac/Tests.lean | 51 +++++++++++++++++++ 2 files changed, 51 insertions(+), 50 deletions(-) diff --git a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean index 3deceaacb..d7b8b2ec5 100644 --- a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean @@ -498,56 +498,6 @@ macro_rules | apply ScalarTac.to_int_sub_to_nat_lt) <;> simp_all <;> scalar_tac) --- Checking that things happen correctly when there are several conjunctions -example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y := by - scalar_tac - --- Checking that things happen correctly when there are several conjunctions -example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y ∧ x + y ≥ 2 := by - scalar_tac - --- Checking that we can prove exfalso -example (a : Prop) (x : Int) (h0: 0 < x) (h1: x < 0) : a := by - scalar_tac - --- Intermediate cast through natural numbers -example (a : Prop) (x : Int) (h0: (0 : Nat) < x) (h1: x < 0) : a := by - scalar_tac - -example (x : Int) (h : x ≤ -3) : x ≤ -2 := by - scalar_tac - -example (x y : Int) (h : x + y = 3) : - let z := x + y - z = 3 := by - intro z - omega - -example (P : Nat → Prop) (z : Nat) (h : ∀ x, P x → x ≤ z) (y : Nat) (hy : P y) : - y + 2 ≤ z + 2 := by - have := h y hy - scalar_tac - --- Checking that we manage to split the cases/if then else -example (x : Int) (b : Bool) (h : if b then x ≤ 0 else x ≤ 0) : x ≤ 0 := by - scalar_tac +split - -/-! -Checking some non-linear problems --/ - -example (x y : Nat) (h0 : x ≤ 4) (h1 : y ≤ 5): x * y ≤ 4 * 5 := by - scalar_tac +nonLin - -example (x y : Nat) (h0 : x ≤ 4) (h1 : y < 5): x * y < 4 * 5 := by - scalar_tac +nonLin - -example (x y : Nat) (h0 : x < 4) (h1 : y < 5): x * y < 4 * 5 := by - scalar_tac +nonLin - -example (x y : Nat) (h0 : x < 4) (h1 : y ≤ 5): x * y < 4 * 5 := by - scalar_tac +nonLin - end ScalarTac end Aeneas diff --git a/backends/lean/Aeneas/ScalarTac/Tests.lean b/backends/lean/Aeneas/ScalarTac/Tests.lean index 9298b23d7..b7acdcf94 100644 --- a/backends/lean/Aeneas/ScalarTac/Tests.lean +++ b/backends/lean/Aeneas/ScalarTac/Tests.lean @@ -6,6 +6,57 @@ namespace Aeneas.Std.ScalarTac.Tests # Tests -/ +-- Checking that things happen correctly when there are several conjunctions +example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y := by + scalar_tac + +-- Checking that things happen correctly when there are several conjunctions +example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y ∧ x + y ≥ 2 := by + scalar_tac + +-- Checking that we can prove exfalso +example (a : Prop) (x : Int) (h0: 0 < x) (h1: x < 0) : a := by + scalar_tac + +-- Intermediate cast through natural numbers +example (a : Prop) (x : Int) (h0: (0 : Nat) < x) (h1: x < 0) : a := by + scalar_tac + +example (x : Int) (h : x ≤ -3) : x ≤ -2 := by + scalar_tac + +example (x y : Int) (h : x + y = 3) : + let z := x + y + z = 3 := by + intro z + omega + +example (P : Nat → Prop) (z : Nat) (h : ∀ x, P x → x ≤ z) (y : Nat) (hy : P y) : + y + 2 ≤ z + 2 := by + have := h y hy + scalar_tac + +-- Checking that we manage to split the cases/if then else +example (x : Int) (b : Bool) (h : if b then x ≤ 0 else x ≤ 0) : x ≤ 0 := by + scalar_tac +split + +/-! +Checking some non-linear problems +-/ + +example (x y : Nat) (h0 : x ≤ 4) (h1 : y ≤ 5): x * y ≤ 4 * 5 := by + scalar_tac +nonLin + +example (x y : Nat) (h0 : x ≤ 4) (h1 : y < 5): x * y < 4 * 5 := by + scalar_tac +nonLin + +example (x y : Nat) (h0 : x < 4) (h1 : y < 5): x * y < 4 * 5 := by + scalar_tac +nonLin + +example (x y : Nat) (h0 : x < 4) (h1 : y ≤ 5): x * y < 4 * 5 := by + scalar_tac +nonLin + + example (x _y : U32) : x.val ≤ UScalar.max .U32 := by scalar_tac_preprocess From 3f1e9ee5208759bd504c4a4812db87e794acc163 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 24 Jun 2025 11:55:59 +0100 Subject: [PATCH 12/31] Use specialized simp lemmas to simplify the hyps to use in condSimpTac --- backends/lean/Aeneas/Bvify/Bvify.lean | 6 +++++- backends/lean/Aeneas/Bvify/Init.lean | 14 ++++++++++++++ backends/lean/Aeneas/List/List.lean | 2 +- .../lean/Aeneas/ScalarTac/CondSimpTac.lean | 18 ++++++++++++++---- backends/lean/Aeneas/ScalarTac/ScalarTac.lean | 5 ++--- backends/lean/Aeneas/SimpBoolProp/Init.lean | 16 ++++++++++++++++ backends/lean/Aeneas/SimpIfs/Init.lean | 15 +++++++++++++++ backends/lean/Aeneas/SimpIfs/SimpIfs.lean | 13 ++++++++++--- backends/lean/Aeneas/SimpLists/Init.lean | 14 ++++++++++++++ backends/lean/Aeneas/SimpLists/SimpLists.lean | 9 ++++++++- backends/lean/Aeneas/SimpListsScalar.lean | 16 +++++++++++++++- backends/lean/Aeneas/SimpScalar/Init.lean | 14 ++++++++++++++ .../lean/Aeneas/SimpScalar/SimpScalar.lean | 9 ++++++++- backends/lean/Aeneas/Std/Array/Array.lean | 16 ++++++++-------- backends/lean/Aeneas/Std/Slice.lean | 16 ++++++++-------- backends/lean/Aeneas/Std/Vec.lean | 16 ++++++++-------- backends/lean/Aeneas/ZModify/Init.lean | 17 +++++++++++++++++ backends/lean/Aeneas/ZModify/ZModify.lean | 11 +++++++++-- 18 files changed, 186 insertions(+), 41 deletions(-) diff --git a/backends/lean/Aeneas/Bvify/Bvify.lean b/backends/lean/Aeneas/Bvify/Bvify.lean index d05eceadf..ec75e043b 100644 --- a/backends/lean/Aeneas/Bvify/Bvify.lean +++ b/backends/lean/Aeneas/Bvify/Bvify.lean @@ -288,12 +288,16 @@ def bvifyTacSimp (loc : Utils.Location) : TacticM (Option (Array FVarId)) := do ScalarTac.condSimpTacSimp bvifySimpConfig args loc #[] #[] none def bvifyTac (config : Config) (n : Expr) (loc : Utils.Location) : TacticM Unit := do + let hypsArgs : ScalarTac.CondSimpArgs := { + simpThms := #[← bvifyHypsSimpExt.getTheorems] + simprocs := #[← bvifyHypsSimprocExt.getSimprocs] + } let args : ScalarTac.CondSimpArgs := { simpThms := #[← bvifySimpExt.getTheorems, ← SimpBoolProp.simpBoolPropSimpExt.getTheorems] simprocs := #[← bvifySimprocExt.getSimprocs, ← SimpBoolProp.simpBoolPropSimprocExt.getSimprocs] } let config := { nonLin := config.nonLin, saturationPasses := config.saturationPasses } - ScalarTac.condSimpTac "bvify" config bvifySimpConfig args (bvifyAddSimpThms n) true loc + ScalarTac.condSimpTac "bvify" config bvifySimpConfig hypsArgs args (bvifyAddSimpThms n) true loc syntax (name := bvify) "bvify " colGt Parser.Tactic.optConfig term (location)? : tactic diff --git a/backends/lean/Aeneas/Bvify/Init.lean b/backends/lean/Aeneas/Bvify/Init.lean index 2b446df94..b7a6fa352 100644 --- a/backends/lean/Aeneas/Bvify/Init.lean +++ b/backends/lean/Aeneas/Bvify/Init.lean @@ -29,4 +29,18 @@ initialize bvifySimprocExt : Simp.SimprocExtension ← The `bvify_simps_proc` attribute registers simp procedures to be used by `bvify` during its preprocessing phase." none --(some bvifySimprocsRef) +/-- The `bvify_hyps_simps` simp attribute. -/ +initialize bvifyHypsSimpExt : SimpExtension ← + registerSimpAttr `bvify_hyps_simps "\ + The `bvify_hyps_simps` attribute registers simp lemmas to be used by `bvify`." + +-- TODO: initialization fails with this, while the same works for `scalar_tac`?? +--initialize bvifySimprocsRef : IO.Ref Simprocs ← IO.mkRef {} + +/-- The `bvify_hyps_simps_proc` simp attribute for the simp rocs. -/ +initialize bvifyHypsSimprocExt : Simp.SimprocExtension ← + Simp.registerSimprocAttr `bvify_hyps_simps_proc "\ + The `bvify_hyps_simps_proc` attribute registers simp procedures to be used by `bvify` + during its preprocessing phase." none --(some bvifySimprocsRef) + end Aeneas.Bvify diff --git a/backends/lean/Aeneas/List/List.lean b/backends/lean/Aeneas/List/List.lean index 0b225d945..19c689aa1 100644 --- a/backends/lean/Aeneas/List/List.lean +++ b/backends/lean/Aeneas/List/List.lean @@ -588,7 +588,7 @@ theorem setSlice!_drop {α} (l : List α) (i : ℕ) : by_cases h1: j < l.length <;> simp_lists simp_scalar -@[simp_lists_simps] +@[simp_lists_hyps_simps] def Inhabited_getElem_eq_getElem! {α} [Inhabited α] (l : List α) (i : ℕ) (hi : i < l.length) : l[i] = l[i]! := by simp only [List.getElem!_eq_getElem?_getD, List.getElem?_eq_getElem, Option.getD_some, hi] diff --git a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean index 5541cb2f9..aab149182 100644 --- a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean +++ b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean @@ -71,6 +71,12 @@ structure CondSimpArgs where addSimpThms : Array Name := #[] hypsToUse : Array FVarId := #[] +instance : HAppend CondSimpArgs CondSimpArgs CondSimpArgs where + hAppend c0 c1 := + let ⟨ a0, b0, c0, d0, e0 ⟩ := c0 + let ⟨ a1, b1, c1, d1, e1 ⟩ := c1 + ⟨ a0 ++ a1, b0 ++ b1, c0 ++ c1, d0 ++ d1, e0 ++ e1 ⟩ + def CondSimpArgs.toSimpArgs (args : CondSimpArgs) : Simp.SimpArgs := { simpThms := args.simpThms, simprocs := args.simprocs, @@ -101,11 +107,15 @@ def condSimpTacSimp (config : Simp.Config) (args : CondSimpArgs) (loc : Utils.Lo /-- A helper to define tactics which perform conditional simplifications with `scalar_tac` as a discharger. -/ def condSimpTac (tacName : String) (config : CondSimpTacConfig) - (simpConfig : Simp.Config) (args : CondSimpArgs) + (simpConfig : Simp.Config) (hypsArgs args : CondSimpArgs) (addSimpThms : TacticM (Array FVarId)) (doFirstSimp : Bool) (loc : Utils.Location) : TacticM Unit := do Elab.Tactic.focus do withMainContext do + /- Concatenate the arguments for the conditional rewritings: we only want to restrict the simplifications + performed on the higher-order hypotheses, but we need all the simplifications performed to those to + be applied to the rest of the context as well. -/ + let args := args ++ hypsArgs trace[CondSimpTac] "Initial goal: {← getMainGoal}" /- First duplicate the propositions in the context: we need the preprocessing of `scalar_tac` to modify the assumptions, but we need to preserve a copy so that we can present a clean state to the user later @@ -129,12 +139,12 @@ def condSimpTac /- First the hyps to use. Note that we do not inline the local let-declarations: we will do this only for the "regular" assumptions and the target. -/ - let some (_, hypsToUse) ← scalarTacPartialPreprocess scalarConfig state (zetaDelta := false) #[] hypsToUse false + let some (_, hypsToUse) ← scalarTacPartialPreprocess scalarConfig hypsArgs.toSimpArgs state (zetaDelta := false) #[] hypsToUse false | trace[CondSimpTac] "Goal proven through preprocessing!"; return withMainContext do trace[CondSimpTac] "Goal after preprocessing the hyps to use ({hypsToUse.map Expr.fvar}): {← getMainGoal}" /- Remove the `forall'` and simplify the hyps to use -/ - let simpHypsToUseArgs := { args.toSimpArgs with hypsToUse := #[], declsToUnfold := #[``forall'] } + let simpHypsToUseArgs := { hypsArgs with hypsToUse := #[], declsToUnfold := #[``forall'] } let some hypsToUse ← Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 0} simpHypsToUseArgs (.targets hypsToUse false) | trace[ScalarTac] "Goal proven by preprocessing!"; return @@ -142,7 +152,7 @@ def condSimpTac withMainContext do trace[CondSimpTac] "Goal after simplifying the preprocessed hyps to use ({hypsToUse.map Expr.fvar}): {← getMainGoal}" /- Preprocess the "regular" assumptions -/ - let some (state, newAsms) ← scalarTacPartialPreprocess scalarConfig state (zetaDelta := true) #[] newAsms false + let some (state, newAsms) ← scalarTacPartialPreprocess scalarConfig (← ScalarTac.getSimpArgs) state (zetaDelta := true) #[] newAsms false | trace[CondSimpTac] "Goal proven through preprocessing!"; return withMainContext do trace[CondSimpTac] "Goal after the initial preprocessing: {← getMainGoal}" diff --git a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean index d7b8b2ec5..8c46e026b 100644 --- a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean @@ -285,7 +285,7 @@ def State.new (config : Config) : MetaM State := do let saturateState := Saturate.State.new rules pure { saturateState } -def scalarTacPartialPreprocess (config : Config) (state : State) +def scalarTacPartialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (state : State) (zetaDelta : Bool) (hypsToUseForSimp assumptionsToPreprocess : Array FVarId) (simpTarget : Bool) : Tactic.TacticM (Option (State × Array FVarId)) := do Tactic.focus do @@ -297,7 +297,6 @@ def scalarTacPartialPreprocess (config : Config) (state : State) typed expressions such as `UScalar.ofNat`. -/ trace[ScalarTac] "Original goal before preprocessing: {← getMainGoal}" - let simpArgs : Simp.SimpArgs ← getSimpArgs let r ← Simp.simpAt true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} /- Remove the forall quantifiers to prepare for the call of `simp_all` (we don't want `simp_all` to use assumptions of the shape `∀ x, P x`)) -/ @@ -479,7 +478,7 @@ def incrScalarTac (config : Config) (state : State) (toClear : Array FVarId) (as let mvarId ← (← getMainGoal).tryClearMany toClear setGoals [mvarId] /- Saturate by exploring only the goal -/ - let some (_, _) ← scalarTacPartialPreprocess config state (zetaDelta := true) assumptions #[] true + let some (_, _) ← scalarTacPartialPreprocess config (← ScalarTac.getSimpArgs) state (zetaDelta := true) assumptions #[] true | trace[ScalarTac] "incrScalarTac: goal proven by preprocessing" trace[ScalarTac] "Goal after final preprocessing: {← getMainGoal}" /- Call omega -/ diff --git a/backends/lean/Aeneas/SimpBoolProp/Init.lean b/backends/lean/Aeneas/SimpBoolProp/Init.lean index 9038ac238..b4e9499f2 100644 --- a/backends/lean/Aeneas/SimpBoolProp/Init.lean +++ b/backends/lean/Aeneas/SimpBoolProp/Init.lean @@ -36,4 +36,20 @@ initialize simpBoolPropSimprocExt : Simp.SimprocExtension ← Those simp procedures are used by several tactics such as `scalar_tac`, `simp_scalar`, `simp_ifs`, etc." none --(some simpListsSimprocsRef) +/-- The `simp_bool_prop_hyps_simps` simp attribute. -/ +initialize simpBoolPropHypsSimpExt : SimpExtension ← + registerSimpAttr `simp_bool_prop_hyps_simps "\ + The `simp_bool_prop_hyps_simps` attribute registers simp lemmas to be used to simplify booleans and propositions. + Those simp lemmas are used by several tactics such as `scalar_tac`, `simp_scalar`, `simp_ifs`, etc." + +-- TODO: initialization fails with this, while the same works for `scalar_tac`?? +--initialize simpListsSimprocsRef : IO.Ref Simprocs ← IO.mkRef {} + +/-- The `simp_bool_prop_hyps_simps_proc` simp attribute for the simp rocs. -/ +initialize simpBoolPropHypsSimprocExt : Simp.SimprocExtension ← + Simp.registerSimprocAttr `simp_bool_prop_hyps_simps_proc "\ + The `simp_bool_prop_hyps_simps` attribute registers simp procedures to be used to simplify booleans and propositions. + Those simp procedures are used by several tactics such as `scalar_tac`, `simp_scalar`, `simp_ifs`, etc." + none --(some simpListsSimprocsRef) + end Aeneas.SimpBoolProp diff --git a/backends/lean/Aeneas/SimpIfs/Init.lean b/backends/lean/Aeneas/SimpIfs/Init.lean index d2cdf72d5..f6d3e8754 100644 --- a/backends/lean/Aeneas/SimpIfs/Init.lean +++ b/backends/lean/Aeneas/SimpIfs/Init.lean @@ -29,4 +29,19 @@ initialize simpIfsSimprocExt : Simp.SimprocExtension ← The `simp_ifs_simps_proc` attribute registers simp procedures to be used by `simp_ifs` during its preprocessing phase." none --(some simpIfsSimprocsRef) +/-- The `simp_ifs_hyps_simps` simp attribute. -/ +initialize simpIfsHypsSimpExt : SimpExtension ← + registerSimpAttr `simp_ifs_hyps_simps "\ + The `simp_ifs_hyps_simps` attribute registers simp lemmas to be used by `simp_ifs` when preprocessing hypotheses." + +-- TODO: initialization fails with this, while the same works for `scalar_tac`?? +--initialize simpIfsSimprocsRef : IO.Ref Simprocs ← IO.mkRef {} + +/-- The `simp_ifs_hyps_simps_proc` simp attribute for the simp rocs. -/ +initialize simpIfsHypsSimprocExt : Simp.SimprocExtension ← + Simp.registerSimprocAttr `simp_ifs_hyps_simps_proc "\ + The `simp_ifs_hyps_simps_proc` attribute registers simp procedures to be used by `simp_ifs` + during its preprocessing phase." none --(some simpIfsSimprocsRef) + + end Aeneas.SimpIfs diff --git a/backends/lean/Aeneas/SimpIfs/SimpIfs.lean b/backends/lean/Aeneas/SimpIfs/SimpIfs.lean index d5ed1180f..58ee45b30 100644 --- a/backends/lean/Aeneas/SimpIfs/SimpIfs.lean +++ b/backends/lean/Aeneas/SimpIfs/SimpIfs.lean @@ -15,6 +15,13 @@ open Lean Lean.Meta Lean.Parser.Tactic Lean.Elab.Tactic def simpIfsTac (config : ScalarTac.CondSimpTacConfig) (args : ScalarTac.CondSimpPartialArgs) (loc : Utils.Location) : TacticM Unit := do let addSimpThms : TacticM (Array FVarId) := pure #[] + let hypsArgs : ScalarTac.CondSimpArgs := { + simpThms := #[← simpIfsHypsSimpExt.getTheorems, ← SimpBoolProp.simpBoolPropHypsSimpExt.getTheorems], + simprocs := #[← simpIfsHypsSimprocExt.getSimprocs, ← SimpBoolProp.simpBoolPropHypsSimprocExt.getSimprocs], + declsToUnfold := #[], + addSimpThms := #[], + hypsToUse := #[], + } let args : ScalarTac.CondSimpArgs := { simpThms := #[← simpIfsSimpExt.getTheorems, ← SimpBoolProp.simpBoolPropSimpExt.getTheorems], simprocs := #[← simpIfsSimprocExt.getSimprocs, ← SimpBoolProp.simpBoolPropSimprocExt.getSimprocs], @@ -22,7 +29,7 @@ def simpIfsTac (config : ScalarTac.CondSimpTacConfig) addSimpThms := args.addSimpThms, hypsToUse := args.hypsToUse, } - ScalarTac.condSimpTac "simp_ifs" config {maxDischargeDepth := 2, failIfUnchanged := false} args addSimpThms false loc + ScalarTac.condSimpTac "simp_ifs" config {maxDischargeDepth := 2, failIfUnchanged := false} hypsArgs args addSimpThms false loc syntax (name := simp_ifs) "simp_ifs" Parser.Tactic.optConfig ("[" term,* "]")? (location)? : tactic @@ -35,11 +42,11 @@ theorem if_false {α} (b : Prop) [Decidable b] (x y : α) (hb : ¬ b) : (if b th simp only [hb, Bool.false_eq_true, ↓reduceIte] @[simp_ifs_simps] -theorem dite_true (c : Prop) [Decidable c] (h : c) (t : c → α) (e : ¬c → α) : +theorem dite_true {α} (c : Prop) [Decidable c] (h : c) (t : c → α) (e : ¬c → α) : dite c t e = t h := by simp [h] @[simp_ifs_simps] -theorem dite_fase (c : Prop) [Decidable c] (h : ¬ c) (t : c → α) (e : ¬c → α) : +theorem dite_fase {α} (c : Prop) [Decidable c] (h : ¬ c) (t : c → α) (e : ¬c → α) : dite c t e = e h := by simp [h] def parseSimpIfs : diff --git a/backends/lean/Aeneas/SimpLists/Init.lean b/backends/lean/Aeneas/SimpLists/Init.lean index 5b0183f08..763f664f7 100644 --- a/backends/lean/Aeneas/SimpLists/Init.lean +++ b/backends/lean/Aeneas/SimpLists/Init.lean @@ -29,4 +29,18 @@ initialize simpListsSimprocExt : Simp.SimprocExtension ← The `simp_lists_simps_proc` attribute registers simp procedures to be used by `simp_lists` during its preprocessing phase." none --(some simpListsSimprocsRef) +/-- The `simp_lists_hyps_simp` simp attribute. -/ +initialize simpListsHypsSimpExt : SimpExtension ← + registerSimpAttr `simp_lists_hyps_simps "\ + The `simp_lists_hyps_simp` attribute registers simp lemmas to be used by `simp_lists`." + +-- TODO: initialization fails with this, while the same works for `scalar_tac`?? +--initialize simpListsSimprocsRef : IO.Ref Simprocs ← IO.mkRef {} + +/-- The `simp_lists_hyps_simp_proc` simp attribute for the simp rocs. -/ +initialize simpListsHypsSimprocExt : Simp.SimprocExtension ← + Simp.registerSimprocAttr `simp_lists_hyps_simps_proc "\ + The `simp_lists_hyps_simp_proc` attribute registers simp procedures to be used by `simp_lists` + during its preprocessing phase." none --(some simpListsSimprocsRef) + end Aeneas.SimpLists diff --git a/backends/lean/Aeneas/SimpLists/SimpLists.lean b/backends/lean/Aeneas/SimpLists/SimpLists.lean index 576b3166d..25165d4f0 100644 --- a/backends/lean/Aeneas/SimpLists/SimpLists.lean +++ b/backends/lean/Aeneas/SimpLists/SimpLists.lean @@ -40,6 +40,13 @@ attribute [simp_lists_simps] Fin.getElem!_fin def simpListsTac (config : ScalarTac.CondSimpTacConfig) (args : ScalarTac.CondSimpPartialArgs) (loc : Utils.Location) : TacticM Unit := do let addSimpThms : TacticM (Array FVarId) := pure #[] + let hypsArgs : ScalarTac.CondSimpArgs := { + simpThms := #[← simpListsHypsSimpExt.getTheorems, ← SimpBoolProp.simpBoolPropHypsSimpExt.getTheorems], + simprocs := #[← simpListsHypsSimprocExt.getSimprocs, ← SimpBoolProp.simpBoolPropHypsSimprocExt.getSimprocs], + declsToUnfold := #[], + addSimpThms := #[], + hypsToUse := #[], + } let args : ScalarTac.CondSimpArgs := { simpThms := #[← simpListsSimpExt.getTheorems, ← SimpBoolProp.simpBoolPropSimpExt.getTheorems], simprocs := #[← simpListsSimprocExt.getSimprocs, ← SimpBoolProp.simpBoolPropSimprocExt.getSimprocs], @@ -47,7 +54,7 @@ def simpListsTac (config : ScalarTac.CondSimpTacConfig) addSimpThms := args.addSimpThms, hypsToUse := args.hypsToUse, } - ScalarTac.condSimpTac "simp_lists" config {maxDischargeDepth := 2, failIfUnchanged := false, contextual := true} args addSimpThms false loc + ScalarTac.condSimpTac "simp_lists" config {maxDischargeDepth := 2, failIfUnchanged := false, contextual := true} hypsArgs args addSimpThms false loc syntax (name := simp_lists) "simp_lists" Parser.Tactic.optConfig ("[" (term<|>"*"),* "]")? (location)? : tactic diff --git a/backends/lean/Aeneas/SimpListsScalar.lean b/backends/lean/Aeneas/SimpListsScalar.lean index f0e19df15..feae76c4e 100644 --- a/backends/lean/Aeneas/SimpListsScalar.lean +++ b/backends/lean/Aeneas/SimpListsScalar.lean @@ -14,6 +14,20 @@ open Lean Lean.Meta Lean.Parser.Tactic Lean.Elab.Tactic def simpListsScalarTac (config : ScalarTac.CondSimpTacConfig) (args : ScalarTac.CondSimpPartialArgs) (loc : Utils.Location) : TacticM Unit := do let addSimpThms : TacticM (Array FVarId) := pure #[] + let hypsArgs : ScalarTac.CondSimpArgs := { + simpThms := #[ + ← SimpBoolProp.simpBoolPropHypsSimpExt.getTheorems, + ← SimpLists.simpListsHypsSimpExt.getTheorems, + ← SimpScalar.simpScalarHypsSimpExt.getTheorems], + simprocs := #[ + ← SimpBoolProp.simpBoolPropHypsSimprocExt.getSimprocs, + ← SimpLists.simpListsHypsSimprocExt.getSimprocs, + ← SimpScalar.simpScalarHypsSimprocExt.getSimprocs + ], + declsToUnfold := #[], + addSimpThms := #[], + hypsToUse := #[], + } let args : ScalarTac.CondSimpArgs := { simpThms := #[ ← SimpBoolProp.simpBoolPropSimpExt.getTheorems, @@ -28,7 +42,7 @@ def simpListsScalarTac (config : ScalarTac.CondSimpTacConfig) addSimpThms := args.addSimpThms, hypsToUse := args.hypsToUse, } - ScalarTac.condSimpTac "simp_lists_scalar" config {maxDischargeDepth := 2, failIfUnchanged := false, contextual := true} args addSimpThms false loc + ScalarTac.condSimpTac "simp_lists_scalar" config {maxDischargeDepth := 2, failIfUnchanged := false, contextual := true} hypsArgs args addSimpThms false loc syntax (name := simp_lists) "simp_lists_scalar" Parser.Tactic.optConfig ("[" (term<|>"*"),* "]")? (location)? : tactic diff --git a/backends/lean/Aeneas/SimpScalar/Init.lean b/backends/lean/Aeneas/SimpScalar/Init.lean index c4a0d5ed6..3f9fb800d 100644 --- a/backends/lean/Aeneas/SimpScalar/Init.lean +++ b/backends/lean/Aeneas/SimpScalar/Init.lean @@ -29,4 +29,18 @@ initialize simpScalarSimprocExt : Simp.SimprocExtension ← The `simp_scalar_simps_proc` attribute registers simp procedures to be used by `simp_scalar` during its preprocessing phase." none --(some simpScalarSimprocsRef) +/-- The `simp_scalar_hyps_simp` simp attribute. -/ +initialize simpScalarHypsSimpExt : SimpExtension ← + registerSimpAttr `simp_scalar_hyps_simps "\ + The `simp_scalar_hyps_simp` attribute registers simp lemmas to be used by `simp_scalar`." + +-- TODO: initialization fails with this, while the same works for `scalar_tac`?? +--initialize simpScalarSimprocsRef : IO.Ref Simprocs ← IO.mkRef {} + +/-- The `simp_scalar_hyps_simp_proc` simp attribute for the simp rocs. -/ +initialize simpScalarHypsSimprocExt : Simp.SimprocExtension ← + Simp.registerSimprocAttr `simp_scalar_hyps_simps_proc "\ + The `simp_scalar_hyps_simp_proc` attribute registers simp procedures to be used by `simp_scalar` + during its preprocessing phase." none --(some simpScalarSimprocsRef) + end Aeneas.SimpScalar diff --git a/backends/lean/Aeneas/SimpScalar/SimpScalar.lean b/backends/lean/Aeneas/SimpScalar/SimpScalar.lean index 76839cd8f..2ca05c3e8 100644 --- a/backends/lean/Aeneas/SimpScalar/SimpScalar.lean +++ b/backends/lean/Aeneas/SimpScalar/SimpScalar.lean @@ -175,6 +175,13 @@ attribute [simp_scalar_simps] BitVec.setWidth_eq BitVec.ofNat_eq_ofNat def simpScalarTac (config : ScalarTac.CondSimpTacConfig) (args : ScalarTac.CondSimpPartialArgs) (loc : Utils.Location) : TacticM Unit := do let addSimpThms : TacticM (Array FVarId) := pure #[] + let hypsArgs : ScalarTac.CondSimpArgs := { + simpThms := #[← simpScalarHypsSimpExt.getTheorems, ← SimpBoolProp.simpBoolPropHypsSimpExt.getTheorems], + simprocs := #[← simpScalarHypsSimprocExt.getSimprocs, ← SimpBoolProp.simpBoolPropHypsSimprocExt.getSimprocs], + declsToUnfold := #[], + addSimpThms := #[], + hypsToUse := #[], + } let args : ScalarTac.CondSimpArgs := { simpThms := #[← simpScalarSimpExt.getTheorems, ← SimpBoolProp.simpBoolPropSimpExt.getTheorems], simprocs := #[← simpScalarSimprocExt.getSimprocs, ← SimpBoolProp.simpBoolPropSimprocExt.getSimprocs], @@ -184,7 +191,7 @@ def simpScalarTac (config : ScalarTac.CondSimpTacConfig) } ScalarTac.condSimpTac "simp_scalar" config {maxDischargeDepth := 2, failIfUnchanged := false, contextual := true} - args addSimpThms false loc + hypsArgs args addSimpThms false loc syntax (name := simp_scalar) "simp_scalar" Parser.Tactic.optConfig ("[" (term<|>"*"),* "]")? (location)? : tactic diff --git a/backends/lean/Aeneas/Std/Array/Array.lean b/backends/lean/Aeneas/Std/Array/Array.lean index 732430d94..f534c0e77 100644 --- a/backends/lean/Aeneas/Std/Array/Array.lean +++ b/backends/lean/Aeneas/Std/Array/Array.lean @@ -59,10 +59,10 @@ example : Result (Array Int (Usize.ofNat 2)) := do @[reducible] instance {α : Type u} {n : Usize} : GetElem? (Array α n) Nat α (fun a i => i < a.val.length) where getElem? a i := getElem? a.val i -@[simp, scalar_tac_simps, simp_lists_simps] +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Array.getElem?_Nat_eq {α : Type u} {n : Usize} (v : Array α n) (i : Nat) : v[i]? = v.val[i]? := by rfl -@[simp, scalar_tac_simps, simp_lists_simps] +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Array.getElem!_Nat_eq {α : Type u} [Inhabited α] {n : Usize} (v : Array α n) (i : Nat) : v[i]! = v.val[i]! := by simp only [instGetElem?ArrayNatLtLengthValListEqVal, List.getElem!_eq_getElem?_getD]; split <;> simp_all rfl @@ -73,13 +73,13 @@ theorem Array.getElem!_Nat_eq {α : Type u} [Inhabited α] {n : Usize} (v : Arra @[reducible] instance {α : Type u} {n : Usize} : GetElem? (Array α n) Usize α (fun a i => i.val < a.val.length) where getElem? a i := getElem? a.val i.val -@[simp, scalar_tac_simps] theorem Array.getElem?_Usize_eq {α : Type u} {n : Usize} (v : Array α n) (i : Usize) : v[i]? = v.val[i.val]? := by rfl -@[simp, scalar_tac_simps] theorem Array.getElem!_Usize_eq {α : Type u} [Inhabited α] {n : Usize} (v : Array α n) (i : Usize) : v[i]! = v.val[i.val]! := by +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Array.getElem?_Usize_eq {α : Type u} {n : Usize} (v : Array α n) (i : Usize) : v[i]? = v.val[i.val]? := by rfl +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Array.getElem!_Usize_eq {α : Type u} [Inhabited α] {n : Usize} (v : Array α n) (i : Usize) : v[i]! = v.val[i.val]! := by simp [instGetElem?ArrayUsizeLtNatValLengthValListEq]; split <;> simp_all rfl -@[simp, scalar_tac_simps] abbrev Array.get? {α : Type u} {n : Usize} (v : Array α n) (i : Nat) : Option α := getElem? v i -@[simp, scalar_tac_simps] abbrev Array.get! {α : Type u} {n : Usize} [Inhabited α] (v : Array α n) (i : Nat) : α := getElem! v i +@[simp, scalar_tac_simps, simp_lists_hyps_simps] abbrev Array.get? {α : Type u} {n : Usize} (v : Array α n) (i : Nat) : Option α := getElem? v i +@[simp, scalar_tac_simps, simp_lists_hyps_simps] abbrev Array.get! {α : Type u} {n : Usize} [Inhabited α] (v : Array α n) (i : Nat) : α := getElem! v i @[simp] abbrev Array.slice {α : Type u} {n : Usize} [Inhabited α] (v : Array α n) (i j : Nat) : List α := @@ -119,12 +119,12 @@ def Array.set {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: α) : Arr def Array.set_opt {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: Option α) : Array α n := ⟨ v.val.set_opt i.val x, by have := v.property; simp [*] ⟩ -@[simp, simp_lists_simps] +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Array.set_val_eq {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: α) : (v.set i x).val = v.val.set i.val x := by simp [set] -@[simp, simp_lists_simps] +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Array.set_opt_val_eq {α : Type u} {n : Usize} (v: Array α n) (i: Usize) (x: Option α) : (v.set_opt i x).val = v.val.set_opt i.val x := by simp [set_opt] diff --git a/backends/lean/Aeneas/Std/Slice.lean b/backends/lean/Aeneas/Std/Slice.lean index b653fc687..74d75c9c2 100644 --- a/backends/lean/Aeneas/Std/Slice.lean +++ b/backends/lean/Aeneas/Std/Slice.lean @@ -54,10 +54,10 @@ theorem Slice.len_val {α : Type u} (v : Slice α) : (Slice.len v).val = v.lengt @[reducible] instance {α : Type u} : GetElem? (Slice α) Nat α (fun a i => i < a.val.length) where getElem? a i := getElem? a.val i -@[simp, scalar_tac_simps, simp_lists_simps] +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Slice.getElem?_Nat_eq {α : Type u} (v : Slice α) (i : Nat) : v[i]? = v.val[i]? := by rfl -@[simp, scalar_tac_simps, simp_lists_simps] +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Slice.getElem!_Nat_eq {α : Type u} [Inhabited α] (v : Slice α) (i : Nat) : v[i]! = v.val[i]! := by simp only [instGetElem?SliceNatLtLengthValListLeMax, List.getElem!_eq_getElem?_getD]; split <;> simp_all rfl @@ -68,13 +68,13 @@ theorem Slice.getElem!_Nat_eq {α : Type u} [Inhabited α] (v : Slice α) (i : N @[reducible] instance {α : Type u} : GetElem? (Slice α) Usize α (fun a i => i < a.val.length) where getElem? a i := getElem? a.val i.val -@[simp, scalar_tac_simps] theorem Slice.getElem?_Usize_eq {α : Type u} (v : Slice α) (i : Usize) : v[i]? = v.val[i.val]? := by rfl -@[simp, scalar_tac_simps] theorem Slice.getElem!_Usize_eq {α : Type u} [Inhabited α] (v : Slice α) (i : Usize) : v[i]! = v.val[i.val]! := by +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Slice.getElem?_Usize_eq {α : Type u} (v : Slice α) (i : Usize) : v[i]? = v.val[i.val]? := by rfl +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Slice.getElem!_Usize_eq {α : Type u} [Inhabited α] (v : Slice α) (i : Usize) : v[i]! = v.val[i.val]! := by simp only [instGetElem?SliceUsizeLtNatValLengthValListLeMax, List.getElem!_eq_getElem?_getD]; split <;> simp_all rfl -@[simp, scalar_tac_simps] abbrev Slice.get? {α : Type u} (v : Slice α) (i : Nat) : Option α := getElem? v i -@[simp, scalar_tac_simps] abbrev Slice.get! {α : Type u} [Inhabited α] (v : Slice α) (i : Nat) : α := getElem! v i +@[simp, scalar_tac_simps, simp_lists_hyps_simps] abbrev Slice.get? {α : Type u} (v : Slice α) (i : Nat) : Option α := getElem? v i +@[simp, scalar_tac_simps, simp_lists_hyps_simps] abbrev Slice.get! {α : Type u} [Inhabited α] (v : Slice α) (i : Nat) : α := getElem! v i def Slice.set {α : Type u} (v: Slice α) (i: Usize) (x: α) : Slice α := ⟨ v.val.set i.val x, by have := v.property; simp [*] ⟩ @@ -115,12 +115,12 @@ theorem Slice.index_usize_spec {α : Type u} [Inhabited α] (v: Slice α) (i: Us simp only [length, getElem?_Usize_eq, exists_eq_right] at * simp only [List.getElem?_eq_getElem, List.getElem!_eq_getElem?_getD, Option.getD_some, hbound] -@[simp, scalar_tac_simps, simp_lists_simps] +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Slice.set_val_eq {α : Type u} (v: Slice α) (i: Usize) (x: α) : (v.set i x) = v.val.set i.val x := by simp [set] -@[simp, scalar_tac_simps, simp_lists_simps] +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Slice.set_opt_val_eq {α : Type u} (v: Slice α) (i: Usize) (x: Option α) : (v.set_opt i x) = v.val.set_opt i.val x := by simp [set_opt] diff --git a/backends/lean/Aeneas/Std/Vec.lean b/backends/lean/Aeneas/Std/Vec.lean index fd58daea1..02a2e2604 100644 --- a/backends/lean/Aeneas/Std/Vec.lean +++ b/backends/lean/Aeneas/Std/Vec.lean @@ -60,10 +60,10 @@ theorem Vec.len_val {α : Type u} (v : Vec α) : (Vec.len v).val = v.length := getElem? a i := getElem? a.val i getElem! a i := getElem! a.val i -@[simp, scalar_tac_simps, simp_lists_simps] +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Vec.getElem?_Nat_eq {α : Type u} (v : Vec α) (i : Nat) : v[i]? = v.val[i]? := by rfl -@[simp, scalar_tac_simps, simp_lists_simps] +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Vec.getElem!_Nat_eq {α : Type u} [Inhabited α] (v : Vec α) (i : Nat) : v[i]! = v.val[i]! := by rfl @[reducible] instance {α : Type u} : GetElem (Vec α) Usize α (fun a i => i < a.val.length) where @@ -73,11 +73,11 @@ theorem Vec.getElem!_Nat_eq {α : Type u} [Inhabited α] (v : Vec α) (i : Nat) getElem? a i := getElem? a.val i.val getElem! a i := getElem! a.val i.val -@[simp, scalar_tac_simps] theorem Vec.getElem?_Usize_eq {α : Type u} (v : Vec α) (i : Usize) : v[i]? = v.val[i.val]? := by rfl -@[simp, scalar_tac_simps] theorem Vec.getElem!_Usize_eq {α : Type u} [Inhabited α] (v : Vec α) (i : Usize) : v[i]! = v.val[i.val]! := by rfl +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Vec.getElem?_Usize_eq {α : Type u} (v : Vec α) (i : Usize) : v[i]? = v.val[i.val]? := by rfl +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Vec.getElem!_Usize_eq {α : Type u} [Inhabited α] (v : Vec α) (i : Usize) : v[i]! = v.val[i.val]! := by rfl -@[simp, scalar_tac_simps] abbrev Vec.get? {α : Type u} (v : Vec α) (i : Nat) : Option α := getElem? v i -@[simp, scalar_tac_simps] abbrev Vec.get! {α : Type u} [Inhabited α] (v : Vec α) (i : Nat) : α := getElem! v i +@[simp, scalar_tac_simps, simp_lists_hyps_simps] abbrev Vec.get? {α : Type u} (v : Vec α) (i : Nat) : Option α := getElem? v i +@[simp, scalar_tac_simps, simp_lists_hyps_simps] abbrev Vec.get! {α : Type u} [Inhabited α] (v : Vec α) (i : Nat) : α := getElem! v i def Vec.set {α : Type u} (v: Vec α) (i: Usize) (x: α) : Vec α := ⟨ v.val.set i.val x, by have := v.property; simp [*] ⟩ @@ -85,12 +85,12 @@ def Vec.set {α : Type u} (v: Vec α) (i: Usize) (x: α) : Vec α := def Vec.set_opt {α : Type u} (v: Vec α) (i: Usize) (x: Option α) : Vec α := ⟨ v.val.set_opt i.val x, by have := v.property; simp [*] ⟩ -@[simp] +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Vec.set_val_eq {α : Type u} (v: Vec α) (i: Usize) (x: α) : (v.set i x) = v.val.set i.val x := by simp [set] -@[simp] +@[simp, scalar_tac_simps, simp_lists_hyps_simps] theorem Vec.set_opt_val_eq {α : Type u} (v: Vec α) (i: Usize) (x: Option α) : (v.set_opt i x) = v.val.set_opt i.val x := by simp [set_opt] diff --git a/backends/lean/Aeneas/ZModify/Init.lean b/backends/lean/Aeneas/ZModify/Init.lean index 184b73dc4..55c560073 100644 --- a/backends/lean/Aeneas/ZModify/Init.lean +++ b/backends/lean/Aeneas/ZModify/Init.lean @@ -3,6 +3,10 @@ open Lean Meta namespace Aeneas.ZModify +/-! +# ZMod-ify +-/ + /-- The `zmodify_simps` simp attribute. -/ initialize zmodifySimpExt : SimpExtension ← registerSimpAttr `zmodify_simps "\ @@ -16,4 +20,17 @@ initialize zmodifySimprocExt : Simp.SimprocExtension ← The `zmodify_simps_proc` attribute registers simp procedures to be used by `zmodify` during its preprocessing phase." (some zmodifySimprocsRef) +/-- The `zmodify_hyps_simp` simp attribute. -/ +initialize zmodifyHypsSimpExt : SimpExtension ← + registerSimpAttr `zmodify_hyps_simps "\ + The `zmodify_hyps_simp` attribute registers simp lemmas to be used by `zmodify`." + +initialize zmodifyHypsSimprocsRef : IO.Ref Simprocs ← IO.mkRef {} + +/-- The `zmodify_hyps_simp_proc` simp attribute for the simp rocs. -/ +initialize zmodifyHypsSimprocExt : Simp.SimprocExtension ← + Simp.registerSimprocAttr `zmodify_hyps_simps_proc "\ + The `zmodify_hyps_simp_proc` attribute registers simp procedures to be used by `zmodify` + during its preprocessing phase." (some zmodifyHypsSimprocsRef) + end Aeneas.ZModify diff --git a/backends/lean/Aeneas/ZModify/ZModify.lean b/backends/lean/Aeneas/ZModify/ZModify.lean index f96dc43d9..42d00a380 100644 --- a/backends/lean/Aeneas/ZModify/ZModify.lean +++ b/backends/lean/Aeneas/ZModify/ZModify.lean @@ -9,7 +9,7 @@ import Aeneas.ScalarTac.CondSimpTac import Aeneas.SimpBoolProp.SimpBoolProp /-! -# `zmodify` tactic +# ZMod-ify tactic The `zmodify` tactic is used to shift propositions about, e.g., `Nat`, to `ZMod`. This tactic is adapted from `zify`. @@ -52,6 +52,13 @@ def zmodifyTac (config : Config) let thm ← mkAppM thName #[n] Utils.addDeclTac (← Utils.mkFreshAnonPropUserName) thm (← inferType thm) (asLet := false) fun thm => pure thm.fvarId! pure #[← addThm ``Nat.lt_imp_eq_iff_eq_ZMod] + let hypsArgs : ScalarTac.CondSimpArgs := { + simpThms := #[← zmodifyHypsSimpExt.getTheorems, ← SimpBoolProp.simpBoolPropHypsSimpExt.getTheorems], + simprocs := #[← zmodifyHypsSimprocExt.getSimprocs, ← SimpBoolProp.simpBoolPropHypsSimprocExt.getSimprocs], + declsToUnfold := #[], + addSimpThms := #[], + hypsToUse := #[], + } let args : ScalarTac.CondSimpArgs := { -- Note that we also add the push_cast theorems simpThms := #[← zmodifySimpExt.getTheorems, ← SimpBoolProp.simpBoolPropSimpExt.getTheorems, ← Lean.Meta.NormCast.pushCastExt.getTheorems], @@ -61,7 +68,7 @@ def zmodifyTac (config : Config) hypsToUse := args.hypsToUse, } let config := { nonLin := config.nonLin, saturationPasses := config.saturationPasses } - ScalarTac.condSimpTac "zmodify" config {maxDischargeDepth := 2, failIfUnchanged := false, contextual := true} args addSimpThms false loc + ScalarTac.condSimpTac "zmodify" config {maxDischargeDepth := 2, failIfUnchanged := false, contextual := true} hypsArgs args addSimpThms false loc syntax (name := zmodify) "zmodify" Parser.Tactic.optConfig ("to" term)? ("[" (term<|>"*"),* "]")? (location)? : tactic From 664d533128adc9d9311ce6727ddb2f5ca4a58aac Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 24 Jun 2025 11:58:32 +0100 Subject: [PATCH 13/31] Fix a minor issue in the simpAll step of scalarTacPartialPreprocess --- backends/lean/Aeneas/ScalarTac/ScalarTac.lean | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean index 8c46e026b..afad32c73 100644 --- a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean @@ -328,11 +328,16 @@ def scalarTacPartialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (sta Simp.evalSimpAllAssumptions {failIfUnchanged := false, maxSteps := config.simpAllMaxSteps, maxDischargeDepth := 0} true simpArgs assumptions simpTarget - if r.isSome then trace[ScalarTac] "applySimpAll succeeded" - else trace[ScalarTac] "applySimpAll failed" match r with - | some (some (fvars, _)) => pure (some fvars) - | some none | none => pure (some assumptions) + | some r => + trace[ScalarTac] "applySimpAll succeeded" + match r with + | some (fvars, _) => pure (some fvars) + | none => pure none -- Goal proven through `simpAll` + | none => + trace[ScalarTac] "applySimpAll failed" + -- `simpAll` failed: let's just continue with the same state as before + pure (some assumptions) let r ← do if config.simpAllMaxSteps ≠ 0 ∧ assumptionsToPreprocess.size + nassumptions.size > 0 then /- Simplify the new assumptions with the old assumptions (it's often enough to go in that direction) -/ From 3374e89670881f663e6c12a86a7cc20d5178704a Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 24 Jun 2025 11:58:52 +0100 Subject: [PATCH 14/31] Make minor modifications --- backends/lean/Aeneas/Saturate/Tactic.lean | 5 +++-- backends/lean/Aeneas/ScalarTac/CondSimpTac.lean | 9 +++++---- tests/lean/Hashmap/Properties.lean | 3 +-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/backends/lean/Aeneas/Saturate/Tactic.lean b/backends/lean/Aeneas/Saturate/Tactic.lean index 994621aea..4bdb6a334 100644 --- a/backends/lean/Aeneas/Saturate/Tactic.lean +++ b/backends/lean/Aeneas/Saturate/Tactic.lean @@ -382,6 +382,8 @@ private partial def visit (boundVars : Std.HashSet FVarId) (state : State) (e : Expr) : MetaM State := do + let e := e.consumeMData + -- trace[Saturate.explore] "Visiting {e}" -- Register the current assumption, if it is a conjunct inside an assumption let state := state.insertAssumption path e @@ -405,7 +407,6 @@ private partial def visit | some v => visit config none (depth + 1) exploreSubterms preprocessThm boundVars state v pure (boundVars, state) ) (boundVars, state) - let e := e.consumeMData match e with | .bvar _ | .fvar _ @@ -470,9 +471,9 @@ private partial def visitRecomputeAssumptions (state : State) (e : Expr) : MetaM State := do trace[Saturate.explore] "Visiting {e}" + let e := e.consumeMData -- Register the current assumption, if it is a conjunct inside an assumption let state := state.insertAssumption path e - let e := e.consumeMData match e with | .bvar _ | .fvar _ diff --git a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean index aab149182..8230cc85a 100644 --- a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean +++ b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean @@ -86,13 +86,14 @@ def CondSimpArgs.toSimpArgs (args : CondSimpArgs) : Simp.SimpArgs := { def condSimpTacSimp (config : Simp.Config) (args : CondSimpArgs) (loc : Utils.Location) (toClear : Array FVarId := #[]) - (additionalAsms : Array FVarId := #[]) (state : Option (ScalarTac.State × Array FVarId)) : + (additionalHypsToUse : Array FVarId := #[]) (state : Option (ScalarTac.State × Array FVarId)) : TacticM (Option (Array FVarId)) := do withMainContext do let simpArgs := args.toSimpArgs - let simpArgs := { simpArgs with hypsToUse := simpArgs.hypsToUse ++ additionalAsms } + let simpArgs := { simpArgs with hypsToUse := simpArgs.hypsToUse ++ additionalHypsToUse } match state with | some (state, asms) => + trace[CondSimpTac] "condSimpTacSimp: scalarTac assumptions: {asms.map Expr.fvar}" /- Note that when calling `scalar_tac` we saturate only by looking at the target: we have already saturated by looking at the assumptions (we do this once and for all beforehand) -/ let dischargeWrapper ← Simp.tacticToDischarge (incrScalarTac {saturateAssumptions := false} state toClear asms) @@ -168,7 +169,7 @@ def condSimpTac | .targets hyps type => pure (Utils.Location.targets hyps type) let nloc ← if doFirstSimp then - match ← condSimpTacSimp simpConfig args loc toClear additionalSimpThms none with + match ← condSimpTacSimp simpConfig args loc (toClear := toClear) (additionalHypsToUse := additionalSimpThms) none with | none => return | some freshFvarIds => match loc with @@ -181,7 +182,7 @@ def condSimpTac TODO: scalar_tac should only be allowed to preprocess `scalarTacAsms`. TODO: we should preprocess those. -/ - let _ ← condSimpTacSimp simpConfig args nloc toClear additionalSimpThms (some (state, newAsms)) + let _ ← condSimpTacSimp simpConfig args nloc (toClear := toClear) (additionalHypsToUse := additionalSimpThms) (some (state, newAsms)) if (← getUnsolvedGoals) == [] then return /- Clear the additional assumptions -/ setGoals [← (← getMainGoal).tryClearMany hypsToUse] diff --git a/tests/lean/Hashmap/Properties.lean b/tests/lean/Hashmap/Properties.lean index 298837a12..2b7197f59 100644 --- a/tests/lean/Hashmap/Properties.lean +++ b/tests/lean/Hashmap/Properties.lean @@ -239,9 +239,8 @@ theorem new_with_capacity_spec . simp_all [HashMap.v, length] . fsimp [lookup] intro k - simp at Hnil -- TODO: this is annoying simp_lists [Hnil] - simp -- TODO: remove + simp @[progress] theorem new_spec (α : Type) : From b5d241385db54dfe9b1e94fc56033672298291b0 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 24 Jun 2025 12:54:35 +0100 Subject: [PATCH 15/31] Add some scalar_tac lemmas --- backends/lean/Aeneas/ScalarTac/Lemmas.lean | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/backends/lean/Aeneas/ScalarTac/Lemmas.lean b/backends/lean/Aeneas/ScalarTac/Lemmas.lean index b64532a49..527a2636d 100644 --- a/backends/lean/Aeneas/ScalarTac/Lemmas.lean +++ b/backends/lean/Aeneas/ScalarTac/Lemmas.lean @@ -1,6 +1,7 @@ import Mathlib.Data.ZMod.Basic import Aeneas.ScalarTac.ScalarTac import Aeneas.Std.Scalar.Core +import Aeneas.ReduceZMod.ReduceZMod namespace Aeneas @@ -245,6 +246,14 @@ attribute [scalar_tac a.toNat] Int.toNat_eq_max @[scalar_tac_simps] theorem Nat_neq_zero_iff (x : ℕ) : x ≠ 0 ↔ 0 < x := by omega +attribute [scalar_tac_simps] Nat.not_eq Int.not_eq + +/-! +# Casts +-/ + +attribute [scalar_tac_simps, simp_scalar_simps] Nat.cast_add Nat.cast_mul Nat.cast_ofNat + /-! # Min, Max -/ @@ -313,6 +322,8 @@ attribute [simp_scalar_simps] ZMod.val_natCast ZMod.val_intCast theorem ZMod.cast_intCast {n : ℕ} (a : ℤ) [NeZero n] : ((a : ZMod n).cast : ℤ) = a % ↑n := by simp only [ZMod.cast_eq_val, ZMod.val_intCast] +attribute [scalar_tac_simps, simp_scalar_simps] ReduceZMod.reduceZMod + /-! # Sets -/ From 01eab31374a7088027cdae00a36d8f3e6242f05b Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 24 Jun 2025 12:54:44 +0100 Subject: [PATCH 16/31] Add a simp_lists test --- backends/lean/Aeneas/SimpLists/Tests.lean | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/backends/lean/Aeneas/SimpLists/Tests.lean b/backends/lean/Aeneas/SimpLists/Tests.lean index 7d2c1d62f..387d4eea8 100644 --- a/backends/lean/Aeneas/SimpLists/Tests.lean +++ b/backends/lean/Aeneas/SimpLists/Tests.lean @@ -1,5 +1,6 @@ import Aeneas.SimpLists.SimpLists import Aeneas.List.List +import Aeneas.Std.Slice example [Inhabited α] (l : List α) (x : α) (i j : Nat) (hj : i ≠ j) : (l.set j x)[i]! = l[i]! := by simp_lists @@ -24,3 +25,19 @@ example (CList : Type) (l : CList) (get : CList → Nat → Bool) (set : CList example (CList : Type) (l : CList) (get : CList → Nat → Bool) (set : CList → Nat → Bool → CList) (h : ∀ i j l x, i ≠ j → get (set l i x) j = get l j) (i j : Nat) (hi : i < j) : get (set l i x) j = get l j := by simp_lists [*] + +example + (T : Type) + [Inhabited T] + (i : ℕ) + (tl : List T) + (h : i < tl.length + 1) + (hi : ¬i = 0) + (i1 : ℕ) + (_ : i1 = i - 1) + (_ : 1 ≤ i) + (x : T) + (_ : x = tl[i1]!) : + x = tl[i - 1]! + := by + simp_lists [*] From 4e29a89dbe7f0da8380cf8605a8ae990d150f7a8 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 24 Jun 2025 13:58:26 +0100 Subject: [PATCH 17/31] Simplify Int.cast_subNatNat in scalar_tac, simp_scalar, etc. --- backends/lean/Aeneas/ScalarTac/Lemmas.lean | 1 + backends/lean/Aeneas/ScalarTac/ScalarTac.lean | 2 +- backends/lean/Aeneas/SimpLists/Tests.lean | 11 +++++++++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/backends/lean/Aeneas/ScalarTac/Lemmas.lean b/backends/lean/Aeneas/ScalarTac/Lemmas.lean index 527a2636d..c708dfdef 100644 --- a/backends/lean/Aeneas/ScalarTac/Lemmas.lean +++ b/backends/lean/Aeneas/ScalarTac/Lemmas.lean @@ -253,6 +253,7 @@ attribute [scalar_tac_simps] Nat.not_eq Int.not_eq -/ attribute [scalar_tac_simps, simp_scalar_simps] Nat.cast_add Nat.cast_mul Nat.cast_ofNat +attribute [scalar_tac_simps, simp_lists_hyps_simps, simp_scalar_hyps_simps] Int.cast_subNatNat /-! # Min, Max diff --git a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean index afad32c73..f9b78894a 100644 --- a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean @@ -265,7 +265,7 @@ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do trace[ScalarTac] "Goal proven by preprocessing!" return trace[ScalarTac] "Goal after normCast: {← getMainGoal}" - -- Call `simp` again because `normCast` sometimes does weird things + -- Call `simp` again because `normCast` sometimes introduces strange terms let _ ← Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 1} simpArgs .wildcard -- We might have proven the goal if (← getGoals).isEmpty then diff --git a/backends/lean/Aeneas/SimpLists/Tests.lean b/backends/lean/Aeneas/SimpLists/Tests.lean index 387d4eea8..e40c0539d 100644 --- a/backends/lean/Aeneas/SimpLists/Tests.lean +++ b/backends/lean/Aeneas/SimpLists/Tests.lean @@ -41,3 +41,14 @@ example x = tl[i - 1]! := by simp_lists [*] + +abbrev Zq := ZMod 3329 + +example + (x y : ℕ) + (l : List ℕ) + (h : ∀ (j : ℕ), (↑l[j]! : Zq) = (↑x : Zq) - (↑y : Zq)) + (j : ℕ) : + (↑l[j]! : Zq) = (↑x : Zq) - (↑y : Zq) + := by + simp_lists [h] From aa3e274792407b160c12dbe9b7b0479ebbb81409 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 24 Jun 2025 14:00:54 +0100 Subject: [PATCH 18/31] Cleanup a bit the simp_lists tests --- backends/lean/Aeneas/SimpLists/Tests.lean | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/backends/lean/Aeneas/SimpLists/Tests.lean b/backends/lean/Aeneas/SimpLists/Tests.lean index e40c0539d..134e55da5 100644 --- a/backends/lean/Aeneas/SimpLists/Tests.lean +++ b/backends/lean/Aeneas/SimpLists/Tests.lean @@ -2,27 +2,20 @@ import Aeneas.SimpLists.SimpLists import Aeneas.List.List import Aeneas.Std.Slice -example [Inhabited α] (l : List α) (x : α) (i j : Nat) (hj : i ≠ j) : (l.set j x)[i]! = l[i]! := by +example {α} [Inhabited α] (l : List α) (x : α) (i j : Nat) (hj : i ≠ j) : (l.set j x)[i]! = l[i]! := by simp_lists -example [Inhabited α] (l : List α) (x : α) (i : Nat) (hi : i < l.length) : (l.set i x)[i]! = x := by +example {α} [Inhabited α] (l : List α) (x : α) (i : Nat) (hi : i < l.length) : (l.set i x)[i]! = x := by simp_lists -/-- We use this lemma to "normalize" successive calls to `List.set` -/ -@[simp_lists_simps] -theorem List.set_comm_lt (a b : α) (n m : Nat) (l : List α) (h : n < m) : - (l.set m a).set n b = (l.set n b).set m a := by - rw [List.set_comm] - omega - -example [Inhabited α] (l : List α) (x y : α) (i j : Nat) (hj : i < j) : (l.set i x).set j y = (l.set j y).set i x := by +example {α} [Inhabited α] (l : List α) (x y : α) (i j : Nat) (hj : i < j) : (l.set i x).set j y = (l.set j y).set i x := by simp_lists -example (CList : Type) (l : CList) (get : CList → Nat → Bool) (set : CList → Nat → Bool → CList) +example (CList : Type) (l : CList) (x : Bool) (get : CList → Nat → Bool) (set : CList → Nat → Bool → CList) (h : ∀ i j l x, i ≠ j → get (set l i x) j = get l j) (i j : Nat) (hi : i < j) : get (set l i x) j = get l j := by simp_lists [h] -example (CList : Type) (l : CList) (get : CList → Nat → Bool) (set : CList → Nat → Bool → CList) +example (CList : Type) (l : CList) (x : Bool) (get : CList → Nat → Bool) (set : CList → Nat → Bool → CList) (h : ∀ i j l x, i ≠ j → get (set l i x) j = get l j) (i j : Nat) (hi : i < j) : get (set l i x) j = get l j := by simp_lists [*] @@ -42,13 +35,11 @@ example := by simp_lists [*] -abbrev Zq := ZMod 3329 - example (x y : ℕ) (l : List ℕ) - (h : ∀ (j : ℕ), (↑l[j]! : Zq) = (↑x : Zq) - (↑y : Zq)) + (h : ∀ (j : ℕ), (↑l[j]! : ZMod 3329) = (↑x : ZMod 3329) - (↑y : ZMod 3329)) (j : ℕ) : - (↑l[j]! : Zq) = (↑x : Zq) - (↑y : Zq) + (↑l[j]! : ZMod 3329) = (↑x : ZMod 3329) - (↑y : ZMod 3329) := by simp_lists [h] From 952fa50a7fbad730eaa96fa644ce809fe4ced0b6 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 24 Jun 2025 21:27:47 +0100 Subject: [PATCH 19/31] Improve tracting in progress and progress* and improve progress* --- backends/lean/Aeneas/Progress/Init.lean | 46 +-- backends/lean/Aeneas/Progress/Progress.lean | 143 +++++---- .../lean/Aeneas/Progress/ProgressStar.lean | 289 +++++++++++++----- backends/lean/Aeneas/Progress/Trace.lean | 1 - backends/lean/Aeneas/Utils.lean | 3 - 5 files changed, 319 insertions(+), 163 deletions(-) diff --git a/backends/lean/Aeneas/Progress/Init.lean b/backends/lean/Aeneas/Progress/Init.lean index 48a782ded..0b8b0034f 100644 --- a/backends/lean/Aeneas/Progress/Init.lean +++ b/backends/lean/Aeneas/Progress/Init.lean @@ -65,33 +65,33 @@ structure ProgressSpecDesc where postcond? : Option Expr section Methods - variable [MonadLiftT MetaM m] [MonadControlT MetaM m] [Monad m] [MonadOptions m] + variable {m} [MonadLiftT MetaM m] [MonadControlT MetaM m] [Monad m] [MonadOptions m] variable [MonadTrace m] [MonadLiftT IO m] [MonadRef m] [AddMessageContext m] variable [MonadError m] variable {a : Type} -/-- Given ty := ∀ xs.., ∃ zs.., program = res ∧ post?, destruct and run continuation -/ -def programTelescope[Inhabited (m α)] [Nonempty (m α)] (ty: Expr) - (k: (xs:Array (MVarId × BinderInfo)) → (zs:Array FVarId) → (program:Expr) → (res:Expr) → (post:Option Expr) → m α) -: m α := do - let ty := ty.consumeMData - unless ←isProp ty do - throwError "Expected a proposition, got {←inferType ty}" - -- ty == ∀ xs, ty₂ - let (xs, xs_bi, ty₂) ← forallMetaTelescope ty - trace[Progress] "Universally quantified arguments and assumptions: {xs}" - -- ty₂ == ∃ zs, ty₃ ≃ Exists {α} (fun zs => ty₃) - existsTelescope ty₂.consumeMData fun zs ty₃ => do - trace[Progress] "Existentials: {zs}" - trace[Progress] "Proposition after stripping the quantifiers: {ty₃}" - -- ty₃ == ty₄ ∧ post? - let (ty₄, post?) ← Utils.optSplitConj ty₃.consumeMData - trace[Progress] "After splitting the conjunction:\n- eq: {ty₄}\n- post: {post?}" - -- ty₄ == (program = res) - let (program, res) ← Utils.destEq ty₄.consumeMData - trace[Progress] "After splitting the equality:\n- lhs: {program}\n- rhs: {res}" - k (xs.map (·.mvarId!) |>.zip xs_bi) (zs.map (·.fvarId!)) program res post? + /-- Given ty := ∀ xs.., ∃ zs.., program = res ∧ post?, destruct and run continuation -/ + def monadTelescope {α} [Inhabited (m α)] [Nonempty (m α)] (ty: Expr) + (k: (xs:Array (MVarId × BinderInfo)) → (zs:Array FVarId) → (program:Expr) → (res:Expr) → (post:Option Expr) → m α) + : m α := do + let ty := ty.consumeMData + unless ←isProp ty do + throwError "Expected a proposition, got {←inferType ty}" + -- ty == ∀ xs, ty₂ + let (xs, xs_bi, ty₂) ← forallMetaTelescope ty + trace[Progress] "Universally quantified arguments and assumptions: {xs}" + -- ty₂ == ∃ zs, ty₃ ≃ Exists {α} (fun zs => ty₃) + existsTelescope ty₂.consumeMData fun zs ty₃ => do + trace[Progress] "Existentials: {zs}" + trace[Progress] "Proposition after stripping the quantifiers: {ty₃}" + -- ty₃ == ty₄ ∧ post? + let (ty₄, post?) ← Utils.optSplitConj ty₃.consumeMData + trace[Progress] "After splitting the conjunction:\n- eq: {ty₄}\n- post: {post?}" + -- ty₄ == (program = res) + let (program, res) ← Utils.destEq ty₄.consumeMData + trace[Progress] "After splitting the equality:\n- lhs: {program}\n- rhs: {res}" + k (xs.map (·.mvarId!) |>.zip xs_bi) (zs.map (·.fvarId!)) program res post? /- Analyze a goal or a progress theorem to decompose its arguments. @@ -113,7 +113,7 @@ def programTelescope[Inhabited (m α)] [Nonempty (m α)] (ty: Expr) def withProgressSpec [Inhabited (m a)] [Nonempty (m a)] (isGoal : Bool) (th : Expr) (k : ProgressSpecDesc → m a) : m a := do - programTelescope th fun xs evars mExpr ret post => do + monadTelescope th fun xs evars mExpr ret post => do -- Recursively destruct the monadic application to dive into the binds, -- if necessary (this is for when we use `withProgressSpec` inside of the `progress` tactic), -- and destruct the application to get the function name diff --git a/backends/lean/Aeneas/Progress/Progress.lean b/backends/lean/Aeneas/Progress/Progress.lean index f8c2de626..9711756bc 100644 --- a/backends/lean/Aeneas/Progress/Progress.lean +++ b/backends/lean/Aeneas/Progress/Progress.lean @@ -23,6 +23,8 @@ example (x y z : Std.U32) (_ : [> let z ← (x + y) <]) : True := by simp theorem eq_imp_prettyMonadEq {α : Type u} {x : Std.Result α} {y : α} (h : x = .ok y) : prettyMonadEq x y := by simp [prettyMonadEq, h] +def traceGoalWithNode (msg : String) : TacticM Unit := do + withTraceNode `Progress (fun _ => do pure msg) do trace[Progress] "{← getMainGoal}" -- TODO: the scalar types annoyingly often get reduced when we use the progress -- tactic. We should find a way of controling reduction. For now we use rewriting @@ -124,10 +126,17 @@ deriving Inhabited attribute [progress_post_simps] Std.IScalar.toNat Std.UScalar.ofNat_val_eq Std.IScalar.ofInt_val_eq +structure Args where + keep : Option Name + keepPretty : Option Name + ids : Array (Option Name) + splitPost : Bool + assumTac : TacticM Unit + solvePreconditionTac : TacticM Unit + open Result in -def progressWith (fExpr : Expr) (th : Expr) - (keep keepPretty : Option Name) (ids : Array (Option Name)) (splitPost : Bool) - (assumTac asmTac : TacticM Unit) : TacticM (Result ProgressWithOutput MessageData) := do +def progressWith (args : Args) (fExpr : Expr) (th : Expr) : TacticM (Result ProgressWithOutput MessageData) := do + withTraceNode `Progress (fun _ => pure m!"progressWith") do /- Apply the theorem We try to match the theorem with the goal In order to do so, we introduce meta-variables for all the parameters @@ -145,13 +154,15 @@ def progressWith (fExpr : Expr) (th : Expr) -- Normalize to inline the let-bindings let thTy ← normalizeLetBindings thTy trace[Progress] "After normalizing the let-bindings: {thTy}" - programTelescope thTy fun xs _zs thBody _ _ => do + monadTelescope thTy fun xs _zs thBody _ _ => do let (mvars, binders) := xs.unzip let mvars := mvars.map .mvar -- Match the body with the target trace[Progress] "Matching:\n- body:\n{thBody}\n- target:\n{fExpr}" let ok ← isDefEq thBody fExpr - if ¬ ok then throwError "Could not unify the theorem with the target:\n- theorem: {thBody}\n- target: {fExpr}" + if ¬ ok then + trace[Progress] "Could not unify the theorem with the target" + throwError "Could not unify the theorem with the target:\n- theorem: {thBody}\n- target: {fExpr}" let mgoal ← Tactic.getMainGoal postprocessAppMVars `progress mgoal mvars binders true true Term.synthesizeSyntheticMVarsNoPostponing @@ -160,7 +171,7 @@ def progressWith (fExpr : Expr) (th : Expr) -- Add the instantiated theorem to the assumptions (we apply it on the metavariables). let th := mkAppN th mvars trace[Progress] "Instantiated theorem reusing the metavariables: {th}" - let asmName ← do match keep with | none => mkFreshAnonPropUserName | some n => do pure n + let asmName ← do match args.keep with | none => mkFreshAnonPropUserName | some n => do pure n let thTy ← inferType th trace[Progress] "thTy (after application): {thTy}" /- Normalize the let-bindings (note that we already inlined the let bindings once above when analizing @@ -171,7 +182,7 @@ def progressWith (fExpr : Expr) (th : Expr) trace[Progress] "thTy (after normalizing let-bindings): {thTy}" Utils.addDeclTac asmName th thTy (asLet := false) fun thAsm => do let ngoal ← getMainGoal - trace[Progress] "current goal: {ngoal}" + traceGoalWithNode "current goal" trace[Progress] "current goal is assigned: {← ngoal.isAssigned}" /- The assumption should be of the shape: `∃ x1 ... xn, f args = ... ∧ ...` @@ -180,12 +191,12 @@ def progressWith (fExpr : Expr) (th : Expr) introduced variables. -/ let splitExistsEqAndPost (next : Result (Array FVarId) MessageData → TacticM (Result ProgressWithOutput MessageData)) : TacticM (Result ProgressWithOutput MessageData) := do - splitAllExistsTac thAsm ids.toList fun h ids => do + splitAllExistsTac thAsm args.ids.toList fun h ids => do /- Introduce the pretty equality if the user requests it. We take care of introducing it *before* splitting the post-conditions, so that those appear after it. -/ - match keepPretty with + match args.keepPretty with | none => pure () | some name => trace[Progress] "About to introduce the pretty equality" @@ -209,7 +220,7 @@ def progressWith (fExpr : Expr) (th : Expr) trace[Progress] "Introducing the \"pretty\" let binding" let e ← mkAppM ``eq_imp_prettyMonadEq #[h] Utils.addDeclTac name e (← inferType e) (asLet := false) fun _ => do - trace[Progress] "Introduced the \"pretty\" let binding: {← getMainGoal}" + traceGoalWithNode "Introduced the \"pretty\" let binding" /- Split the conjunctions. For the conjunctions, we split according once to separate the equality `f ... = .ret ...` @@ -234,7 +245,7 @@ def progressWith (fExpr : Expr) (th : Expr) -- We shouldn't simplify the goal with the equality, then simplify again. splitEqAndPost fun hEq hPost ids => do trace[Progress] "eq and post:\n{hEq} : {← inferType hEq}\n{hPost}" - trace[Progress] "current goal: {← getMainGoal}" + traceGoalWithNode "current goal" let r ← Simp.simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} {simpThms := #[← progressSimpExt.getTheorems], hypsToUse := #[hEq.fvarId!]} (.targets #[] true) /- It may happen that at this point the goal is already solved (though this is rare) @@ -243,34 +254,37 @@ def progressWith (fExpr : Expr) (th : Expr) trace[Progress] "The main goal was solved!" next (Ok #[]) else - trace[Progress] "goal after applying the eq and simplifying the binds: {← getMainGoal}" + traceGoalWithNode "goal after applying the eq and simplifying the binds" -- TODO: remove this? (some types get unfolded too much: we "fold" them back) tryTac (do let _ ← Simp.simpAt true {} {addSimpThms := scalar_eqs} .wildcard_dep) - trace[Progress] "goal after folding back scalar types: {← getMainGoal}" + traceGoalWithNode "goal after folding back scalar types" -- Clear the equality, unless the user requests not to do so - if keep.isSome then pure () + if args.keep.isSome then pure () else do let mgoal ← getMainGoal - let mgoal ← mgoal.tryClearMany #[hEq.fvarId!] + let mgoal ← mgoal.tryClear hEq.fvarId! setGoals [mgoal] - trace[Progress] "Unsolved goals: {← getUnsolvedGoals}" - trace[Progress] "Goal after clearing the equality: {← getMainGoal}" + withMainContext do + withTraceNode `Progress (fun _ => pure m!"Unsolved goals") do trace[Progress] "Unsolved goals: {← getUnsolvedGoals}" + traceGoalWithNode "Goal after clearing the equality" -- Continue splitting following the post following the user's instructions match hPost with | none => -- Sanity check if ¬ ids.isEmpty then logWarning m!"Too many ids provided ({ids}): there is no postcondition to split" + trace[Progress] "No post to split" next (Ok #[]) | some hPost => do + trace[Progress] "Post to split: {hPost}" let rec splitPostWithIds (prevId : Name) (hPosts : List FVarId) (hPost : Expr) (ids0 : List (Option Name)) : - TacticM (Result (Array FVarId) MessageData) := do + TacticM (Result (Array FVarId) MessageData) := do match ids0 with | [] => /- We used all the user provided ids. Split the remaining conjunctions by using fresh ids if the user instructed to fully split the post-condition, otherwise stop -/ - if splitPost then + if args.splitPost then splitFullConjTac true hPost (λ asms => do pure (Ok (hPosts.reverse ++ (asms.map (fun x => x.fvarId!))).toArray)) else pure (Ok (hPost.fvarId! :: hPosts).reverse.toArray) @@ -287,12 +301,16 @@ def progressWith (fExpr : Expr) (th : Expr) else logWarning m!"Too many ids provided ({ids0}) not enough conjuncts to split in the postcondition" pure (Ok (hPost.fvarId! :: hPosts).reverse.toArray) + withMainContext do let curPostId := (← hPost.fvarId!.getDecl).userName let res ← splitPostWithIds curPostId [] hPost ids next res splitExistsEqAndPost fun res => do + trace[Progress] "Finished splitting the post" match res with - | Error msg => return (Error msg) -- Can we get there? We're using `return` + | Error msg => + trace[Progress] "" + return (Error msg) -- Can we get there? We're using `return` | Ok hPosts => trace[Progress] "type of hPosts: {← hPosts.mapM (·.getType >>= (liftM ∘ ppExpr))}" -- Update the set of goals @@ -319,8 +337,8 @@ def progressWith (fExpr : Expr) (th : Expr) pure ((← Utils.getMVarIds ty).size, g)) let ordPropGoals := (ordPropGoals.mergeSort (fun (mvars0, _) (mvars1, _) => mvars0 ≤ mvars1)).reverse setGoals (ordPropGoals.map Prod.snd) - allGoalsNoRecover (tryTac assumTac) - allGoalsNoRecover asmTac + allGoalsNoRecover (tryTac args.assumTac) + allGoalsNoRecover args.solvePreconditionTac -- Make sure we use the original order when presenting the preconditions to the user let newPropGoals ← newPropGoals.filterMapM (fun g => do if ← g.isAssigned then pure none else pure (some g)) /- Simplify the post-conditions in the main goal - note that we waited until now @@ -341,6 +359,7 @@ def progressWith (fExpr : Expr) (th : Expr) | none => -- We actually closed the goal: we shouldn't get there -- TODO: make this more robust + trace[Progress] "Unexpected: goal closed by simplifying the introduced post-conditions" throwError "Unexpected: goal closed by simplifying the introduced post-conditions" | some hPosts => -- Simplify the goal again @@ -382,9 +401,7 @@ def getFirstArg (args : Array Expr) : Option Expr := do /-- Helper: try to apply a theorem. Return the list of post-conditions we introduced if it succeeded. -/ -def tryApply (keep keepPretty : Option Name) (ids : Array (Option Name)) (splitPost : Bool) - (assumTac asmTac : TacticM Unit) (fExpr : Expr) - (kind : String) (th : Option Expr) : TacticM (Option ProgressWithOutput) := do +def tryApply (args : Args) (fExpr : Expr) (kind : String) (th : Option Expr) : TacticM (Option ProgressWithOutput) := do let res ← do match th with | none => @@ -395,7 +412,7 @@ def tryApply (keep keepPretty : Option Name) (ids : Array (Option Name)) (splitP -- Apply the theorem let res ← do try - let res ← progressWith fExpr th keep keepPretty ids splitPost assumTac asmTac + let res ← progressWith args fExpr th pure (some res) catch _ => pure none match res with @@ -403,9 +420,26 @@ def tryApply (keep keepPretty : Option Name) (ids : Array (Option Name)) (splitP | some (.Error msg) => throwError msg | none => pure none +/-- Try to progress with an assumption. + Return `some` if we succeed, `none` otherwise. +-/ +def tryAssumptions (args : Args) (fExpr : Expr) : TacticM (Option (ProgressGoals × UsedTheorem)) := do + withTraceNode `Progress (fun _ => pure m!"tryAssumptions") do run +where + run := + withMainContext do + let ctx ← Lean.MonadLCtx.getLCtx + let decls ← ctx.getAssumptions + for decl in decls.reverse do + trace[Progress] "Trying assumption: {decl.userName} : {decl.type}" + let res ← do try progressWith args fExpr decl.toExpr catch _ => continue + match res with + | .Ok res => return (some (res.toProgressGoals, .localHyp decl)) + | .Error msg => throwError msg + pure none + -- The array of ids are identifiers to use when introducing fresh variables -def progressAsmsOrLookupTheorem (keep keepPretty : Option Name) (withTh : Option Expr) - (ids : Array (Option Name)) (splitPost : Bool) (assumTac asmTac : TacticM Unit) : +def progressAsmsOrLookupTheorem (args : Args) (withTh : Option Expr) : TacticM (ProgressGoals × UsedTheorem) := do withMainContext do -- Retrieve the goal @@ -414,7 +448,7 @@ def progressAsmsOrLookupTheorem (keep keepPretty : Option Name) (withTh : Option /- There might be uninstantiated meta-variables in the goal that we need to instantiate (otherwise we will get stuck). -/ let goalTy ← instantiateMVars goalTy - trace[Progress] "goal: {goalTy}" + trace[Progress] "progressAsmsOrLookupTheorem: target: {goalTy}" /- Dive into the goal to lookup the theorem Remark: if we don't isolate the call to `withProgressSpec` to immediately "close" the terms immediately, we may end up with the error: @@ -425,6 +459,7 @@ def progressAsmsOrLookupTheorem (keep keepPretty : Option Name) (withTh : Option have the proper shape. -/ let fExpr ← do let isGoal := true + withTraceNode `Progress (fun _ => pure m!"Calling withProgressSpec to deconstruct the target") do withProgressSpec isGoal goalTy fun {fArgsExpr := fExpr, ..} => do trace[Progress] "Expression to match: {fExpr}" pure fExpr @@ -432,27 +467,22 @@ def progressAsmsOrLookupTheorem (keep keepPretty : Option Name) (withTh : Option -- Otherwise, lookup one. match withTh with | some th => do - match ← progressWith fExpr th keep keepPretty ids splitPost assumTac asmTac with + match ← progressWith args fExpr th with | .Ok res => -- Remark: exprToSyntax doesn't give the expected result return (res.toProgressGoals, .givenExpr th) | .Error msg => throwError msg | none => -- Try all the assumptions one by one and if it fails try to lookup a theorem. - let ctx ← Lean.MonadLCtx.getLCtx - let decls ← ctx.getDecls - for decl in decls.reverse do - trace[Progress] "Trying assumption: {decl.userName} : {decl.type}" - let res ← do try progressWith fExpr decl.toExpr keep keepPretty ids splitPost assumTac asmTac catch _ => continue - match res with - | .Ok res => return (res.toProgressGoals, .localHyp decl) - | .Error msg => throwError msg + if let some res ← tryAssumptions args fExpr then return res /- It failed: lookup the pspec theorems which match the expression *only if the function is a constant* -/ let fIsConst ← do fExpr.consumeMData.withApp fun mf _ => do pure mf.isConst - if ¬ fIsConst then throwError "Progress failed" + if ¬ fIsConst then + trace[Progress] "Progress failed: the target function is not a constant" + throwError "Progress failed" else do trace[Progress] "No assumption succeeded: trying to lookup a pspec theorem" let pspecs : Array Name ← do @@ -467,7 +497,7 @@ def progressAsmsOrLookupTheorem (keep keepPretty : Option Name) (withTh : Option -- Try the theorems one by one for pspec in pspecs do let pspecExpr ← Term.mkConst pspec - match ← tryApply keep keepPretty ids splitPost assumTac asmTac fExpr "pspec theorem" pspecExpr with + match ← tryApply args fExpr "pspec theorem" pspecExpr with | some res => return (res.toProgressGoals, .progressThm pspec) | none => pure () -- It failed: try to use the recursive assumptions @@ -479,12 +509,14 @@ def progressAsmsOrLookupTheorem (keep keepPretty : Option Name) (withTh : Option | .default | .implDetail => false | .auxDecl => true) for decl in decls.reverse do trace[Progress] "Trying recursive assumption: {decl.userName} : {decl.type}" - let res ← do try progressWith fExpr decl.toExpr keep keepPretty ids splitPost assumTac asmTac catch _ => continue + let res ← do try progressWith args fExpr decl.toExpr catch _ => continue match res with | .Ok res => return (res.toProgressGoals, .localHyp decl) | .Error msg => throwError msg -- Nothing worked: failed - throwError "Progress failed: could not find a local assumption or a theorem to apply" + let msg := "Progress failed: could not find a local assumption or a theorem to apply" + trace[Progress] msg + throwError msg syntax progressArgs := ("keep" binderIdent)? ("with" term)? ("as" " ⟨ " binderIdent,* " ⟩")? ("by" tacticSeq)? @@ -530,6 +562,7 @@ def parseProgressArgs def evalProgress (keep keepPretty : Option Name) (withArg: Option Expr) (ids: Array (Option Name)) (byTac : Option Syntax.Tactic) : TacticM Stats := do + withTraceNode `Progress (fun _ => pure m!"evalProgress") do /- Simplify the goal -- TODO: this might close it: we need to check that and abort if necessary, and properly track that in the `Stats` -/ let _ ← Simp.simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} @@ -542,7 +575,7 @@ def evalProgress (keep keepPretty : Option Name) (withArg: Option Expr) (ids: Ar arithmetic goal, we skip (note that otherwise, scalarTac would try to prove a contradiction) -/ let scalarTac : TacticM Unit := do - trace[Progress] "Attempting to solve with `scalarTac`" + withTraceNode `Progress (fun _ => pure m!"Attempting to solve with `scalarTac`") do if ← ScalarTac.goalIsLinearInt then /- Also: we don't try to split the goal if it is a conjunction (it shouldn't be), but we split the disjunctions. -/ @@ -550,13 +583,10 @@ def evalProgress (keep keepPretty : Option Name) (withArg: Option Expr) (ids: Ar else throwError "Not a linear arithmetic goal" let simpLemmas ← Aeneas.ScalarTac.scalarTacSimpExt.getTheorems - let localAsms ← (← (← getLCtx).getDecls).filterMapM fun decl => do - if ← isProp decl.type then - pure (some decl.fvarId) - else pure none + let localAsms ← pure ((← (← getLCtx).getAssumptions).map LocalDecl.fvarId) let simpArgs : Simp.SimpArgs := {simpThms := #[simpLemmas], hypsToUse := localAsms.toArray} let simpTac : TacticM Unit := do - trace[Progress] "Attempting to solve with `simp [*]`" + withTraceNode `Progress (fun _ => pure m!"Attempting to solve with `simp [*]`") do -- Simplify the goal let r ← Simp.simpAt false { maxDischargeDepth := 1 } simpArgs (.targets #[] true) -- Raise an error if the goal is not proved @@ -564,20 +594,27 @@ def evalProgress (keep keepPretty : Option Name) (withArg: Option Expr) (ids: Ar /- We use our custom assumption tactic, which instantiates meta-variables only if there is a single assumption matching the goal. -/ let customAssumTac : TacticM Unit := do - trace[Progress] "Attempting to solve with `singleAssumptionTac`" + withTraceNode `Progress (fun _ => pure m!"Attempting to solve with `singleAssumptionTac`") do singleAssumptionTacCore singleAssumptionTacDtree /- Also use the tactic provided by the user, if there is -/ let byTac := match byTac with | none => [] - | some byTac => [evalTactic byTac] - let (goals, usedTheorem) ← progressAsmsOrLookupTheorem keep keepPretty withArg ids splitPost customAssumTac ( + | some byTac => [ + withTraceNode `Progress (fun _ => pure m!"Attempting to solve with the user tactic: `{byTac}`") do + evalTactic byTac] + let solvePreconditionTac := withMainContext do - trace[Progress] "trying to solve precondition: {← getMainGoal}" + withTraceNode `Progress (fun _ => pure m!"Trying to solve a precondition") do + trace[Progress] "Precondition: {← getMainGoal}" try firstTacSolve ([simpTac, scalarTac] ++ byTac) trace[Progress] "Precondition solved!" catch _ => - trace[Progress] "Precondition not solved") + trace[Progress] "Precondition not solved" + let args : Args := { + keep, keepPretty, ids, splitPost, assumTac := customAssumTac, solvePreconditionTac + } + let (goals, usedTheorem) ← progressAsmsOrLookupTheorem args withArg trace[Progress] "Progress done" return ⟨ goals, usedTheorem ⟩ diff --git a/backends/lean/Aeneas/Progress/ProgressStar.lean b/backends/lean/Aeneas/Progress/ProgressStar.lean index ae6a01a10..d1ef45fcd 100644 --- a/backends/lean/Aeneas/Progress/ProgressStar.lean +++ b/backends/lean/Aeneas/Progress/ProgressStar.lean @@ -131,6 +131,8 @@ end Bifurcation namespace ProgressStar +abbrev traceGoalWithNode := Progress.traceGoalWithNode + structure Config where preconditionTac: Option Syntax.Tactic := none /-- Should we use the special syntax `let* ⟨ ...⟩ ← ...` or the more standard syntax `progress with ... as ⟨ ... ⟩`? -/ @@ -138,8 +140,8 @@ structure Config where useCase' : Bool := false structure Info where - script: Array Syntax.Tactic := #[] - unsolvedGoals: List MVarId := [] + script: Array Syntax.Tactic := #[] -- TODO: update so that we get a tree + unsolvedGoals: Array MVarId := #[] instance: Append Info where append inf1 inf2 := { @@ -149,91 +151,194 @@ instance: Append Info where attribute [progress_simps] Aeneas.Std.bind_assoc_eq -partial def evalProgressStar(cfg: Config): TacticM Info := +inductive TargetKind where +| bind (fn : Name) +| switch (info : Bifurcation.Info) +| result +| unknown + +/- Smaller helper which we use to check in which situation we are -/ +def analyzeTarget : TacticM TargetKind := do + withTraceNode `Progress (fun _ => do pure m!"analyzeTarget") do + try + Progress.monadTelescope (← getMainTarget) fun _xs _zs program _res _post => do + let e ← Utils.normalizeLetBindings program + if let .const ``Bind.bind .. := e.getAppFn then + let #[_m, _self, _α, _β, _value, cont] := e.getAppArgs + | throwError "Expected bind to have 4 arguments, found {← e.getAppArgs.mapM (liftM ∘ ppExpr)}" + Utils.lambdaOne cont fun x _ => do + let name ← x.fvarId!.getUserName + pure (.bind name) + else if let .some bfInfo ← Bifurcation.Info.ofExpr e then + pure (.switch bfInfo) + else + pure .result + catch _ => pure .unknown + +partial def evalProgressStar (cfg: Config) : TacticM Info := withMainContext do focus do - trace[ProgressStar] "Simplifying the goal: {←(getMainTarget >>= (liftM ∘ ppExpr))}" - let r ← Simp.simpAt (simpOnly := true) - { maxDischargeDepth := 1, failIfUnchanged := false} - {simpThms := #[← Progress.progressSimpExt.getTheorems]} - (.targets #[] true) - /- We may have proven the goal already -/ - if r.isNone then - let progress_simps ← `(Parser.Tactic.simpLemma| $(mkIdent `progress_simps):term) - return ⟨ #[← `(tactic|simp [$progress_simps])], [] ⟩ + withTraceNode `Progress (fun _ => do pure m!"evalProgressStar") do /- Continue -/ - trace[ProgressStar] "After simplifying the goal: {← getMainTarget}" - let res ← traverseProgram cfg - setGoals res.unsolvedGoals - return res + let (info, mvarId) ← simplifyTarget + match mvarId with + | some _ => + let info' ← traverseProgram cfg + let info := info ++ info' + setGoals info.unsolvedGoals.toList + pure info + | none => pure info where + simplifyTarget : TacticM (Info × Option MVarId) := do + withTraceNode `Progress (fun _ => do pure m!"simplifyTarget") do + traceGoalWithNode "about to simplify goal" + let mvarId0 ← getMainGoal + let r ← Simp.simpAt (simpOnly := true) + { maxDischargeDepth := 1, failIfUnchanged := false} + {simpThms := #[← Progress.progressSimpExt.getTheorems]} + (.targets #[] true) + /- We may have proven the goal already -/ + let tac : Array Syntax.Tactic ← do + let genSimp : Bool ← do + if r.isNone then pure true + else do + pure ((← getMainGoal) != mvarId0) + if genSimp then + let progress_simps ← `(Parser.Tactic.simpLemma| $(mkIdent `progress_simps):term) + pure #[← `(tactic|simp only [$progress_simps])] + else pure #[] + let info : Info := ⟨ tac, #[] ⟩ + if r.isSome then traceGoalWithNode "after simplification" + else trace[Progress] "goal proved" + let goal ← do if r.isSome then pure (some (← getMainGoal)) else pure none + pure (info, goal) + traverseProgram (cfg : Config): TacticM Info := do withMainContext do - trace[ProgressStar] "traverseProgram: current goal: {← getMainGoal}" - try -- `programTelescope` can fail - Progress.programTelescope (← getMainTarget) fun _xs _zs program _res _post => do - let e ← Utils.normalizeLetBindings program - if let .const ``Bind.bind .. := e.getAppFn then - let #[_m, _self, _α, _β, _value, cont] := e.getAppArgs - | throwError "Expected bind to have 4 arguments, found {← e.getAppArgs.mapM (liftM ∘ ppExpr)}" - Utils.lambdaOne cont fun x _ => do - let name ← x.fvarId!.getUserName - let (info, mainGoal) ← onBind cfg name - trace[ProgressStar] "traverseProgram: after call to `onBind`: main goal is: {mainGoal}" - /- Continue, if necessary -/ - match mainGoal with - | none => - -- Stop - return info - | some mainGoal => - setGoals [mainGoal] - let restInfo ← traverseProgram cfg - return info ++ restInfo - else if let .some bfInfo ← Bifurcation.Info.ofExpr e then - let contsTaggedVals ← - bfInfo.branches.mapM fun br => do - Utils.lambdaTelescopeN br.toExpr br.numArgs fun xs _ => do - let names ← xs.mapM (·.fvarId!.getUserName) - return names - let (branchGoals, mkStx) ← onBif cfg bfInfo contsTaggedVals - /- Continue exploring from the subgoals -/ - let branchInfos ← branchGoals.mapM fun mainGoal => do - setGoals [mainGoal] - let restInfo ← traverseProgram cfg - pure restInfo - /- Put everything together -/ - mkStx branchInfos - else - let (info, mainGoal) ← onResult cfg - pure { info with unsolvedGoals := info.unsolvedGoals ++ mainGoal.toList} - catch _ => - return ({ script := #[←`(tactic| sorry)], unsolvedGoals := ← getUnsolvedGoals}) + withTraceNode `Progress (fun _ => do pure m!"traverseProgram") do + traceGoalWithNode "current goal" + let targetKind ← analyzeTarget + match targetKind with + | .bind varName => do + let (info, mainGoal) ← onBind cfg varName + /- Continue, if necessary -/ + match mainGoal with + | none => + -- Stop + trace[Progress] "stop" + return info + | some mainGoal => + setGoals [mainGoal] + let restInfo ← traverseProgram cfg + return info ++ restInfo + | .switch bfInfo => do + let contsTaggedVals ← + bfInfo.branches.mapM fun br => do + Utils.lambdaTelescopeN br.toExpr br.numArgs fun xs _ => do + let names ← xs.mapM (·.fvarId!.getUserName) + return names + let (branchGoals, mkStx) ← onBif cfg bfInfo contsTaggedVals + withTraceNode `Progress (fun _ => do pure m!"exploring branches") do + /- Continue exploring from the subgoals -/ + let branchInfos ← branchGoals.mapM fun mainGoal => do + setGoals [mainGoal] + let restInfo ← traverseProgram cfg + pure restInfo + /- Put everything together -/ + mkStx branchInfos + | .result => do + let (info, mainGoal) ← onResult cfg + pure { info with unsolvedGoals := info.unsolvedGoals ++ mainGoal.toList} + | .unknown => do + trace[Progress] "don't know what to do: inserting a sorry" + return ({ script := #[←`(tactic| sorry)], unsolvedGoals := (← getUnsolvedGoals).toArray}) onResult (cfg : Config) : TacticM (Info × Option MVarId) := do - trace[ProgressStar] "onResult: Since (· >>= pure) = id, we treat this result as a bind on id" - -- If we encounter `(do f a)` we process it as if it were `(do let res ← f a; return res)` - -- since (id = (· >>= pure)) and when we desugar the do block we have that - -- - -- (do f a) == f a - -- == (f a) >>= pure - -- == (do let res ← f a; return res) - -- - -- We known in advance the result of processing `return res`, which is to do nothing. - -- This allows us to prevent code duplication with the `onBind` function. - onBind cfg (.str .anonymous "res") - - onBind (cfg : Config) (name : Name) : TacticM (Info × Option MVarId) := do - trace[ProgressStar] "onBind (name={name})" + withTraceNode `Progress (fun _ => pure m!"onResult") do + /- If we encounter `(do f a)` we process it as if it were `(do let res ← f a; return res)` + since (id = (· >>= pure)) and when we desugar the do block we have that + + (do f a) == f a + == (f a) >>= pure + == (do let res ← f a; return res) + + We known in advance the result of processing `return res`, which is to do nothing. + This allows us to prevent code duplication with the `onBind` function. -/ + let res ← onBind cfg (.str .anonymous "res") + match res.snd with + | none => + trace[Progress] "done" + pure res + | some mvarId => + let (info', mvarId) ← onFinish mvarId + pure (res.fst ++ info', mvarId) + + onFinish (mvarId : MVarId) : TacticM (Info × Option MVarId) := do + withTraceNode `Progress (fun _ => pure m!"onFinish") do + setGoals [mvarId] + traceGoalWithNode "goal" + /- Simplify a bit -/ + let (info, mvarId) ← simplifyTarget + match mvarId with + | none => pure (info, mvarId) + | some mvarId => + /- Attempt to finish with a tactic -/ + -- `simp [*]` + let simpTac : TacticM Syntax.Tactic := do + let localAsms ← pure ((← (← getLCtx).getAssumptions).map LocalDecl.fvarId) + let simpArgs : Simp.SimpArgs := {hypsToUse := localAsms.toArray} + let r ← Simp.simpAt false { maxDischargeDepth := 1 } simpArgs (.targets #[] true) + -- Raise an error if the goal is not proved + if r.isSome then throwError "Goal not proved" + else `(tactic|simp [*]) + -- `scalar_tac` + let scalarTac : TacticM Syntax.Tactic := do + if ← ScalarTac.goalIsLinearInt then + /- Also: we don't try to split the goal if it is a conjunction + (it shouldn't be), but we split the disjunctions. -/ + ScalarTac.scalarTac { split := false } + `(tactic|scalar_tac) + else + throwError "Not a linear arithmetic goal" + -- TODO: add the tactic given by the user + let rec tryFinish (tacl : List (String × TacticM Syntax.Tactic)) : TacticM Syntax.Tactic := do + match tacl with + | [] => + trace[Progress] "could not prove the goal: inserting a sorry" + `(tactic| sorry) + | (name, tac) :: tacl => + let stx : Option Syntax.Tactic ← + withTraceNode `Progress (fun _ => pure m!"Attempting to solve with `{name}`") do + try + let stx ← tac + -- Check that there are no remaining goals + let gl ← Tactic.getUnsolvedGoals + if ¬ gl.isEmpty then throwError "tactic failed" + else pure (some stx) + catch _ => pure none + match stx with + | some stx => + trace[Progress] "goal solved" + pure stx + | none => tryFinish tacl + let tac ← tryFinish [("simp [*]", simpTac), ("scalar_tac", scalarTac)] + let info' : Info ← pure { script := #[tac], unsolvedGoals := (← getUnsolvedGoals).toArray} + pure (info ++ info', none) + + onBind (cfg : Config) (varName : Name) : TacticM (Info × Option MVarId) := do + withTraceNode `Progress (fun _ => pure m!"onBind ({varName})") do if let some {usedTheorem, preconditions, mainGoal } ← tryProgress then - trace[ProgressStar] "onBind: Can make progres: the new goal is: {mainGoal}, the unsolved preconditions are: {preconditions}" + withTraceNode `Progress (fun _ => pure m!"progress succeeded") do + trace[Progress] "New goal: {mainGoal}" + trace[Progress] "Unsolved preconditions: {preconditions}" let (preconditionTacs, unsolved) ← handleProgressPreconditions preconditions if ¬ preconditionTacs.isEmpty then - trace[ProgressStar] "onBind: Found {preconditionTacs.size} preconditions, left {unsolved.size} unsolved" + trace[Progress] "Found {preconditionTacs.size} preconditions, left {unsolved.size} unsolved" else - trace[ProgressStar] "onBind: all preconditions solved" + trace[Progress] "all preconditions solved" /- Update the main goal, if necessary -/ - let ids ← getIdsFromUsedTheorem name.eraseMacroScopes usedTheorem - trace[ProgressStar] "onBind: ids from used theorem: {ids}" + let ids ← getIdsFromUsedTheorem varName.eraseMacroScopes usedTheorem + trace[Progress] "ids from used theorem: {ids}" let mainGoal ← do mainGoal.mapM fun mainGoal => do if ¬ ids.isEmpty then renameInaccessibles mainGoal ids -- NOTE: Taken from renameI tactic else pure mainGoal @@ -249,15 +354,17 @@ where else `(tactic| progress with $(←usedTheorem.toSyntax) as ⟨$ids,*⟩) let info : Info := { script := #[currTac]++ preconditionTacs, -- TODO: Optimize - unsolvedGoals := unsolved.toList, + unsolvedGoals := unsolved, } pure (info, mainGoal) - else return ({ script := #[←`(tactic| sorry)], unsolvedGoals := ← getUnsolvedGoals}, none) + else + onFinish (← getMainGoal) onBif (cfg : Config) (bfInfo : Bifurcation.Info) (toBeProcessed : Array (Array Name)): TacticM (List MVarId × (List Info → TacticM Info)) := do - trace[ProgressStar] "onBif: encountered {bfInfo.kind}" + withTraceNode `Progress (fun _ => pure m!"onBif") do + trace[Progress] "onBif: encountered {bfInfo.kind}" if (←getGoals).isEmpty then - trace[ProgressStar] "onBif: no goals to be solved!" + trace[Progress] "onBif: no goals to be solved!" -- Tactic.focus fails if there are no goals to be solved. return ({}, fun infos => assert! (infos.length == 0); pure {}) Tactic.focus do @@ -265,7 +372,7 @@ where evalSplit splitStx -- let subgoals ← getUnsolvedGoals - trace[ProgressStar] "onBif: Bifurcation generated {subgoals.length} subgoals" + trace[Progress] "onBif: Bifurcation generated {subgoals.length} subgoals" unless subgoals.length == toBeProcessed.size do throwError "onBif: Expected {toBeProcessed.size} cases, found {subgoals.length}" let infos_mkBranchesStx ← (subgoals.zip toBeProcessed.toList).mapM fun (sg, names) => do @@ -323,12 +430,13 @@ where pure (← preconditions.mapM (fun _ => `(tactic| · sorry)), preconditions) getIdsFromUsedTheorem name usedTheorem: TacticM (Array _) := do + withTraceNode `Progress (fun _ => do pure m!"getIdsFromUsedTheorem") do let some thm ← usedTheorem.getType | throwError "Could not infer proposition of {usedTheorem}" - let (numElem, numPost) ← Progress.programTelescope thm + let (numElem, numPost) ← Progress.monadTelescope thm fun _xs zs _program _res postconds => do let numPost := Utils.numOfConjuncts <$> postconds |>.getD 0 - trace[ProgressStar] "Number of conjuncts for `{←liftM (Option.traverse ppExpr postconds)}` is {numPost}" + trace[Progress] "Number of conjuncts for `{←liftM (Option.traverse ppExpr postconds)}` is {numPost}" pure (zs.size, numPost) return makeIds (base := name) numElem numPost @@ -379,7 +487,7 @@ section Examples /-- info: Try this: -simp [progress_simps] +simp only [progress_simps] -/ #guard_msgs in example : True := by progress*? @@ -399,7 +507,22 @@ info: Try this: -/ #guard_msgs in example (x y : U32) (h : 2 * x.val + 2 * y.val + 4 ≤ U32.max) : - ∃ z, add1 x y = ok z := by + ∃ z, add1 x y = ok z := by + unfold add1 + progress*? + +/-- +info: Try this: + simp only [progress_simps] + let* ⟨ x2, x2_post ⟩ ← U32.add_spec + let* ⟨ x3, x3_post ⟩ ← U32.add_spec + let* ⟨ res, res_post ⟩ ← U32.add_spec + scalar_tac +-/ +#guard_msgs in +example (x y : U32) (h : 2 * x.val + 2 * y.val + 4 ≤ U32.max) : + let v := 2 * x.val + 2 * y.val + 4 + ∃ z, add1 x y = ok z ∧ z.val = v:= by unfold add1 progress*? diff --git a/backends/lean/Aeneas/Progress/Trace.lean b/backends/lean/Aeneas/Progress/Trace.lean index d3f80f150..246fccd8e 100644 --- a/backends/lean/Aeneas/Progress/Trace.lean +++ b/backends/lean/Aeneas/Progress/Trace.lean @@ -5,6 +5,5 @@ namespace Aeneas.Progress -- We can't define and use trace classes in the same file initialize registerTraceClass `Progress -initialize registerTraceClass `ProgressStar end Aeneas.Progress diff --git a/backends/lean/Aeneas/Utils.lean b/backends/lean/Aeneas/Utils.lean index 5fbbd8c93..fde30eb4b 100644 --- a/backends/lean/Aeneas/Utils.lean +++ b/backends/lean/Aeneas/Utils.lean @@ -342,9 +342,6 @@ def firstTacSolve (tacl : List (TacticM Unit)) : TacticM Unit := do match tacl with | [] => throwError "no tactic succeeded" | tac :: tacl => - -- Should use try ... catch or Lean.observing? - -- Generally speaking we should use Lean.observing? to restore the state, - -- but with tactics the try ... catch variant seems to work try do tac -- Check that there are no remaining goals From 2a022759269953b1f963870035cf98a3b42d7b1f Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 24 Jun 2025 21:52:35 +0100 Subject: [PATCH 20/31] Fix an issue in `progress` --- backends/lean/Aeneas/Progress/Progress.lean | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/backends/lean/Aeneas/Progress/Progress.lean b/backends/lean/Aeneas/Progress/Progress.lean index 9711756bc..04c097925 100644 --- a/backends/lean/Aeneas/Progress/Progress.lean +++ b/backends/lean/Aeneas/Progress/Progress.lean @@ -301,7 +301,6 @@ def progressWith (args : Args) (fExpr : Expr) (th : Expr) : TacticM (Result Prog else logWarning m!"Too many ids provided ({ids0}) not enough conjuncts to split in the postcondition" pure (Ok (hPost.fvarId! :: hPosts).reverse.toArray) - withMainContext do let curPostId := (← hPost.fvarId!.getDecl).userName let res ← splitPostWithIds curPostId [] hPost ids next res @@ -312,7 +311,10 @@ def progressWith (args : Args) (fExpr : Expr) (th : Expr) : TacticM (Result Prog trace[Progress] "" return (Error msg) -- Can we get there? We're using `return` | Ok hPosts => - trace[Progress] "type of hPosts: {← hPosts.mapM (·.getType >>= (liftM ∘ ppExpr))}" + /- Warning: the main goal may have been solved here, meaning it is unsafe to, + for instance, call `withMainContext`, or to print the `hPosts` + TODO: fix this. + -/ -- Update the set of goals let curGoals ← getUnsolvedGoals trace[Progress] "current goals: {curGoals}" From 1e46d4afadd82dfc71ecaf767bd2802751da2a72 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 25 Jun 2025 07:50:07 +0100 Subject: [PATCH 21/31] Make progress* always try scalar_tac on the final goal --- backends/lean/Aeneas/Progress/ProgressStar.lean | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/backends/lean/Aeneas/Progress/ProgressStar.lean b/backends/lean/Aeneas/Progress/ProgressStar.lean index d1ef45fcd..482c33296 100644 --- a/backends/lean/Aeneas/Progress/ProgressStar.lean +++ b/backends/lean/Aeneas/Progress/ProgressStar.lean @@ -293,13 +293,8 @@ where else `(tactic|simp [*]) -- `scalar_tac` let scalarTac : TacticM Syntax.Tactic := do - if ← ScalarTac.goalIsLinearInt then - /- Also: we don't try to split the goal if it is a conjunction - (it shouldn't be), but we split the disjunctions. -/ - ScalarTac.scalarTac { split := false } - `(tactic|scalar_tac) - else - throwError "Not a linear arithmetic goal" + ScalarTac.scalarTac {} + `(tactic|scalar_tac) -- TODO: add the tactic given by the user let rec tryFinish (tacl : List (String × TacticM Syntax.Tactic)) : TacticM Syntax.Tactic := do match tacl with From ac29ebab5b63f67accce9d63742dae03231e6344 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 25 Jun 2025 07:50:35 +0100 Subject: [PATCH 22/31] Marke the From....from_val_eq lemmas as `scalar_tac_simps` --- .../Aeneas/Std/Scalar/CoreConvertNum.lean | 114 ++++++++++++------ 1 file changed, 76 insertions(+), 38 deletions(-) diff --git a/backends/lean/Aeneas/Std/Scalar/CoreConvertNum.lean b/backends/lean/Aeneas/Std/Scalar/CoreConvertNum.lean index 715c7cdd6..7ae0df1ce 100644 --- a/backends/lean/Aeneas/Std/Scalar/CoreConvertNum.lean +++ b/backends/lean/Aeneas/Std/Scalar/CoreConvertNum.lean @@ -116,65 +116,84 @@ private theorem BitVec.setWidth_toNat_le (m : Nat) (x : BitVec n) (h : n ≤ m) apply Nat.mod_eq_of_lt omega -@[simp, scalar_tac_simps] theorem FromUsizeU8.from_val_eq (x : U8) : (FromUsizeU8.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromUsizeU8.from_val_eq (x : U8) : (FromUsizeU8.from x).val = x.val := by simp only [FromUsizeU8.from]; apply BitVec.setWidth_toNat_le cases System.Platform.numBits_eq <;> simp [*] -@[simp, scalar_tac_simps] theorem FromUsizeU16.from_val_eq (x : U16) : (FromUsizeU16.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromUsizeU16.from_val_eq (x : U16) : (FromUsizeU16.from x).val = x.val := by simp only [FromUsizeU16.from]; apply BitVec.setWidth_toNat_le cases System.Platform.numBits_eq <;> simp [*] -@[simp, scalar_tac_simps] theorem FromUsizeU32.from_val_eq (x : U32) : (FromUsizeU32.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromUsizeU32.from_val_eq (x : U32) : (FromUsizeU32.from x).val = x.val := by simp only [FromUsizeU32.from]; apply BitVec.setWidth_toNat_le cases System.Platform.numBits_eq <;> simp [*] -@[simp, scalar_tac_simps] theorem FromUsizeUsize.from_val_eq (x : Usize) : (FromUsizeUsize.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromUsizeUsize.from_val_eq (x : Usize) : (FromUsizeUsize.from x).val = x.val := by simp only [FromUsizeUsize.from]; apply BitVec.setWidth_toNat_le cases System.Platform.numBits_eq <;> simp [*] -@[simp, scalar_tac_simps] theorem FromU8U8.from_val_eq (x : U8) : (FromU8U8.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromU8U8.from_val_eq (x : U8) : (FromU8U8.from x).val = x.val := by simp only [FromU8U8.from]; apply BitVec.setWidth_toNat_le; simp -@[simp, scalar_tac_simps] theorem FromU16U8.from_val_eq (x : U8) : (FromU16U8.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromU16U8.from_val_eq (x : U8) : (FromU16U8.from x).val = x.val := by simp only [FromU16U8.from]; apply BitVec.setWidth_toNat_le; simp -@[simp, scalar_tac_simps] theorem FromU16U16.from_val_eq (x : U16) : (FromU16U16.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromU16U16.from_val_eq (x : U16) : (FromU16U16.from x).val = x.val := by simp only [FromU16U16.from]; apply BitVec.setWidth_toNat_le; simp -@[simp, scalar_tac_simps] theorem FromU32U8.from_val_eq (x : U8) : (FromU32U8.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromU32U8.from_val_eq (x : U8) : (FromU32U8.from x).val = x.val := by simp only [FromU32U8.from]; apply BitVec.setWidth_toNat_le; simp -@[simp, scalar_tac_simps] theorem FromU32U16.from_val_eq (x : U16) : (FromU32U16.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromU32U16.from_val_eq (x : U16) : (FromU32U16.from x).val = x.val := by simp only [FromU32U16.from]; apply BitVec.setWidth_toNat_le; simp -@[simp, scalar_tac_simps] theorem FromU32U32.from_val_eq (x : U32) : (FromU32U32.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromU32U32.from_val_eq (x : U32) : (FromU32U32.from x).val = x.val := by simp only [FromU32U32.from]; apply BitVec.setWidth_toNat_le; simp -@[simp, scalar_tac_simps] theorem FromU64U8.from_val_eq (x : U8) : (FromU64U8.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromU64U8.from_val_eq (x : U8) : (FromU64U8.from x).val = x.val := by simp only [FromU64U8.from]; apply BitVec.setWidth_toNat_le; simp -@[simp, scalar_tac_simps] theorem FromU64U16.from_val_eq (x : U16) : (FromU64U16.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromU64U16.from_val_eq (x : U16) : (FromU64U16.from x).val = x.val := by simp only [FromU64U16.from]; apply BitVec.setWidth_toNat_le; simp -@[simp, scalar_tac_simps] theorem FromU64U32.from_val_eq (x : U32) : (FromU64U32.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromU64U32.from_val_eq (x : U32) : (FromU64U32.from x).val = x.val := by simp only [FromU64U32.from]; apply BitVec.setWidth_toNat_le; simp -@[simp, scalar_tac_simps] theorem FromU64U64.from_val_eq (x : U64) : (FromU64U64.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromU64U64.from_val_eq (x : U64) : (FromU64U64.from x).val = x.val := by simp only [FromU64U64.from]; apply BitVec.setWidth_toNat_le; simp -@[simp, scalar_tac_simps] theorem FromU128U8.from_val_eq (x : U8) : (FromU128U8.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromU128U8.from_val_eq (x : U8) : (FromU128U8.from x).val = x.val := by simp only [FromU128U8.from]; apply BitVec.setWidth_toNat_le; simp -@[simp, scalar_tac_simps] theorem FromU128U16.from_val_eq (x : U16) : (FromU128U16.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromU128U16.from_val_eq (x : U16) : (FromU128U16.from x).val = x.val := by simp only [FromU128U16.from]; apply BitVec.setWidth_toNat_le; simp -@[simp, scalar_tac_simps] theorem FromU128U32.from_val_eq (x : U32) : (FromU128U32.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromU128U32.from_val_eq (x : U32) : (FromU128U32.from x).val = x.val := by simp only [FromU128U32.from]; apply BitVec.setWidth_toNat_le; simp -@[simp, scalar_tac_simps] theorem FromU128U64.from_val_eq (x : U64) : (FromU128U64.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromU128U64.from_val_eq (x : U64) : (FromU128U64.from x).val = x.val := by simp only [FromU128U64.from]; apply BitVec.setWidth_toNat_le; simp -@[simp, scalar_tac_simps] theorem FromU128U128.from_val_eq (x : U128) : (FromU128U128.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromU128U128.from_val_eq (x : U128) : (FromU128U128.from x).val = x.val := by simp only [FromU128U128.from]; apply BitVec.setWidth_toNat_le; simp @[simp, bvify_simps] theorem FromUsizeU8.from_bv_eq (x : U8) : (FromUsizeU8.from x).bv = x.bv.setWidth _ := by @@ -248,78 +267,97 @@ private theorem bmod_pow2_eq_of_inBounds' (n : ℕ) (x : ℤ) (h : 0 < n ∧ -2 simp [hn] at this apply this -@[simp, scalar_tac_simps] theorem FromIsizeI8.from_val_eq (x : I8) : (FromIsizeI8.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromIsizeI8.from_val_eq (x : I8) : (FromIsizeI8.from x).val = x.val := by cases System.Platform.numBits_eq <;> simp only [FromIsizeI8.from, IScalar.val, BitVec.signExtend] <;> simp <;> apply bmod_pow2_eq_of_inBounds' _ x.val (by scalar_tac) -@[simp, scalar_tac_simps] theorem FromIsizeI16.from_val_eq (x : I16) : (FromIsizeI16.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromIsizeI16.from_val_eq (x : I16) : (FromIsizeI16.from x).val = x.val := by cases System.Platform.numBits_eq <;> simp only [FromIsizeI16.from, IScalar.val, BitVec.signExtend] <;> simp <;> apply bmod_pow2_eq_of_inBounds' _ x.val (by scalar_tac) -@[simp, scalar_tac_simps] theorem FromIsizeI32.from_val_eq (x : I32) : (FromIsizeI32.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromIsizeI32.from_val_eq (x : I32) : (FromIsizeI32.from x).val = x.val := by cases System.Platform.numBits_eq <;> simp only [FromIsizeI32.from, IScalar.val, BitVec.signExtend] <;> simp <;> apply bmod_pow2_eq_of_inBounds' _ x.val (by scalar_tac) -@[simp, scalar_tac_simps] theorem FromIsizeIsize.from_val_eq (x : Isize) : (FromIsizeIsize.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromIsizeIsize.from_val_eq (x : Isize) : (FromIsizeIsize.from x).val = x.val := by cases System.Platform.numBits_eq <;> simp only [FromIsizeIsize.from, IScalar.val, BitVec.signExtend] <;> simp -@[simp, scalar_tac_simps] theorem FromI8I8.from_val_eq (x : I8) : (FromI8I8.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromI8I8.from_val_eq (x : I8) : (FromI8I8.from x).val = x.val := by simp only [FromI8I8.from, IScalar.val, BitVec.signExtend]; simp -@[simp, scalar_tac_simps] theorem FromI16I8.from_val_eq (x : I8) : (FromI16I8.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromI16I8.from_val_eq (x : I8) : (FromI16I8.from x).val = x.val := by simp only [FromI16I8.from, IScalar.val, BitVec.signExtend]; simp apply bmod_pow2_eq_of_inBounds' 16 x.val (by scalar_tac) -@[simp, scalar_tac_simps] theorem FromI16I16.from_val_eq (x : I16) : (FromI16I16.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromI16I16.from_val_eq (x : I16) : (FromI16I16.from x).val = x.val := by simp only [FromI16I16.from, IScalar.val, BitVec.signExtend]; simp -@[simp, scalar_tac_simps] theorem FromI32I8.from_val_eq (x : I8) : (FromI32I8.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromI32I8.from_val_eq (x : I8) : (FromI32I8.from x).val = x.val := by simp only [FromI32I8.from, IScalar.val, BitVec.signExtend]; simp apply bmod_pow2_eq_of_inBounds' 32 x.val (by scalar_tac) -@[simp, scalar_tac_simps] theorem FromI32I16.from_val_eq (x : I16) : (FromI32I16.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromI32I16.from_val_eq (x : I16) : (FromI32I16.from x).val = x.val := by simp only [FromI32I16.from, IScalar.val, BitVec.signExtend]; simp apply bmod_pow2_eq_of_inBounds' 32 x.val (by scalar_tac) -@[simp, scalar_tac_simps] theorem FromI32I32.from_val_eq (x : I32) : (FromI32I32.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromI32I32.from_val_eq (x : I32) : (FromI32I32.from x).val = x.val := by simp only [FromI32I32.from, IScalar.val, BitVec.signExtend]; simp -@[simp, scalar_tac_simps] theorem FromI64I8.from_val_eq (x : I8) : (FromI64I8.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromI64I8.from_val_eq (x : I8) : (FromI64I8.from x).val = x.val := by simp only [FromI64I8.from, IScalar.val, BitVec.signExtend]; simp apply bmod_pow2_eq_of_inBounds' 64 x.val (by scalar_tac) -@[simp, scalar_tac_simps] theorem FromI64I16.from_val_eq (x : I16) : (FromI64I16.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromI64I16.from_val_eq (x : I16) : (FromI64I16.from x).val = x.val := by simp only [FromI64I16.from, IScalar.val, BitVec.signExtend]; simp apply bmod_pow2_eq_of_inBounds' 64 x.val (by scalar_tac) -@[simp, scalar_tac_simps] theorem FromI64I32.from_val_eq (x : I32) : (FromI64I32.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromI64I32.from_val_eq (x : I32) : (FromI64I32.from x).val = x.val := by simp only [FromI64I32.from, IScalar.val, BitVec.signExtend]; simp apply bmod_pow2_eq_of_inBounds' 64 x.val (by scalar_tac) -@[simp, scalar_tac_simps] theorem FromI64I64.from_val_eq (x : I64) : (FromI64I64.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromI64I64.from_val_eq (x : I64) : (FromI64I64.from x).val = x.val := by simp only [FromI64I64.from, IScalar.val, BitVec.signExtend]; simp -@[simp, scalar_tac_simps] theorem FromI128I8.from_val_eq (x : I8) : (FromI128I8.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromI128I8.from_val_eq (x : I8) : (FromI128I8.from x).val = x.val := by simp only [FromI128I8.from, IScalar.val, BitVec.signExtend]; simp apply bmod_pow2_eq_of_inBounds' 128 x.val (by scalar_tac) -@[simp, scalar_tac_simps] theorem FromI128I16.from_val_eq (x : I16) : (FromI128I16.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromI128I16.from_val_eq (x : I16) : (FromI128I16.from x).val = x.val := by simp only [FromI128I16.from, IScalar.val, BitVec.signExtend]; simp apply bmod_pow2_eq_of_inBounds' 128 x.val (by scalar_tac) -@[simp, scalar_tac_simps] theorem FromI128I32.from_val_eq (x : I32) : (FromI128I32.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromI128I32.from_val_eq (x : I32) : (FromI128I32.from x).val = x.val := by simp only [FromI128I32.from, IScalar.val, BitVec.signExtend]; simp apply bmod_pow2_eq_of_inBounds' 128 x.val (by scalar_tac) -@[simp, scalar_tac_simps] theorem FromI128I64.from_val_eq (x : I64) : (FromI128I64.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromI128I64.from_val_eq (x : I64) : (FromI128I64.from x).val = x.val := by simp only [FromI128I64.from, IScalar.val, BitVec.signExtend]; simp apply bmod_pow2_eq_of_inBounds' 128 x.val (by scalar_tac) -@[simp, scalar_tac_simps] theorem FromI128I128.from_val_eq (x : I128) : (FromI128I128.from x).val = x.val := by +@[simp, scalar_tac_simps, simp_scalar_simps] +theorem FromI128I128.from_val_eq (x : I128) : (FromI128I128.from x).val = x.val := by simp only [FromI128I128.from, IScalar.val, BitVec.signExtend]; simp @[simp, bvify_simps] theorem FromIsizeI8.from_bv_eq (x : I8) : (FromIsizeI8.from x).bv = x.bv.signExtend _ := by From e6e4bd042cfc0961772375c9cc6c23b0c39a01d6 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 25 Jun 2025 09:56:23 +0100 Subject: [PATCH 23/31] Start cleaning up progress --- backends/lean/Aeneas/Progress/Progress.lean | 248 +++++++++--------- .../lean/Aeneas/Progress/ProgressStar.lean | 11 +- 2 files changed, 133 insertions(+), 126 deletions(-) diff --git a/backends/lean/Aeneas/Progress/Progress.lean b/backends/lean/Aeneas/Progress/Progress.lean index 04c097925..1cef37b14 100644 --- a/backends/lean/Aeneas/Progress/Progress.lean +++ b/backends/lean/Aeneas/Progress/Progress.lean @@ -108,34 +108,48 @@ inductive Result (T U : Type) where | Error : U → Result T U deriving Inhabited -structure ProgressGoals where +structure MainGoal where + goal : MVarId + /-- The post-conditions introduced in the context -/ + posts : Array FVarId + +structure Goals where /-- The preconditions that are left to prove -/ preconditions : Array MVarId - /-- the main goal, if it was not proven -/ - mainGoal : Option MVarId + /-- The main goal, if it was not proven -/ + mainGoal : Option MainGoal deriving Inhabited -structure Stats extends ProgressGoals where +structure Stats extends Goals where usedTheorem : UsedTheorem -structure ProgressWithOutput extends ProgressGoals where - /- The post-conditions introduced in the context -/ - posts : Array FVarId -deriving Inhabited - attribute [progress_post_simps] Std.IScalar.toNat Std.UScalar.ofNat_val_eq Std.IScalar.ofInt_val_eq structure Args where + /-- Should we preserve the monadic equality in the context? + + For instance, if making progress on: `let z ← x + y; ...` + We would introduce an assumption: `h : x + y = ok z` + -/ keep : Option Name + /-- Same as `keep`, but use a special wrapper so that the equality gets pretty printed to: + `[> let z ← x + y <]` + -/ keepPretty : Option Name + /-- Identifiers to use when introducing fresh variables -/ ids : Array (Option Name) + /-- Should we split the conjunctions in the post-condition? -/ splitPost : Bool + /-- Tactic to use to prove preconditions while instantiating meta-variables by + matching those preconditions with the assumptions in the context. -/ assumTac : TacticM Unit + /- Tactic to use to solve the preconditions -/ solvePreconditionTac : TacticM Unit open Result in -def progressWith (args : Args) (fExpr : Expr) (th : Expr) : TacticM (Result ProgressWithOutput MessageData) := do +def progressWith (args : Args) (fExpr : Expr) (th : Expr) : + TacticM Goals := do withTraceNode `Progress (fun _ => pure m!"progressWith") do /- Apply the theorem We try to match the theorem with the goal @@ -189,8 +203,8 @@ def progressWith (args : Args) (fExpr : Expr) (th : Expr) : TacticM (Result Prog We introduce the existentially quantified variables and split the top-most conjunction if there is one. We use the provided `ids` list to name the introduced variables. -/ - let splitExistsEqAndPost (next : Result (Array FVarId) MessageData → TacticM (Result ProgressWithOutput MessageData)) : - TacticM (Result ProgressWithOutput MessageData) := do + let splitExistsEqAndPost (next : Option MainGoal → TacticM Goals) : + TacticM Goals := do splitAllExistsTac thAsm args.ids.toList fun h ids => do /- Introduce the pretty equality if the user requests it. We take care of introducing it *before* splitting the post-conditions, so that those appear @@ -226,8 +240,8 @@ def progressWith (args : Args) (fExpr : Expr) (th : Expr) : TacticM (Result Prog For the conjunctions, we split according once to separate the equality `f ... = .ret ...` from the postcondition, if there is, then continue to split the postcondition if there are remaining ids. -/ - let splitEqAndPost (k : Expr → Option Expr → List (Option Name) → TacticM (Result ProgressWithOutput MessageData)) : - TacticM (Result ProgressWithOutput MessageData) := do + let splitEqAndPost (k : Expr → Option Expr → List (Option Name) → TacticM Goals) : + TacticM Goals := do let hTy ← inferType h if ← isConj hTy then let hName := (← h.fvarId!.getDecl).userName @@ -252,7 +266,7 @@ def progressWith (args : Args) (fExpr : Expr) (th : Expr) : TacticM (Result Prog TODO: not sure this is the best way of checking it -/ if r.isNone then trace[Progress] "The main goal was solved!" - next (Ok #[]) + next none else traceGoalWithNode "goal after applying the eq and simplifying the binds" -- TODO: remove this? (some types get unfolded too much: we "fold" them back) @@ -274,11 +288,11 @@ def progressWith (args : Args) (fExpr : Expr) (th : Expr) : TacticM (Result Prog if ¬ ids.isEmpty then logWarning m!"Too many ids provided ({ids}): there is no postcondition to split" trace[Progress] "No post to split" - next (Ok #[]) + next (some { goal := ← getMainGoal, posts := #[]}) | some hPost => do trace[Progress] "Post to split: {hPost}" let rec splitPostWithIds (prevId : Name) (hPosts : List FVarId) (hPost : Expr) (ids0 : List (Option Name)) : - TacticM (Result (Array FVarId) MessageData) := do + TacticM (Array FVarId) := do match ids0 with | [] => /- We used all the user provided ids. @@ -286,8 +300,8 @@ def progressWith (args : Args) (fExpr : Expr) (th : Expr) : TacticM (Result Prog instructed to fully split the post-condition, otherwise stop -/ if args.splitPost then splitFullConjTac true hPost (λ asms => do - pure (Ok (hPosts.reverse ++ (asms.map (fun x => x.fvarId!))).toArray)) - else pure (Ok (hPost.fvarId! :: hPosts).reverse.toArray) + pure (hPosts.reverse ++ (asms.map (fun x => x.fvarId!))).toArray) + else pure (hPost.fvarId! :: hPosts).reverse.toArray | nid :: ids => do trace[Progress] "Splitting post: {← inferType hPost}" -- Split @@ -300,92 +314,82 @@ def progressWith (args : Args) (fExpr : Expr) (th : Expr) : TacticM (Result Prog splitConjTac hPost (some (prevId, nid)) (λ nhAsm nhPost => splitPostWithIds nid (nhAsm.fvarId! :: hPosts) nhPost ids) else logWarning m!"Too many ids provided ({ids0}) not enough conjuncts to split in the postcondition" - pure (Ok (hPost.fvarId! :: hPosts).reverse.toArray) + pure (hPost.fvarId! :: hPosts).reverse.toArray let curPostId := (← hPost.fvarId!.getDecl).userName - let res ← splitPostWithIds curPostId [] hPost ids - next res - splitExistsEqAndPost fun res => do + let posts ← splitPostWithIds curPostId [] hPost ids + next (some ⟨ ← getMainGoal, posts⟩) + splitExistsEqAndPost fun mainGoal => do trace[Progress] "Finished splitting the post" - match res with - | Error msg => - trace[Progress] "" - return (Error msg) -- Can we get there? We're using `return` - | Ok hPosts => - /- Warning: the main goal may have been solved here, meaning it is unsafe to, - for instance, call `withMainContext`, or to print the `hPosts` - TODO: fix this. - -/ - -- Update the set of goals - let curGoals ← getUnsolvedGoals - trace[Progress] "current goals: {curGoals}" - let newGoals := mvars.map Expr.mvarId! - let newGoals ← newGoals.filterM fun mvar => not <$> mvar.isAssigned - trace[Progress] "new goals: {newGoals}" - -- Split between the goals which are propositions and the others - let (newPropGoals, newNonPropGoals) ← - newGoals.toList.partitionM fun mvar => do isProp (← mvar.getType) - trace[Progress] "Prop goals: {newPropGoals}" - trace[Progress] "Non prop goals: {← newNonPropGoals.mapM fun mvarId => do pure ((← mvarId.getDecl).userName, mvarId)}" - /- Try to solve the goals which are propositions - - We do this in several phases: - - we first use the "assumption" tactic to instantiate as many meta-variables as possible, and we do so by starting with the - preconditions with the highest number of meta-variables (this is a way of avoiding spurious instantiations) - - we then use the other tactic on the preconditions - -/ - let ordPropGoals ← - newPropGoals.mapM (fun g => do - let ty ← g.getType - pure ((← Utils.getMVarIds ty).size, g)) - let ordPropGoals := (ordPropGoals.mergeSort (fun (mvars0, _) (mvars1, _) => mvars0 ≤ mvars1)).reverse - setGoals (ordPropGoals.map Prod.snd) - allGoalsNoRecover (tryTac args.assumTac) - allGoalsNoRecover args.solvePreconditionTac - -- Make sure we use the original order when presenting the preconditions to the user - let newPropGoals ← newPropGoals.filterMapM (fun g => do if ← g.isAssigned then pure none else pure (some g)) - /- Simplify the post-conditions in the main goal - note that we waited until now - because by solving the preconditions we may have instantiated meta-variables. - We also simplify the goal again (to simplify let-bindings, etc.) -/ - setGoals curGoals - let hPosts ← - match curGoals with - | [] => pure (hPosts) - | [ _ ] => - -- Simplify the post-conditions - let args : Simp.SimpArgs := - {simpThms := #[← progressPostSimpExt.getTheorems], - simprocs := #[← ScalarTac.scalarTacSimprocExt.getSimprocs]} - let hPosts ← Simp.simpAt true { maxDischargeDepth := 0, failIfUnchanged := false } - args (.targets hPosts false) - match hPosts with - | none => - -- We actually closed the goal: we shouldn't get there - -- TODO: make this more robust - trace[Progress] "Unexpected: goal closed by simplifying the introduced post-conditions" - throwError "Unexpected: goal closed by simplifying the introduced post-conditions" - | some hPosts => - -- Simplify the goal again - tryTac do - let _ ← Simp.simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} - {simpThms := #[← progressSimpExt.getTheorems], declsToUnfold := #[``pure]} (.targets #[] true) - -- - pure (hPosts) - | _ => throwError "Unexpected number of goals" - let curGoals ← getUnsolvedGoals - trace[Progress] "Main goal after simplifying the post-conditions and the target: {curGoals}" - /- Update the list of goals -/ - let newNonPropGoals ← newNonPropGoals.filterM fun mvar => not <$> mvar.isAssigned - let newGoals := newNonPropGoals ++ newPropGoals - trace[Progress] "Final remaining preconditions: {newGoals}" - setGoals (newGoals ++ curGoals) - trace[Progress] "progress: replaced the goals" - -- - let mainGoal ← do - match curGoals with - | [] => pure none - | [ g ] => pure (some g) - | _ => throwError "Unexpected number of goals" - pure (Ok ⟨ ⟨ newGoals.toArray, mainGoal ⟩, hPosts ⟩) + /- Warning: the main goal may have been solved here, meaning it is unsafe to, + for instance, call `withMainContext`, or to print the `hPosts` + TODO: fix this. + -/ + -- Update the set of goals + let curGoals ← getUnsolvedGoals + trace[Progress] "current goals: {curGoals}" + let newGoals := mvars.map Expr.mvarId! + let newGoals ← newGoals.filterM fun mvar => not <$> mvar.isAssigned + trace[Progress] "new goals: {newGoals}" + -- Split between the goals which are propositions and the others + let (newPropGoals, newNonPropGoals) ← + newGoals.toList.partitionM fun mvar => do isProp (← mvar.getType) + trace[Progress] "Prop goals: {newPropGoals}" + trace[Progress] "Non prop goals: {← newNonPropGoals.mapM fun mvarId => do pure ((← mvarId.getDecl).userName, mvarId)}" + /- Try to solve the goals which are propositions + + We do this in several phases: + - we first use the "assumption" tactic to instantiate as many meta-variables as possible, and we do so by starting with the + preconditions with the highest number of meta-variables (this is a way of avoiding spurious instantiations) + - we then use the other tactic on the preconditions + -/ + let ordPropGoals ← + newPropGoals.mapM (fun g => do + let ty ← g.getType + pure ((← Utils.getMVarIds ty).size, g)) + let ordPropGoals := (ordPropGoals.mergeSort (fun (mvars0, _) (mvars1, _) => mvars0 ≤ mvars1)).reverse + setGoals (ordPropGoals.map Prod.snd) + allGoalsNoRecover (tryTac args.assumTac) + allGoalsNoRecover args.solvePreconditionTac + -- Make sure we use the original order when presenting the preconditions to the user + let newPropGoals ← newPropGoals.filterMapM (fun g => do if ← g.isAssigned then pure none else pure (some g)) + /- Simplify the post-conditions in the main goal - note that we waited until now + because by solving the preconditions we may have instantiated meta-variables. + We also simplify the goal again (to simplify let-bindings, etc.) -/ + let mainGoal : Option MainGoal ← do + match mainGoal with + | none => pure none + | some mainGoal => + setGoals [mainGoal.goal] + -- Simplify the post-conditions + let args : Simp.SimpArgs := + {simpThms := #[← progressPostSimpExt.getTheorems], + simprocs := #[← ScalarTac.scalarTacSimprocExt.getSimprocs]} + let posts ← Simp.simpAt true { maxDischargeDepth := 0, failIfUnchanged := false } + args (.targets mainGoal.posts false) + match posts with + | none => + -- We actually closed the goal: we shouldn't get there + -- TODO: make this more robust + trace[Progress] "Goal closed by simplifying the introduced post-conditions" + pure none + | some posts => + -- Simplify the goal again + let r ← Simp.simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} + {simpThms := #[← progressSimpExt.getTheorems], declsToUnfold := #[``pure]} (.targets #[] true) + if r.isSome then + pure (some ({ goal := ← getMainGoal, posts} : MainGoal)) + else pure none + trace[Progress] "Main goal after simplifying the post-conditions and the target: {curGoals}" + /- Update the list of goals - TODO: move this elsewhere -/ + let newGoals ← (newNonPropGoals ++ newPropGoals).filterM fun mvar => not <$> mvar.isAssigned + trace[Progress] "Final remaining preconditions: {newGoals}" + let curGoal := + match mainGoal with + | none => [] + | some goal => [goal.goal] + setGoals (newGoals ++ curGoal) + trace[Progress] "replaced the goals" + pure ({ preconditions := newGoals.toArray, mainGoal }) /-- Small utility: if `args` is not empty, return the name of the app in the first arg, if it is a const. -/ @@ -403,7 +407,8 @@ def getFirstArg (args : Array Expr) : Option Expr := do /-- Helper: try to apply a theorem. Return the list of post-conditions we introduced if it succeeded. -/ -def tryApply (args : Args) (fExpr : Expr) (kind : String) (th : Option Expr) : TacticM (Option ProgressWithOutput) := do +def tryApply (args : Args) (fExpr : Expr) (kind : String) (th : Option Expr) : + TacticM (Option Goals) := do let res ← do match th with | none => @@ -418,14 +423,14 @@ def tryApply (args : Args) (fExpr : Expr) (kind : String) (th : Option Expr) : T pure (some res) catch _ => pure none match res with - | some (.Ok res) => pure (some res) - | some (.Error msg) => throwError msg + | some res => pure (some res) | none => pure none /-- Try to progress with an assumption. Return `some` if we succeed, `none` otherwise. -/ -def tryAssumptions (args : Args) (fExpr : Expr) : TacticM (Option (ProgressGoals × UsedTheorem)) := do +def tryAssumptions (args : Args) (fExpr : Expr) : + TacticM (Option (Goals × UsedTheorem)) := do withTraceNode `Progress (fun _ => pure m!"tryAssumptions") do run where run := @@ -434,15 +439,14 @@ where let decls ← ctx.getAssumptions for decl in decls.reverse do trace[Progress] "Trying assumption: {decl.userName} : {decl.type}" - let res ← do try progressWith args fExpr decl.toExpr catch _ => continue - match res with - | .Ok res => return (some (res.toProgressGoals, .localHyp decl)) - | .Error msg => throwError msg + try + let goal ← progressWith args fExpr decl.toExpr + return (some (goal, .localHyp decl)) + catch _ => continue pure none --- The array of ids are identifiers to use when introducing fresh variables def progressAsmsOrLookupTheorem (args : Args) (withTh : Option Expr) : - TacticM (ProgressGoals × UsedTheorem) := do + TacticM (Goals × UsedTheorem) := do withMainContext do -- Retrieve the goal let mgoal ← Tactic.getMainGoal @@ -469,11 +473,8 @@ def progressAsmsOrLookupTheorem (args : Args) (withTh : Option Expr) : -- Otherwise, lookup one. match withTh with | some th => do - match ← progressWith args fExpr th with - | .Ok res => - -- Remark: exprToSyntax doesn't give the expected result - return (res.toProgressGoals, .givenExpr th) - | .Error msg => throwError msg + let goals ← progressWith args fExpr th + return (goals, .givenExpr th) | none => -- Try all the assumptions one by one and if it fails try to lookup a theorem. if let some res ← tryAssumptions args fExpr then return res @@ -500,7 +501,7 @@ def progressAsmsOrLookupTheorem (args : Args) (withTh : Option Expr) : for pspec in pspecs do let pspecExpr ← Term.mkConst pspec match ← tryApply args fExpr "pspec theorem" pspecExpr with - | some res => return (res.toProgressGoals, .progressThm pspec) + | some goals => return (goals, .progressThm pspec) | none => pure () -- It failed: try to use the recursive assumptions trace[Progress] "Failed using a pspec theorem: trying to use a recursive assumption" @@ -509,12 +510,13 @@ def progressAsmsOrLookupTheorem (args : Args) (withTh : Option Expr) : let decls ← ctx.getAllDecls let decls := decls.filter (λ decl => match decl.kind with | .default | .implDetail => false | .auxDecl => true) + -- TODO: introduce a helper for this for decl in decls.reverse do trace[Progress] "Trying recursive assumption: {decl.userName} : {decl.type}" - let res ← do try progressWith args fExpr decl.toExpr catch _ => continue - match res with - | .Ok res => return (res.toProgressGoals, .localHyp decl) - | .Error msg => throwError msg + try + let goals ← progressWith args fExpr decl.toExpr + return (goals, .localHyp decl) + catch _ => continue -- Nothing worked: failed let msg := "Progress failed: could not find a local assumption or a theorem to apply" trace[Progress] msg diff --git a/backends/lean/Aeneas/Progress/ProgressStar.lean b/backends/lean/Aeneas/Progress/ProgressStar.lean index 482c33296..a6ddde4f3 100644 --- a/backends/lean/Aeneas/Progress/ProgressStar.lean +++ b/backends/lean/Aeneas/Progress/ProgressStar.lean @@ -324,7 +324,11 @@ where withTraceNode `Progress (fun _ => pure m!"onBind ({varName})") do if let some {usedTheorem, preconditions, mainGoal } ← tryProgress then withTraceNode `Progress (fun _ => pure m!"progress succeeded") do - trace[Progress] "New goal: {mainGoal}" + match mainGoal with + | none => trace[Progress] "Main goal solved" + | some goal => + withTraceNode `Progress (fun _ => pure m!"New main goal:") do + trace[Progress] "{goal.goal}" trace[Progress] "Unsolved preconditions: {preconditions}" let (preconditionTacs, unsolved) ← handleProgressPreconditions preconditions if ¬ preconditionTacs.isEmpty then @@ -335,8 +339,9 @@ where let ids ← getIdsFromUsedTheorem varName.eraseMacroScopes usedTheorem trace[Progress] "ids from used theorem: {ids}" let mainGoal ← do mainGoal.mapM fun mainGoal => do - if ¬ ids.isEmpty then renameInaccessibles mainGoal ids -- NOTE: Taken from renameI tactic - else pure mainGoal + if ¬ ids.isEmpty then + renameInaccessibles mainGoal.goal ids -- NOTE: Taken from renameI tactic + else pure mainGoal.goal /- Generate the tactic scripts for the preconditions -/ let currTac ← if cfg.prettyPrintedProgress then From 9cf90dec6ceec78688c7165516c0db260ded7c7d Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 25 Jun 2025 10:33:23 +0100 Subject: [PATCH 24/31] Cleanup Progress further --- backends/lean/Aeneas/Progress/Progress.lean | 399 ++++++++++---------- 1 file changed, 205 insertions(+), 194 deletions(-) diff --git a/backends/lean/Aeneas/Progress/Progress.lean b/backends/lean/Aeneas/Progress/Progress.lean index 1cef37b14..8ddc4883f 100644 --- a/backends/lean/Aeneas/Progress/Progress.lean +++ b/backends/lean/Aeneas/Progress/Progress.lean @@ -99,14 +99,6 @@ def getType: UsedTheorem -> MetaM (Option Expr) return none end UsedTheorem -/- Type to propagate the errors of `progressWith`. - We need this because we use the exceptions to backtrack, when trying to - use the assumptions for instance. When there is actually an error we want - to propagate to the user, we return it. -/ -inductive Result (T U : Type) where -| Ok : T → Result T U -| Error : U → Result T U -deriving Inhabited structure MainGoal where goal : MVarId @@ -147,19 +139,20 @@ structure Args where /- Tactic to use to solve the preconditions -/ solvePreconditionTac : TacticM Unit -open Result in -def progressWith (args : Args) (fExpr : Expr) (th : Expr) : - TacticM Goals := do - withTraceNode `Progress (fun _ => pure m!"progressWith") do +/-- Attempt to match a given theorem with the monadic call in the goal, + and introduce the instantiated theorem in the context if it succeeds. + + If the instantiation succeeds, we return the introduced meta-variables as + well as the new assumption corresponding to the instantiated theorem (the + expression is an fvar). We raise an exception otherwise. + -/ +def tryMatch (args : Args) (fExpr : Expr) (th : Expr) : + TacticM (Array MVarId × Expr) := do + withTraceNode `Progress (fun _ => pure m!"tryMatch") do /- Apply the theorem We try to match the theorem with the goal - In order to do so, we introduce meta-variables for all the parameters - (i.e., quantified variables and assumpions), and unify those with the goal. - Remark: we do not introduce meta-variables for the quantified variables - which don't appear in the function arguments (we want to let them - quantified). - We also make sure that all the meta variables which appear in the - function arguments have been instantiated + In order to match the theorem with the goal, we introduce meta-variables for all + the parameters (i.e., quantified variables and assumpions), and unify those with the goal. -/ /- There might be meta-variables in the type if the theorem comes from a local declaration, especially if this declaration was introduced by a tactic -/ @@ -194,154 +187,142 @@ def progressWith (args : Args) (fExpr : Expr) (th : Expr) : -- TODO: actually we might want to let the user insert them in the context let thTy ← normalizeLetBindings thTy trace[Progress] "thTy (after normalizing let-bindings): {thTy}" - Utils.addDeclTac asmName th thTy (asLet := false) fun thAsm => do - let ngoal ← getMainGoal - traceGoalWithNode "current goal" - trace[Progress] "current goal is assigned: {← ngoal.isAssigned}" - /- The assumption should be of the shape: - `∃ x1 ... xn, f args = ... ∧ ...` - We introduce the existentially quantified variables and split the top-most - conjunction if there is one. We use the provided `ids` list to name the - introduced variables. -/ - let splitExistsEqAndPost (next : Option MainGoal → TacticM Goals) : - TacticM Goals := do - splitAllExistsTac thAsm args.ids.toList fun h ids => do - /- Introduce the pretty equality if the user requests it. - We take care of introducing it *before* splitting the post-conditions, so that those appear - after it. - -/ - match args.keepPretty with - | none => pure () - | some name => - trace[Progress] "About to introduce the pretty equality" - let hTy ← inferType h - trace[Progress] "introPrettyEq: h: {hTy}" - let h ← do - if ← isConj hTy then do - mkAppM ``And.left #[h] - else do pure h - /- Do *not* introduce an equality if the return type is `()` -/ - let hTy ← inferType h - hTy.withApp fun _ args => do -- Deconstruct the equality - trace[Progress] "Checking if type is (): after deconstructing the equality: {args}" - args[0]!.withApp fun _ args => do -- Deconstruct the `Result` - trace[Progress] "Checking if type is (): after deconstructing Result: {args}" - let arg0 := args[0]! - if arg0.isConst ∧ (arg0.constName == ``Unit ∨ arg0.constName == ``PUnit) then - trace[Progress] "Not introducing a pretty equality because the output type is `()`" - else - trace[Progress] "h: {← inferType h}" - trace[Progress] "Introducing the \"pretty\" let binding" - let e ← mkAppM ``eq_imp_prettyMonadEq #[h] - Utils.addDeclTac name e (← inferType e) (asLet := false) fun _ => do - traceGoalWithNode "Introduced the \"pretty\" let binding" - - /- Split the conjunctions. - For the conjunctions, we split according once to separate the equality `f ... = .ret ...` - from the postcondition, if there is, then continue to split the postcondition if there - are remaining ids. -/ - let splitEqAndPost (k : Expr → Option Expr → List (Option Name) → TacticM Goals) : - TacticM Goals := do - let hTy ← inferType h - if ← isConj hTy then - let hName := (← h.fvarId!.getDecl).userName - let (optIds, ids) ← do - match ids with - | [] => do pure (some (hName, ← mkFreshAnonPropUserName), []) - | none :: ids => do pure (some (hName, ← mkFreshAnonPropUserName), ids) - | some id :: ids => do pure (some (hName, id), ids) - splitConjTac h optIds (fun hEq hPost => k hEq (some hPost) ids) - else - k h none ids - /- Simplify the target by using the equality and some monad simplifications, - then continue splitting the post-condition -/ - -- TODO: this is dangerous if we want to use a local assumption to make progress. - -- We shouldn't simplify the goal with the equality, then simplify again. - splitEqAndPost fun hEq hPost ids => do - trace[Progress] "eq and post:\n{hEq} : {← inferType hEq}\n{hPost}" - traceGoalWithNode "current goal" - let r ← Simp.simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} - {simpThms := #[← progressSimpExt.getTheorems], hypsToUse := #[hEq.fvarId!]} (.targets #[] true) - /- It may happen that at this point the goal is already solved (though this is rare) - TODO: not sure this is the best way of checking it -/ - if r.isNone then - trace[Progress] "The main goal was solved!" - next none - else - traceGoalWithNode "goal after applying the eq and simplifying the binds" - -- TODO: remove this? (some types get unfolded too much: we "fold" them back) - tryTac (do let _ ← Simp.simpAt true {} {addSimpThms := scalar_eqs} .wildcard_dep) - traceGoalWithNode "goal after folding back scalar types" - -- Clear the equality, unless the user requests not to do so - if args.keep.isSome then pure () - else do - let mgoal ← getMainGoal - let mgoal ← mgoal.tryClear hEq.fvarId! - setGoals [mgoal] - withMainContext do - withTraceNode `Progress (fun _ => pure m!"Unsolved goals") do trace[Progress] "Unsolved goals: {← getUnsolvedGoals}" - traceGoalWithNode "Goal after clearing the equality" - -- Continue splitting following the post following the user's instructions - match hPost with - | none => - -- Sanity check - if ¬ ids.isEmpty then - logWarning m!"Too many ids provided ({ids}): there is no postcondition to split" - trace[Progress] "No post to split" - next (some { goal := ← getMainGoal, posts := #[]}) - | some hPost => do - trace[Progress] "Post to split: {hPost}" - let rec splitPostWithIds (prevId : Name) (hPosts : List FVarId) (hPost : Expr) (ids0 : List (Option Name)) : - TacticM (Array FVarId) := do - match ids0 with - | [] => - /- We used all the user provided ids. - Split the remaining conjunctions by using fresh ids if the user - instructed to fully split the post-condition, otherwise stop -/ - if args.splitPost then - splitFullConjTac true hPost (λ asms => do - pure (hPosts.reverse ++ (asms.map (fun x => x.fvarId!))).toArray) - else pure (hPost.fvarId! :: hPosts).reverse.toArray - | nid :: ids => do - trace[Progress] "Splitting post: {← inferType hPost}" - -- Split - let nid ← do - match nid with - | none => mkFreshAnonPropUserName - | some nid => pure nid - trace[Progress] "\n- prevId: {prevId}\n- nid: {nid}\n- remaining ids: {ids}" - if ← isConj (← inferType hPost) then - splitConjTac hPost (some (prevId, nid)) (λ nhAsm nhPost => splitPostWithIds nid (nhAsm.fvarId! :: hPosts) nhPost ids) - else - logWarning m!"Too many ids provided ({ids0}) not enough conjuncts to split in the postcondition" - pure (hPost.fvarId! :: hPosts).reverse.toArray - let curPostId := (← hPost.fvarId!.getDecl).userName - let posts ← splitPostWithIds curPostId [] hPost ids - next (some ⟨ ← getMainGoal, posts⟩) - splitExistsEqAndPost fun mainGoal => do - trace[Progress] "Finished splitting the post" - /- Warning: the main goal may have been solved here, meaning it is unsafe to, - for instance, call `withMainContext`, or to print the `hPosts` - TODO: fix this. - -/ - -- Update the set of goals - let curGoals ← getUnsolvedGoals - trace[Progress] "current goals: {curGoals}" - let newGoals := mvars.map Expr.mvarId! - let newGoals ← newGoals.filterM fun mvar => not <$> mvar.isAssigned - trace[Progress] "new goals: {newGoals}" - -- Split between the goals which are propositions and the others - let (newPropGoals, newNonPropGoals) ← - newGoals.toList.partitionM fun mvar => do isProp (← mvar.getType) - trace[Progress] "Prop goals: {newPropGoals}" - trace[Progress] "Non prop goals: {← newNonPropGoals.mapM fun mvarId => do pure ((← mvarId.getDecl).userName, mvarId)}" - /- Try to solve the goals which are propositions - - We do this in several phases: - - we first use the "assumption" tactic to instantiate as many meta-variables as possible, and we do so by starting with the - preconditions with the highest number of meta-variables (this is a way of avoiding spurious instantiations) - - we then use the other tactic on the preconditions + Utils.addDeclTac asmName th thTy (asLet := false) fun thAsm => pure (mvars.map Expr.mvarId!, thAsm) + +/-- Under the condition that `thAsm` is of the shape: + `∃ x1 ... xn, f args = ... ∧ ...` + + introduce the existentially quantified variables and split the top-most + conjunction if there is one. We use the provided `ids` list to name the + introduced variables. +-/ +def splitExistsEqAndPost (args : Args) (thAsm : Expr) : + TacticM (Option MainGoal) := do + withTraceNode `Progress (fun _ => pure m!"splitExistsEqAndPost") do + splitAllExistsTac thAsm args.ids.toList fun h ids => do + /- Introduce the pretty equality if the user requests it. + We take care of introducing it *before* splitting the post-conditions, so that those appear + after it. -/ + match args.keepPretty with + | none => pure () + | some name => + trace[Progress] "About to introduce the pretty equality" + let hTy ← inferType h + trace[Progress] "introPrettyEq: h: {hTy}" + let h ← do + if ← isConj hTy then do + mkAppM ``And.left #[h] + else do pure h + /- Do *not* introduce an equality if the return type is `()` -/ + let hTy ← inferType h + hTy.withApp fun _ args => do -- Deconstruct the equality + trace[Progress] "Checking if type is (): after deconstructing the equality: {args}" + args[0]!.withApp fun _ args => do -- Deconstruct the `Result` + trace[Progress] "Checking if type is (): after deconstructing Result: {args}" + let arg0 := args[0]! + if arg0.isConst ∧ (arg0.constName == ``Unit ∨ arg0.constName == ``PUnit) then + trace[Progress] "Not introducing a pretty equality because the output type is `()`" + else + trace[Progress] "h: {← inferType h}" + trace[Progress] "Introducing the \"pretty\" let binding" + let e ← mkAppM ``eq_imp_prettyMonadEq #[h] + Utils.addDeclTac name e (← inferType e) (asLet := false) fun _ => do + traceGoalWithNode "Introduced the \"pretty\" let binding" + + /- Split the conjunctions. + For the conjunctions, we split according once to separate the equality `f ... = .ret ...` + from the postcondition, if there is, then continue to split the postcondition if there + are remaining ids. -/ + let splitEqAndPost (k : Expr → Option Expr → List (Option Name) → TacticM (Option MainGoal)) : + TacticM (Option MainGoal) := do + let hTy ← inferType h + if ← isConj hTy then + let hName := (← h.fvarId!.getDecl).userName + let (optIds, ids) ← do + match ids with + | [] => do pure (some (hName, ← mkFreshAnonPropUserName), []) + | none :: ids => do pure (some (hName, ← mkFreshAnonPropUserName), ids) + | some id :: ids => do pure (some (hName, id), ids) + splitConjTac h optIds (fun hEq hPost => k hEq (some hPost) ids) + else + k h none ids + /- Simplify the target by using the equality and some monad simplifications, + then continue splitting the post-condition -/ + -- TODO: this is dangerous if we want to use a local assumption to make progress. + -- We shouldn't simplify the goal with the equality, then simplify again. + splitEqAndPost fun hEq hPost ids => do + trace[Progress] "eq and post:\n{hEq} : {← inferType hEq}\n{hPost}" + traceGoalWithNode "current goal" + let r ← Simp.simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} + {simpThms := #[← progressSimpExt.getTheorems], hypsToUse := #[hEq.fvarId!]} (.targets #[] true) + /- It may happen that at this point the goal is already solved (though this is rare) + TODO: not sure this is the best way of checking it -/ + if r.isNone then + trace[Progress] "The main goal was solved!" + pure none + else + traceGoalWithNode "goal after applying the eq and simplifying the binds" + -- TODO: remove this? (some types get unfolded too much: we "fold" them back) + tryTac (do let _ ← Simp.simpAt true {} {addSimpThms := scalar_eqs} .wildcard_dep) + traceGoalWithNode "goal after folding back scalar types" + -- Clear the equality, unless the user requests not to do so + if args.keep.isSome then pure () + else do + let mgoal ← getMainGoal + let mgoal ← mgoal.tryClear hEq.fvarId! + setGoals [mgoal] + withMainContext do + withTraceNode `Progress (fun _ => pure m!"Unsolved goals") do trace[Progress] "Unsolved goals: {← getUnsolvedGoals}" + traceGoalWithNode "Goal after clearing the equality" + -- Continue splitting following the post following the user's instructions + match hPost with + | none => + -- Sanity check + if ¬ ids.isEmpty then + logWarning m!"Too many ids provided ({ids}): there is no postcondition to split" + trace[Progress] "No post to split" + pure (some { goal := ← getMainGoal, posts := #[]}) + | some hPost => do + trace[Progress] "Post to split: {hPost}" + let rec splitPostWithIds (prevId : Name) (hPosts : List FVarId) (hPost : Expr) (ids0 : List (Option Name)) : + TacticM (Array FVarId) := do + match ids0 with + | [] => + /- We used all the user provided ids. + Split the remaining conjunctions by using fresh ids if the user + instructed to fully split the post-condition, otherwise stop -/ + if args.splitPost then + splitFullConjTac true hPost (λ asms => do + pure (hPosts.reverse ++ (asms.map (fun x => x.fvarId!))).toArray) + else pure (hPost.fvarId! :: hPosts).reverse.toArray + | nid :: ids => do + trace[Progress] "Splitting post: {← inferType hPost}" + -- Split + let nid ← do + match nid with + | none => mkFreshAnonPropUserName + | some nid => pure nid + trace[Progress] "\n- prevId: {prevId}\n- nid: {nid}\n- remaining ids: {ids}" + if ← isConj (← inferType hPost) then + splitConjTac hPost (some (prevId, nid)) (λ nhAsm nhPost => splitPostWithIds nid (nhAsm.fvarId! :: hPosts) nhPost ids) + else + logWarning m!"Too many ids provided ({ids0}) not enough conjuncts to split in the postcondition" + pure (hPost.fvarId! :: hPosts).reverse.toArray + let curPostId := (← hPost.fvarId!.getDecl).userName + let posts ← splitPostWithIds curPostId [] hPost ids + pure (some ⟨ ← getMainGoal, posts⟩) + +/-- Attempt to solve the preconditions. + + We do this in several phases: + - we first use the "assumption" tactic to instantiate as many meta-variables as possible, + and we do so by starting with the preconditions with the highest number of meta-variables + (this is a way of avoiding spurious instantiations). This helps with the second phase. + - we then use the other tactic on the preconditions + -/ +def trySolvePreconditions (args : Args) (newPropGoals : List MVarId) : TacticM (List MVarId) := do + withTraceNode `Progress (fun _ => pure m!"trySolvePreconditions") do let ordPropGoals ← newPropGoals.mapM (fun g => do let ty ← g.getType @@ -350,39 +331,69 @@ def progressWith (args : Args) (fExpr : Expr) (th : Expr) : setGoals (ordPropGoals.map Prod.snd) allGoalsNoRecover (tryTac args.assumTac) allGoalsNoRecover args.solvePreconditionTac - -- Make sure we use the original order when presenting the preconditions to the user - let newPropGoals ← newPropGoals.filterMapM (fun g => do if ← g.isAssigned then pure none else pure (some g)) + -- Make sure we preserve the order when presenting the preconditions to the user + newPropGoals.filterMapM (fun g => do if ← g.isAssigned then pure none else pure (some g)) + +/-- Post-process the main goal. + + The main thing we do is simplify the post-conditions. -/ +def postprocessMainGoal (mainGoal : Option MainGoal) : TacticM (Option MainGoal) := do + withTraceNode `Progress (fun _ => pure m!"postprocessMainGoal") do + match mainGoal with + | none => pure none + | some mainGoal => + setGoals [mainGoal.goal] + -- Simplify the post-conditions + let args : Simp.SimpArgs := + {simpThms := #[← progressPostSimpExt.getTheorems], + simprocs := #[← ScalarTac.scalarTacSimprocExt.getSimprocs]} + let posts ← Simp.simpAt true { maxDischargeDepth := 0, failIfUnchanged := false } + args (.targets mainGoal.posts false) + match posts with + | none => + -- We actually closed the goal: we shouldn't get there + -- TODO: make this more robust + trace[Progress] "Goal closed by simplifying the introduced post-conditions" + pure none + | some posts => + -- Simplify the goal again + let r ← Simp.simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} + {simpThms := #[← progressSimpExt.getTheorems], declsToUnfold := #[``pure]} (.targets #[] true) + if r.isSome then + pure (some ({ goal := ← getMainGoal, posts} : MainGoal)) + else pure none + +def progressWith (args : Args) (fExpr : Expr) (th : Expr) : + TacticM Goals := do + withTraceNode `Progress (fun _ => pure m!"progressWith") do + -- Attempt to instantiate the theorem and introduce it in the context + let (newGoals, thAsm) ← tryMatch args fExpr th + withMainContext do + traceGoalWithNode "current goal" + -- Destruct the existential quantifiers and split the conjunctions + let mainGoal ← splitExistsEqAndPost args thAsm + -- Split between the goals which are propositions and the others + let newGoals ← newGoals.filterM fun mvar => not <$> mvar.isAssigned + withTraceNode `Progress (fun _ => pure m!"new goals") do trace[Progress] "{newGoals}" + let (newPropGoals, newNonPropGoals) ← + newGoals.toList.partitionM fun mvar => do isProp (← mvar.getType) + withTraceNode `Progress (fun _ => pure m!"prop goals") do trace[Progress] "{newPropGoals}" + withTraceNode `Progress (fun _ => pure m!"non prop goals") do + trace[Progress] "{← newNonPropGoals.mapM fun mvarId => do pure ((← mvarId.getDecl).userName, mvarId)}" + -- Attempt to solve the goals which are propositions + let newPropGoals ← trySolvePreconditions args newPropGoals /- Simplify the post-conditions in the main goal - note that we waited until now because by solving the preconditions we may have instantiated meta-variables. We also simplify the goal again (to simplify let-bindings, etc.) -/ - let mainGoal : Option MainGoal ← do - match mainGoal with - | none => pure none - | some mainGoal => - setGoals [mainGoal.goal] - -- Simplify the post-conditions - let args : Simp.SimpArgs := - {simpThms := #[← progressPostSimpExt.getTheorems], - simprocs := #[← ScalarTac.scalarTacSimprocExt.getSimprocs]} - let posts ← Simp.simpAt true { maxDischargeDepth := 0, failIfUnchanged := false } - args (.targets mainGoal.posts false) - match posts with - | none => - -- We actually closed the goal: we shouldn't get there - -- TODO: make this more robust - trace[Progress] "Goal closed by simplifying the introduced post-conditions" - pure none - | some posts => - -- Simplify the goal again - let r ← Simp.simpAt true { maxDischargeDepth := 1, failIfUnchanged := false} - {simpThms := #[← progressSimpExt.getTheorems], declsToUnfold := #[``pure]} (.targets #[] true) - if r.isSome then - pure (some ({ goal := ← getMainGoal, posts} : MainGoal)) - else pure none - trace[Progress] "Main goal after simplifying the post-conditions and the target: {curGoals}" + let mainGoal ← postprocessMainGoal mainGoal + if let some mainGoal := mainGoal then + withTraceNode `Progress + (fun _ => pure m!"Main goal after simplifying the post-conditions and the target") do + trace[Progress] "{mainGoal.goal}" /- Update the list of goals - TODO: move this elsewhere -/ let newGoals ← (newNonPropGoals ++ newPropGoals).filterM fun mvar => not <$> mvar.isAssigned - trace[Progress] "Final remaining preconditions: {newGoals}" + withTraceNode `Progress (fun _ => pure m!"Final remaining preconditions") do + trace[Progress] "{newGoals}" let curGoal := match mainGoal with | none => [] From ffdb32fb807df3f81b603c608c342bdf4563df48 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 25 Jun 2025 11:34:58 +0100 Subject: [PATCH 25/31] Improve the tracing in saturate, scalarTac and condSimpTac --- backends/lean/Aeneas/Progress/Progress.lean | 3 +- backends/lean/Aeneas/Saturate/Tactic.lean | 13 +- .../lean/Aeneas/ScalarTac/CondSimpTac.lean | 121 ++++++++++++------ backends/lean/Aeneas/ScalarTac/Init.lean | 2 - backends/lean/Aeneas/ScalarTac/ScalarTac.lean | 24 ++-- backends/lean/Aeneas/Utils.lean | 6 + 6 files changed, 117 insertions(+), 52 deletions(-) diff --git a/backends/lean/Aeneas/Progress/Progress.lean b/backends/lean/Aeneas/Progress/Progress.lean index 8ddc4883f..968ed0d6b 100644 --- a/backends/lean/Aeneas/Progress/Progress.lean +++ b/backends/lean/Aeneas/Progress/Progress.lean @@ -23,8 +23,7 @@ example (x y z : Std.U32) (_ : [> let z ← (x + y) <]) : True := by simp theorem eq_imp_prettyMonadEq {α : Type u} {x : Std.Result α} {y : α} (h : x = .ok y) : prettyMonadEq x y := by simp [prettyMonadEq, h] -def traceGoalWithNode (msg : String) : TacticM Unit := do - withTraceNode `Progress (fun _ => do pure msg) do trace[Progress] "{← getMainGoal}" +def traceGoalWithNode (msg : String) : TacticM Unit := Utils.traceGoalWithNode `Progress msg -- TODO: the scalar types annoyingly often get reduced when we use the progress -- tactic. We should find a way of controling reduction. For now we use rewriting diff --git a/backends/lean/Aeneas/Saturate/Tactic.lean b/backends/lean/Aeneas/Saturate/Tactic.lean index 4bdb6a334..d73da2c24 100644 --- a/backends/lean/Aeneas/Saturate/Tactic.lean +++ b/backends/lean/Aeneas/Saturate/Tactic.lean @@ -151,6 +151,7 @@ def mkExprFromPath (path : AsmPath) : MetaM Expr := do def State.insertPartialMatch (state : State) (boundVars : Std.HashSet FVarId) (pmatch : PartialMatch) : MetaM State := do + withTraceNode `Saturate (fun _ => pure m!"insertPartialMatch") do trace[Saturate.insertPartialMatch] "insertPartialMatch: {pmatch}" let mut state := state /- Check if there remains patterns: if no then the match is total -/ @@ -204,6 +205,7 @@ def State.insertPartialMatch (state : State) def checkIfPatDefEq (preprocessThm : Option (Array Expr → Expr → MetaM Unit)) (pat : Expr) (numBinders : Nat) (e : Expr) : MetaM (Option (Array Expr)) := do + withTraceNode `Saturate (fun _ => pure m!"checkIfPatDefEq") do -- Strip the binders, introduce meta-variables at the same time, and match let (mvars, _, pat) ← lambdaMetaTelescope pat (some numBinders) match preprocessThm with | none => pure () | some preprocessThm => preprocessThm mvars pat @@ -237,6 +239,7 @@ def matchExprWithRules (boundVars : Std.HashSet FVarId) (state : State) (e : Expr) : MetaM State := do + withTraceNode `Saturate (fun _ => pure m!"matchExprWithRules") do let mut state := state for rules in state.rules do let exprs ← rules.rules.getMatch e @@ -298,6 +301,7 @@ def matchExprWithPartialMatches (boundVars : Std.HashSet FVarId) (state : State) (e : Expr) : MetaM State := do + withTraceNode `Saturate (fun _ => pure m!"matchExprWithPartialMatches") do let mut state := state let exprs ← state.pmatches.getMatch e trace[Saturate.explore] "Potential matches: {exprs}" @@ -351,6 +355,7 @@ def matchExpr (boundVars : Std.HashSet FVarId) (state : State) (e : Expr) : MetaM State := do + withTraceNode `Saturate (fun _ => pure m!"matchExpr") do trace[Saturate.explore] "Matching: {e}" /- First check if the expression contains bound vars: if it does, then we don't match it -/ if ¬ boundVars.isEmpty then @@ -382,6 +387,7 @@ private partial def visit (boundVars : Std.HashSet FVarId) (state : State) (e : Expr) : MetaM State := do + withTraceNode `Saturate (fun _ => pure m!"visit") do let e := e.consumeMData -- trace[Saturate.explore] "Visiting {e}" @@ -552,6 +558,7 @@ partial def evalSaturateCore (exploreTarget : Bool := true) : TacticM (State × Array FVarId) := do + withTraceNode `Saturate (fun _ => pure m!"evalSaturateCore") do withMainContext do trace[Saturate] "Exploring goal: {← getMainGoal}" -- Explore @@ -598,6 +605,7 @@ partial def evalSaturateCore trace[Saturate] "Finished exploring the goal. Matched:\n{state.matched.toList}" let addAssumptions (state : State) (allFVars : Array FVarId) : TacticM (Array FVarId × Array FVarId × Std.HashMap Expr AsmPath) := do + withTraceNode `Saturate (fun _ => pure m!"addAssumptions") do withMainContext do let matched := state.matched.toArray let mut assumptions : Std.HashMap Expr AsmPath := state.assumptions @@ -627,6 +635,7 @@ partial def evalSaturateCore (assumptions : Std.HashMap Expr AsmPath) : TacticM (State × Array FVarId × Array FVarId × Std.HashMap Expr AsmPath) := do + withTraceNode `Saturate (fun _ => pure m!"saturateExtra") do withMainContext do trace[Saturate] "state.pmatches (num of partial matches: {state.pmatches.size}):\n{state.pmatches.toArray.map Prod.snd}" trace[Saturate] "state.assumptions: {state.assumptions.toArray}" @@ -657,7 +666,7 @@ partial def evalSaturateCore (state, allFVars, newFVars, assumptions) ← saturateExtra state allFVars newFVars assumptions withMainContext do - trace[Saturate] "Introduced the assumptions in the context" + trace[Saturate] "Finished saturating" -- Display the diagnostics information trace[Saturate.diagnostics] "Saturate diagnostics info: {state.diagnostics.toArray}" @@ -671,6 +680,7 @@ def recomputeAssumptions (declsToExplore : Array FVarId) : TacticM State := do + withTraceNode `Saturate (fun _ => pure m!"recomputeAssumptions") do withMainContext do trace[Saturate] "Exploring goal: {← getMainGoal}" let ignore := state.assumptions.fold (fun ignore asm _ => ignore.insert asm) state.ignore @@ -704,6 +714,7 @@ partial def evalSaturate {α} (next : Array FVarId → TacticM α) : TacticM α := do + withTraceNode `Saturate (fun _ => pure m!"evalSaturate") do -- Retrieve the rule sets let env ← getEnv let s := satAttr.map fun s => s.ext.getState env diff --git a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean index 8230cc85a..acd3d7652 100644 --- a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean +++ b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean @@ -16,7 +16,9 @@ structure CondSimpPartialArgs where addSimpThms : Array Name := #[] hypsToUse : Array FVarId := #[] -def condSimpParseArgs (tacName : String) (args : TSyntaxArray [`term, `token.«*»]) : TacticM CondSimpPartialArgs := do +def condSimpParseArgs (tacName : String) (args : TSyntaxArray [`term, `token.«*»]) : + TacticM CondSimpPartialArgs := do + withTraceNode `ScalarTac (fun _ => pure m!"condSimpParseArgs") do let mut declsToUnfold := #[] let mut addSimpThms := #[] let mut hypsToUse := #[] @@ -27,14 +29,14 @@ def condSimpParseArgs (tacName : String) (args : TSyntaxArray [`term, `token.«* | `($stx:ident) => do match (← getLCtx).findFromUserName? stx.getId with | .some decl => - trace[CondSimpTac] "arg (local decl): {stx.raw}" + trace[ScalarTac] "arg (local decl): {stx.raw}" if decl.isLet then declsToUnfold := declsToUnfold.push decl.userName else hypsToUse := hypsToUse.push decl.fvarId | .none => -- Not a local declaration - trace[CondSimpTac] "arg (theorem): {stx.raw}" + trace[ScalarTac] "arg (theorem): {stx.raw}" let some e ← Lean.Elab.Term.resolveId? stx (withInfo := true) | throwError m!"Could not find theorem: {arg}" if let .const name _ := e then @@ -50,17 +52,17 @@ def condSimpParseArgs (tacName : String) (args : TSyntaxArray [`term, `token.«* | _ => throwError m!"{name} is not a theorem, an axiom or a definition" else throwError m!"Unexpected: {arg}" | term => do - trace[CondSimpTac] "term kind: {term.raw.getKind}" + trace[ScalarTac] "term kind: {term.raw.getKind}" if term.raw.getKind == `token.«*» then - trace[CondSimpTac] "found token: *" + trace[ScalarTac] "found token: *" let decls ← (← getLCtx).getDecls let decls ← decls.filterMapM ( fun d => do if ← isProp d.type then pure (some d.fvarId) else pure none) - trace[CondSimpTac] "filtered decls: {decls.map Expr.fvar}" + trace[ScalarTac] "filtered decls: {decls.map Expr.fvar}" hypsToUse := hypsToUse.append decls.toArray else -- TODO: we need to make that work - trace[CondSimpTac] "arg (term): {term}" + trace[ScalarTac] "arg (term): {term}" throwError m!"Unimplemented: arbitrary terms are not supported yet as arguments to `{tacName}` (received: {arg})" pure ⟨ declsToUnfold, addSimpThms, hypsToUse ⟩ @@ -88,12 +90,13 @@ def condSimpTacSimp (config : Simp.Config) (args : CondSimpArgs) (loc : Utils.Lo (toClear : Array FVarId := #[]) (additionalHypsToUse : Array FVarId := #[]) (state : Option (ScalarTac.State × Array FVarId)) : TacticM (Option (Array FVarId)) := do + withTraceNode `ScalarTac (fun _ => pure m!"condSimpTacSimp") do withMainContext do let simpArgs := args.toSimpArgs let simpArgs := { simpArgs with hypsToUse := simpArgs.hypsToUse ++ additionalHypsToUse } match state with | some (state, asms) => - trace[CondSimpTac] "condSimpTacSimp: scalarTac assumptions: {asms.map Expr.fvar}" + trace[ScalarTac] "scalarTac assumptions: {asms.map Expr.fvar}" /- Note that when calling `scalar_tac` we saturate only by looking at the target: we have already saturated by looking at the assumptions (we do this once and for all beforehand) -/ let dischargeWrapper ← Simp.tacticToDischarge (incrScalarTac {saturateAssumptions := false} state toClear asms) @@ -105,35 +108,43 @@ def condSimpTacSimp (config : Simp.Config) (args : CondSimpArgs) (loc : Utils.Lo | none => Simp.simpAt true config simpArgs loc -/-- A helper to define tactics which perform conditional simplifications with `scalar_tac` as a discharger. -/ -def condSimpTac - (tacName : String) (config : CondSimpTacConfig) - (simpConfig : Simp.Config) (hypsArgs args : CondSimpArgs) - (addSimpThms : TacticM (Array FVarId)) (doFirstSimp : Bool) - (loc : Utils.Location) : TacticM Unit := do - Elab.Tactic.focus do +structure PreprocessResult where + args : CondSimpArgs + toClear : Array FVarId + hypsToUse : Array FVarId + state : State + oldAsms : Array FVarId + newAsms : Array FVarId + additionalSimpThms : Array FVarId + +/-- Preprocess the goal. + Return `none` if the preprocessing actually solves the goal. +-/ +def condSimpTacPreprocess (config : CondSimpTacConfig) (hypsArgs args : CondSimpArgs) + (addSimpThms : TacticM (Array FVarId)) : TacticM (Option PreprocessResult) := do + withTraceNode `ScalarTac (fun _ => pure m!"condSimpTacPreprocess") do withMainContext do /- Concatenate the arguments for the conditional rewritings: we only want to restrict the simplifications performed on the higher-order hypotheses, but we need all the simplifications performed to those to be applied to the rest of the context as well. -/ let args := args ++ hypsArgs - trace[CondSimpTac] "Initial goal: {← getMainGoal}" + traceGoalWithNode `ScalarTac "Initial goal" /- First duplicate the propositions in the context: we need the preprocessing of `scalar_tac` to modify the assumptions, but we need to preserve a copy so that we can present a clean state to the user later (and pretend nothing happened). Note that we do this in two times: we want to treat the simp theorems provided by the user in `args` separately from the other assumptions. -/ let allAssumptions ← pure (← (← getLCtx).getAssumptions).toArray - trace[CondSimpTac] "allAssumptions: {allAssumptions.map fun d => Expr.fvar d.fvarId}" + trace[ScalarTac] "allAssumptions: {allAssumptions.map fun d => Expr.fvar d.fvarId}" let (_, hypsToUse) ← Utils.duplicateAssumptions (some args.hypsToUse) withMainContext do - trace[CondSimpTac] "Goal after duplicating the hyps to use: {← getMainGoal}" - trace[CondSimpTac] "hypsToUse: {hypsToUse.map Expr.fvar}" + traceGoalWithNode `ScalarTac "Goal after duplicating the hyps to use" + trace[ScalarTac] "hypsToUse: {hypsToUse.map Expr.fvar}" /- -/ let (oldAsms, newAsms) ← Utils.duplicateAssumptions (some (allAssumptions.map LocalDecl.fvarId)) let toClear := oldAsms withMainContext do - trace[CondSimpTac] "Goal after duplicating the assumptions: {← getMainGoal}" - trace[CondSimpTac] "newAsms: {newAsms.map Expr.fvar}" + traceGoalWithNode `ScalarTac "Goal after duplicating the assumptions" + trace[ScalarTac] "newAsms: {newAsms.map Expr.fvar}" /- Preprocess the assumptions -/ let scalarConfig : ScalarTac.Config := { nonLin := config.nonLin, saturationPasses := config.saturationPasses } let state ← State.new scalarConfig @@ -141,27 +152,40 @@ def condSimpTac Note that we do not inline the local let-declarations: we will do this only for the "regular" assumptions and the target. -/ let some (_, hypsToUse) ← scalarTacPartialPreprocess scalarConfig hypsArgs.toSimpArgs state (zetaDelta := false) #[] hypsToUse false - | trace[CondSimpTac] "Goal proven through preprocessing!"; return + | trace[ScalarTac] "Goal proved through preprocessing!"; return none withMainContext do - trace[CondSimpTac] "Goal after preprocessing the hyps to use ({hypsToUse.map Expr.fvar}): {← getMainGoal}" + withTraceNode `ScalarTac (fun _ => pure m!"Goal after preprocessing the hyps to use ({hypsToUse.map Expr.fvar})") do + trace[ScalarTac] "{← getMainGoal}" /- Remove the `forall'` and simplify the hyps to use -/ let simpHypsToUseArgs := { hypsArgs with hypsToUse := #[], declsToUnfold := #[``forall'] } let some hypsToUse ← Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 0} simpHypsToUseArgs (.targets hypsToUse false) - | trace[ScalarTac] "Goal proven by preprocessing!"; return + | trace[ScalarTac] "Goal proved by preprocessing!"; return none let args := { args with hypsToUse } withMainContext do - trace[CondSimpTac] "Goal after simplifying the preprocessed hyps to use ({hypsToUse.map Expr.fvar}): {← getMainGoal}" + withTraceNode `ScalarTac (fun _ => pure m!"Goal after simplifying the preprocessed hyps to use ({hypsToUse.map Expr.fvar})") do + trace[ScalarTac] "{← getMainGoal}" /- Preprocess the "regular" assumptions -/ let some (state, newAsms) ← scalarTacPartialPreprocess scalarConfig (← ScalarTac.getSimpArgs) state (zetaDelta := true) #[] newAsms false - | trace[CondSimpTac] "Goal proven through preprocessing!"; return + | trace[ScalarTac] "Goal proved through preprocessing!"; return none withMainContext do - trace[CondSimpTac] "Goal after the initial preprocessing: {← getMainGoal}" - trace[CondSimpTac] "newAsms: {newAsms.map Expr.fvar}" + traceGoalWithNode `ScalarTac "Goal after the initial preprocessing" + trace[ScalarTac] "newAsms: {newAsms.map Expr.fvar}" /- Introduce the additional simp theorems -/ let additionalSimpThms ← addSimpThms - trace[CondSimpTac] "Goal after adding the additional simp assumptions: {← getMainGoal}" + traceGoalWithNode `ScalarTac "Goal after adding the additional simp assumptions" + pure (some { args, toClear, hypsToUse, state, oldAsms, newAsms, additionalSimpThms }) + +def condSimpTacCore + (tacName : String) + (simpConfig : Simp.Config) + (doFirstSimp : Bool) + (loc : Utils.Location) + (res : PreprocessResult) : TacticM (Option MVarId) := do + withTraceNode `ScalarTac (fun _ => pure m!"CondSimpTacCore") do + let { args, toClear, hypsToUse := _, state, oldAsms, newAsms, additionalSimpThms } := res /- Simplify the targets (note that we preserve the new assumptions for `scalar_tac`) -/ + withMainContext do let loc ← do match loc with | .wildcard => pure (Utils.Location.targets oldAsms true) @@ -170,26 +194,47 @@ def condSimpTac let nloc ← if doFirstSimp then match ← condSimpTacSimp simpConfig args loc (toClear := toClear) (additionalHypsToUse := additionalSimpThms) none with - | none => return + | none => trace[ScalarTac] "Goal proved through preprocessing!"; return none | some freshFvarIds => match loc with | .wildcard => pure (Utils.Location.targets freshFvarIds true) | .wildcard_dep => throwError "{tacName} does not support using location `Utils.Location.wildcard_dep`" | .targets _ type => pure (Utils.Location.targets freshFvarIds type) else pure loc - trace[CondSimpTac] "Goal after simplifying: {← getMainGoal}" + traceGoalWithNode `ScalarTac "Goal after simplifying" /- Simplify the targets by using `scalar_tac` as a discharger. TODO: scalar_tac should only be allowed to preprocess `scalarTacAsms`. TODO: we should preprocess those. -/ let _ ← condSimpTacSimp simpConfig args nloc (toClear := toClear) (additionalHypsToUse := additionalSimpThms) (some (state, newAsms)) - if (← getUnsolvedGoals) == [] then return + if (← getUnsolvedGoals) == [] then pure none + else pure (some (← getMainGoal)) + +def condSimpTacClear (res : PreprocessResult) : TacticM Unit := do + withTraceNode `ScalarTac (fun _ => pure m!"CondSimpTacClear") do + setGoals [← (← getMainGoal).tryClearMany res.hypsToUse] + traceGoalWithNode `ScalarTac "Goal after clearing the duplicated hypotheses to use" + setGoals [← (← getMainGoal).tryClearMany res.newAsms] + traceGoalWithNode `ScalarTac "Goal after clearing the duplicated assumptions" + setGoals [← (← getMainGoal).tryClearMany res.additionalSimpThms] + traceGoalWithNode `ScalarTac "Goal after clearing the additional theorems" + +/-- A helper to define tactics which perform conditional simplifications with `scalar_tac` as a discharger. -/ +def condSimpTac + (tacName : String) (config : CondSimpTacConfig) + (simpConfig : Simp.Config) (hypsArgs args : CondSimpArgs) + (addSimpThms : TacticM (Array FVarId)) (doFirstSimp : Bool) + (loc : Utils.Location) : TacticM Unit := do + withTraceNode `ScalarTac (fun _ => pure m!"CondSimpTac") do + Elab.Tactic.focus do + /- Preprocess -/ + let some res ← + condSimpTacPreprocess config hypsArgs args addSimpThms + | trace[ScalarTac] "Goal proved through preprocessing!"; return -- goal proved + /- Simplify the targets (note that we preserve the new assumptions for `scalar_tac`) -/ + let some _ ← condSimpTacCore tacName simpConfig doFirstSimp loc res + | trace[ScalarTac] "Goal proved!"; return -- goal proved /- Clear the additional assumptions -/ - setGoals [← (← getMainGoal).tryClearMany hypsToUse] - trace[CondSimpTac] "Goal after clearing the duplicated hypotheses to use: {← getMainGoal}" - setGoals [← (← getMainGoal).tryClearMany newAsms] - trace[CondSimpTac] "Goal after clearing the duplicated assumptions: {← getMainGoal}" - setGoals [← (← getMainGoal).tryClearMany additionalSimpThms] - trace[CondSimpTac] "Goal after clearing the additional theorems: {← getMainGoal}" + condSimpTacClear res end Aeneas.ScalarTac diff --git a/backends/lean/Aeneas/ScalarTac/Init.lean b/backends/lean/Aeneas/ScalarTac/Init.lean index 9bb6cafe4..a50bdcd17 100644 --- a/backends/lean/Aeneas/ScalarTac/Init.lean +++ b/backends/lean/Aeneas/ScalarTac/Init.lean @@ -7,8 +7,6 @@ namespace Aeneas.ScalarTac # Tracing -/ -initialize registerTraceClass `CondSimpTac - -- We can't define and use trace classes in the same file initialize registerTraceClass `ScalarTac diff --git a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean index f9b78894a..6b6129939 100644 --- a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean @@ -148,6 +148,7 @@ def scalarTacSaturateForward {α} (satState : Option Saturate.State) (declsToExplore : Option (Array FVarId)) (f : Saturate.State → Array FVarId → TacticM α) : TacticM α := do + withTraceNode `ScalarTac (fun _ => pure m!"scalarTacSaturateForward") do withMainContext do /- We always use the rule set `Aeneas.ScalarTac`, but also need to add other rule sets locally activated by the user. The `Aeneas.ScalarTacNonLin` rule set has a special treatment as @@ -214,6 +215,7 @@ def simpAsmsTarget (simpOnly : Bool) (config : Simp.Config) (args : Simp.SimpArg /- Boosting a bit the `omega` tac. -/ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do + withTraceNode `ScalarTac (fun _ => pure m!"scalarTacPreprocess") do Tactic.withMainContext do -- Pre-preprocessing /- We simplify a first time before saturating the context. @@ -221,7 +223,7 @@ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do for the saturation phase, and it also often allows to get rid of some dependently typed expressions such as `UScalar.ofNat`. -/ - trace[ScalarTac] "Original goal before preprocessing: {← getMainGoal}" + traceGoalWithNode `ScalarTac "Original goal" let simpArgs : Simp.SimpArgs ← getSimpArgs let r ← simpAsmsTarget true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} -- Remove the forall quantifiers to prepare for the call of `simp_all` (we @@ -231,11 +233,11 @@ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do if r.isNone then trace[ScalarTac] "Goal proven by preprocessing!" return - trace[ScalarTac] "Goal after first simplification: {← getMainGoal}" + traceGoalWithNode `ScalarTac "Goal after first simplification" -- Apply the forward rules if config.saturate then scalarTacSaturateForward config.toSaturateConfig none none (fun _ _ => pure ()) - trace[ScalarTac] "Goal after saturation: {← getMainGoal}" + traceGoalWithNode `ScalarTac "Goal after saturation" -- Apply `simpAll` if config.simpAllMaxSteps ≠ 0 then tryTac do @@ -250,28 +252,28 @@ def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do if (← getGoals).isEmpty then trace[ScalarTac] "Goal proven by preprocessing!" return - trace[ScalarTac] "Goal after simpAll: {← getMainGoal}" + traceGoalWithNode `ScalarTac "Goal after simpAll" -- Call `simp` again, this time to inline the let-bindings (otherwise, omega doesn't always manage to deal with them) let _ ← Simp.simpAt true {zetaDelta := true, failIfUnchanged := false, maxDischargeDepth := 1} simpArgs .wildcard -- We might have proven the goal if (← getGoals).isEmpty then trace[ScalarTac] "Goal proven by preprocessing!" return - trace[ScalarTac] "Goal after 2nd simp (with zetaDelta): {← getMainGoal}" + traceGoalWithNode `ScalarTac "Goal after 2nd simp (with zetaDelta)" -- Apply normCast let _ ← Utils.normCastAt .wildcard -- We might have proven the goal if (← getGoals).isEmpty then trace[ScalarTac] "Goal proven by preprocessing!" return - trace[ScalarTac] "Goal after normCast: {← getMainGoal}" + traceGoalWithNode `ScalarTac "Goal after normCast" -- Call `simp` again because `normCast` sometimes introduces strange terms let _ ← Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 1} simpArgs .wildcard -- We might have proven the goal if (← getGoals).isEmpty then trace[ScalarTac] "Goal proven by preprocessing!" return - trace[ScalarTac] "Goal after 2nd call to simpAt: {← getMainGoal}" + traceGoalWithNode `ScalarTac "Goal after 2nd call to simpAt" structure State where saturateState : Saturate.State @@ -288,6 +290,7 @@ def State.new (config : Config) : MetaM State := do def scalarTacPartialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (state : State) (zetaDelta : Bool) (hypsToUseForSimp assumptionsToPreprocess : Array FVarId) (simpTarget : Bool) : Tactic.TacticM (Option (State × Array FVarId)) := do + withTraceNode `ScalarTac (fun _ => pure m!"scalarTacPartialPreprocess") do Tactic.focus do Tactic.withMainContext do -- Pre-preprocessing @@ -296,7 +299,7 @@ def scalarTacPartialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (sta for the saturation phase, and it also often allows to get rid of some dependently typed expressions such as `UScalar.ofNat`. -/ - trace[ScalarTac] "Original goal before preprocessing: {← getMainGoal}" + traceGoalWithNode `ScalarTac "Original goal before preprocessing" let r ← Simp.simpAt true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} /- Remove the forall quantifiers to prepare for the call of `simp_all` (we don't want `simp_all` to use assumptions of the shape `∀ x, P x`)) -/ @@ -398,6 +401,7 @@ elab "scalar_tac_preprocess" config:Parser.Tactic.optConfig : tactic => do scalarTacPreprocess config def scalarTacCore (config : Config) : Tactic.TacticM Unit := do + withTraceNode `ScalarTac (fun _ => pure m!"scalarTacCore") do Tactic.withMainContext do Tactic.focus do let simpArgs : Simp.SimpArgs ← getSimpArgs @@ -439,6 +443,7 @@ def scalarTacCore (config : Config) : Tactic.TacticM Unit := do This works also with strict inequalities. -/ def scalarTac (config : Config) : TacticM Unit := do + withTraceNode `ScalarTac (fun _ => pure m!"scalarTac") do Tactic.withMainContext do let error : TacticM Unit := do let g ← Tactic.getMainGoal @@ -477,6 +482,7 @@ elab "scalar_tac" config:Parser.Tactic.optConfig : tactic => do TODO: do we really need the config? -/ def incrScalarTac (config : Config) (state : State) (toClear : Array FVarId) (assumptions : Array FVarId) : TacticM Unit := do + withTraceNode `ScalarTac (fun _ => pure m!"incrScalarTac") do Tactic.focus do Tactic.withMainContext do /- Clear the useless assumptions -/ @@ -485,7 +491,7 @@ def incrScalarTac (config : Config) (state : State) (toClear : Array FVarId) (as /- Saturate by exploring only the goal -/ let some (_, _) ← scalarTacPartialPreprocess config (← ScalarTac.getSimpArgs) state (zetaDelta := true) assumptions #[] true | trace[ScalarTac] "incrScalarTac: goal proven by preprocessing" - trace[ScalarTac] "Goal after final preprocessing: {← getMainGoal}" + trace[ScalarTac] "Goal after preprocessing: {← getMainGoal}" /- Call omega -/ trace[ScalarTac] "Calling omega" Tactic.Omega.omegaTactic {} diff --git a/backends/lean/Aeneas/Utils.lean b/backends/lean/Aeneas/Utils.lean index fde30eb4b..f3b659bfe 100644 --- a/backends/lean/Aeneas/Utils.lean +++ b/backends/lean/Aeneas/Utils.lean @@ -1671,6 +1671,12 @@ def optElabTerm (e : Option (TSyntax `term)) : TacticM (Option Expr) := do | none => pure none | some e => pure (some (← Lean.Elab.Tactic.elabTerm e none)) +def traceGoalWithNode (cls : Name) (msg : String) : TacticM Unit := do + withTraceNode cls (fun _ => do pure msg) do + if ← isTracingEnabledFor cls then do + addTrace cls m!"{← getMainGoal}" + else pure () + end Utils end Aeneas From 45d01cd11081cd1388be85c2ee6308c4de018154 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 25 Jun 2025 11:54:50 +0100 Subject: [PATCH 26/31] Make prettyMonadEq a Type rather than a Prop --- backends/lean/Aeneas/Progress/Progress.lean | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/backends/lean/Aeneas/Progress/Progress.lean b/backends/lean/Aeneas/Progress/Progress.lean index 968ed0d6b..c25635089 100644 --- a/backends/lean/Aeneas/Progress/Progress.lean +++ b/backends/lean/Aeneas/Progress/Progress.lean @@ -12,7 +12,7 @@ open Lean Elab Term Meta Tactic open Utils /-- A special definition that we use to introduce pretty-printed terms in the context -/ -def prettyMonadEq {α : Type u} (x : Std.Result α) (y : α) : Prop := x = .ok y +@[irreducible] def prettyMonadEq {α : Type u} (_ : Std.Result α) (_ : α) : Type := Unit macro:max "[> " "let" y:term " ← " x:term " <]" : term => `(prettyMonadEq $x $y) @@ -21,7 +21,9 @@ def unexpPrettyMonadEqofNat : Lean.PrettyPrinter.Unexpander | `($_ $x $y) => `([ example (x y z : Std.U32) (_ : [> let z ← (x + y) <]) : True := by simp -theorem eq_imp_prettyMonadEq {α : Type u} {x : Std.Result α} {y : α} (h : x = .ok y) : prettyMonadEq x y := by simp [prettyMonadEq, h] +def eq_imp_prettyMonadEq {α : Type u} {x : Std.Result α} {y : α} (_ : x = .ok y) : prettyMonadEq x y := by + unfold prettyMonadEq + constructor def traceGoalWithNode (msg : String) : TacticM Unit := Utils.traceGoalWithNode `Progress msg From 0f9f90e94dfb83263eee49b1014af969cf092dcd Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 25 Jun 2025 11:55:04 +0100 Subject: [PATCH 27/31] Update the rule timed-lean in the Makefile --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index b92264fc7..59ca9f861 100644 --- a/Makefile +++ b/Makefile @@ -178,7 +178,7 @@ test-%: build-dev # Replay the Lean tests and time them .PHONY: timed-lean timed-lean: - cd tests/lean && find . -type f -iname "*.lean" -not -path "./.lake/*" -exec printf "\n{}\n" \; -exec lake env time lean {} \; >& timing.out + cd tests/lean && find . -type f -iname "*.lean" -not -path "./.lake/*" -exec printf "\n{}\n" \; -exec lake env time lean {} \; # ============================================================================= # Nix From fca500a7fcad74dbc8287b401e184358cde2a41d Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 27 Jun 2025 13:18:53 +0200 Subject: [PATCH 28/31] Add a state to progress --- backends/lean/Aeneas/Progress/Progress.lean | 215 ++++++++++++++---- .../lean/Aeneas/Progress/ProgressStar.lean | 3 +- .../lean/Aeneas/ScalarTac/CondSimpTac.lean | 4 +- backends/lean/Aeneas/ScalarTac/ScalarTac.lean | 10 +- backends/lean/Aeneas/Utils.lean | 11 +- 5 files changed, 192 insertions(+), 51 deletions(-) diff --git a/backends/lean/Aeneas/Progress/Progress.lean b/backends/lean/Aeneas/Progress/Progress.lean index c25635089..5ddd13816 100644 --- a/backends/lean/Aeneas/Progress/Progress.lean +++ b/backends/lean/Aeneas/Progress/Progress.lean @@ -101,16 +101,38 @@ def getType: UsedTheorem -> MetaM (Option Expr) end UsedTheorem +/-- The `progress` state, which is optionally used by `progress`, allows to + factor out computations between several calls to `progress`. It is useful + for instance for `progress*`. +-/ +structure State where + scalarTacState : ScalarTac.State + /-- Assumptions in the context which are only used by `scalar_tac` -/ + scalarTacAsms : Array FVarId + /-- Assumptions which should be visible to the user (those are the declarations + which were not introduced by `scalar_tac`) -/ + userAsms : Array FVarId + /-- Assumptions that we could try to use with `progress` -/ + progressAsms : Array LocalDecl + /-- The declarations introduced for the recursive calls -/ + recDecls : Array LocalDecl + +/-- The configuration to use for `scalar_tac` -/ +def scalarTacConfig : ScalarTac.Config := { split := false } + structure MainGoal where goal : MVarId /-- The post-conditions introduced in the context -/ posts : Array FVarId +structure MainGoalWithState extends MainGoal where + state : Option State + structure Goals where /-- The preconditions that are left to prove -/ preconditions : Array MVarId /-- The main goal, if it was not proven -/ - mainGoal : Option MainGoal + mainGoal : Option MainGoalWithState deriving Inhabited structure Stats extends Goals where @@ -338,7 +360,8 @@ def trySolvePreconditions (args : Args) (newPropGoals : List MVarId) : TacticM ( /-- Post-process the main goal. The main thing we do is simplify the post-conditions. -/ -def postprocessMainGoal (mainGoal : Option MainGoal) : TacticM (Option MainGoal) := do +def postprocessMainGoal (mainGoal : Option MainGoal) : + TacticM (Option MainGoal) := do withTraceNode `Progress (fun _ => pure m!"postprocessMainGoal") do match mainGoal with | none => pure none @@ -364,7 +387,46 @@ def postprocessMainGoal (mainGoal : Option MainGoal) : TacticM (Option MainGoal) pure (some ({ goal := ← getMainGoal, posts} : MainGoal)) else pure none -def progressWith (args : Args) (fExpr : Expr) (th : Expr) : +/-- Check if an assumption could be used as the theorem given to `progress` -/ +def checkAssumptionIsUsableWithProgress (fvarId : FVarId) : TacticM Bool := do + try + withProgressSpec (isGoal := false) (.fvar fvarId) fun _ => pure true + catch _ => pure false + +/-- We assume that if there is a `mainGoal`, it is the current goal -/ +def updateState (mainGoal : Option MainGoal) (state : Option State) : TacticM (Option MainGoalWithState) := do + withTraceNode `Progress (fun _ => pure m!"updateState") do + /- We update the state only if there is still a main goal -/ + let some mainGoal := mainGoal + | trace[Progress] "No main goal"; return none + let some state := state + | trace[Progress] "No state"; return (some { toMainGoal := mainGoal, state := none }) + /- Add the post-conditions to the set of user declarations -/ + let { scalarTacState, scalarTacAsms, userAsms, progressAsms, recDecls } := state + let userAsms := userAsms ++ mainGoal.posts + /- Filter the post-conditions which could be used as candidates for `progress` -/ + let progressAsms' ← mainGoal.posts.filterM checkAssumptionIsUsableWithProgress + let progressAsms' ← liftM (progressAsms'.mapM FVarId.getDecl) + let progressAsms := progressAsms ++ progressAsms' + /- Duplicate the post-conditions -/ + let (_, newPosts) ← duplicateAssumptions mainGoal.posts + /- Preprocess -/ + let some (scalarTacState, newPosts) ← + ScalarTac.partialPreprocess scalarTacConfig (← ScalarTac.getSimpArgs) + scalarTacState (zetaDelta := true) + (hypsToUseForSimp := state.scalarTacAsms) + (assumptionsToPreprocess := newPosts) + (simpTarget := false) + | trace[Progress] "Goal proven by preprocessing!"; return none + let scalarTacAsms := scalarTacAsms ++ newPosts + let mainGoal := { mainGoal with goal := ← getMainGoal } + let mainGoal : MainGoalWithState := { + toMainGoal := mainGoal, + state := some { scalarTacState, scalarTacAsms, userAsms, progressAsms, recDecls }, + } + pure (some mainGoal) + +def progressWith (args : Args) (state : Option State) (fExpr : Expr) (th : Expr) : TacticM Goals := do withTraceNode `Progress (fun _ => pure m!"progressWith") do -- Attempt to instantiate the theorem and introduce it in the context @@ -391,6 +453,8 @@ def progressWith (args : Args) (fExpr : Expr) (th : Expr) : withTraceNode `Progress (fun _ => pure m!"Main goal after simplifying the post-conditions and the target") do trace[Progress] "{mainGoal.goal}" + /- Update the state with the new assumptions -/ + let mainGoal ← updateState mainGoal state /- Update the list of goals - TODO: move this elsewhere -/ let newGoals ← (newNonPropGoals ++ newPropGoals).filterM fun mvar => not <$> mvar.isAssigned withTraceNode `Progress (fun _ => pure m!"Final remaining preconditions") do @@ -401,7 +465,7 @@ def progressWith (args : Args) (fExpr : Expr) (th : Expr) : | some goal => [goal.goal] setGoals (newGoals ++ curGoal) trace[Progress] "replaced the goals" - pure ({ preconditions := newGoals.toArray, mainGoal }) + pure { preconditions := newGoals.toArray, mainGoal } /-- Small utility: if `args` is not empty, return the name of the app in the first arg, if it is a const. -/ @@ -419,7 +483,7 @@ def getFirstArg (args : Array Expr) : Option Expr := do /-- Helper: try to apply a theorem. Return the list of post-conditions we introduced if it succeeded. -/ -def tryApply (args : Args) (fExpr : Expr) (kind : String) (th : Option Expr) : +def tryApply (args : Args) (state : Option State) (fExpr : Expr) (kind : String) (th : Option Expr) : TacticM (Option Goals) := do let res ← do match th with @@ -431,7 +495,7 @@ def tryApply (args : Args) (fExpr : Expr) (kind : String) (th : Option Expr) : -- Apply the theorem let res ← do try - let res ← progressWith args fExpr th + let res ← progressWith args state fExpr th pure (some res) catch _ => pure none match res with @@ -441,23 +505,57 @@ def tryApply (args : Args) (fExpr : Expr) (kind : String) (th : Option Expr) : /-- Try to progress with an assumption. Return `some` if we succeed, `none` otherwise. -/ -def tryAssumptions (args : Args) (fExpr : Expr) : +def tryAssumptions (args : Args) (state : Option State) (fExpr : Expr) : TacticM (Option (Goals × UsedTheorem)) := do withTraceNode `Progress (fun _ => pure m!"tryAssumptions") do run where run := withMainContext do - let ctx ← Lean.MonadLCtx.getLCtx - let decls ← ctx.getAssumptions + /- If there is a state: try the assumptions in the state, otherwise try all the assumptions + (in case there is a state, it means there are a lot of additional assumptions introduced + for `scalar_tac`, so we definitely do not want to try all of those). + -/ + let decls ← do match state with + | none => + let ctx ← Lean.MonadLCtx.getLCtx + pure (← ctx.getAssumptions).toArray + | some state => pure state.progressAsms for decl in decls.reverse do trace[Progress] "Trying assumption: {decl.userName} : {decl.type}" try - let goal ← progressWith args fExpr decl.toExpr + let goal ← progressWith args state fExpr decl.toExpr return (some (goal, .localHyp decl)) catch _ => continue pure none -def progressAsmsOrLookupTheorem (args : Args) (withTh : Option Expr) : +def getRecDecls : TacticM (Array LocalDecl) := do + let ctx ← Lean.MonadLCtx.getLCtx + let decls ← ctx.getAllDecls + pure (decls.filter (λ decl => match decl.kind with + | .default | .implDetail => false | .auxDecl => true)).toArray + +/-- Attempt to use a recursive assumption -/ +def progressWithRecDecl (args : Args) (state : Option State) (fExpr : Expr) : + TacticM (Goals × UsedTheorem) := do + withTraceNode `Progress (fun _ => pure m!"progressWithRecDecl") do + -- We try to apply the assumptions of kind "auxDecl" + let decls ← do + match state with + | none => getRecDecls + | some state => pure state.recDecls + -- TODO: introduce a helper for this + for decl in decls.reverse do + trace[Progress] "Trying recursive assumption: {decl.userName} : {decl.type}" + try + let goals ← progressWith args state fExpr decl.toExpr + return (goals, .localHyp decl) + catch _ => continue + -- Nothing worked: failed + let msg := "Progress failed: could not find a local assumption or a theorem to apply" + trace[Progress] msg + throwError msg + +def progressAsmsOrLookupTheorem (args : Args) (state : Option State) (withTh : Option Expr) : TacticM (Goals × UsedTheorem) := do withMainContext do -- Retrieve the goal @@ -485,11 +583,11 @@ def progressAsmsOrLookupTheorem (args : Args) (withTh : Option Expr) : -- Otherwise, lookup one. match withTh with | some th => do - let goals ← progressWith args fExpr th + let goals ← progressWith args state fExpr th return (goals, .givenExpr th) | none => -- Try all the assumptions one by one and if it fails try to lookup a theorem. - if let some res ← tryAssumptions args fExpr then return res + if let some res ← tryAssumptions args state fExpr then return res /- It failed: lookup the pspec theorems which match the expression *only if the function is a constant* -/ let fIsConst ← do @@ -512,27 +610,12 @@ def progressAsmsOrLookupTheorem (args : Args) (withTh : Option Expr) : -- Try the theorems one by one for pspec in pspecs do let pspecExpr ← Term.mkConst pspec - match ← tryApply args fExpr "pspec theorem" pspecExpr with + match ← tryApply args state fExpr "pspec theorem" pspecExpr with | some goals => return (goals, .progressThm pspec) | none => pure () -- It failed: try to use the recursive assumptions trace[Progress] "Failed using a pspec theorem: trying to use a recursive assumption" - -- We try to apply the assumptions of kind "auxDecl" - let ctx ← Lean.MonadLCtx.getLCtx - let decls ← ctx.getAllDecls - let decls := decls.filter (λ decl => match decl.kind with - | .default | .implDetail => false | .auxDecl => true) - -- TODO: introduce a helper for this - for decl in decls.reverse do - trace[Progress] "Trying recursive assumption: {decl.userName} : {decl.type}" - try - let goals ← progressWith args fExpr decl.toExpr - return (goals, .localHyp decl) - catch _ => continue - -- Nothing worked: failed - let msg := "Progress failed: could not find a local assumption or a theorem to apply" - trace[Progress] msg - throwError msg + progressWithRecDecl args state fExpr syntax progressArgs := ("keep" binderIdent)? ("with" term)? ("as" " ⟨ " binderIdent,* " ⟩")? ("by" tacticSeq)? @@ -575,7 +658,45 @@ def parseProgressArgs return (keep?, withTh?, ids, byTac) | _ => throwUnsupportedSyntax -def evalProgress (keep keepPretty : Option Name) (withArg: Option Expr) (ids: Array (Option Name)) +/-- Initialize the `progress` state -/ +def State.init : TacticM (Option State) := do + withTraceNode `Progress (fun _ => pure m!"State.init") do + /- Retrieve the recursive declarations -/ + let recDecls ← getRecDecls + /- Duplicate and preprocess the assumptions -/ + let (oldAsms, newAsms) ← duplicateAssumptions + let scalarTacState ← ScalarTac.State.new scalarTacConfig + let some (scalarTacState, newAsms) ← + ScalarTac.partialPreprocess scalarTacConfig (← ScalarTac.getSimpArgs) + scalarTacState (zetaDelta := true) + (hypsToUseForSimp := #[]) + (assumptionsToPreprocess := newAsms) + (simpTarget := false) + | trace[Progress] "Goal proven through preprocessing!"; return none + /- Compute the list of assumptions which can be used as progress theorems -/ + let progressAsms ← liftM ((← oldAsms.filterM checkAssumptionIsUsableWithProgress).mapM FVarId.getDecl) + /- -/ + pure (some { + scalarTacState, + scalarTacAsms := newAsms, + userAsms := oldAsms, + progressAsms, + recDecls, + }) + +/-- Cleanup the context to remove all the auxiliary assumptions introduced for the progress state + and which should not be shown to the user -/ +def State.clean (state : State) : TacticM Unit := do + setGoals [← (← getMainGoal).tryClearMany state.scalarTacAsms] + +def cleanState (state : Option State) : TacticM Unit := do + match state with + | none => pure () + | some state => state.clean + +def evalProgress + (state : Option State) + (keep keepPretty : Option Name) (withArg: Option Expr) (ids: Array (Option Name)) (byTac : Option Syntax.Tactic) : TacticM Stats := do withTraceNode `Progress (fun _ => pure m!"evalProgress") do @@ -586,16 +707,27 @@ def evalProgress (keep keepPretty : Option Name) (withArg: Option Expr) (ids: Ar withMainContext do let splitPost := true /- Preprocessing step for `singleAssumptionTac` -/ - let singleAssumptionTacDtree ← singleAssumptionTacPreprocess + let singleAssumptionTacDtree ← do + let decls := + match state with + | none => none + | some state => some state.userAsms + singleAssumptionTacPreprocess decls /- For scalarTac we have a fast track: if the goal is not a linear arithmetic goal, we skip (note that otherwise, scalarTac would try to prove a contradiction) -/ let scalarTac : TacticM Unit := do withTraceNode `Progress (fun _ => pure m!"Attempting to solve with `scalarTac`") do if ← ScalarTac.goalIsLinearInt then - /- Also: we don't try to split the goal if it is a conjunction - (it shouldn't be), but we split the disjunctions. -/ - ScalarTac.scalarTac { split := false } + /- Use the precomputed state, if there is -/ + match state with + | none => + /- Also: we don't try to split the goal if it is a conjunction + (it shouldn't be), but we split the disjunctions. -/ + ScalarTac.scalarTac scalarTacConfig + | some state => + ScalarTac.incrScalarTac scalarTacConfig state.scalarTacState (toClear := state.userAsms) + (assumptions := state.scalarTacAsms) else throwError "Not a linear arithmetic goal" let simpLemmas ← Aeneas.ScalarTac.scalarTacSimpExt.getTheorems @@ -603,6 +735,8 @@ def evalProgress (keep keepPretty : Option Name) (withArg: Option Expr) (ids: Ar let simpArgs : Simp.SimpArgs := {simpThms := #[simpLemmas], hypsToUse := localAsms.toArray} let simpTac : TacticM Unit := do withTraceNode `Progress (fun _ => pure m!"Attempting to solve with `simp [*]`") do + -- Cleanup the context + cleanState state -- Simplify the goal let r ← Simp.simpAt false { maxDischargeDepth := 1 } simpArgs (.targets #[] true) -- Raise an error if the goal is not proved @@ -617,6 +751,9 @@ def evalProgress (keep keepPretty : Option Name) (withArg: Option Expr) (ids: Ar | none => [] | some byTac => [ withTraceNode `Progress (fun _ => pure m!"Attempting to solve with the user tactic: `{byTac}`") do + -- Cleanup the context + cleanState state + -- Call the tactic evalTactic byTac] let solvePreconditionTac := withMainContext do @@ -630,17 +767,17 @@ def evalProgress (keep keepPretty : Option Name) (withArg: Option Expr) (ids: Ar let args : Args := { keep, keepPretty, ids, splitPost, assumTac := customAssumTac, solvePreconditionTac } - let (goals, usedTheorem) ← progressAsmsOrLookupTheorem args withArg + let (goals, usedTheorem) ← progressAsmsOrLookupTheorem args state withArg trace[Progress] "Progress done" return ⟨ goals, usedTheorem ⟩ elab (name := progress) "progress" args:progressArgs : tactic => do let (keep?, withArg, ids, byTac) ← parseProgressArgs args - evalProgress keep? none withArg ids byTac *> return () + evalProgress none keep? none withArg ids byTac *> return () elab tk:"progress?" args:progressArgs : tactic => do let (keep?, withArg, ids, byTac) ← parseProgressArgs args - let stats ← evalProgress keep? none withArg ids byTac + let stats ← evalProgress none keep? none withArg ids byTac let mut stxArgs := args.raw if stxArgs[1].isNone then let withArg := mkNullNode #[mkAtom "with", ←stats.usedTheorem.toSyntax] @@ -696,7 +833,7 @@ def parseLetProgress elab tk:letProgress : tactic => do withMainContext do let (withArg, suggest, ids, byTac) ← parseLetProgress tk - let stats ← evalProgress none (some (.str .anonymous "_")) withArg ids byTac + let stats ← evalProgress none none (some (.str .anonymous "_")) withArg ids byTac let mut stxArgs := tk.raw if suggest then trace[Progress] "suggest is true" diff --git a/backends/lean/Aeneas/Progress/ProgressStar.lean b/backends/lean/Aeneas/Progress/ProgressStar.lean index a6ddde4f3..bd436c974 100644 --- a/backends/lean/Aeneas/Progress/ProgressStar.lean +++ b/backends/lean/Aeneas/Progress/ProgressStar.lean @@ -405,7 +405,8 @@ where return (infos, mkStx) tryProgress := do - try some <$> Progress.evalProgress none (some (.str .anonymous "_")) none #[] none + let state := none + try some <$> Progress.evalProgress state none (some (.str .anonymous "_")) none #[] none catch _ => pure none handleProgressPreconditions (preconditions : Array MVarId) : TacticM (Array Syntax.Tactic × Array MVarId) := do diff --git a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean index acd3d7652..735b81070 100644 --- a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean +++ b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean @@ -151,7 +151,7 @@ def condSimpTacPreprocess (config : CondSimpTacConfig) (hypsArgs args : CondSimp /- First the hyps to use. Note that we do not inline the local let-declarations: we will do this only for the "regular" assumptions and the target. -/ - let some (_, hypsToUse) ← scalarTacPartialPreprocess scalarConfig hypsArgs.toSimpArgs state (zetaDelta := false) #[] hypsToUse false + let some (_, hypsToUse) ← partialPreprocess scalarConfig hypsArgs.toSimpArgs state (zetaDelta := false) #[] hypsToUse false | trace[ScalarTac] "Goal proved through preprocessing!"; return none withMainContext do withTraceNode `ScalarTac (fun _ => pure m!"Goal after preprocessing the hyps to use ({hypsToUse.map Expr.fvar})") do @@ -166,7 +166,7 @@ def condSimpTacPreprocess (config : CondSimpTacConfig) (hypsArgs args : CondSimp withTraceNode `ScalarTac (fun _ => pure m!"Goal after simplifying the preprocessed hyps to use ({hypsToUse.map Expr.fvar})") do trace[ScalarTac] "{← getMainGoal}" /- Preprocess the "regular" assumptions -/ - let some (state, newAsms) ← scalarTacPartialPreprocess scalarConfig (← ScalarTac.getSimpArgs) state (zetaDelta := true) #[] newAsms false + let some (state, newAsms) ← partialPreprocess scalarConfig (← ScalarTac.getSimpArgs) state (zetaDelta := true) #[] newAsms false | trace[ScalarTac] "Goal proved through preprocessing!"; return none withMainContext do traceGoalWithNode `ScalarTac "Goal after the initial preprocessing" diff --git a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean index 6b6129939..00de2f4cc 100644 --- a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean @@ -214,7 +214,7 @@ def simpAsmsTarget (simpOnly : Bool) (config : Simp.Config) (args : Simp.SimpArg Aeneas.Simp.simpAt simpOnly config args (.targets props true) /- Boosting a bit the `omega` tac. -/ -def scalarTacPreprocess (config : Config) : Tactic.TacticM Unit := do +def preprocess (config : Config) : Tactic.TacticM Unit := do withTraceNode `ScalarTac (fun _ => pure m!"scalarTacPreprocess") do Tactic.withMainContext do -- Pre-preprocessing @@ -287,7 +287,7 @@ def State.new (config : Config) : MetaM State := do let saturateState := Saturate.State.new rules pure { saturateState } -def scalarTacPartialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (state : State) +def partialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (state : State) (zetaDelta : Bool) (hypsToUseForSimp assumptionsToPreprocess : Array FVarId) (simpTarget : Bool) : Tactic.TacticM (Option (State × Array FVarId)) := do withTraceNode `ScalarTac (fun _ => pure m!"scalarTacPartialPreprocess") do @@ -398,7 +398,7 @@ def scalarTacPartialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (sta elab "scalar_tac_preprocess" config:Parser.Tactic.optConfig : tactic => do let config ← elabConfig config - scalarTacPreprocess config + preprocess config def scalarTacCore (config : Config) : Tactic.TacticM Unit := do withTraceNode `ScalarTac (fun _ => pure m!"scalarTacCore") do @@ -412,7 +412,7 @@ def scalarTacCore (config : Config) : Tactic.TacticM Unit := do Tactic.setGoals [g] -- Preprocess - wondering if we should do this before or after splitting -- the goal. I think before leads to a smaller proof term? - allGoalsNoRecover (scalarTacPreprocess config) + allGoalsNoRecover (preprocess config) allGoalsNoRecover do if config.split then do trace[ScalarTac] "Splitting the goal" @@ -489,7 +489,7 @@ def incrScalarTac (config : Config) (state : State) (toClear : Array FVarId) (as let mvarId ← (← getMainGoal).tryClearMany toClear setGoals [mvarId] /- Saturate by exploring only the goal -/ - let some (_, _) ← scalarTacPartialPreprocess config (← ScalarTac.getSimpArgs) state (zetaDelta := true) assumptions #[] true + let some (_, _) ← partialPreprocess config (← ScalarTac.getSimpArgs) state (zetaDelta := true) assumptions #[] true | trace[ScalarTac] "incrScalarTac: goal proven by preprocessing" trace[ScalarTac] "Goal after preprocessing: {← getMainGoal}" /- Call omega -/ diff --git a/backends/lean/Aeneas/Utils.lean b/backends/lean/Aeneas/Utils.lean index f3b659bfe..5a48ea327 100644 --- a/backends/lean/Aeneas/Utils.lean +++ b/backends/lean/Aeneas/Utils.lean @@ -435,10 +435,13 @@ partial def getMVarIds (e : Expr) (hs : Std.HashSet MVarId := Std.HashSet.emptyW def assumptionTac : TacticM Unit := liftMetaTactic fun mvarId => do mvarId.assumption; pure [] -def filterAssumptionTacPreprocess : TacticM (DiscrTree FVarId) := do +def filterAssumptionTacPreprocess (decls : Option (Array FVarId) := none) : TacticM (DiscrTree FVarId) := do let mut dtree := DiscrTree.empty - for decl in (← (← getLCtx).getDecls) do - dtree ← dtree.insert decl.type decl.fvarId + let decls ← match decls with + | none => pure ((← (← getLCtx).getDecls).toArray.map LocalDecl.fvarId) + | some decls => pure decls + for decl in decls do + dtree ← dtree.insert (← decl.getType) decl pure dtree /-- Return `true` if managed to close goal `mvarId` using an assumption. -/ @@ -503,7 +506,7 @@ def getAllMatchingAssumptions (type : Expr) : MetaM (List (LocalDecl × Name)) : restoreState s pure x -def singleAssumptionTacPreprocess := filterAssumptionTacPreprocess +def singleAssumptionTacPreprocess (decls := none) := filterAssumptionTacPreprocess decls def singleAssumptionTacCore (dtree : DiscrTree FVarId) : TacticM Unit := do withMainContext do From 0fc31ee86683d622431b9f1789b85f784f4399c6 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 27 Jun 2025 13:19:49 +0200 Subject: [PATCH 29/31] Use the state in progress* --- backends/lean/Aeneas/Progress/Progress.lean | 45 ++--- .../lean/Aeneas/Progress/ProgressStar.lean | 156 +++++++++++++----- backends/lean/Aeneas/ScalarTac/ScalarTac.lean | 70 +++++--- backends/lean/Aeneas/Simp/Simp.lean | 3 + .../lean/Aeneas/Simp/SimpAllAssumptions.lean | 1 + backends/lean/Aeneas/Utils.lean | 17 ++ 6 files changed, 201 insertions(+), 91 deletions(-) diff --git a/backends/lean/Aeneas/Progress/Progress.lean b/backends/lean/Aeneas/Progress/Progress.lean index 5ddd13816..b7ab08ae7 100644 --- a/backends/lean/Aeneas/Progress/Progress.lean +++ b/backends/lean/Aeneas/Progress/Progress.lean @@ -393,23 +393,17 @@ def checkAssumptionIsUsableWithProgress (fvarId : FVarId) : TacticM Bool := do withProgressSpec (isGoal := false) (.fvar fvarId) fun _ => pure true catch _ => pure false -/-- We assume that if there is a `mainGoal`, it is the current goal -/ -def updateState (mainGoal : Option MainGoal) (state : Option State) : TacticM (Option MainGoalWithState) := do - withTraceNode `Progress (fun _ => pure m!"updateState") do - /- We update the state only if there is still a main goal -/ - let some mainGoal := mainGoal - | trace[Progress] "No main goal"; return none - let some state := state - | trace[Progress] "No state"; return (some { toMainGoal := mainGoal, state := none }) +def State.update (state : State) (asms : Array FVarId) : TacticM (Option (State × MVarId)) := do + withTraceNode `Progress (fun _ => pure m!"State.update") do /- Add the post-conditions to the set of user declarations -/ let { scalarTacState, scalarTacAsms, userAsms, progressAsms, recDecls } := state - let userAsms := userAsms ++ mainGoal.posts + let userAsms := userAsms ++ asms /- Filter the post-conditions which could be used as candidates for `progress` -/ - let progressAsms' ← mainGoal.posts.filterM checkAssumptionIsUsableWithProgress + let progressAsms' ← asms.filterM checkAssumptionIsUsableWithProgress let progressAsms' ← liftM (progressAsms'.mapM FVarId.getDecl) let progressAsms := progressAsms ++ progressAsms' /- Duplicate the post-conditions -/ - let (_, newPosts) ← duplicateAssumptions mainGoal.posts + let (_, newPosts) ← duplicateAssumptions asms /- Preprocess -/ let some (scalarTacState, newPosts) ← ScalarTac.partialPreprocess scalarTacConfig (← ScalarTac.getSimpArgs) @@ -419,12 +413,21 @@ def updateState (mainGoal : Option MainGoal) (state : Option State) : TacticM (O (simpTarget := false) | trace[Progress] "Goal proven by preprocessing!"; return none let scalarTacAsms := scalarTacAsms ++ newPosts - let mainGoal := { mainGoal with goal := ← getMainGoal } - let mainGoal : MainGoalWithState := { - toMainGoal := mainGoal, - state := some { scalarTacState, scalarTacAsms, userAsms, progressAsms, recDecls }, - } - pure (some mainGoal) + let mainGoal ← getMainGoal + let state := { scalarTacState, scalarTacAsms, userAsms, progressAsms, recDecls } + pure (some (state, mainGoal)) + +def updateState (mainGoal : Option MainGoal) (state : Option State) : TacticM (Option MainGoalWithState) := do + withTraceNode `Progress (fun _ => pure m!"updateState") do + /- We update the state only if there is still a main goal -/ + let some mainGoal := mainGoal + | trace[Progress] "No main goal"; return none + let some state := state + | trace[Progress] "No state"; return (some { toMainGoal := mainGoal, state := none }) + /- Update -/ + let some (state, goal) ← state.update mainGoal.posts + | return none + pure (some { mainGoal with state, goal }) def progressWith (args : Args) (state : Option State) (fExpr : Expr) (th : Expr) : TacticM Goals := do @@ -731,12 +734,12 @@ def evalProgress else throwError "Not a linear arithmetic goal" let simpLemmas ← Aeneas.ScalarTac.scalarTacSimpExt.getTheorems - let localAsms ← pure ((← (← getLCtx).getAssumptions).map LocalDecl.fvarId) - let simpArgs : Simp.SimpArgs := {simpThms := #[simpLemmas], hypsToUse := localAsms.toArray} + let localAsms ← match state with + | none => pure ((← (← getLCtx).getAssumptions).map LocalDecl.fvarId).toArray + | some state => pure state.userAsms + let simpArgs : Simp.SimpArgs := {simpThms := #[simpLemmas], hypsToUse := localAsms} let simpTac : TacticM Unit := do withTraceNode `Progress (fun _ => pure m!"Attempting to solve with `simp [*]`") do - -- Cleanup the context - cleanState state -- Simplify the goal let r ← Simp.simpAt false { maxDischargeDepth := 1 } simpArgs (.targets #[] true) -- Raise an error if the goal is not proved diff --git a/backends/lean/Aeneas/Progress/ProgressStar.lean b/backends/lean/Aeneas/Progress/ProgressStar.lean index bd436c974..e04937b50 100644 --- a/backends/lean/Aeneas/Progress/ProgressStar.lean +++ b/backends/lean/Aeneas/Progress/ProgressStar.lean @@ -143,6 +143,13 @@ structure Info where script: Array Syntax.Tactic := #[] -- TODO: update so that we get a tree unsolvedGoals: Array MVarId := #[] +structure State where + progressState : Progress.State + +structure MainGoal where + state : State + goal : MVarId + instance: Append Info where append inf1 inf2 := { script := inf1.script ++ inf2.script, @@ -157,6 +164,9 @@ inductive TargetKind where | result | unknown +def State.init : TacticM (Option State) := do + pure (Option.map State.mk (← Progress.State.init)) + /- Smaller helper which we use to check in which situation we are -/ def analyzeTarget : TacticM TargetKind := do withTraceNode `Progress (fun _ => do pure m!"analyzeTarget") do @@ -178,11 +188,16 @@ def analyzeTarget : TacticM TargetKind := do partial def evalProgressStar (cfg: Config) : TacticM Info := withMainContext do focus do withTraceNode `Progress (fun _ => do pure m!"evalProgressStar") do - /- Continue -/ + /- Initialize the state -/ + let some state ← State.init + | trace[Progress] "Goal proven while initializing the progress state" + return ⟨ #[ ← `(tactic|scalar_tac)], #[] ⟩ + /- Simplify the target -/ let (info, mvarId) ← simplifyTarget match mvarId with | some _ => - let info' ← traverseProgram cfg + /- Go as far as possible with `progress` -/ + let info' ← traverseProgram cfg state let info := info ++ info' setGoals info.unsolvedGoals.toList pure info @@ -213,23 +228,23 @@ where let goal ← do if r.isSome then pure (some (← getMainGoal)) else pure none pure (info, goal) - traverseProgram (cfg : Config): TacticM Info := do + traverseProgram (cfg : Config) (state : State) : TacticM Info := do withMainContext do withTraceNode `Progress (fun _ => do pure m!"traverseProgram") do traceGoalWithNode "current goal" let targetKind ← analyzeTarget match targetKind with | .bind varName => do - let (info, mainGoal) ← onBind cfg varName + let (info, mainGoal) ← onBind cfg state varName /- Continue, if necessary -/ match mainGoal with | none => -- Stop trace[Progress] "stop" return info - | some mainGoal => - setGoals [mainGoal] - let restInfo ← traverseProgram cfg + | some state => + setGoals [state.goal] + let restInfo ← traverseProgram cfg state.state return info ++ restInfo | .switch bfInfo => do let contsTaggedVals ← @@ -237,23 +252,22 @@ where Utils.lambdaTelescopeN br.toExpr br.numArgs fun xs _ => do let names ← xs.mapM (·.fvarId!.getUserName) return names - let (branchGoals, mkStx) ← onBif cfg bfInfo contsTaggedVals + let (branchGoals, mkStx) ← onBif cfg state bfInfo contsTaggedVals withTraceNode `Progress (fun _ => do pure m!"exploring branches") do /- Continue exploring from the subgoals -/ - let branchInfos ← branchGoals.mapM fun mainGoal => do - setGoals [mainGoal] - let restInfo ← traverseProgram cfg + let branchInfos ← branchGoals.mapM fun state => do + setGoals [state.goal] + let restInfo ← traverseProgram cfg state.state pure restInfo /- Put everything together -/ mkStx branchInfos | .result => do - let (info, mainGoal) ← onResult cfg - pure { info with unsolvedGoals := info.unsolvedGoals ++ mainGoal.toList} + onResult cfg state | .unknown => do trace[Progress] "don't know what to do: inserting a sorry" return ({ script := #[←`(tactic| sorry)], unsolvedGoals := (← getUnsolvedGoals).toArray}) - onResult (cfg : Config) : TacticM (Info × Option MVarId) := do + onResult (cfg : Config) (state : State) : TacticM Info := do withTraceNode `Progress (fun _ => pure m!"onResult") do /- If we encounter `(do f a)` we process it as if it were `(do let res ← f a; return res)` since (id = (· >>= pure)) and when we desugar the do block we have that @@ -264,24 +278,27 @@ where We known in advance the result of processing `return res`, which is to do nothing. This allows us to prevent code duplication with the `onBind` function. -/ - let res ← onBind cfg (.str .anonymous "res") + let res ← onBind cfg state (.str .anonymous "res") match res.snd with | none => trace[Progress] "done" - pure res - | some mvarId => - let (info', mvarId) ← onFinish mvarId - pure (res.fst ++ info', mvarId) + pure res.fst + | some goal => + let info' ← onFinish goal + pure (res.fst ++ info') - onFinish (mvarId : MVarId) : TacticM (Info × Option MVarId) := do + onFinish (goal : MainGoal) : TacticM Info := do withTraceNode `Progress (fun _ => pure m!"onFinish") do - setGoals [mvarId] + setGoals [goal.goal] traceGoalWithNode "goal" /- Simplify a bit -/ let (info, mvarId) ← simplifyTarget match mvarId with - | none => pure (info, mvarId) + | none => pure info | some mvarId => + setGoals [mvarId] + /- Clean the goal to remove the assumptions introduced for `scalar_tac` -/ + goal.state.progressState.clean /- Attempt to finish with a tactic -/ -- `simp [*]` let simpTac : TacticM Syntax.Tactic := do @@ -318,11 +335,11 @@ where | none => tryFinish tacl let tac ← tryFinish [("simp [*]", simpTac), ("scalar_tac", scalarTac)] let info' : Info ← pure { script := #[tac], unsolvedGoals := (← getUnsolvedGoals).toArray} - pure (info ++ info', none) + pure (info ++ info') - onBind (cfg : Config) (varName : Name) : TacticM (Info × Option MVarId) := do + onBind (cfg : Config) (state : State) (varName : Name) : TacticM (Info × Option MainGoal) := do withTraceNode `Progress (fun _ => pure m!"onBind ({varName})") do - if let some {usedTheorem, preconditions, mainGoal } ← tryProgress then + if let some {usedTheorem, preconditions, mainGoal } ← tryProgress state then withTraceNode `Progress (fun _ => pure m!"progress succeeded") do match mainGoal with | none => trace[Progress] "Main goal solved" @@ -330,7 +347,7 @@ where withTraceNode `Progress (fun _ => pure m!"New main goal:") do trace[Progress] "{goal.goal}" trace[Progress] "Unsolved preconditions: {preconditions}" - let (preconditionTacs, unsolved) ← handleProgressPreconditions preconditions + let (preconditionTacs, unsolved) ← handleProgressPreconditions state preconditions if ¬ preconditionTacs.isEmpty then trace[Progress] "Found {preconditionTacs.size} preconditions, left {unsolved.size} unsolved" else @@ -340,8 +357,8 @@ where trace[Progress] "ids from used theorem: {ids}" let mainGoal ← do mainGoal.mapM fun mainGoal => do if ¬ ids.isEmpty then - renameInaccessibles mainGoal.goal ids -- NOTE: Taken from renameI tactic - else pure mainGoal.goal + pure { mainGoal with goal := ← renameInaccessibles mainGoal.goal ids } -- NOTE: Taken from renameI tactic + else pure mainGoal /- Generate the tactic scripts for the preconditions -/ let currTac ← if cfg.prettyPrintedProgress then @@ -356,26 +373,67 @@ where script := #[currTac]++ preconditionTacs, -- TODO: Optimize unsolvedGoals := unsolved, } + let mainGoal ← + match mainGoal with + | none => pure none + | some goal => + let progressState ← do + match goal.state with + | some state => pure state | none => throwError "Unreachable" + pure (some { + state := { state with progressState }, + goal := goal.goal, + }) pure (info, mainGoal) else - onFinish (← getMainGoal) + let info ← onFinish { goal := ← getMainGoal, state } + pure (info, none) - onBif (cfg : Config) (bfInfo : Bifurcation.Info) (toBeProcessed : Array (Array Name)): TacticM (List MVarId × (List Info → TacticM Info)) := do + onBif (cfg : Config) (state : State) (bfInfo : Bifurcation.Info) (toBeProcessed : Array (Array Name)) : + TacticM (List MainGoal × (List Info → TacticM Info)) := do withTraceNode `Progress (fun _ => pure m!"onBif") do trace[Progress] "onBif: encountered {bfInfo.kind}" if (←getGoals).isEmpty then + -- TODO: we shouldn't get there trace[Progress] "onBif: no goals to be solved!" -- Tactic.focus fails if there are no goals to be solved. return ({}, fun infos => assert! (infos.length == 0); pure {}) Tactic.focus do - let splitStx ← `(tactic| split) - evalSplit splitStx - -- - let subgoals ← getUnsolvedGoals + /- Split, then update the state with the new assumptions-/ + let some subgoals ← Utils.splitTarget? (← getMainGoal) + | trace[Progress] "Could not split" + -- TODO: make more robust + throwError "Could not split the goal" + let subgoals : List (Option (State × MVarId)) ← subgoals.mapM fun (mvarId, fvars) => do + setGoals [mvarId] + let { progressState } := state + match ← progressState.update fvars with + | none => pure none + | some (progressState, mvarId) => pure (some ({ progressState }, mvarId)) trace[Progress] "onBif: Bifurcation generated {subgoals.length} subgoals" unless subgoals.length == toBeProcessed.size do throwError "onBif: Expected {toBeProcessed.size} cases, found {subgoals.length}" - let infos_mkBranchesStx ← (subgoals.zip toBeProcessed.toList).mapM fun (sg, names) => do + + let (solved, nonSolved) := (subgoals.zip toBeProcessed.toList).partition (Option.isSome ∘ Prod.fst) + let solvedStx : Info ← pure ⟨ #[← `(tactic|scalar_tac)], #[] ⟩ + let rec split (subgoals : List (Option (State × MVarId) × Array Name)) : + TacticM (List (State × MVarId × Array Name) × (List Info → TacticM (List Info))) := do + match subgoals with + | [] => pure ([], fun _ => pure []) + | (sg, names) :: subgoals => + let (nonSolved, mkStx) ← split subgoals + match sg with + | none => pure (nonSolved, fun infos => do pure (solvedStx :: (← mkStx infos))) + | some (state, sg) => + pure ((state, sg, names) :: nonSolved, + fun infos => + match infos with + | [] => throwError "Unexpected" + | info :: infos => do + pure (info :: (← mkStx infos))) + let (subgoals, groupSolvedUnsolvedStx) ← split (subgoals.zip toBeProcessed.toList) + + let infos_mkBranchesStx ← subgoals.mapM fun (state, sg, names) => do setGoals [sg] -- TODO: rename the variables let mkStx (branchTacs : Array Syntax.Tactic) : TacticM (TSyntax `tactic) := do @@ -388,10 +446,11 @@ where let caseArgs := makeCaseArgs (←sg.getTag) names if cfg.useCase' then `(tactic| case' $caseArgs => $branchTacs*) else `(tactic|. $branchTacs*) - pure (sg, mkStx) + pure (({goal := sg, state} : MainGoal), mkStx) let (infos, mkBranchesStx) := infos_mkBranchesStx.unzip let mkStx (infos : List Info) : TacticM Info := do + -- Process the branches that were left to be solved unless infos.length == mkBranchesStx.length do throwError "onBif: Expected {mkBranchesStx.length} infos, found {infos.length}" let infos := infos.zip mkBranchesStx @@ -399,17 +458,26 @@ where let stx ← mkBranchStx info.script pure ({ unsolvedGoals := info.unsolvedGoals, script := #[stx] } : Info) + -- Add the branches which were already solved + let infos ← groupSolvedUnsolvedStx infos + -- Put everything together let infos := infos.foldr (fun info acc => info ++ acc) {} - pure (({script:=#[splitStx]} : Info) ++ infos) + pure (({script:=#[← `(tactic|split)]} : Info) ++ infos) return (infos, mkStx) - tryProgress := do - let state := none - try some <$> Progress.evalProgress state none (some (.str .anonymous "_")) none #[] none + tryProgress (state : State) := do + try some <$> Progress.evalProgress state.progressState none (some (.str .anonymous "_")) none #[] none catch _ => pure none - handleProgressPreconditions (preconditions : Array MVarId) : TacticM (Array Syntax.Tactic × Array MVarId) := do + handleProgressPreconditions (state : State) (preconditions : Array MVarId) : + TacticM (Array Syntax.Tactic × Array MVarId) := do + -- Cleanup the state + let preconditions ← preconditions.mapM fun g => do + setGoals [g] + state.progressState.clean + getMainGoal + -- if let .some tac := cfg.preconditionTac then let trySolve (sg : MVarId) : TacticM (Syntax.Tactic × Option MVarId) := do setGoals [sg] @@ -572,7 +640,7 @@ x y : U32 h✝ : b = true ⊢ ↑x + ↑y ≤ U32.max -case intro.hmax +case hmax b : Bool x y : U32 h✝ : b = true @@ -581,7 +649,7 @@ _ : [> let x2 ← x + y <] x2_post : ↑x2 = ↑x + ↑y ⊢ ↑x2 + ↑x2 ≤ U32.max -case intro.hmax +case hmax b : Bool x y : U32 h✝ : b = true @@ -599,7 +667,7 @@ x y : U32 h✝ : ¬b = true ⊢ ↑x + ↑y ≤ U32.max -case intro.hmax +case hmax b : Bool x y✝ : U32 h✝ : ¬b = true diff --git a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean index 00de2f4cc..ea392ad86 100644 --- a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean @@ -300,27 +300,31 @@ def partialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (state : Stat typed expressions such as `UScalar.ofNat`. -/ traceGoalWithNode `ScalarTac "Original goal before preprocessing" - let r ← Simp.simpAt true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} - /- Remove the forall quantifiers to prepare for the call of `simp_all` (we - don't want `simp_all` to use assumptions of the shape `∀ x, P x`)) -/ - {simpArgs with addSimpThms := #[``forall_eq_forall']} - -- TODO: it would be good to always simplify the target before exploring it to saturate - (.targets assumptionsToPreprocess simpTarget) + let r ← + withTraceNode `ScalarTac (fun _ => pure m!"simpAt assumptionsToPreprocess simpTarget") do + Simp.simpAt true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} + /- Remove the forall quantifiers to prepare for the call of `simp_all` (we + don't want `simp_all` to use assumptions of the shape `∀ x, P x`)) -/ + {simpArgs with addSimpThms := #[``forall_eq_forall']} + -- TODO: it would be good to always simplify the target before exploring it to saturate + (.targets assumptionsToPreprocess simpTarget) -- We might have proven the goal let some assumptions := r | trace[ScalarTac] "Goal proven by preprocessing!"; return none - trace[ScalarTac] "Goal after first simplification: {← getMainGoal}" + traceGoalWithNode `ScalarTac "Goal after first simplification" -- Apply the forward rules let satConfig := config.toSaturateConfig let satConfig := { satConfig with saturateAssumptions := satConfig.saturateAssumptions && config.saturate, saturateTarget := satConfig.saturateTarget && config.saturate, } - scalarTacSaturateForward satConfig state.saturateState (some assumptions) fun saturateState nassumptions => do + let (saturateState, nassumptions) ← scalarTacSaturateForward satConfig state.saturateState (some assumptions) + fun saturateState nassumptions => pure (saturateState, nassumptions) + withMainContext do let state := { state with saturateState } let finish : Tactic.TacticM (Option (State × Array FVarId)) := do - trace[ScalarTac] "Goal after saturation: {← getMainGoal}" + traceGoalWithNode `ScalarTac "Goal after saturation" -- Apply `simpAll` to the *new assumptions* let applySimpAll assumptions simpTarget := do let r ← tryTactic? do @@ -344,51 +348,65 @@ def partialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (state : Stat let r ← do if config.simpAllMaxSteps ≠ 0 ∧ assumptionsToPreprocess.size + nassumptions.size > 0 then /- Simplify the new assumptions with the old assumptions (it's often enough to go in that direction) -/ - let some nassumptions ← Simp.simpAt true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} - {simpArgs with hypsToUse := hypsToUseForSimp ++ assumptions} (.targets nassumptions false) + let some nassumptions ← + withTraceNode `ScalarTac (fun _ => pure m!"simpAt: nassumptions ") do + Simp.simpAt true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} + {simpArgs with hypsToUse := hypsToUseForSimp ++ assumptions} (.targets nassumptions false) | trace[ScalarTac] "Goal proven by preprocessing!"; return none /- Apply `simpAll` on the new assumptions -/ - let some nassumptions ← applySimpAll nassumptions false + let some nassumptions ← + withTraceNode `ScalarTac (fun _ => pure m!"simpAll: nassumptions ") do + applySimpAll nassumptions false | trace[ScalarTac] "Goal proven by preprocessing!"; return none /- We apply `simpAll` to all the assumptions -/ + withTraceNode `ScalarTac (fun _ => pure m!"simpAll: all assumptions ") do applySimpAll (hypsToUseForSimp ++ assumptions ++ nassumptions) simpTarget else if hypsToUseForSimp.size > 0 && simpTarget then /- Even though there is nothing to preprocess, we want to simplify the goal by using the hypotheses to use, to make sure we propagate equalities for instance -/ - let some _ ← Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 0} - { hypsToUse := hypsToUseForSimp } (.targets #[] simpTarget) + let some _ ← + withTraceNode `ScalarTac (fun _ => pure m!"simpAt (no assumptions to preprocess)") do + Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 0} + { hypsToUse := hypsToUseForSimp } (.targets #[] simpTarget) | trace[ScalarTac] "Goal proven by preprocessing!"; return none pure (some assumptions) else pure (some assumptions) let some assumptions := r | trace[ScalarTac] "Goal proven by preprocessing!"; return none - trace[ScalarTac] "Goal after simpAll: {← getMainGoal}" + traceGoalWithNode `ScalarTac "Goal after simpAll" -- Call `simp` again, this time to inline the let-bindings (otherwise, omega doesn't always manage to deal with them) let r ← do + withTraceNode `ScalarTac (fun _ => pure m!"simpAt") do if zetaDelta then Simp.simpAt true {zetaDelta := true, failIfUnchanged := false, maxDischargeDepth := 1} simpArgs (.targets assumptions simpTarget) else pure assumptions -- We might have proven the goal let some assumptions := r | trace[ScalarTac] "Goal proven by preprocessing!"; return none - trace[ScalarTac] "Goal after 2nd simp (with zetaDelta): {← getMainGoal}" + traceGoalWithNode `ScalarTac "Goal after 2nd simp (with zetaDelta)" -- Apply normCast - let some assumptions ← Utils.normCastAt (.targets assumptions simpTarget) + let some assumptions ← + withTraceNode `ScalarTac (fun _ => pure m!"normCast") do + Utils.normCastAt (.targets assumptions simpTarget) | trace[ScalarTac] "Goal proven by preprocessing!"; return none - trace[ScalarTac] "Goal after normCast: {← getMainGoal}" + traceGoalWithNode `ScalarTac "Goal after normCast" -- Call `simp` again because `normCast` sometimes does weird things - let some assumptions ← Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 1} simpArgs - (.targets assumptions simpTarget) - | trace[ScalarTac] "Goal proven by preprocessing!"; return none - trace[ScalarTac] "Goal after 2nd call to simpAt: {← getMainGoal}" + let some assumptions ← + withTraceNode `ScalarTac (fun _ => pure m!"simpAt") do + Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 1} simpArgs + (.targets assumptions simpTarget) + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + traceGoalWithNode `ScalarTac "Goal after 2nd call to simpAt" /- Remove the occurrences of `forall'` in the target -/ if simpTarget then - let some _ ← Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 0} - { declsToUnfold := #[``forall'] } (.targets #[] simpTarget) - | trace[ScalarTac] "Goal proven by preprocessing!"; return none - trace[ScalarTac] "Goal after eliminating `forall'` in the target: {← getMainGoal}" + let some _ ← + withTraceNode `ScalarTac (fun _ => pure m!"simpAt") do + Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 0} + { declsToUnfold := #[``forall'] } (.targets #[] simpTarget) + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + traceGoalWithNode `ScalarTac "Goal after eliminating `forall'` in the target" -- We modified the assumptions in the context so we need to update the state accordingly let saturateState ← Saturate.recomputeAssumptions state.saturateState none assumptions let state := { state with saturateState } diff --git a/backends/lean/Aeneas/Simp/Simp.lean b/backends/lean/Aeneas/Simp/Simp.lean index a74f6ad5a..174ba3e79 100644 --- a/backends/lean/Aeneas/Simp/Simp.lean +++ b/backends/lean/Aeneas/Simp/Simp.lean @@ -136,6 +136,7 @@ where /- Call the simp tactic. -/ def simpAt (simpOnly : Bool) (config : Simp.Config) (args : SimpArgs) (loc : Utils.Location) : TacticM (Option (Array FVarId)) := do + withTraceNode `Simp (fun _ => pure m!"simpAt") do -- Initialize the simp context let (ctx, simprocs) ← mkSimpCtx simpOnly config .simp args -- Apply the simplifier @@ -146,6 +147,7 @@ def simpAt (simpOnly : Bool) (config : Simp.Config) (args : SimpArgs) (loc : Uti -/ def dsimpAt (simpOnly : Bool) (config : Simp.Config) (args : SimpArgs) (loc : Location) : TacticM Unit := do + withTraceNode `Simp (fun _ => pure m!"dsimpAt") do -- Initialize the simp context let (ctx, simprocs) ← mkSimpCtx simpOnly config .dsimp args -- Apply the simplifier @@ -154,6 +156,7 @@ def dsimpAt (simpOnly : Bool) (config : Simp.Config) (args : SimpArgs) (loc : Lo -- Call the simpAll tactic def simpAll (config : Simp.Config) (simpOnly : Bool) (args : SimpArgs) : TacticM Unit := do + withTraceNode `Simp (fun _ => pure m!"simpAll") do -- Initialize the simp context let (ctx, simprocs) ← mkSimpCtx simpOnly config .simpAll args -- Apply the simplifier diff --git a/backends/lean/Aeneas/Simp/SimpAllAssumptions.lean b/backends/lean/Aeneas/Simp/SimpAllAssumptions.lean index c56789564..901cc3b93 100644 --- a/backends/lean/Aeneas/Simp/SimpAllAssumptions.lean +++ b/backends/lean/Aeneas/Simp/SimpAllAssumptions.lean @@ -189,6 +189,7 @@ open Utils Simp Elab Tactic in def simpAllAssumptions (config : Simp.Config) (simpOnly : Bool) (args : SimpArgs) (mvarId : MVarId) (fvars : Array FVarId) (target : Bool) : MetaM (Option (Array FVarId × Std.HashMap FVarId FVarId × MVarId)) := do + withTraceNode `Simp (fun _ => pure m!"simpAllAssumptions") do -- Initialize the simp context let (ctx, simprocs) ← mkSimpCtx simpOnly config .simpAll args -- Apply the simplifier diff --git a/backends/lean/Aeneas/Utils.lean b/backends/lean/Aeneas/Utils.lean index 5a48ea327..52c811599 100644 --- a/backends/lean/Aeneas/Utils.lean +++ b/backends/lean/Aeneas/Utils.lean @@ -1680,6 +1680,23 @@ def traceGoalWithNode (cls : Name) (msg : String) : TacticM Unit := do addTrace cls m!"{← getMainGoal}" else pure () +def splitTarget? (mvarId : MVarId) (splitIte := true) : MetaM (Option (List (MVarId × Array FVarId))) := do + mvarId.withContext do + -- Split + let some mvars ← Lean.Meta.splitTarget? mvarId splitIte + | return none + -- Compute the set of fvar ids + let decl ← mvarId.getDecl + let decls ← pure ((← decl.lctx.getDecls).map LocalDecl.fvarId) + let declsSet := Std.HashSet.ofList decls + -- Check which fvars are new + let mvars ← mvars.mapM fun mvarId => do + let decl ← mvarId.getDecl + let decls ← pure ((← decl.lctx.getDecls).map LocalDecl.fvarId) + let decls := decls.filter fun d => ¬ (declsSet.contains d) + pure (mvarId, decls.toArray) + pure (some mvars) + end Utils end Aeneas From 44501badf939731fd8ae7468a1c218117cda9981 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 27 Jun 2025 19:31:11 +0200 Subject: [PATCH 30/31] Fix issues with ScalarTac.partialPreprocess --- backends/lean/Aeneas/Progress/Progress.lean | 33 +-- .../lean/Aeneas/Progress/ProgressStar.lean | 4 +- .../lean/Aeneas/ScalarTac/CondSimpTac.lean | 12 +- backends/lean/Aeneas/ScalarTac/ScalarTac.lean | 202 ++++++++++-------- backends/lean/Aeneas/Simp/Simp.lean | 3 + backends/lean/Aeneas/SimpLists/Tests.lean | 38 ++++ backends/lean/Aeneas/Utils.lean | 14 +- 7 files changed, 192 insertions(+), 114 deletions(-) diff --git a/backends/lean/Aeneas/Progress/Progress.lean b/backends/lean/Aeneas/Progress/Progress.lean index b7ab08ae7..2ff20dd57 100644 --- a/backends/lean/Aeneas/Progress/Progress.lean +++ b/backends/lean/Aeneas/Progress/Progress.lean @@ -352,7 +352,13 @@ def trySolvePreconditions (args : Args) (newPropGoals : List MVarId) : TacticM ( pure ((← Utils.getMVarIds ty).size, g)) let ordPropGoals := (ordPropGoals.mergeSort (fun (mvars0, _) (mvars1, _) => mvars0 ≤ mvars1)).reverse setGoals (ordPropGoals.map Prod.snd) - allGoalsNoRecover (tryTac args.assumTac) + allGoalsNoRecover (do + traceGoalWithNode "Attempting to solve with `singleAssumptionTac`" + try + args.assumTac + trace[Progress] "Goal solved!" + catch _ => trace[Progress] "Goal not solved" + ) allGoalsNoRecover args.solvePreconditionTac -- Make sure we preserve the order when presenting the preconditions to the user newPropGoals.filterMapM (fun g => do if ← g.isAssigned then pure none else pure (some g)) @@ -407,7 +413,7 @@ def State.update (state : State) (asms : Array FVarId) : TacticM (Option (State /- Preprocess -/ let some (scalarTacState, newPosts) ← ScalarTac.partialPreprocess scalarTacConfig (← ScalarTac.getSimpArgs) - scalarTacState (zetaDelta := true) + scalarTacState (zetaDelta := true) (simpAllAssumptions := false) (hypsToUseForSimp := state.scalarTacAsms) (assumptionsToPreprocess := newPosts) (simpTarget := false) @@ -567,7 +573,7 @@ def progressAsmsOrLookupTheorem (args : Args) (state : Option State) (withTh : O /- There might be uninstantiated meta-variables in the goal that we need to instantiate (otherwise we will get stuck). -/ let goalTy ← instantiateMVars goalTy - trace[Progress] "progressAsmsOrLookupTheorem: target: {goalTy}" + withTraceNode `Progress (fun _ => pure m!"target") do trace[Progress] "{goalTy}" /- Dive into the goal to lookup the theorem Remark: if we don't isolate the call to `withProgressSpec` to immediately "close" the terms immediately, we may end up with the error: @@ -671,7 +677,7 @@ def State.init : TacticM (Option State) := do let scalarTacState ← ScalarTac.State.new scalarTacConfig let some (scalarTacState, newAsms) ← ScalarTac.partialPreprocess scalarTacConfig (← ScalarTac.getSimpArgs) - scalarTacState (zetaDelta := true) + scalarTacState (zetaDelta := true) (simpAllAssumptions := false) (hypsToUseForSimp := #[]) (assumptionsToPreprocess := newAsms) (simpTarget := false) @@ -710,11 +716,14 @@ def evalProgress withMainContext do let splitPost := true /- Preprocessing step for `singleAssumptionTac` -/ - let singleAssumptionTacDtree ← do - let decls := + let singleAssumptionTacDtree ← + withTraceNode `Progress (fun _ => pure m!"preprocess singleAssumptionTac") do + let decls ← match state with - | none => none - | some state => some state.userAsms + | none => pure none + | some state => + trace[Progress] "Using state decls: {state.userAsms.map Expr.fvar}" + pure (some state.userAsms) singleAssumptionTacPreprocess decls /- For scalarTac we have a fast track: if the goal is not a linear arithmetic goal, we skip (note that otherwise, scalarTac would try @@ -730,7 +739,7 @@ def evalProgress ScalarTac.scalarTac scalarTacConfig | some state => ScalarTac.incrScalarTac scalarTacConfig state.scalarTacState (toClear := state.userAsms) - (assumptions := state.scalarTacAsms) + (assumptions := state.scalarTacAsms) (simpAllAssumptions := false) else throwError "Not a linear arithmetic goal" let simpLemmas ← Aeneas.ScalarTac.scalarTacSimpExt.getTheorems @@ -747,8 +756,8 @@ def evalProgress /- We use our custom assumption tactic, which instantiates meta-variables only if there is a single assumption matching the goal. -/ let customAssumTac : TacticM Unit := do - withTraceNode `Progress (fun _ => pure m!"Attempting to solve with `singleAssumptionTac`") do - singleAssumptionTacCore singleAssumptionTacDtree + let asms := state.map State.userAsms + singleAssumptionTacCore singleAssumptionTacDtree asms /- Also use the tactic provided by the user, if there is -/ let byTac := match byTac with | none => [] @@ -761,7 +770,7 @@ def evalProgress let solvePreconditionTac := withMainContext do withTraceNode `Progress (fun _ => pure m!"Trying to solve a precondition") do - trace[Progress] "Precondition: {← getMainGoal}" + traceGoalWithNode "Precondition" try firstTacSolve ([simpTac, scalarTac] ++ byTac) trace[Progress] "Precondition solved!" diff --git a/backends/lean/Aeneas/Progress/ProgressStar.lean b/backends/lean/Aeneas/Progress/ProgressStar.lean index e04937b50..4fcbdff6a 100644 --- a/backends/lean/Aeneas/Progress/ProgressStar.lean +++ b/backends/lean/Aeneas/Progress/ProgressStar.lean @@ -265,6 +265,7 @@ where onResult cfg state | .unknown => do trace[Progress] "don't know what to do: inserting a sorry" + state.progressState.clean return ({ script := #[←`(tactic| sorry)], unsolvedGoals := (← getUnsolvedGoals).toArray}) onResult (cfg : Config) (state : State) : TacticM Info := do @@ -340,7 +341,8 @@ where onBind (cfg : Config) (state : State) (varName : Name) : TacticM (Info × Option MainGoal) := do withTraceNode `Progress (fun _ => pure m!"onBind ({varName})") do if let some {usedTheorem, preconditions, mainGoal } ← tryProgress state then - withTraceNode `Progress (fun _ => pure m!"progress succeeded") do + trace[Progress] "progressSucceeded with {preconditions.size} unsolved preconditions" + withTraceNode `Progress (fun _ => pure m!"post-processing") do match mainGoal with | none => trace[Progress] "Main goal solved" | some goal => diff --git a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean index 735b81070..3eba296d1 100644 --- a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean +++ b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean @@ -99,7 +99,9 @@ def condSimpTacSimp (config : Simp.Config) (args : CondSimpArgs) (loc : Utils.Lo trace[ScalarTac] "scalarTac assumptions: {asms.map Expr.fvar}" /- Note that when calling `scalar_tac` we saturate only by looking at the target: we have already saturated by looking at the assumptions (we do this once and for all beforehand) -/ - let dischargeWrapper ← Simp.tacticToDischarge (incrScalarTac {saturateAssumptions := false} state toClear asms) + let dischargeWrapper ← + Simp.tacticToDischarge (incrScalarTac {saturateAssumptions := false} state + toClear (simpAllAssumptions := false) asms) dischargeWrapper.with fun discharge? => do -- Initialize the simp context let (ctx, simprocs) ← Simp.mkSimpCtx true config .simp simpArgs @@ -151,7 +153,9 @@ def condSimpTacPreprocess (config : CondSimpTacConfig) (hypsArgs args : CondSimp /- First the hyps to use. Note that we do not inline the local let-declarations: we will do this only for the "regular" assumptions and the target. -/ - let some (_, hypsToUse) ← partialPreprocess scalarConfig hypsArgs.toSimpArgs state (zetaDelta := false) #[] hypsToUse false + let some (_, hypsToUse) ← + partialPreprocess scalarConfig hypsArgs.toSimpArgs state (zetaDelta := false) + (simpAllAssumptions := false) #[] hypsToUse false | trace[ScalarTac] "Goal proved through preprocessing!"; return none withMainContext do withTraceNode `ScalarTac (fun _ => pure m!"Goal after preprocessing the hyps to use ({hypsToUse.map Expr.fvar})") do @@ -166,7 +170,9 @@ def condSimpTacPreprocess (config : CondSimpTacConfig) (hypsArgs args : CondSimp withTraceNode `ScalarTac (fun _ => pure m!"Goal after simplifying the preprocessed hyps to use ({hypsToUse.map Expr.fvar})") do trace[ScalarTac] "{← getMainGoal}" /- Preprocess the "regular" assumptions -/ - let some (state, newAsms) ← partialPreprocess scalarConfig (← ScalarTac.getSimpArgs) state (zetaDelta := true) #[] newAsms false + let some (state, newAsms) ← + partialPreprocess scalarConfig (← ScalarTac.getSimpArgs) state (zetaDelta := true) + (simpAllAssumptions := false) #[] newAsms false | trace[ScalarTac] "Goal proved through preprocessing!"; return none withMainContext do traceGoalWithNode `ScalarTac "Goal after the initial preprocessing" diff --git a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean index ea392ad86..5a1a3b20a 100644 --- a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean @@ -288,7 +288,8 @@ def State.new (config : Config) : MetaM State := do pure { saturateState } def partialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (state : State) - (zetaDelta : Bool) (hypsToUseForSimp assumptionsToPreprocess : Array FVarId) (simpTarget : Bool) : + (zetaDelta : Bool) (simpAllAssumptions : Bool) + (hypsToUseForSimp assumptionsToPreprocess : Array FVarId) (simpTarget : Bool) : Tactic.TacticM (Option (State × Array FVarId)) := do withTraceNode `ScalarTac (fun _ => pure m!"scalarTacPartialPreprocess") do Tactic.focus do @@ -300,7 +301,7 @@ def partialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (state : Stat typed expressions such as `UScalar.ofNat`. -/ traceGoalWithNode `ScalarTac "Original goal before preprocessing" - let r ← + let some assumptions ← withTraceNode `ScalarTac (fun _ => pure m!"simpAt assumptionsToPreprocess simpTarget") do Simp.simpAt true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} /- Remove the forall quantifiers to prepare for the call of `simp_all` (we @@ -308,8 +309,6 @@ def partialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (state : Stat {simpArgs with addSimpThms := #[``forall_eq_forall']} -- TODO: it would be good to always simplify the target before exploring it to saturate (.targets assumptionsToPreprocess simpTarget) - -- We might have proven the goal - let some assumptions := r | trace[ScalarTac] "Goal proven by preprocessing!"; return none traceGoalWithNode `ScalarTac "Goal after first simplification" -- Apply the forward rules @@ -318,101 +317,116 @@ def partialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (state : Stat saturateAssumptions := satConfig.saturateAssumptions && config.saturate, saturateTarget := satConfig.saturateTarget && config.saturate, } - let (saturateState, nassumptions) ← scalarTacSaturateForward satConfig state.saturateState (some assumptions) + let (saturateState, satAssumptions) ← scalarTacSaturateForward satConfig state.saturateState (some assumptions) fun saturateState nassumptions => pure (saturateState, nassumptions) withMainContext do let state := { state with saturateState } - - let finish : Tactic.TacticM (Option (State × Array FVarId)) := do - traceGoalWithNode `ScalarTac "Goal after saturation" - -- Apply `simpAll` to the *new assumptions* - let applySimpAll assumptions simpTarget := do - let r ← tryTactic? do - /- By setting the maxDischargeDepth at 0, we make sure that assumptions of the shape `∀ x, P x → ...` - will not have any effect. This is important because it often happens that the user instantiates - one such assumptions with specific arguments, meaning that if we call `simpAll` naively, those - instantiations will get simplified to `True` and thus eliminated. -/ - Simp.evalSimpAllAssumptions - {failIfUnchanged := false, maxSteps := config.simpAllMaxSteps, maxDischargeDepth := 0} true - simpArgs assumptions simpTarget + traceGoalWithNode `ScalarTac "Goal after saturation" + /- Simplify the new assumptions with the old assumptions (it's often enough to go in that direction) -/ + let some assumptions ← + withTraceNode `ScalarTac (fun _ => pure m!"simpAt") do + Simp.simpAt true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} + {simpArgs with hypsToUse := hypsToUseForSimp} (.targets assumptions false) + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + traceGoalWithNode `ScalarTac "Goal after simplifying the new assumptions with the old" + /- Simplify the assumptions introduced through saturation with the old and new assumptions -/ + let some satAssumptions ← do + withTraceNode `ScalarTac (fun _ => pure m!"simpAt") do + Simp.simpAt true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} + {simpArgs with hypsToUse := hypsToUseForSimp ++ assumptions} (.targets satAssumptions false) + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + traceGoalWithNode `ScalarTac "Goal after simplifying satAssumptions" + /- Small helper -/ + let applySimpAll assumptions simpTarget := do + let r ← tryTactic? do + /- By setting the maxDischargeDepth at 0, we make sure that assumptions of the shape `∀ x, P x → ...` + will not have any effect. This is important because it often happens that the user instantiates + one such assumptions with specific arguments, meaning that if we call `simpAll` naively, those + instantiations will get simplified to `True` and thus eliminated. -/ + Simp.evalSimpAllAssumptions + {failIfUnchanged := false, maxSteps := config.simpAllMaxSteps, maxDischargeDepth := 0} true + simpArgs assumptions simpTarget + match r with + | some r => + trace[ScalarTac] "applySimpAll succeeded" match r with - | some r => - trace[ScalarTac] "applySimpAll succeeded" - match r with - | some (fvars, _) => pure (some fvars) - | none => pure none -- Goal proven through `simpAll` - | none => - trace[ScalarTac] "applySimpAll failed" - -- `simpAll` failed: let's just continue with the same state as before - pure (some assumptions) - let r ← do - if config.simpAllMaxSteps ≠ 0 ∧ assumptionsToPreprocess.size + nassumptions.size > 0 then - /- Simplify the new assumptions with the old assumptions (it's often enough to go in that direction) -/ - let some nassumptions ← - withTraceNode `ScalarTac (fun _ => pure m!"simpAt: nassumptions ") do - Simp.simpAt true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} - {simpArgs with hypsToUse := hypsToUseForSimp ++ assumptions} (.targets nassumptions false) - | trace[ScalarTac] "Goal proven by preprocessing!"; return none - /- Apply `simpAll` on the new assumptions -/ - let some nassumptions ← - withTraceNode `ScalarTac (fun _ => pure m!"simpAll: nassumptions ") do - applySimpAll nassumptions false - | trace[ScalarTac] "Goal proven by preprocessing!"; return none - /- We apply `simpAll` to all the assumptions -/ - withTraceNode `ScalarTac (fun _ => pure m!"simpAll: all assumptions ") do - applySimpAll (hypsToUseForSimp ++ assumptions ++ nassumptions) simpTarget - else if hypsToUseForSimp.size > 0 && simpTarget then - /- Even though there is nothing to preprocess, we want to simplify the goal by using the hypotheses to use, - to make sure we propagate equalities for instance -/ - let some _ ← - withTraceNode `ScalarTac (fun _ => pure m!"simpAt (no assumptions to preprocess)") do - Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 0} - { hypsToUse := hypsToUseForSimp } (.targets #[] simpTarget) - | trace[ScalarTac] "Goal proven by preprocessing!"; return none - pure (some assumptions) - else - pure (some assumptions) - let some assumptions := r - | trace[ScalarTac] "Goal proven by preprocessing!"; return none - traceGoalWithNode `ScalarTac "Goal after simpAll" + | some (fvars, _) => pure (some fvars) + | none => pure none -- Goal proven through `simpAll` + | none => + trace[ScalarTac] "applySimpAll failed" + -- `simpAll` failed: let's just continue with the same state as before + pure (some assumptions) + /- Apply simpAll to all the assumptions - we do it subset by subset: it is likely + that applying simp_all on the new assumptions will finish the work, meaning that + applying simp_all to *all* the assumptions will actually do nothing -/ + let some assumptions ← applySimpAll assumptions false + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + traceGoalWithNode `ScalarTac "Goal after simpAll on assumptions" + let some satAssumptions ← applySimpAll satAssumptions false + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + traceGoalWithNode `ScalarTac "Goal after simpAll on satAssumptions" + let assumptions := assumptions ++ satAssumptions + let some assumptions ← applySimpAll assumptions false + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + traceGoalWithNode `ScalarTac "Goal after simpAll on assumptions and satAssumptions" + let some assumptions ← do + if config.simpAllMaxSteps ≠ 0 ∧ hypsToUseForSimp.size + assumptions.size > 0 ∧ simpAllAssumptions then + withTraceNode `ScalarTac (fun _ => pure m!"simpAll: all assumptions ") do + applySimpAll (hypsToUseForSimp ++ assumptions) simpTarget + else if (hypsToUseForSimp.size + assumptions.size > 0) && simpTarget then + /- Even though there is nothing to preprocess, we want to simplify the goal by using the `hypsToUseForSimp` + to use, to make sure we propagate equalities for instance -/ + let some _ ← + withTraceNode `ScalarTac (fun _ => pure m!"simpAt (no assumptions to preprocess)") do + Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 0} + { hypsToUse := hypsToUseForSimp ++ assumptions } (.targets #[] simpTarget) + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + pure (some assumptions) + else + pure (some assumptions) + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + traceGoalWithNode `ScalarTac "Goal after simpAll" - -- Call `simp` again, this time to inline the let-bindings (otherwise, omega doesn't always manage to deal with them) - let r ← do - withTraceNode `ScalarTac (fun _ => pure m!"simpAt") do - if zetaDelta then - Simp.simpAt true {zetaDelta := true, failIfUnchanged := false, maxDischargeDepth := 1} simpArgs (.targets assumptions simpTarget) - else pure assumptions - -- We might have proven the goal - let some assumptions := r - | trace[ScalarTac] "Goal proven by preprocessing!"; return none - traceGoalWithNode `ScalarTac "Goal after 2nd simp (with zetaDelta)" - -- Apply normCast - let some assumptions ← - withTraceNode `ScalarTac (fun _ => pure m!"normCast") do - Utils.normCastAt (.targets assumptions simpTarget) + -- Call `simp` again, this time to inline the let-bindings (otherwise, omega doesn't always manage to deal with them) + let r ← do + withTraceNode `ScalarTac (fun _ => pure m!"simpAt") do + if zetaDelta then + Simp.simpAt true {zetaDelta := true, failIfUnchanged := false, maxDischargeDepth := 1} simpArgs + (.targets assumptions simpTarget) + else pure assumptions + -- We might have proven the goal + let some assumptions := r + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + traceGoalWithNode `ScalarTac "Goal after 2nd simp (with zetaDelta)" + -- Apply normCast + let some assumptions ← + withTraceNode `ScalarTac (fun _ => pure m!"normCast") do + Utils.normCastAt (.targets assumptions simpTarget) + | trace[ScalarTac] "Goal proven by preprocessing!"; return none + traceGoalWithNode `ScalarTac "Goal after normCast" + -- Call `simp` again because `normCast` sometimes does weird things + let some assumptions ← + withTraceNode `ScalarTac (fun _ => pure m!"simpAt") do + Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 1} simpArgs + (.targets assumptions simpTarget) | trace[ScalarTac] "Goal proven by preprocessing!"; return none - traceGoalWithNode `ScalarTac "Goal after normCast" - -- Call `simp` again because `normCast` sometimes does weird things - let some assumptions ← + traceGoalWithNode `ScalarTac "Goal after 2nd call to simpAt" + /- Remove the occurrences of `forall'` in the target -/ + if simpTarget then + let some _ ← withTraceNode `ScalarTac (fun _ => pure m!"simpAt") do - Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 1} simpArgs - (.targets assumptions simpTarget) + Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 0} + { declsToUnfold := #[``forall'] } (.targets #[] simpTarget) | trace[ScalarTac] "Goal proven by preprocessing!"; return none - traceGoalWithNode `ScalarTac "Goal after 2nd call to simpAt" - /- Remove the occurrences of `forall'` in the target -/ - if simpTarget then - let some _ ← - withTraceNode `ScalarTac (fun _ => pure m!"simpAt") do - Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 0} - { declsToUnfold := #[``forall'] } (.targets #[] simpTarget) - | trace[ScalarTac] "Goal proven by preprocessing!"; return none - traceGoalWithNode `ScalarTac "Goal after eliminating `forall'` in the target" - -- We modified the assumptions in the context so we need to update the state accordingly - let saturateState ← Saturate.recomputeAssumptions state.saturateState none assumptions - let state := { state with saturateState } - -- We're done - return some (state, assumptions) - finish + traceGoalWithNode `ScalarTac "Goal after eliminating `forall'` in the target" + -- We modified the assumptions in the context so we need to update the state accordingly + let saturateState ← + Saturate.recomputeAssumptions state.saturateState none + -- TODO: do something smarter + (if simpAllAssumptions then assumptions else hypsToUseForSimp ++ assumptions) + let state := { state with saturateState } + -- We're done + return some (state, assumptions) elab "scalar_tac_preprocess" config:Parser.Tactic.optConfig : tactic => do let config ← elabConfig config @@ -499,7 +513,8 @@ elab "scalar_tac" config:Parser.Tactic.optConfig : tactic => do saturate the context. TODO: do we really need the config? -/ -def incrScalarTac (config : Config) (state : State) (toClear : Array FVarId) (assumptions : Array FVarId) : TacticM Unit := do +def incrScalarTac (config : Config) (state : State) (toClear : Array FVarId) + (simpAllAssumptions : Bool) (assumptions : Array FVarId) : TacticM Unit := do withTraceNode `ScalarTac (fun _ => pure m!"incrScalarTac") do Tactic.focus do Tactic.withMainContext do @@ -507,7 +522,8 @@ def incrScalarTac (config : Config) (state : State) (toClear : Array FVarId) (as let mvarId ← (← getMainGoal).tryClearMany toClear setGoals [mvarId] /- Saturate by exploring only the goal -/ - let some (_, _) ← partialPreprocess config (← ScalarTac.getSimpArgs) state (zetaDelta := true) assumptions #[] true + let some (_, _) ← partialPreprocess config (← ScalarTac.getSimpArgs) state (zetaDelta := true) + simpAllAssumptions assumptions #[] true | trace[ScalarTac] "incrScalarTac: goal proven by preprocessing" trace[ScalarTac] "Goal after preprocessing: {← getMainGoal}" /- Call omega -/ diff --git a/backends/lean/Aeneas/Simp/Simp.lean b/backends/lean/Aeneas/Simp/Simp.lean index 174ba3e79..3c51569f3 100644 --- a/backends/lean/Aeneas/Simp/Simp.lean +++ b/backends/lean/Aeneas/Simp/Simp.lean @@ -137,6 +137,7 @@ where def simpAt (simpOnly : Bool) (config : Simp.Config) (args : SimpArgs) (loc : Utils.Location) : TacticM (Option (Array FVarId)) := do withTraceNode `Simp (fun _ => pure m!"simpAt") do + withMainContext do -- Initialize the simp context let (ctx, simprocs) ← mkSimpCtx simpOnly config .simp args -- Apply the simplifier @@ -148,6 +149,7 @@ def simpAt (simpOnly : Bool) (config : Simp.Config) (args : SimpArgs) (loc : Uti def dsimpAt (simpOnly : Bool) (config : Simp.Config) (args : SimpArgs) (loc : Location) : TacticM Unit := do withTraceNode `Simp (fun _ => pure m!"dsimpAt") do + withMainContext do -- Initialize the simp context let (ctx, simprocs) ← mkSimpCtx simpOnly config .dsimp args -- Apply the simplifier @@ -157,6 +159,7 @@ def dsimpAt (simpOnly : Bool) (config : Simp.Config) (args : SimpArgs) (loc : Lo def simpAll (config : Simp.Config) (simpOnly : Bool) (args : SimpArgs) : TacticM Unit := do withTraceNode `Simp (fun _ => pure m!"simpAll") do + withMainContext do -- Initialize the simp context let (ctx, simprocs) ← mkSimpCtx simpOnly config .simpAll args -- Apply the simplifier diff --git a/backends/lean/Aeneas/SimpLists/Tests.lean b/backends/lean/Aeneas/SimpLists/Tests.lean index 134e55da5..9db890174 100644 --- a/backends/lean/Aeneas/SimpLists/Tests.lean +++ b/backends/lean/Aeneas/SimpLists/Tests.lean @@ -2,6 +2,8 @@ import Aeneas.SimpLists.SimpLists import Aeneas.List.List import Aeneas.Std.Slice +open Aeneas Std + example {α} [Inhabited α] (l : List α) (x : α) (i j : Nat) (hj : i ≠ j) : (l.set j x)[i]! = l[i]! := by simp_lists @@ -43,3 +45,39 @@ example (↑l[j]! : ZMod 3329) = (↑x : ZMod 3329) - (↑y : ZMod 3329) := by simp_lists [h] + +/- This example comes from SymCrypt -/ +set_option trace.profiler true in +example + (f : Std.Array U16 256#usize) + (g : Std.Array U16 256#usize) + (B0 : ℕ) + (B1 : ℕ) + (i0 : ℕ) + (i : Usize) + (c0 : U32) + (c1 : U32) + (paDst0 : Std.Array U32 256#usize) + (paDst : Std.Array U32 256#usize) + (hi0 : i0 < 128) + (hc0bound : (↑c0 : ℕ) ≤ B0 + B1) + (hc1bound : (↑c1 : ℕ) ≤ B0 + B1) + (hc1 : (↑c1 : ℕ) = + (↑paDst0[(↑i : ℕ) + 1]! : ℕ) + (↑f[(↑i : ℕ)]! : ℕ) * (↑g[(↑i : ℕ) + 1]! : ℕ) + + (↑f[(↑i : ℕ) + 1]! : ℕ) * (↑g[(↑i : ℕ)]! : ℕ)) + (hi : (↑i : ℕ) = 2 * i0) + (paDst1 : Std.Array U32 256#usize) + (paDst1_post : paDst1 = paDst.set i c0) + (i' : Usize) + (i'_post : (↑i' : ℕ) = (↑i : ℕ) + 1) + (paDst2 : Std.Array U32 256#usize) + (paDst2_post : paDst2 = paDst1.set i' c1) + (h0 : ∀ j < i0, (↑paDst[2 * j]! : ℕ) ≤ B0 + B1 ∧ (↑paDst[2 * j + 1]! : ℕ) ≤ B0 + B1) + (h1 : ∀ (j : ℕ), i0 ≤ j → j < 128 → (↑paDst0[2 * j]! : ℕ) ≤ B0 ∧ (↑paDst0[2 * j + 1]! : ℕ) ≤ B0) + (h3 : ∀ (j : ℕ), i0 ≤ j → j < 128 → paDst[2 * j]! = paDst0[2 * j]! ∧ paDst[2 * j + 1]! = paDst0[2 * j + 1]!) + (j : ℕ) + (hj : j < i0 + 1) + (hjeq : j = i0) : + (↑paDst2[2 * j]! : ℕ) ≤ B0 + B1 ∧ (↑paDst2[2 * j + 1]! : ℕ) ≤ B0 + B1 + := by + simp_lists [*] diff --git a/backends/lean/Aeneas/Utils.lean b/backends/lean/Aeneas/Utils.lean index 52c811599..68c92a85f 100644 --- a/backends/lean/Aeneas/Utils.lean +++ b/backends/lean/Aeneas/Utils.lean @@ -446,6 +446,7 @@ def filterAssumptionTacPreprocess (decls : Option (Array FVarId) := none) : Tact /-- Return `true` if managed to close goal `mvarId` using an assumption. -/ def filterAssumptionTacCore (dtree : DiscrTree FVarId) : TacticM Bool := do + withTraceNode `Utils (fun _ => pure m!"filterAssumptionTacCore") do withMainContext do let g ← getMainGoal let type ← instantiateMVars (← g.getType) @@ -490,9 +491,11 @@ example (x y : Nat) (_ : y * 3000 ≤ 1) (_ : x * 3000 ≤ 1) : y * 3000 ≤ 1 : fassumption -- List all the local declarations matching the goal -def getAllMatchingAssumptions (type : Expr) : MetaM (List (LocalDecl × Name)) := do +def getAllMatchingAssumptions (type : Expr) (candidates : Option (Array FVarId)) : MetaM (List (LocalDecl × Name)) := do let typeType ← inferType type - let decls ← (← getLCtx).getAllDecls + let decls ← match candidates with + | some candidates => pure (← candidates.mapM FVarId.getDecl).toList + | none => (← getLCtx).getDecls decls.filterMapM fun localDecl => do -- Make sure we revert the meta-variables instantiations by saving the state and restoring it let s ← saveState @@ -508,7 +511,8 @@ def getAllMatchingAssumptions (type : Expr) : MetaM (List (LocalDecl × Name)) : def singleAssumptionTacPreprocess (decls := none) := filterAssumptionTacPreprocess decls -def singleAssumptionTacCore (dtree : DiscrTree FVarId) : TacticM Unit := do +def singleAssumptionTacCore (dtree : DiscrTree FVarId) + (candidates : Option (Array FVarId)) : TacticM Unit := do withMainContext do let mvarId ← getMainGoal mvarId.checkNotAssigned `sassumption @@ -528,7 +532,7 @@ def singleAssumptionTacCore (dtree : DiscrTree FVarId) : TacticM Unit := do several times, but discrimination trees don't work if the expression we match over contains meta-variables. -/ - match ← (getAllMatchingAssumptions goal) with + match ← (getAllMatchingAssumptions goal candidates) with | [(localDecl, _)] => /- There is a single assumption which matches the goal: use it Note that we need to call isDefEq again to properly instantiate the meta-variables -/ @@ -549,7 +553,7 @@ def singleAssumptionTacCore (dtree : DiscrTree FVarId) : TacticM Unit := do -/ def singleAssumptionTac : TacticM Unit := do let dtree ← singleAssumptionTacPreprocess - singleAssumptionTacCore dtree + singleAssumptionTacCore dtree none elab "sassumption " : tactic => do singleAssumptionTac From 3f953fe95be939f0aa717ffb0f3851a75dd68149 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 27 Jun 2025 20:06:59 +0200 Subject: [PATCH 31/31] Fix an issue in partialPreprocess --- backends/lean/Aeneas/Progress/Progress.lean | 19 ++++++---- .../lean/Aeneas/Progress/ProgressStar.lean | 23 ++++++++++++ .../lean/Aeneas/ScalarTac/CondSimpTac.lean | 24 +++++++----- backends/lean/Aeneas/ScalarTac/ScalarTac.lean | 37 ++++++++----------- backends/lean/Aeneas/SimpLists/Tests.lean | 1 - 5 files changed, 64 insertions(+), 40 deletions(-) diff --git a/backends/lean/Aeneas/Progress/Progress.lean b/backends/lean/Aeneas/Progress/Progress.lean index 2ff20dd57..56b049513 100644 --- a/backends/lean/Aeneas/Progress/Progress.lean +++ b/backends/lean/Aeneas/Progress/Progress.lean @@ -109,6 +109,7 @@ structure State where scalarTacState : ScalarTac.State /-- Assumptions in the context which are only used by `scalar_tac` -/ scalarTacAsms : Array FVarId + scalarTacHypsForPres : Array FVarId /-- Assumptions which should be visible to the user (those are the declarations which were not introduced by `scalar_tac`) -/ userAsms : Array FVarId @@ -402,7 +403,7 @@ def checkAssumptionIsUsableWithProgress (fvarId : FVarId) : TacticM Bool := do def State.update (state : State) (asms : Array FVarId) : TacticM (Option (State × MVarId)) := do withTraceNode `Progress (fun _ => pure m!"State.update") do /- Add the post-conditions to the set of user declarations -/ - let { scalarTacState, scalarTacAsms, userAsms, progressAsms, recDecls } := state + let { scalarTacState, scalarTacAsms, scalarTacHypsForPres, userAsms, progressAsms, recDecls } := state let userAsms := userAsms ++ asms /- Filter the post-conditions which could be used as candidates for `progress` -/ let progressAsms' ← asms.filterM checkAssumptionIsUsableWithProgress @@ -411,16 +412,17 @@ def State.update (state : State) (asms : Array FVarId) : TacticM (Option (State /- Duplicate the post-conditions -/ let (_, newPosts) ← duplicateAssumptions asms /- Preprocess -/ - let some (scalarTacState, newPosts) ← + let some (scalarTacState, postsForPres, newPosts) ← ScalarTac.partialPreprocess scalarTacConfig (← ScalarTac.getSimpArgs) - scalarTacState (zetaDelta := true) (simpAllAssumptions := false) + scalarTacState (zetaDelta := true) (hypsToUseForSimp := state.scalarTacAsms) (assumptionsToPreprocess := newPosts) (simpTarget := false) | trace[Progress] "Goal proven by preprocessing!"; return none let scalarTacAsms := scalarTacAsms ++ newPosts + let scalarTacHypsForPres := scalarTacHypsForPres ++ postsForPres let mainGoal ← getMainGoal - let state := { scalarTacState, scalarTacAsms, userAsms, progressAsms, recDecls } + let state := { scalarTacState, scalarTacAsms, scalarTacHypsForPres, userAsms, progressAsms, recDecls } pure (some (state, mainGoal)) def updateState (mainGoal : Option MainGoal) (state : Option State) : TacticM (Option MainGoalWithState) := do @@ -675,9 +677,9 @@ def State.init : TacticM (Option State) := do /- Duplicate and preprocess the assumptions -/ let (oldAsms, newAsms) ← duplicateAssumptions let scalarTacState ← ScalarTac.State.new scalarTacConfig - let some (scalarTacState, newAsms) ← + let some (scalarTacState, asmsForPres, newAsms) ← ScalarTac.partialPreprocess scalarTacConfig (← ScalarTac.getSimpArgs) - scalarTacState (zetaDelta := true) (simpAllAssumptions := false) + scalarTacState (zetaDelta := true) (hypsToUseForSimp := #[]) (assumptionsToPreprocess := newAsms) (simpTarget := false) @@ -688,6 +690,7 @@ def State.init : TacticM (Option State) := do pure (some { scalarTacState, scalarTacAsms := newAsms, + scalarTacHypsForPres := asmsForPres, userAsms := oldAsms, progressAsms, recDecls, @@ -696,7 +699,7 @@ def State.init : TacticM (Option State) := do /-- Cleanup the context to remove all the auxiliary assumptions introduced for the progress state and which should not be shown to the user -/ def State.clean (state : State) : TacticM Unit := do - setGoals [← (← getMainGoal).tryClearMany state.scalarTacAsms] + setGoals [← (← getMainGoal).tryClearMany (state.scalarTacAsms ++ state.scalarTacHypsForPres)] def cleanState (state : Option State) : TacticM Unit := do match state with @@ -739,7 +742,7 @@ def evalProgress ScalarTac.scalarTac scalarTacConfig | some state => ScalarTac.incrScalarTac scalarTacConfig state.scalarTacState (toClear := state.userAsms) - (assumptions := state.scalarTacAsms) (simpAllAssumptions := false) + (assumptions := state.scalarTacAsms) else throwError "Not a linear arithmetic goal" let simpLemmas ← Aeneas.ScalarTac.scalarTacSimpExt.getTheorems diff --git a/backends/lean/Aeneas/Progress/ProgressStar.lean b/backends/lean/Aeneas/Progress/ProgressStar.lean index 4fcbdff6a..18bc277a5 100644 --- a/backends/lean/Aeneas/Progress/ProgressStar.lean +++ b/backends/lean/Aeneas/Progress/ProgressStar.lean @@ -684,6 +684,29 @@ example b (x y : U32) : unfold add2 progress*? +def add16 (x : U32) := do + let y ← x + x + let y ← y + x + let y ← y + x + let y ← y + x + let y ← y + x + let y ← y + x + let y ← y + x + let y ← y + x + let y ← y + x + let y ← y + x + let y ← y + x + let y ← y + x + let y ← y + x + let y ← y + x + let y ← y + x + pure y + +example (x : U32) (h : 16 * x ≤ U32.max) : + ∃ y, add16 x = ok y := by + unfold add16 + progress* + end Examples end Aeneas diff --git a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean index 3eba296d1..9b3494dd2 100644 --- a/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean +++ b/backends/lean/Aeneas/ScalarTac/CondSimpTac.lean @@ -100,8 +100,7 @@ def condSimpTacSimp (config : Simp.Config) (args : CondSimpArgs) (loc : Utils.Lo /- Note that when calling `scalar_tac` we saturate only by looking at the target: we have already saturated by looking at the assumptions (we do this once and for all beforehand) -/ let dischargeWrapper ← - Simp.tacticToDischarge (incrScalarTac {saturateAssumptions := false} state - toClear (simpAllAssumptions := false) asms) + Simp.tacticToDischarge (incrScalarTac {saturateAssumptions := false} state toClear asms) dischargeWrapper.with fun discharge? => do -- Initialize the simp context let (ctx, simprocs) ← Simp.mkSimpCtx true config .simp simpArgs @@ -117,6 +116,7 @@ structure PreprocessResult where state : State oldAsms : Array FVarId newAsms : Array FVarId + asmsForPres : Array FVarId additionalSimpThms : Array FVarId /-- Preprocess the goal. @@ -143,7 +143,6 @@ def condSimpTacPreprocess (config : CondSimpTacConfig) (hypsArgs args : CondSimp trace[ScalarTac] "hypsToUse: {hypsToUse.map Expr.fvar}" /- -/ let (oldAsms, newAsms) ← Utils.duplicateAssumptions (some (allAssumptions.map LocalDecl.fvarId)) - let toClear := oldAsms withMainContext do traceGoalWithNode `ScalarTac "Goal after duplicating the assumptions" trace[ScalarTac] "newAsms: {newAsms.map Expr.fvar}" @@ -153,9 +152,9 @@ def condSimpTacPreprocess (config : CondSimpTacConfig) (hypsArgs args : CondSimp /- First the hyps to use. Note that we do not inline the local let-declarations: we will do this only for the "regular" assumptions and the target. -/ - let some (_, hypsToUse) ← + let some (_, assumptionsForPres, hypsToUse) ← partialPreprocess scalarConfig hypsArgs.toSimpArgs state (zetaDelta := false) - (simpAllAssumptions := false) #[] hypsToUse false + #[] hypsToUse false | trace[ScalarTac] "Goal proved through preprocessing!"; return none withMainContext do withTraceNode `ScalarTac (fun _ => pure m!"Goal after preprocessing the hyps to use ({hypsToUse.map Expr.fvar})") do @@ -170,9 +169,9 @@ def condSimpTacPreprocess (config : CondSimpTacConfig) (hypsArgs args : CondSimp withTraceNode `ScalarTac (fun _ => pure m!"Goal after simplifying the preprocessed hyps to use ({hypsToUse.map Expr.fvar})") do trace[ScalarTac] "{← getMainGoal}" /- Preprocess the "regular" assumptions -/ - let some (state, newAsms) ← + let some (state, assumptionsForPres', newAsms) ← partialPreprocess scalarConfig (← ScalarTac.getSimpArgs) state (zetaDelta := true) - (simpAllAssumptions := false) #[] newAsms false + #[] newAsms false | trace[ScalarTac] "Goal proved through preprocessing!"; return none withMainContext do traceGoalWithNode `ScalarTac "Goal after the initial preprocessing" @@ -180,7 +179,9 @@ def condSimpTacPreprocess (config : CondSimpTacConfig) (hypsArgs args : CondSimp /- Introduce the additional simp theorems -/ let additionalSimpThms ← addSimpThms traceGoalWithNode `ScalarTac "Goal after adding the additional simp assumptions" - pure (some { args, toClear, hypsToUse, state, oldAsms, newAsms, additionalSimpThms }) + let toClear := oldAsms + let asmsForPres := assumptionsForPres ++ assumptionsForPres' + pure (some { args, toClear, hypsToUse, state, oldAsms, newAsms, asmsForPres, additionalSimpThms }) def condSimpTacCore (tacName : String) @@ -189,7 +190,7 @@ def condSimpTacCore (loc : Utils.Location) (res : PreprocessResult) : TacticM (Option MVarId) := do withTraceNode `ScalarTac (fun _ => pure m!"CondSimpTacCore") do - let { args, toClear, hypsToUse := _, state, oldAsms, newAsms, additionalSimpThms } := res + let { args, toClear, hypsToUse := _, state, oldAsms, newAsms, asmsForPres, additionalSimpThms } := res /- Simplify the targets (note that we preserve the new assumptions for `scalar_tac`) -/ withMainContext do let loc ← do @@ -212,7 +213,8 @@ def condSimpTacCore TODO: scalar_tac should only be allowed to preprocess `scalarTacAsms`. TODO: we should preprocess those. -/ - let _ ← condSimpTacSimp simpConfig args nloc (toClear := toClear) (additionalHypsToUse := additionalSimpThms) (some (state, newAsms)) + let _ ← condSimpTacSimp simpConfig args nloc (toClear := toClear ++ asmsForPres) + (additionalHypsToUse := additionalSimpThms) (some (state, newAsms)) if (← getUnsolvedGoals) == [] then pure none else pure (some (← getMainGoal)) @@ -220,6 +222,8 @@ def condSimpTacClear (res : PreprocessResult) : TacticM Unit := do withTraceNode `ScalarTac (fun _ => pure m!"CondSimpTacClear") do setGoals [← (← getMainGoal).tryClearMany res.hypsToUse] traceGoalWithNode `ScalarTac "Goal after clearing the duplicated hypotheses to use" + setGoals [← (← getMainGoal).tryClearMany res.asmsForPres] + traceGoalWithNode `ScalarTac "Goal after clearing the asmsForPres" setGoals [← (← getMainGoal).tryClearMany res.newAsms] traceGoalWithNode `ScalarTac "Goal after clearing the duplicated assumptions" setGoals [← (← getMainGoal).tryClearMany res.additionalSimpThms] diff --git a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean index 5a1a3b20a..36acda5c2 100644 --- a/backends/lean/Aeneas/ScalarTac/ScalarTac.lean +++ b/backends/lean/Aeneas/ScalarTac/ScalarTac.lean @@ -288,9 +288,8 @@ def State.new (config : Config) : MetaM State := do pure { saturateState } def partialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (state : State) - (zetaDelta : Bool) (simpAllAssumptions : Bool) - (hypsToUseForSimp assumptionsToPreprocess : Array FVarId) (simpTarget : Bool) : - Tactic.TacticM (Option (State × Array FVarId)) := do + (zetaDelta : Bool) (hypsToUseForSimp assumptionsToPreprocess : Array FVarId) (simpTarget : Bool) : + Tactic.TacticM (Option (State × Array FVarId × Array FVarId)) := do withTraceNode `ScalarTac (fun _ => pure m!"scalarTacPartialPreprocess") do Tactic.focus do Tactic.withMainContext do @@ -301,7 +300,7 @@ def partialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (state : Stat typed expressions such as `UScalar.ofNat`. -/ traceGoalWithNode `ScalarTac "Original goal before preprocessing" - let some assumptions ← + let some assumptionsForPres ← withTraceNode `ScalarTac (fun _ => pure m!"simpAt assumptionsToPreprocess simpTarget") do Simp.simpAt true {dsimp := false, failIfUnchanged := false, maxDischargeDepth := 1} /- Remove the forall quantifiers to prepare for the call of `simp_all` (we @@ -311,15 +310,20 @@ def partialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (state : Stat (.targets assumptionsToPreprocess simpTarget) | trace[ScalarTac] "Goal proven by preprocessing!"; return none traceGoalWithNode `ScalarTac "Goal after first simplification" + let (_, assumptions) ← duplicateAssumptions assumptionsForPres -- Apply the forward rules let satConfig := config.toSaturateConfig let satConfig := { satConfig with saturateAssumptions := satConfig.saturateAssumptions && config.saturate, saturateTarget := satConfig.saturateTarget && config.saturate, } - let (saturateState, satAssumptions) ← scalarTacSaturateForward satConfig state.saturateState (some assumptions) + let (saturateState, satAssumptions) ← + scalarTacSaturateForward satConfig state.saturateState (some assumptionsForPres) fun saturateState nassumptions => pure (saturateState, nassumptions) withMainContext do + -- TODO: we want to post-process the assumptionsForPres': `saturate` should give us this possibility + let (assumptionsForPres', satAssumptions) ← duplicateAssumptions satAssumptions + let assumptionsForPres := assumptionsForPres ++ assumptionsForPres' let state := { state with saturateState } traceGoalWithNode `ScalarTac "Goal after saturation" /- Simplify the new assumptions with the old assumptions (it's often enough to go in that direction) -/ @@ -370,12 +374,9 @@ def partialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (state : Stat | trace[ScalarTac] "Goal proven by preprocessing!"; return none traceGoalWithNode `ScalarTac "Goal after simpAll on assumptions and satAssumptions" let some assumptions ← do - if config.simpAllMaxSteps ≠ 0 ∧ hypsToUseForSimp.size + assumptions.size > 0 ∧ simpAllAssumptions then - withTraceNode `ScalarTac (fun _ => pure m!"simpAll: all assumptions ") do - applySimpAll (hypsToUseForSimp ++ assumptions) simpTarget - else if (hypsToUseForSimp.size + assumptions.size > 0) && simpTarget then - /- Even though there is nothing to preprocess, we want to simplify the goal by using the `hypsToUseForSimp` - to use, to make sure we propagate equalities for instance -/ + /- Even though there is nothing to preprocess, we want to simplify the goal by using the `hypsToUseForSimp` + to use, to make sure we propagate equalities for instance -/ + if (hypsToUseForSimp.size + assumptions.size > 0) && simpTarget then let some _ ← withTraceNode `ScalarTac (fun _ => pure m!"simpAt (no assumptions to preprocess)") do Simp.simpAt true {failIfUnchanged := false, maxDischargeDepth := 0} @@ -419,14 +420,8 @@ def partialPreprocess (config : Config) (simpArgs : Simp.SimpArgs) (state : Stat { declsToUnfold := #[``forall'] } (.targets #[] simpTarget) | trace[ScalarTac] "Goal proven by preprocessing!"; return none traceGoalWithNode `ScalarTac "Goal after eliminating `forall'` in the target" - -- We modified the assumptions in the context so we need to update the state accordingly - let saturateState ← - Saturate.recomputeAssumptions state.saturateState none - -- TODO: do something smarter - (if simpAllAssumptions then assumptions else hypsToUseForSimp ++ assumptions) - let state := { state with saturateState } -- We're done - return some (state, assumptions) + return some (state, assumptionsForPres, assumptions) elab "scalar_tac_preprocess" config:Parser.Tactic.optConfig : tactic => do let config ← elabConfig config @@ -514,7 +509,7 @@ elab "scalar_tac" config:Parser.Tactic.optConfig : tactic => do TODO: do we really need the config? -/ def incrScalarTac (config : Config) (state : State) (toClear : Array FVarId) - (simpAllAssumptions : Bool) (assumptions : Array FVarId) : TacticM Unit := do + (assumptions : Array FVarId) : TacticM Unit := do withTraceNode `ScalarTac (fun _ => pure m!"incrScalarTac") do Tactic.focus do Tactic.withMainContext do @@ -522,8 +517,8 @@ def incrScalarTac (config : Config) (state : State) (toClear : Array FVarId) let mvarId ← (← getMainGoal).tryClearMany toClear setGoals [mvarId] /- Saturate by exploring only the goal -/ - let some (_, _) ← partialPreprocess config (← ScalarTac.getSimpArgs) state (zetaDelta := true) - simpAllAssumptions assumptions #[] true + let some _ ← partialPreprocess config (← ScalarTac.getSimpArgs) state (zetaDelta := true) + assumptions #[] true | trace[ScalarTac] "incrScalarTac: goal proven by preprocessing" trace[ScalarTac] "Goal after preprocessing: {← getMainGoal}" /- Call omega -/ diff --git a/backends/lean/Aeneas/SimpLists/Tests.lean b/backends/lean/Aeneas/SimpLists/Tests.lean index 9db890174..d7bb68777 100644 --- a/backends/lean/Aeneas/SimpLists/Tests.lean +++ b/backends/lean/Aeneas/SimpLists/Tests.lean @@ -47,7 +47,6 @@ example simp_lists [h] /- This example comes from SymCrypt -/ -set_option trace.profiler true in example (f : Std.Array U16 256#usize) (g : Std.Array U16 256#usize)