77prelude
88public import Lean.Meta.Tactic.Grind.Order.OrderM
99import Init.Data.Int.OfNat
10+ import Init.Grind.Module.Envelope
11+ import Init.Grind.Order
1012import Lean.Meta.Tactic.Grind.Arith.CommRing.SafePoly
1113import Lean.Meta.Tactic.Grind.Arith.CommRing.Reify
1214import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr
@@ -15,6 +17,7 @@ import Lean.Meta.Tactic.Grind.Arith.Cutsat.Nat
1517import Lean.Meta.Tactic.Grind.Order.StructId
1618import Lean.Meta.Tactic.Grind.Order.Util
1719import Lean.Meta.Tactic.Grind.Order.Assert
20+ import Lean.Meta.Tactic.Grind.Order.Proof
1821namespace Lean.Meta.Grind.Order
1922
2023open Arith CommRing
@@ -40,8 +43,24 @@ def getType? (e : Expr) : Option Expr :=
4043 | _ => none
4144
4245def isForbiddenParent (parent? : Option Expr) : Bool :=
43- if let some parent := parent? then
44- getType? parent |>.isSome
46+ if isIntModuleVirtualParent parent? then
47+ /-
48+ **Note** : `linarith` uses a virtual parent to mark auxiliary declarations used to encode
49+ terms into an `IntModule`.
50+ -/
51+ true
52+ else if let some parent := parent? then
53+ (getType? parent |>.isSome)
54+ ||
55+ /-
56+ **Note** : We currently ignore `•`. We may reconsider it in the future.
57+ -/
58+ match_expr parent with
59+ | HSMul.hSMul _ _ _ _ _ _ => true
60+ | Nat.cast _ _ _ => true
61+ | NatCast.natCast _ _ _ => true
62+ | Grind.IntModule.OfNatModule.toQ _ _ _ => true
63+ | _ => false
4564 else
4665 false
4766
@@ -206,27 +225,89 @@ def internalizeCnstr (e : Expr) (kind : CnstrKind) (lhs rhs : Expr) : OrderM Uni
206225 s.cnstrsOf.insert (u, v) cs
207226 }
208227
228+ /--
229+ Normalization result. A nested term `e` is normalized as `a + k` and
230+ `h` is a proof for `e = a + k`
231+ -/
232+ structure OffsetTermResult where
233+ a : Expr
234+ k : Int
235+ h : Expr
236+
237+ def toOffsetTermCommRing? (e : Expr) : RingM (Option OffsetTermResult) := do
238+ let some e ← reify? e (skipVar := false ) | return none
239+ let p ← e.toPolyM
240+ let k := p.getConst
241+ let p := p.addConst (-k)
242+ let a ← shareCommon (← p.denoteExpr)
243+ let h ← mkTermEqProof e (.add (p.toExpr) (.intCast k))
244+ return some { a, k, h }
245+
246+ def toOffsetTermNonCommRing? (e : Expr) : NonCommRingM (Option OffsetTermResult) := do
247+ let some e ← ncreify? e (skipVar := false ) | return none
248+ let p := e.toPoly_nc
249+ let k := p.getConst
250+ let p := p.addConst (-k)
251+ let a ← shareCommon (← p.denoteExpr)
252+ let h ← mkNonCommTermEqProof e (.add (p.toExpr) (.intCast k))
253+ return some { a, k, h }
254+
255+ def toOffsetTerm? (e : Expr) : OrderM (Option OffsetTermResult) := do
256+ let s ← getStruct
257+ /-
258+ **Note** : If it is not a partial order, then it is not worth internalizing term
259+ since we will not be able to propagate implied equalities back to core.
260+ -/
261+ unless s.isPartialInst?.isSome do return none
262+ let some ringId := s.ringId? | return none
263+ if s.isCommRing then
264+ RingM.run ringId <| toOffsetTermCommRing? e
265+ else
266+ NonCommRingM.run ringId <| toOffsetTermNonCommRing? e
267+
268+ def internalizeTerm (e : Expr) : OrderM Unit := do
269+ let some r ← toOffsetTerm? e | return ()
270+ let x ← mkNode e
271+ let y ← mkNode r.a
272+ let h₁ ← mkOrdRingPrefix ``Grind.Order.le_of_offset_eq_1_k
273+ let h₁ := mkApp4 h₁ e r.a (toExpr r.k) r.h
274+ addEdge x y { k := r.k } h₁
275+ let h₂ ← mkOrdRingPrefix ``Grind.Order.le_of_offset_eq_2_k
276+ let h₂ := mkApp4 h₂ e r.a (toExpr r.k) r.h
277+ addEdge y x { k := -r.k } h₂
278+
279+ def updateTermMap (e eNew h : Expr) : GoalM Unit := do
280+ modify' fun s => { s with
281+ termMap := s.termMap.insert { expr := e } (eNew, h)
282+ termMapInv := s.termMapInv.insert { expr := eNew } (e, h)
283+ }
284+
209285open Arith.Cutsat in
210286def adaptNat (e : Expr) : GoalM Expr := do
211- let (eNew, h) ← match_expr e with
212- | LE.le _ _ lhs rhs =>
213- let (lhs', h₁) ← natToInt lhs
214- let (rhs', h₂) ← natToInt rhs
215- let eNew := mkIntLE lhs' rhs'
216- let h := mkApp6 (mkConst ``Nat.ToInt.le_eq) lhs rhs lhs' rhs' h₁ h₂
217- pure (eNew, h)
218- | LT.lt _ _ lhs rhs =>
219- let (lhs', h₁) ← natToInt lhs
220- let (rhs', h₂) ← natToInt rhs
221- let eNew := mkIntLT lhs' rhs'
222- let h := mkApp6 (mkConst ``Nat.ToInt.lt_eq) lhs rhs lhs' rhs' h₁ h₂
223- pure (eNew, h)
287+ if let some (eNew, _) := (← get').termMap.find? { expr := e } then
288+ return eNew
289+ else match_expr e with
290+ | LE.le _ _ lhs rhs => adaptCnstr lhs rhs (isLT := false )
291+ | LT.lt _ _ lhs rhs => adaptCnstr lhs rhs (isLT := true )
292+ | HAdd.hAdd _ _ _ _ _ _ => adaptTerm
293+ | HSub.hSub _ _ _ _ _ _ => adaptTerm
294+ | OfNat.ofNat _ _ _ => adaptTerm
224295 | _ => return e
225- modify' fun s => { s with
226- cnstrsMap := s.cnstrsMap.insert { expr := e } (eNew, h)
227- cnstrsMapInv := s.cnstrsMapInv.insert { expr := eNew } (e, h)
228- }
229- return eNew
296+ where
297+ adaptCnstr (lhs rhs : Expr) (isLT : Bool) : GoalM Expr := do
298+ let (lhs', h₁) ← natToInt lhs
299+ let (rhs', h₂) ← natToInt rhs
300+ let eNew := if isLT then mkIntLT lhs' rhs' else mkIntLE lhs' rhs'
301+ let h := mkApp6
302+ (mkConst (if isLT then ``Nat.ToInt.lt_eq else ``Nat.ToInt.le_eq))
303+ lhs rhs lhs' rhs' h₁ h₂
304+ updateTermMap e eNew h
305+ return eNew
306+
307+ adaptTerm : GoalM Expr := do
308+ let (eNew, h) ← natToInt e
309+ updateTermMap e eNew h
310+ return eNew
230311
231312def adapt (α : Expr) (e : Expr) : GoalM (Expr × Expr) := do
232313 -- **Note** : We currently only adapt `Nat` expressions
@@ -250,8 +331,9 @@ public def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
250331 match_expr e with
251332 | LE.le _ _ lhs rhs => internalizeCnstr e .le lhs rhs
252333 | LT.lt _ _ lhs rhs => if (← hasLt) then internalizeCnstr e .lt lhs rhs
253- | _ =>
254- -- **Note** : We currently do not internalize offset terms nested in other terms.
255- return ()
334+ | HAdd.hAdd _ _ _ _ _ _ => internalizeTerm e
335+ | HSub.hSub _ _ _ _ _ _ => internalizeTerm e
336+ | OfNat.ofNat _ _ _ => internalizeTerm e
337+ | _ => return ()
256338
257339end Lean.Meta.Grind.Order
0 commit comments