@@ -6,6 +6,7 @@ Authors: Leonardo de Moura
66module
77prelude
88public import Lean.Meta.Tactic.Grind.AC.Util
9+ public import Lean.Meta.Tactic.Grind.CheckResult
910import Lean.Meta.Tactic.Grind.AC.DenoteExpr
1011import Lean.Meta.Tactic.Grind.AC.Proof
1112import Lean.Meta.Tactic.Grind.AC.Seq
@@ -445,60 +446,65 @@ def mkEqData (e : Expr) (r : AC.Expr) : ACM EqData := do
445446
446447abbrev PropagateEqMap := Std.HashMap AC.Seq EqData
447448
448- private def propagateEqs : ACM Unit := do
449- if (← isInconsistent) then return ()
449+ private def propagateEqs : ACM Bool := do
450+ if (← isInconsistent) then return false
450451 /-
451452 This is a very simple procedure that does not use any indexing data-structure.
452453 We don't even cache the simplified expressions.
453454 TODO: optimize
454455 -/
455- let mut map : PropagateEqMap := {}
456- for e in (← getStruct).vars do
457- if (← checkMaxSteps) then return ()
458- let r ← asACExpr e
459- map ← process map e r
460- for (e, r) in (← getStruct).denoteEntries do
461- if (← checkMaxSteps) then return ()
462- map ← process map e r
456+ let go : StateT (Bool × PropagateEqMap) ACM Unit := do
457+ for e in (← getStruct).vars do
458+ if (← checkMaxSteps) then return
459+ let r ← asACExpr e
460+ process e r
461+ for (e, r) in (← getStruct).denoteEntries do
462+ if (← checkMaxSteps) then return
463+ process e r
464+ let (_, (propagated, _)) ← go.run (false , {})
465+ return propagated
463466where
464- process (map : PropagateEqMap) ( e : Expr) (r : AC.Expr) : ACM PropagateEqMap := do
467+ process (e : Expr) (r : AC.Expr) : StateT (Bool × PropagateEqMap) ACM Unit := do
465468 let d ← mkEqData e r
466469 let s := d.c.rhs
467470 trace[grind.debug.ac.eq] "{e}, s: {← s.denoteExpr}"
468- if let some d' := map [s]? then
471+ if let some d' := (← get). 2 [s]? then
469472 trace[grind.debug.ac.eq] "found [{← isEqv d.e d'.e}]: {d.e}, {d'.e}"
470473 unless (← isEqv d.e d'.e) do
471474 propagateEq d.e d'.e d.r d'.r d.c d'.c
472- return map
475+ modify fun s => ( true , s. 2 )
473476 else
474- return map.insert s d
477+ modify fun (propagated, map) => (propagated, map .insert s d)
475478
476- private def checkStruct : ACM Bool := do
477- unless (← needCheck) do return false
479+ private def checkStruct : ACM CheckResult := do
480+ unless (← needCheck) do return .none
478481 trace_goal[grind.debug.ac.check] "{(← getStruct).op}"
479482 repeat
480483 checkSystem "ac"
481484 let some c ← getNext? | break
482485 trace_goal[grind.debug.ac.check] "{← c.denoteExpr}"
483486 c.addToBasis
484- if (← isInconsistent) then return true
485- if (← checkMaxSteps) then return true
487+ if (← isInconsistent) then return .closed
488+ if (← checkMaxSteps) then return .progress
486489 checkDiseqs
487- propagateEqs
488490 modifyStruct fun s => { s with recheck := false }
489- return true
491+ if (← propagateEqs) then return .propagated
492+ return .progress
490493
491- def check : GoalM Bool := do profileitM Exception "grind ac" (← getOptions) do
492- if (← checkMaxSteps) then return false
493- let mut progress := false
494+ def check : GoalM CheckResult := do profileitM Exception "grind ac" (← getOptions) do
495+ if (← checkMaxSteps) then return .none
496+ let mut result : CheckResult := .none
494497 checkInvariants
495498 try
496499 for opId in *...(← get').structs.size do
497500 let r ← ACM.run opId checkStruct
498- progress := progress || r
499- if (← isInconsistent) then return true
500- return progress
501+ result := result.join r
502+ if (← isInconsistent) then return .closed
503+ return result
501504 finally
502505 checkInvariants
503506
507+ def check' : GoalM Bool :=
508+ return (← check) != .none
509+
504510end Lean.Meta.Grind.AC
0 commit comments