Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/Init/Grind/Order.lean
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ theorem nat_eq (a b : Nat) (x y : Int) : NatCast.natCast a = x → NatCast.natCa
intro _ _; subst x y; intro h
exact Int.natCast_inj.mp h

theorem of_nat_eq (a b : Nat) (x y : Int) : NatCast.natCast a = x → NatCast.natCast b = y → a = b → x = y := by
intro _ _; subst x y; intro; simp [*]

theorem le_of_not_le {α} [LE α] [Std.IsLinearPreorder α]
{a b : α} : ¬ a ≤ b → b ≤ a := by
intro h
Expand Down
5 changes: 4 additions & 1 deletion src/Lean/Meta/Tactic/Grind/Arith/CommRing/DenoteExpr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ variable [Monad M] [MonadError M] [MonadLiftT MetaM M] [MonadCanon M] [MonadRing
def denoteNum (k : Int) : M Expr := do
let ring ← getRing
let n := mkRawNatLit k.natAbs
let ofNatInst := mkApp3 (mkConst ``Grind.Semiring.ofNat [ring.u]) ring.type ring.semiringInst n
let ofNatInst ← if let some inst ← synthInstance? (mkApp2 (mkConst ``OfNat [ring.u]) ring.type n) then
pure inst
else
pure <| mkApp3 (mkConst ``Grind.Semiring.ofNat [ring.u]) ring.type ring.semiringInst n
let n := mkApp3 (mkConst ``OfNat.ofNat [ring.u]) ring.type n ofNatInst
if k < 0 then
return mkApp (← getNegFn) n
Expand Down
41 changes: 24 additions & 17 deletions src/Lean/Meta/Tactic/Grind/Order/Assert.lean
Original file line number Diff line number Diff line change
Expand Up @@ -148,23 +148,18 @@ def propagatePending : OrderM Unit := do
- `h₁ : ↑ue' = ue`
- `h₂ : ↑ve' = ve`
- `h : ue = ve`
**Note**: We currently only support `Nat`. Thus `↑a` is actually
`NatCast.natCast a`. If we decide to support arbitrary semirings
in this module, we must adjust this code.
-/
pushEq ue' ve' <| mkApp7 (mkConst ``Grind.Order.nat_eq) ue' ve' ue ve h₁ h₂ h
where
/--
If `e` is an auxiliary term used to represent some term `a`, returns
`some (a, h)` s.t. `h : ↑a = e`
**Note**: We currently only support `Nat`. Thus `↑a` is actually
`NatCast.natCast a`. If we decide to support arbitrary semirings
in this module, we must adjust this code.
-/
getOriginal? (e : Expr) : GoalM (Option (Expr × Expr)) := do
if let some r := (← get').termMapInv.find? { expr := e } then
return some r
else
let_expr NatCast.natCast _ _ a := e | return none
let h ← mkEqRefl e
return some (a, h)
return (← get').termMapInv.find? { expr := e }

/--
Returns `true` if `e` is already `True` in the `grind` core.
Expand Down Expand Up @@ -233,10 +228,7 @@ def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do
pushToPropagate <| .eq u v
else
/-
Check whether `ue` and `ve` are auxiliary terms used to encode `Nat` terms.
**Note**: `getOriginal?` is currently hard coded to the `Nat` case since
it is the only type we map to rings. If in the future, we want to support
arbitrary `Semiring`s, we must adjust this code.
Check whether `ue` and `ve` are auxiliary terms.
-/
let some ue ← getOriginal? ue | return ()
let some ve ← getOriginal? ve | return ()
Expand All @@ -245,9 +237,7 @@ def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do
pushToPropagate <| .eq u v
where
getOriginal? (e : Expr) : GoalM (Option Expr) := do
let_expr NatCast.natCast _ _ a := e
| let some (a, _) := (← get').termMapInv.find? { expr := e } | return none
return some a
let some (a, _) := (← get').termMapInv.find? { expr := e } | return none
return some a

/-- Finds constrains and equalities to be propagated. -/
Expand Down Expand Up @@ -378,14 +368,31 @@ builtin_grind_propagator propagateLT ↓LT.lt := propagateIneq

public def processNewEq (a b : Expr) : GoalM Unit := do
unless isSameExpr a b do
let h ← mkEqProof a b
if let some (a', h₁) ← getAuxTerm? a then
let some (b', h₂) ← getAuxTerm? b | return ()
/-
We have
- `h : a = b`
- `h₁ : ↑a = a'`
- `h₂ : ↑b = b'`
-/
let h := mkApp7 (mkConst ``Grind.Order.of_nat_eq) a b a' b' h₁ h₂ h
go a' b' h
else
go a b h
where
getAuxTerm? (e : Expr) : GoalM (Option (Expr × Expr)) := do
return (← get').termMap.find? { expr := e }

go (a b h : Expr) : GoalM Unit := do
let some id₁ ← getStructIdOf? a | return ()
let some id₂ ← getStructIdOf? b | return ()
unless id₁ == id₂ do return ()
OrderM.run id₁ do
trace[grind.order.assert] "{a} = {b}"
let u ← getNodeId a
let v ← getNodeId b
let h ← mkEqProof a b
if (← isRing) then
let h₁ := mkApp3 (← mkOrdRingPrefix ``Grind.Order.le_of_eq_1_k) a b h
let h₂ := mkApp3 (← mkOrdRingPrefix ``Grind.Order.le_of_eq_2_k) a b h
Expand Down
54 changes: 35 additions & 19 deletions src/Lean/Meta/Tactic/Grind/Order/Internalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,20 @@ def isForbiddenParent (parent? : Option Expr) : Bool :=
-/
true
else if let some parent := parent? then
(getType? parent |>.isSome)
||
/-
**Note**: We currently ignore `•`. We may reconsider it in the future.
-/
match_expr parent with
| HSMul.hSMul _ _ _ _ _ _ => true
| Nat.cast _ _ _ => true
| NatCast.natCast _ _ _ => true
| Grind.IntModule.OfNatModule.toQ _ _ _ => true
| _ => false
if parent.isEq then
false
else
(getType? parent |>.isSome)
||
/-
**Note**: We currently ignore `•`. We may reconsider it in the future.
-/
match_expr parent with
| HSMul.hSMul _ _ _ _ _ _ => true
| Nat.cast _ _ _ => true
| NatCast.natCast _ _ _ => true
| Grind.IntModule.OfNatModule.toQ _ _ _ => true
| _ => false
else
false

Expand Down Expand Up @@ -173,6 +176,12 @@ def setStructId (e : Expr) : OrderM Unit := do
exprToStructId := s.exprToStructId.insert { expr := e } structId
}

def updateTermMap (e eNew h : Expr) : GoalM Unit := do
modify' fun s => { s with
termMap := s.termMap.insert { expr := e } (eNew, h)
termMapInv := s.termMapInv.insert { expr := eNew } (e, h)
}

def mkNode (e : Expr) : OrderM NodeId := do
if let some nodeId := (← getStruct).nodeMap.find? { expr := e } then
return nodeId
Expand All @@ -192,7 +201,19 @@ def mkNode (e : Expr) : OrderM NodeId := do
-/
if (← alreadyInternalized e) then
orderExt.markTerm e
if let some e' ← getOriginal? e then
orderExt.markTerm e'
return nodeId
where
getOriginal? (e : Expr) : GoalM (Option Expr) := do
if let some (e', _) := (← get').termMapInv.find? { expr := e } then
return some e'
let_expr NatCast.natCast _ _ a := e | return none
if (← alreadyInternalized a) then
updateTermMap a e (← mkEqRefl e)
return some a
else
return none

def internalizeCnstr (e : Expr) (kind : CnstrKind) (lhs rhs : Expr) : OrderM Unit := do
let some c ← mkCnstr? e kind lhs rhs | return ()
Expand Down Expand Up @@ -267,6 +288,7 @@ def toOffsetTerm? (e : Expr) : OrderM (Option OffsetTermResult) := do

def internalizeTerm (e : Expr) : OrderM Unit := do
let some r ← toOffsetTerm? e | return ()
if e == r.a && r.k == 0 then return ()
let x ← mkNode e
let y ← mkNode r.a
let h₁ ← mkOrdRingPrefix ``Grind.Order.le_of_offset_eq_1_k
Expand All @@ -276,12 +298,6 @@ def internalizeTerm (e : Expr) : OrderM Unit := do
let h₂ := mkApp4 h₂ e r.a (toExpr r.k) r.h
addEdge y x { k := -r.k } h₂

def updateTermMap (e eNew h : Expr) : GoalM Unit := do
modify' fun s => { s with
termMap := s.termMap.insert { expr := e } (eNew, h)
termMapInv := s.termMapInv.insert { expr := eNew } (e, h)
}

open Arith.Cutsat in
def adaptNat (e : Expr) : GoalM Expr := do
if let some (eNew, _) := (← get').termMap.find? { expr := e } then
Expand Down Expand Up @@ -317,7 +333,7 @@ def adapt (α : Expr) (e : Expr) : GoalM (Expr × Expr) := do
else
return (α, e)

def alreadyInternalized (e : Expr) : OrderM Bool := do
def alreadyInternalizedHere (e : Expr) : OrderM Bool := do
let s ← getStruct
return s.cnstrs.contains { expr := e } || s.nodeMap.contains { expr := e }

Expand All @@ -327,7 +343,7 @@ public def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
let (α, e) ← adapt α e
if isForbiddenParent parent? then return ()
if let some structId ← getStructId? α then OrderM.run structId do
if (← alreadyInternalized e) then return ()
if (← alreadyInternalizedHere e) then return ()
match_expr e with
| LE.le _ _ lhs rhs => internalizeCnstr e .le lhs rhs
| LT.lt _ _ lhs rhs => if (← hasLt) then internalizeCnstr e .lt lhs rhs
Expand Down
6 changes: 6 additions & 0 deletions tests/lean/run/grind_dep_match_overlap.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,9 @@ example (a b : Vec α 2) : h a b = 20 := by

example (a b : Vec α 2) : h a b = 20 := by
grind (splits := 4) [h.eq_def, Vec]

example (a b : Vec α 2) : h a b = 20 := by
grind -offset [h.eq_def, Vec]

example (a b : Vec α 2) : h a b = 20 := by
grind -offset (splits := 4) [h.eq_def, Vec]
3 changes: 3 additions & 0 deletions tests/lean/run/grind_fun_singleton.lean
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,6 @@ example [Inhabited α] : ((fun (_ : α) => x = a + 1) = fun (_ : α) => True)

example : c = 5 → ((fun (_ : Nat × Nat) => { down := a + c = b + 5 : ULift Prop }) = fun (_ : Nat × Nat) => { down := c < 10 : ULift Prop }) → a = b := by
grind

example : c = 5 → ((fun (_ : Nat × Nat) => { down := a + c = b + 5 : ULift Prop }) = fun (_ : Nat × Nat) => { down := c < 10 : ULift Prop }) → a = b := by
grind -offset
3 changes: 3 additions & 0 deletions tests/lean/run/grind_match_eq_propagation.lean
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def h (v w : Vec α n) : Nat :=
example : h a b > 0 := by
grind [h.eq_def]

example : h a b > 0 := by
grind -offset [h.eq_def]

-- TODO: introduce casts while instantiating equation theorems for `h.match_1`
-- example (a b : Vec α 2) : h a b = 20 := by
-- unfold h
Expand Down
12 changes: 12 additions & 0 deletions tests/lean/run/grind_order_eq.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,15 @@ example (a b : Nat) (f : Nat → Int) : a ≤ b + 1 → b + 1 ≤ c → c ≤ a

example (a b : Nat) (f : Nat → Int) : a ≤ b + 1 → b + 1 ≤ a → f (1 + a) = f (1 + b + 1) := by
grind -offset -mbtc -lia -linarith (splits := 0)

example
: 2*n_1 + a = 1 → 2*n_1 + a = n_2 + 1 → n = 1 → n = n_3 + 1 → n_2 ≠ n_3 → False := by
grind -lia -linarith -offset -ring (splits := 0)

example
: a = b → a ≤ b + 1 := by
grind -lia -linarith -offset -ring (splits := 0) only

example
: a = b + 1 → a ≤ b + 2 := by
grind -lia -linarith -offset -ring (splits := 0) only
1 change: 1 addition & 0 deletions tests/lean/run/grind_sort_eqc.lean
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ h_2 : ¬f (f x) = g x x
[eqc] {z, g y z}
[eqc] {0, f x + -1 * f (f x) + -1, f (f x) + -1 * f (f (f x)) + -1, f (f (f x)) + -1 * f (f (f (f x))) + -1}
[eqc] {g z y, g y x}
[eqc] {f x + -1 * f (f x), f (f x) + -1 * f (f (f x)), f (f (f x)) + -1 * f (f (f (f x)))}
[ematch] E-matching patterns
[thm] feq: [@f #4 #3 #0]
[thm] geq: [@g #2 #1 #0 #0]
Expand Down
Loading