@@ -129,11 +129,42 @@ def propagatePending : OrderM Unit := do
129129 | .eq u v =>
130130 let ue ← getExpr u
131131 let ve ← getExpr v
132- unless (← isEqv ue ve) do
132+ if (← alreadyInternalized ue <&&> alreadyInternalized ve) then
133+ unless (← isEqv ue ve) do
134+ let huv ← mkProofForPath u v
135+ let hvu ← mkProofForPath v u
136+ let h ← mkEqProofOfLeOfLe ue ve huv hvu
137+ pushEq ue ve h
138+ -- Checks whether `ue` and `ve` are auxiliary terms
139+ let some (ue', h₁) ← getOriginal? ue | continue
140+ let some (ve', h₂) ← getOriginal? ve | continue
141+ if (← alreadyInternalized ue' <&&> alreadyInternalized ve') then
142+ unless (← isEqv ue' ve') do
133143 let huv ← mkProofForPath u v
134144 let hvu ← mkProofForPath v u
135145 let h ← mkEqProofOfLeOfLe ue ve huv hvu
136- pushEq ue ve h
146+ /-
147+ We have
148+ - `h₁ : ↑ue' = ue`
149+ - `h₂ : ↑ve' = ve`
150+ - `h : ue = ve`
151+ -/
152+ pushEq ue' ve' <| mkApp7 (mkConst ``Grind.Order.nat_eq) ue' ve' ue ve h₁ h₂ h
153+ where
154+ /--
155+ If `e` is an auxiliary term used to represent some term `a`, returns
156+ `some (a, h)` s.t. `h : ↑a = e`
157+ **Note** : We currently only support `Nat`. Thus `↑a` is actually
158+ `NatCast.natCast a`. If we decide to support arbitrary semirings
159+ in this module, we must adjust this code.
160+ -/
161+ getOriginal? (e : Expr) : GoalM (Option (Expr × Expr)) := do
162+ if let some r := (← get').termMapInv.find? { expr := e } then
163+ return some r
164+ else
165+ let_expr NatCast.natCast _ _ a := e | return none
166+ let h ← mkEqRefl e
167+ return some (a, h)
137168
138169/--
139170Returns `true` if `e` is already `True` in the `grind` core.
@@ -190,6 +221,7 @@ Traverses the constraints `c` (representing an expression `e`) s.t.
190221
191222/-- Equality propagation. -/
192223def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do
224+ if u == v then return ()
193225 if (← isPartialOrder) then
194226 if !k.isZero then return ()
195227 let some k' ← getDist? v u | return ()
@@ -199,6 +231,24 @@ def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do
199231 if (← alreadyInternalized ue <&&> alreadyInternalized ve) then
200232 if (← isEqv ue ve) then return ()
201233 pushToPropagate <| .eq u v
234+ else
235+ /-
236+ Check whether `ue` and `ve` are auxiliary terms used to encode `Nat` terms.
237+ **Note** : `getOriginal?` is currently hard coded to the `Nat` case since
238+ it is the only type we map to rings. If in the future, we want to support
239+ arbitrary `Semiring`s, we must adjust this code.
240+ -/
241+ let some ue ← getOriginal? ue | return ()
242+ let some ve ← getOriginal? ve | return ()
243+ if (← alreadyInternalized ue <&&> alreadyInternalized ve) then
244+ if (← isEqv ue ve) then return ()
245+ pushToPropagate <| .eq u v
246+ where
247+ getOriginal? (e : Expr) : GoalM (Option Expr) := do
248+ let_expr NatCast.natCast _ _ a := e
249+ | let some (a, _) := (← get').termMapInv.find? { expr := e } | return none
250+ return some a
251+ return some a
202252
203253/-- Finds constrains and equalities to be propagated. -/
204254def checkToPropagate (u v : NodeId) (k : Weight) : OrderM Unit := do
0 commit comments