Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 48 additions & 40 deletions src/Lean/Meta/Tactic/Grind/Arith/Linear/Reify.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def isZeroInst (struct : Struct) (inst : Expr) : Bool :=
isSameExpr struct.zero.appArg! inst
def isHMulInst (struct : Struct) (inst : Expr) : Bool :=
isSameExpr struct.hmulFn.appArg! inst
def isHMulNatInst (struct : Struct) (inst : Expr) : Bool :=
isSameExpr struct.hmulNatFn.appArg! inst
def isSMulInst (struct : Struct) (inst : Expr) : Bool :=
if let some smulFn := struct.smulFn? then
isSameExpr smulFn.appArg! inst
Expand All @@ -35,52 +37,14 @@ If `skipVar` is `true`, then the result is `none` if `e` is not an interpreted `
We use `skipVar := false` when processing inequalities, and `skipVar := true` for equalities and disequalities
-/
partial def reify? (e : Expr) (skipVar : Bool) : LinearM (Option LinExpr) := do
let toVar (e : Expr) : LinearM LinExpr := do
return .var (← mkVar e)
let asVar (e : Expr) : LinearM LinExpr := do
reportInstIssue e
return .var (← mkVar e)
let isOfNatZero (e : Expr) : LinearM Bool := do
withDefault <| isDefEq e (← getStruct).ofNatZero
let rec go (e : Expr) : LinearM LinExpr := do
match_expr e with
| HAdd.hAdd _ _ _ i a b =>
if isAddInst (← getStruct) i then return .add (← go a) (← go b) else asVar e
| HSub.hSub _ _ _ i a b =>
if isSubInst (← getStruct) i then return .sub (← go a) (← go b) else asVar e
| HMul.hMul _ _ _ i a b =>
if isHMulInst (← getStruct) i then
let some k ← getIntValue? a | pure ()
return .mul k (← go b)
asVar e
| HSMul.hSMul _ _ _ i a b =>
if isSMulInst (← getStruct) i then
let some k ← getIntValue? a | pure ()
return .mul k (← go b)
asVar e
| Neg.neg _ i a =>
if isNegInst (← getStruct) i then return .neg (← go a) else asVar e
| Zero.zero _ i =>
if isZeroInst (← getStruct) i then return .zero else asVar e
| OfNat.ofNat _ _ _ =>
if (← isOfNatZero e) then return .zero else toVar e
| _ => toVar e
let asTopVar (e : Expr) : LinearM (Option LinExpr) := do
reportInstIssue e
if skipVar then
return none
else
return some (← asVar e)
match_expr e with
| HAdd.hAdd _ _ _ i a b =>
if isAddInst (← getStruct ) i then return some (.add (← go a) (← go b)) else asTopVar e
| HSub.hSub _ _ _ i a b =>
if isSubInst (← getStruct ) i then return some (.sub (← go a) (← go b)) else asTopVar e
| HMul.hMul _ _ _ i a b =>
if isHMulInst (← getStruct) i then
let some k ← getIntValue? a | pure ()
return some (.mul k (← go b))
asTopVar e
let some r ← processHMul i a b | asTopVar e
return some r
| HSMul.hSMul _ _ _ i a b =>
if isSMulInst (← getStruct) i then
let some k ← getIntValue? a | pure ()
Expand All @@ -97,5 +61,49 @@ partial def reify? (e : Expr) (skipVar : Bool) : LinearM (Option LinExpr) := do
return none
else
return some (← toVar e)
where
toVar (e : Expr) : LinearM LinExpr := do
return .var (← mkVar e)
asVar (e : Expr) : LinearM LinExpr := do
reportInstIssue e
return .var (← mkVar e)
asTopVar (e : Expr) : LinearM (Option LinExpr) := do
reportInstIssue e
if skipVar then
return none
else
return some (← asVar e)
isOfNatZero (e : Expr) : LinearM Bool := do
withDefault <| isDefEq e (← getStruct).ofNatZero
processHMul (i a b : Expr) : LinearM (Option LinExpr) := do
if isHMulInst (← getStruct) i then
let some k ← getIntValue? a | return none
return some (.mul k (← go b))
else if isHMulNatInst (← getStruct) i then
let some k ← getNatValue? a | return none
return some (.mul k (← go b))
return none
go (e : Expr) : LinearM LinExpr := do
match_expr e with
| HAdd.hAdd _ _ _ i a b =>
if isAddInst (← getStruct) i then return .add (← go a) (← go b) else asVar e
| HSub.hSub _ _ _ i a b =>
if isSubInst (← getStruct) i then return .sub (← go a) (← go b) else asVar e
| HMul.hMul _ _ _ i a b =>
let some r ← processHMul i a b | asVar e
return r
| HSMul.hSMul _ _ _ i a b =>
if isSMulInst (← getStruct) i then
let some k ← getIntValue? a | pure ()
return .mul k (← go b)
asVar e
| Neg.neg _ i a =>
if isNegInst (← getStruct) i then return .neg (← go a) else asVar e
| Zero.zero _ i =>
if isZeroInst (← getStruct) i then return .zero else asVar e
| OfNat.ofNat _ _ _ =>
if (← isOfNatZero e) then return .zero else toVar e
| _ => toVar e


end Lean.Meta.Grind.Arith.Linear
9 changes: 8 additions & 1 deletion src/Lean/Meta/Tactic/Grind/Arith/Linear/StructId.lean
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ where
let .some inst ← trySynthInstance instType
| throwError "`grind linarith` failed to find instance{indentExpr instType}"
return inst
let getHMulNatInst : GoalM Expr := do
let instType := mkApp3 (mkConst ``HMul [0, u, u]) Nat.mkType type type
let .some inst ← trySynthInstance instType
| throwError "`grind linarith` failed to find instance{indentExpr instType}"
return inst
let checkToFieldDefEq? (parentInst? : Option Expr) (inst? : Option Expr) (toFieldName : Name) : GoalM (Option Expr) := do
let some parentInst := parentInst? | return none
let some inst := inst? | return none
Expand Down Expand Up @@ -100,6 +105,8 @@ where
let negFn ← internalizeFn <| mkApp2 (mkConst ``Neg.neg [u]) type negInst
let hmulInst ← getHMulInst
let hmulFn ← internalizeFn <| mkApp4 (mkConst ``HMul.hMul [0, u, u]) Int.mkType type type hmulInst
let hmulNatInst ← getHMulNatInst
let hmulNatFn ← internalizeFn <| mkApp4 (mkConst ``HMul.hMul [0, u, u]) Nat.mkType type type hmulNatInst
ensureToFieldDefEq zeroInst intModuleInst ``Grind.IntModule.toZero
ensureToHomoFieldDefEq addInst intModuleInst ``Grind.IntModule.toAdd ``instHAdd
ensureToHomoFieldDefEq subInst intModuleInst ``Grind.IntModule.toSub ``instHSub
Expand Down Expand Up @@ -152,7 +159,7 @@ where
let id := (← get').structs.size
let struct : Struct := {
id, type, u, intModuleInst, preorderInst, isOrdInst, partialInst?, linearInst?, noNatDivInst?
leFn, ltFn, addFn, subFn, negFn, hmulFn, smulFn?, zero, one?
leFn, ltFn, addFn, subFn, negFn, hmulFn, hmulNatFn, smulFn?, zero, one?
ringInst?, commRingInst?, ringIsOrdInst?, ringId?, ofNatZero
}
modify' fun s => { s with structs := s.structs.push struct }
Expand Down
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Grind/Arith/Linear/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ structure Struct where
ltFn : Expr
addFn : Expr
hmulFn : Expr
hmulNatFn : Expr
smulFn? : Option Expr
subFn : Expr
negFn : Expr
Expand Down
76 changes: 76 additions & 0 deletions tests/lean/run/grind_ord_module.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
open Lean.Grind

variable (R : Type u) [IntModule R] [LinearOrder R] [IntModule.IsOrdered R]

example (a b c : R) (h : a < b) : a + c < b + c := by grind
example (a b c : R) (h : a < b) : c + a < c + b := by grind
example (a b : R) (h : a < b) : -b < -a := by grind

/--
trace: [grind.linarith.model] a := 0
[grind.linarith.model] b := 1
-/
#guard_msgs (drop error, trace) in
set_option trace.grind.linarith.model true in
example (a b : R) (h : a < b) : -a < -b := by grind

example (a b c : R) (h : a ≤ b) : a + c ≤ b + c := by grind
example (a b c : R) (h : a ≤ b) : c + a ≤ c + b := by grind
example (a b : R) (h : a ≤ b) : -b ≤ -a := by grind

/--
trace: [grind.linarith.model] a := 0
[grind.linarith.model] b := 1
-/
#guard_msgs (drop error, trace) in
set_option trace.grind.linarith.model true in
example (a b : R) (h : a ≤ b) : -a ≤ -b := by grind

example (a : R) (h : 0 < a) : 0 ≤ a := by grind
example (a : R) (h : 0 < a) : -2 * a < 0 := by grind

example (a b c : R) (_ : a ≤ b) (_ : b ≤ c) : a ≤ c := by grind
example (a b c : R) (_ : a ≤ b) (_ : b < c) : a < c := by grind
example (a b c : R) (_ : a < b) (_ : b ≤ c) : a < c := by grind
example (a b c : R) (_ : a < b) (_ : b < c) : a < c := by grind

example (a : R) (h : 2 * a < 0) : a < 0 := by grind
example (a : R) (h : 2 * a < 0) : 0 ≤ -a := by grind

example (a b : R) (_ : a < b) (_ : b < a) : False := by grind
example (a b : R) (_ : a < b ∧ b < a) : False := by grind
example (a b : R) (_ : a < b) : a ≠ b := by grind

example (a b c e v0 v1 : R) (h1 : v0 = 5 * a) (h2 : v1 = 3 * b) (h3 : v0 + v1 + c = 10 * e) :
v0 + 5 * e + (v1 - 3 * e) + (c - 2 * e) = 10 * e := by
grind

example (x y z : Int) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : 12 * y - 4 * z < 0) : False := by
grind
example (x y z : R) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : 12 * y - 4 * z < 0) : False := by
grind

example (x y z : Int) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : x * y < 5) (h3 : 12 * y - 4 * z < 0) :
False := by grind
example (x y z : R) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : 12 * y - 4 * z < 0) :
False := by grind

example (x y z : Int) (hx : x ≤ 3 * y) (h2 : y ≤ 2 * z) (h3 : x ≥ 6 * z) : x = 3*y := by
grind
example (x y z : R) (hx : x ≤ 3 * y) (h2 : y ≤ 2 * z) (h3 : x ≥ 6 * z) : x = 3 * y := by
grind

example (x y z : Int) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : x * y < 5) : ¬ 12*y - 4* z < 0 := by
grind
example (x y z : R) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) : ¬ 12 * y - 4 * z < 0 := by
grind

example (x y z : Int) (hx : ¬ x > 3 * y) (h2 : ¬ y > 2 * z) (h3 : x ≥ 6 * z) : x = 3 * y := by
grind
example (x y z : R) (hx : ¬ x > 3 * y) (h2 : ¬ y > 2 * z) (h3 : x ≥ 6 * z) : x = 3 * y := by
grind

example (x y z : Nat) (hx : x ≤ 3 * y) (h2 : y ≤ 2 * z) (h3 : x ≥ 6 * z) : x = 3 * y := by
grind
example (x y z : R) (hx : x ≤ 3 * y) (h2 : y ≤ 2 * z) (h3 : x ≥ 6 * z) : x = 3 * y := by
grind
24 changes: 24 additions & 0 deletions tests/lean/run/grind_preord_module.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
open Lean.Grind

-- `grind linarith` currently does not support negation of linear constraints.
variable (R : Type u) [IntModule R] [Preorder R] [IntModule.IsOrdered R]

example (a b : R) (_ : a < b) (_ : b < a) : False := by grind
example (a b : R) (_ : a < b ∧ b < a) : False := by grind
example (a b : R) (_ : a < b) : a ≠ b := by grind

example (x y z : Int) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : 12 * y - 4 * z < 0) : False := by
grind
example (x y z : R) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : 12 * y - 4 * z < 0) : False := by
grind

example (x y z : Int) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : x * y < 5) (h3 : 12 * y - 4 * z < 0) :
False := by grind
example (x y z : R) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : 12 * y - 4 * z < 0) :
False := by grind

-- It does cancel the double negation in the following two examples
example (x y z : Int) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) (h3 : x * y < 5) : ¬ 12*y - 4* z < 0 := by
grind
example (x y z : R) (h1 : 2 * x < 3 * y) (h2 : -4 * x + 2 * z < 0) : ¬ 12 * y - 4 * z < 0 := by
grind
Loading