Skip to content

Commit 6bfc268

Browse files
leodemouraalgebraic-dev
authored andcommitted
feat: heterogeneous (k : Nat) * (a : R) support in grind linarith (leanprover#8773)
This PR implements support for the heterogeneous `(k : Nat) * (a : R)` in ordered modules. Example: ```lean variable (R : Type u) [IntModule R] [LinearOrder R] [IntModule.IsOrdered R] example (x y z : R) (hx : x ≤ 3 * y) (h2 : y ≤ 2 * z) (h3 : x ≥ 6 * z) : x = 3 * y := by grind example (x y z : Int) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : x * y < 5) : ¬ 12*y - 4* z < 0 := by grind ```
1 parent 61868c1 commit 6bfc268

File tree

5 files changed

+157
-41
lines changed

5 files changed

+157
-41
lines changed

src/Lean/Meta/Tactic/Grind/Arith/Linear/Reify.lean

Lines changed: 48 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ def isZeroInst (struct : Struct) (inst : Expr) : Bool :=
1515
isSameExpr struct.zero.appArg! inst
1616
def isHMulInst (struct : Struct) (inst : Expr) : Bool :=
1717
isSameExpr struct.hmulFn.appArg! inst
18+
def isHMulNatInst (struct : Struct) (inst : Expr) : Bool :=
19+
isSameExpr struct.hmulNatFn.appArg! inst
1820
def isSMulInst (struct : Struct) (inst : Expr) : Bool :=
1921
if let some smulFn := struct.smulFn? then
2022
isSameExpr smulFn.appArg! inst
@@ -35,52 +37,14 @@ If `skipVar` is `true`, then the result is `none` if `e` is not an interpreted `
3537
We use `skipVar := false` when processing inequalities, and `skipVar := true` for equalities and disequalities
3638
-/
3739
partial def reify? (e : Expr) (skipVar : Bool) : LinearM (Option LinExpr) := do
38-
let toVar (e : Expr) : LinearM LinExpr := do
39-
return .var (← mkVar e)
40-
let asVar (e : Expr) : LinearM LinExpr := do
41-
reportInstIssue e
42-
return .var (← mkVar e)
43-
let isOfNatZero (e : Expr) : LinearM Bool := do
44-
withDefault <| isDefEq e (← getStruct).ofNatZero
45-
let rec go (e : Expr) : LinearM LinExpr := do
46-
match_expr e with
47-
| HAdd.hAdd _ _ _ i a b =>
48-
if isAddInst (← getStruct) i then return .add (← go a) (← go b) else asVar e
49-
| HSub.hSub _ _ _ i a b =>
50-
if isSubInst (← getStruct) i then return .sub (← go a) (← go b) else asVar e
51-
| HMul.hMul _ _ _ i a b =>
52-
if isHMulInst (← getStruct) i then
53-
let some k ← getIntValue? a | pure ()
54-
return .mul k (← go b)
55-
asVar e
56-
| HSMul.hSMul _ _ _ i a b =>
57-
if isSMulInst (← getStruct) i then
58-
let some k ← getIntValue? a | pure ()
59-
return .mul k (← go b)
60-
asVar e
61-
| Neg.neg _ i a =>
62-
if isNegInst (← getStruct) i then return .neg (← go a) else asVar e
63-
| Zero.zero _ i =>
64-
if isZeroInst (← getStruct) i then return .zero else asVar e
65-
| OfNat.ofNat _ _ _ =>
66-
if (← isOfNatZero e) then return .zero else toVar e
67-
| _ => toVar e
68-
let asTopVar (e : Expr) : LinearM (Option LinExpr) := do
69-
reportInstIssue e
70-
if skipVar then
71-
return none
72-
else
73-
return some (← asVar e)
7440
match_expr e with
7541
| HAdd.hAdd _ _ _ i a b =>
7642
if isAddInst (← getStruct ) i then return some (.add (← go a) (← go b)) else asTopVar e
7743
| HSub.hSub _ _ _ i a b =>
7844
if isSubInst (← getStruct ) i then return some (.sub (← go a) (← go b)) else asTopVar e
7945
| HMul.hMul _ _ _ i a b =>
80-
if isHMulInst (← getStruct) i then
81-
let some k ← getIntValue? a | pure ()
82-
return some (.mul k (← go b))
83-
asTopVar e
46+
let some r ← processHMul i a b | asTopVar e
47+
return some r
8448
| HSMul.hSMul _ _ _ i a b =>
8549
if isSMulInst (← getStruct) i then
8650
let some k ← getIntValue? a | pure ()
@@ -97,5 +61,49 @@ partial def reify? (e : Expr) (skipVar : Bool) : LinearM (Option LinExpr) := do
9761
return none
9862
else
9963
return some (← toVar e)
64+
where
65+
toVar (e : Expr) : LinearM LinExpr := do
66+
return .var (← mkVar e)
67+
asVar (e : Expr) : LinearM LinExpr := do
68+
reportInstIssue e
69+
return .var (← mkVar e)
70+
asTopVar (e : Expr) : LinearM (Option LinExpr) := do
71+
reportInstIssue e
72+
if skipVar then
73+
return none
74+
else
75+
return some (← asVar e)
76+
isOfNatZero (e : Expr) : LinearM Bool := do
77+
withDefault <| isDefEq e (← getStruct).ofNatZero
78+
processHMul (i a b : Expr) : LinearM (Option LinExpr) := do
79+
if isHMulInst (← getStruct) i then
80+
let some k ← getIntValue? a | return none
81+
return some (.mul k (← go b))
82+
else if isHMulNatInst (← getStruct) i then
83+
let some k ← getNatValue? a | return none
84+
return some (.mul k (← go b))
85+
return none
86+
go (e : Expr) : LinearM LinExpr := do
87+
match_expr e with
88+
| HAdd.hAdd _ _ _ i a b =>
89+
if isAddInst (← getStruct) i then return .add (← go a) (← go b) else asVar e
90+
| HSub.hSub _ _ _ i a b =>
91+
if isSubInst (← getStruct) i then return .sub (← go a) (← go b) else asVar e
92+
| HMul.hMul _ _ _ i a b =>
93+
let some r ← processHMul i a b | asVar e
94+
return r
95+
| HSMul.hSMul _ _ _ i a b =>
96+
if isSMulInst (← getStruct) i then
97+
let some k ← getIntValue? a | pure ()
98+
return .mul k (← go b)
99+
asVar e
100+
| Neg.neg _ i a =>
101+
if isNegInst (← getStruct) i then return .neg (← go a) else asVar e
102+
| Zero.zero _ i =>
103+
if isZeroInst (← getStruct) i then return .zero else asVar e
104+
| OfNat.ofNat _ _ _ =>
105+
if (← isOfNatZero e) then return .zero else toVar e
106+
| _ => toVar e
107+
100108

101109
end Lean.Meta.Grind.Arith.Linear

src/Lean/Meta/Tactic/Grind/Arith/Linear/StructId.lean

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ where
6969
let .some inst ← trySynthInstance instType
7070
| throwError "`grind linarith` failed to find instance{indentExpr instType}"
7171
return inst
72+
let getHMulNatInst : GoalM Expr := do
73+
let instType := mkApp3 (mkConst ``HMul [0, u, u]) Nat.mkType type type
74+
let .some inst ← trySynthInstance instType
75+
| throwError "`grind linarith` failed to find instance{indentExpr instType}"
76+
return inst
7277
let checkToFieldDefEq? (parentInst? : Option Expr) (inst? : Option Expr) (toFieldName : Name) : GoalM (Option Expr) := do
7378
let some parentInst := parentInst? | return none
7479
let some inst := inst? | return none
@@ -100,6 +105,8 @@ where
100105
let negFn ← internalizeFn <| mkApp2 (mkConst ``Neg.neg [u]) type negInst
101106
let hmulInst ← getHMulInst
102107
let hmulFn ← internalizeFn <| mkApp4 (mkConst ``HMul.hMul [0, u, u]) Int.mkType type type hmulInst
108+
let hmulNatInst ← getHMulNatInst
109+
let hmulNatFn ← internalizeFn <| mkApp4 (mkConst ``HMul.hMul [0, u, u]) Nat.mkType type type hmulNatInst
103110
ensureToFieldDefEq zeroInst intModuleInst ``Grind.IntModule.toZero
104111
ensureToHomoFieldDefEq addInst intModuleInst ``Grind.IntModule.toAdd ``instHAdd
105112
ensureToHomoFieldDefEq subInst intModuleInst ``Grind.IntModule.toSub ``instHSub
@@ -152,7 +159,7 @@ where
152159
let id := (← get').structs.size
153160
let struct : Struct := {
154161
id, type, u, intModuleInst, preorderInst, isOrdInst, partialInst?, linearInst?, noNatDivInst?
155-
leFn, ltFn, addFn, subFn, negFn, hmulFn, smulFn?, zero, one?
162+
leFn, ltFn, addFn, subFn, negFn, hmulFn, hmulNatFn, smulFn?, zero, one?
156163
ringInst?, commRingInst?, ringIsOrdInst?, ringId?, ofNatZero
157164
}
158165
modify' fun s => { s with structs := s.structs.push struct }

src/Lean/Meta/Tactic/Grind/Arith/Linear/Types.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ structure Struct where
9494
ltFn : Expr
9595
addFn : Expr
9696
hmulFn : Expr
97+
hmulNatFn : Expr
9798
smulFn? : Option Expr
9899
subFn : Expr
99100
negFn : Expr
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
open Lean.Grind
2+
3+
variable (R : Type u) [IntModule R] [LinearOrder R] [IntModule.IsOrdered R]
4+
5+
example (a b c : R) (h : a < b) : a + c < b + c := by grind
6+
example (a b c : R) (h : a < b) : c + a < c + b := by grind
7+
example (a b : R) (h : a < b) : -b < -a := by grind
8+
9+
/--
10+
trace: [grind.linarith.model] a := 0
11+
[grind.linarith.model] b := 1
12+
-/
13+
#guard_msgs (drop error, trace) in
14+
set_option trace.grind.linarith.model true in
15+
example (a b : R) (h : a < b) : -a < -b := by grind
16+
17+
example (a b c : R) (h : a ≤ b) : a + c ≤ b + c := by grind
18+
example (a b c : R) (h : a ≤ b) : c + a ≤ c + b := by grind
19+
example (a b : R) (h : a ≤ b) : -b ≤ -a := by grind
20+
21+
/--
22+
trace: [grind.linarith.model] a := 0
23+
[grind.linarith.model] b := 1
24+
-/
25+
#guard_msgs (drop error, trace) in
26+
set_option trace.grind.linarith.model true in
27+
example (a b : R) (h : a ≤ b) : -a ≤ -b := by grind
28+
29+
example (a : R) (h : 0 < a) : 0 ≤ a := by grind
30+
example (a : R) (h : 0 < a) : -2 * a < 0 := by grind
31+
32+
example (a b c : R) (_ : a ≤ b) (_ : b ≤ c) : a ≤ c := by grind
33+
example (a b c : R) (_ : a ≤ b) (_ : b < c) : a < c := by grind
34+
example (a b c : R) (_ : a < b) (_ : b ≤ c) : a < c := by grind
35+
example (a b c : R) (_ : a < b) (_ : b < c) : a < c := by grind
36+
37+
example (a : R) (h : 2 * a < 0) : a < 0 := by grind
38+
example (a : R) (h : 2 * a < 0) : 0 ≤ -a := by grind
39+
40+
example (a b : R) (_ : a < b) (_ : b < a) : False := by grind
41+
example (a b : R) (_ : a < b ∧ b < a) : False := by grind
42+
example (a b : R) (_ : a < b) : a ≠ b := by grind
43+
44+
example (a b c e v0 v1 : R) (h1 : v0 = 5 * a) (h2 : v1 = 3 * b) (h3 : v0 + v1 + c = 10 * e) :
45+
v0 + 5 * e + (v1 - 3 * e) + (c - 2 * e) = 10 * e := by
46+
grind
47+
48+
example (x y z : Int) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : 12 * y - 4 * z < 0) : False := by
49+
grind
50+
example (x y z : R) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : 12 * y - 4 * z < 0) : False := by
51+
grind
52+
53+
example (x y z : Int) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : x * y < 5) (h3 : 12 * y - 4 * z < 0) :
54+
False := by grind
55+
example (x y z : R) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : 12 * y - 4 * z < 0) :
56+
False := by grind
57+
58+
example (x y z : Int) (hx : x ≤ 3 * y) (h2 : y ≤ 2 * z) (h3 : x ≥ 6 * z) : x = 3*y := by
59+
grind
60+
example (x y z : R) (hx : x ≤ 3 * y) (h2 : y ≤ 2 * z) (h3 : x ≥ 6 * z) : x = 3 * y := by
61+
grind
62+
63+
example (x y z : Int) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : x * y < 5) : ¬ 12*y - 4* z < 0 := by
64+
grind
65+
example (x y z : R) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) : ¬ 12 * y - 4 * z < 0 := by
66+
grind
67+
68+
example (x y z : Int) (hx : ¬ x > 3 * y) (h2 : ¬ y > 2 * z) (h3 : x ≥ 6 * z) : x = 3 * y := by
69+
grind
70+
example (x y z : R) (hx : ¬ x > 3 * y) (h2 : ¬ y > 2 * z) (h3 : x ≥ 6 * z) : x = 3 * y := by
71+
grind
72+
73+
example (x y z : Nat) (hx : x ≤ 3 * y) (h2 : y ≤ 2 * z) (h3 : x ≥ 6 * z) : x = 3 * y := by
74+
grind
75+
example (x y z : R) (hx : x ≤ 3 * y) (h2 : y ≤ 2 * z) (h3 : x ≥ 6 * z) : x = 3 * y := by
76+
grind
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
open Lean.Grind
2+
3+
-- `grind linarith` currently does not support negation of linear constraints.
4+
variable (R : Type u) [IntModule R] [Preorder R] [IntModule.IsOrdered R]
5+
6+
example (a b : R) (_ : a < b) (_ : b < a) : False := by grind
7+
example (a b : R) (_ : a < b ∧ b < a) : False := by grind
8+
example (a b : R) (_ : a < b) : a ≠ b := by grind
9+
10+
example (x y z : Int) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : 12 * y - 4 * z < 0) : False := by
11+
grind
12+
example (x y z : R) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : 12 * y - 4 * z < 0) : False := by
13+
grind
14+
15+
example (x y z : Int) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : x * y < 5) (h3 : 12 * y - 4 * z < 0) :
16+
False := by grind
17+
example (x y z : R) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : 12 * y - 4 * z < 0) :
18+
False := by grind
19+
20+
-- It does cancel the double negation in the following two examples
21+
example (x y z : Int) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : x * y < 5) : ¬ 12*y - 4* z < 0 := by
22+
grind
23+
example (x y z : R) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) : ¬ 12 * y - 4 * z < 0 := by
24+
grind

0 commit comments

Comments
 (0)