@@ -211,30 +211,47 @@ where
211211 /- Check whether new path: `i -(k₁)-> u -(k)-> v -(k₂) -> j` is shorter -/
212212 updateIfShorter i j (k₁+k+k₂) v
213213
214- def assertIneqTrue (c : Cnstr NodeId) (e : Expr) : OrderM Unit := do
214+ /--
215+ Asserts constraint `c` associated with `e` where `he : e = True`.
216+ -/
217+ def assertIneqTrue (c : Cnstr NodeId) (e : Expr) (he : Expr) : OrderM Unit := do
215218 trace[grind.order.assert] "{← c.pp}"
216- let he ← mkEqTrueProof e
217219 let h ← if let some h := c.h? then
218220 pure <| mkApp4 (mkConst ``Grind.Order.eq_mp) e c.e h (mkOfEqTrueCore e he)
219221 else
220222 pure <| mkOfEqTrueCore e he
221223 let k : Weight := { k := c.k, strict := c.kind matches .lt }
222224 addEdge c.u c.v k h
223225
224- def assertIneqFalse (c : Cnstr NodeId) (_e : Expr) : OrderM Unit := do
226+ /--
227+ Asserts constraint `c` associated with `e` where `he : e = False`.
228+ -/
229+ def assertIneqFalse (c : Cnstr NodeId) (_e : Expr) (_he : Expr) : OrderM Unit := do
225230 trace[grind.order.assert] "¬ {← c.pp}"
226231
227232def getStructIdOf? (e : Expr) : GoalM (Option Nat) := do
228233 return (← get').exprToStructId.find? { expr := e }
229234
230235def propagateIneq (e : Expr) : GoalM Unit := do
231- let some structId ← getStructIdOf? e | return ()
232- OrderM.run structId do
233- let some c ← getCnstr? e | return ()
234- if (← isEqTrue e) then
235- assertIneqTrue c e
236- else if (← isEqFalse e) then
237- assertIneqFalse c e
236+ if let some (e', he) := (← get').cnstrsMap.find? { expr := e } then
237+ go e' (some he)
238+ else
239+ go e none
240+ where
241+ go (e' : Expr) (he? : Option Expr) : GoalM Unit := do
242+ let some structId ← getStructIdOf? e' | return ()
243+ OrderM.run structId do
244+ let some c ← getCnstr? e' | return ()
245+ if (← isEqTrue e) then
246+ let mut h ← mkEqTrueProof e
247+ if let some he := he? then
248+ h := mkApp4 (mkConst ``Grind.Order.eq_trans_true') e e' he h
249+ assertIneqTrue c e' h
250+ else if (← isEqFalse e) then
251+ let mut h ← mkEqFalseProof e
252+ if let some he := he? then
253+ h := mkApp4 (mkConst ``Grind.Order.eq_trans_false') e e' he h
254+ assertIneqFalse c e' h
238255
239256builtin_grind_propagator propagateLE ↓LE.le := propagateIneq
240257builtin_grind_propagator propagateLT ↓LT.lt := propagateIneq
0 commit comments