Skip to content

Commit 2a63b39

Browse files
authored
fix: ring module in grind (#8713)
This PR fixes a bug in the commutative ring module used in `grind`. It was missing simplification opportunities.
1 parent 0b2884b commit 2a63b39

File tree

9 files changed

+64
-57
lines changed

9 files changed

+64
-57
lines changed

src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,6 @@ builtin_initialize registerTraceClass `grind.debug.ring.proof
3838
builtin_initialize registerTraceClass `grind.debug.ring.check
3939
builtin_initialize registerTraceClass `grind.debug.ring.impEq
4040
builtin_initialize registerTraceClass `grind.debug.ring.simpBasis
41+
builtin_initialize registerTraceClass `grind.debug.ring.basis
4142

4243
end Lean

src/Lean/Meta/Tactic/Grind/Arith/CommRing/EqCnstr.lean

Lines changed: 31 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,11 @@ then the leading coefficient of the equation must also divide `k`
4848
def _root_.Lean.Grind.CommRing.Mon.findSimp? (k : Int) (m : Mon) : RingM (Option EqCnstr) := do
4949
let checkCoeff ← checkCoeffDvd
5050
let noZeroDiv ← noZeroDivisors
51-
let rec go : Mon → RingM (Option EqCnstr)
52-
| .unit => return none
53-
| .mult pw m' => do
54-
for c in (← getRing).varToBasis[pw.x]! do
55-
if !checkCoeff || noZeroDiv || (c.p.lc ∣ k) then
56-
if c.p.divides m then
57-
return some c
58-
go m'
59-
go m
51+
for c in (← getRing).basis do
52+
if !checkCoeff || noZeroDiv || (c.p.lc ∣ k) then
53+
if c.p.divides m then
54+
return some c
55+
return none
6056

6157
/--
6258
Returns `some c`, where `c` is an equation from the basis whose leading monomial divides some
@@ -129,6 +125,7 @@ def EqCnstr.checkConstant (c : EqCnstr) : RingM Bool := do
129125
c.setUnsat
130126
else
131127
-- Remark: we currently don't do anything if the characteristic is not known.
128+
-- TODO: if `k.natAbs` is `1`, we could set all terms of this ring `0`.
132129
trace_goal[grind.ring.assert.discard] "{← c.denoteExpr}"
133130
return true
134131

@@ -153,10 +150,9 @@ private def addSorted (c : EqCnstr) : List EqCnstr → List EqCnstr
153150
c' :: addSorted c cs
154151

155152
def addToBasisCore (c : EqCnstr) : RingM Unit := do
156-
let .add _ m _ := c.p | return ()
157-
let .mult pw _ := m | return ()
153+
trace[grind.debug.ring.basis] "{← c.denoteExpr}"
158154
modifyRing fun s => { s with
159-
varToBasis := s.varToBasis.modify pw.x (addSorted c)
155+
basis := addSorted c s.basis
160156
recheck := true
161157
}
162158

@@ -168,18 +164,12 @@ def EqCnstr.addToQueue (c : EqCnstr) : RingM Unit := do
168164
def EqCnstr.superposeWith (c : EqCnstr) : RingM Unit := do
169165
if (← checkMaxSteps) then return ()
170166
let .add _ m _ := c.p | return ()
171-
go m
172-
where
173-
go : Mon → RingM Unit
174-
| .unit => return ()
175-
| .mult pw m => do
176-
let x := pw.x
177-
let cs := (← getRing).varToBasis[x]!
178-
for c' in cs do
179-
let r ← c.p.spolM c'.p
180-
trace_goal[grind.ring.superpose] "{← c.denoteExpr}\nwith: {← c'.denoteExpr}\nresult: {← r.spol.denoteExpr} = 0"
181-
addToQueue (← mkEqCnstr r.spol <| .superpose r.k₁ r.m₁ c r.k₂ r.m₂ c')
182-
go m
167+
for c' in (← getRing).basis do
168+
let .add _ m' _ := c'.p | pure ()
169+
if m.sharesVar m' then
170+
let r ← c.p.spolM c'.p
171+
trace_goal[grind.ring.superpose] "{← c.denoteExpr}\nwith: {← c'.denoteExpr}\nresult: {← r.spol.denoteExpr} = 0"
172+
addToQueue (← mkEqCnstr r.spol <| .superpose r.k₁ r.m₁ c r.k₂ r.m₂ c')
183173

184174
/--
185175
Tries to convert the leading monomial into a monic one.
@@ -215,25 +205,23 @@ def EqCnstr.toMonic (c : EqCnstr) : RingM EqCnstr := do
215205
def EqCnstr.simplifyBasis (c : EqCnstr) : RingM Unit := do
216206
trace[grind.debug.ring.simpBasis] "using: {← c.denoteExpr}"
217207
let .add _ m _ := c.p | return ()
218-
let rec go (m' : Mon) : RingM Unit := do
219-
match m' with
220-
| .unit => return ()
221-
| .mult pw m' => goVar m pw.x; go m'
222-
go m
223-
where
224-
goVar (m : Mon) (x : Var) : RingM Unit := do
225-
let cs := (← getRing).varToBasis[x]!
226-
if cs.isEmpty then return ()
227-
modifyRing fun s => { s with varToBasis := s.varToBasis.set x {} }
228-
for c' in cs do
229-
trace[grind.debug.ring.simpBasis] "target: {← c'.denoteExpr}"
230-
let .add _ m' _ := c'.p | pure ()
231-
if m.divides m' then
232-
let c'' ← c'.simplifyWithExhaustively c
233-
trace[grind.debug.ring.simpBasis] "simplified: {← c''.denoteExpr}"
234-
addToQueue c''
235-
else
236-
addToBasisCore c'
208+
let rec go (basis : List EqCnstr) (acc : List EqCnstr) : RingM (List EqCnstr) := do
209+
match basis with
210+
| [] => return acc.reverse
211+
| c' :: basis =>
212+
match c'.p with
213+
| .add _ m' _ =>
214+
if m.divides m' then
215+
let c'' ← c'.simplifyWithExhaustively c
216+
trace[grind.debug.ring.simpBasis] "simplified: {← c''.denoteExpr}"
217+
unless (← checkConstant c'') do
218+
addToQueue c''
219+
go basis acc
220+
else
221+
go basis (c' :: acc)
222+
| _ => go basis (c' :: acc)
223+
let basis ← go (← getRing).basis []
224+
modifyRing fun s => { s with basis }
237225

238226
def EqCnstr.addToBasisAfterSimp (c : EqCnstr) : RingM Unit := do
239227
let c ← c.toMonic

src/Lean/Meta/Tactic/Grind/Arith/CommRing/Inv.lean

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,8 @@ private def checkPoly (p : Poly) : RingM Unit := do
2929

3030
private def checkBasis : RingM Unit := do
3131
let mut x := 0
32-
for cs in (← getRing).varToBasis do
33-
for c in cs do
34-
checkPoly c.p
35-
let .add _ m _ := c.p | unreachable!
36-
let .mult pw _ := m | unreachable!
37-
assert! pw.x == x
32+
for c in (← getRing).basis do
33+
checkPoly c.p
3834
x := x + 1
3935

4036
private def checkQueue : RingM Unit := do

src/Lean/Meta/Tactic/Grind/Arith/CommRing/PP.lean

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@ private def push (msgs : Array MessageData) (msg? : Option MessageData) : Array
2424

2525
def ppBasis? : ReaderT Ring MetaM (Option MessageData) := do
2626
let mut basis := #[]
27-
for cs in (← getRing).varToBasis do
28-
for c in cs do
29-
basis := basis.push (toTraceElem (← c.denoteExpr))
27+
for c in (← getRing).basis do
28+
basis := basis.push (toTraceElem (← c.denoteExpr))
3029
return toOption `basis "Basis" basis
3130

3231
def ppDiseqs? : ReaderT Ring MetaM (Option MessageData) := do

src/Lean/Meta/Tactic/Grind/Arith/CommRing/Poly.lean

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@ prelude
77
import Init.Grind.CommRing.Poly
88
namespace Lean.Grind.CommRing
99

10+
/-- `sharesVar m₁ m₂` returns `true` if `m₁` and `m₂` shares at least one variable. -/
11+
def Mon.sharesVar : Mon → Mon → Bool
12+
| .unit, _ => false
13+
| _, .unit => false
14+
| .mult pw₁ m₁, .mult pw₂ m₂ =>
15+
match compare pw₁.x pw₂.x with
16+
| .eq => true
17+
| .lt => sharesVar m₁ (.mult pw₂ m₂)
18+
| .gt => sharesVar (.mult pw₁ m₁) m₂
19+
1020
/-- `lcm m₁ m₂` returns the least common multiple of the given monomials. -/
1121
def Mon.lcm : Mon → Mon → Mon
1222
| .unit, m₂ => m₂

src/Lean/Meta/Tactic/Grind/Arith/CommRing/Types.lean

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,11 @@ structure Ring where
165165
/-- Equations to process. -/
166166
queue : Queue := {}
167167
/--
168-
Mapping from variables `x` to equations such that the smallest variable
169-
in the leading monomial is `x`.
168+
The basis is currently just a list. If this is a performance bottleneck, we should use
169+
a better data-structure. For examples, we could use a simple indexing for the linear case
170+
where we map variable in the leading monomial to `EqCnstr`.
170171
-/
171-
varToBasis : PArray (List EqCnstr) := {}
172+
basis : List EqCnstr := {}
172173
/-- Disequalities. -/
173174
-- TODO: add indexing
174175
diseqs : PArray DiseqCnstr := {}

src/Lean/Meta/Tactic/Grind/Arith/CommRing/Var.lean

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def mkVar (e : Expr) : RingM Var := do
1616
modifyRing fun s => { s with
1717
vars := s.vars.push e
1818
varMap := s.varMap.insert { expr := e } var
19-
varToBasis := s.varToBasis.push []
2019
}
2120
setTermRingId e
2221
markAsCommRingTerm e

tests/lean/run/grind_ring_1.lean

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,16 @@ set_option trace.grind.ring.assert.queue true in
6666
example (x y : Int) : x + 16*y^2 - 7*x^2 = 0 → False := by
6767
fail_if_success grind
6868
sorry
69+
70+
/--
71+
trace: [grind.debug.ring.basis] a ^ 2 * b + -1 = 0
72+
[grind.debug.ring.basis] a * b ^ 2 + -1 * b = 0
73+
[grind.debug.ring.basis] a * b + -1 * b = 0
74+
[grind.debug.ring.basis] b + -1 = 0
75+
[grind.debug.ring.basis] a + -1 = 0
76+
-/
77+
#guard_msgs (drop error, trace) in
78+
set_option trace.grind.debug.ring.basis true in
79+
example [CommRing α] (a b c : α)
80+
: a^2*b = 1 → a*b^2 = b → False := by
81+
grind

tests/lean/run/grind_ring_2.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ example (a b c : Int) (f : Int → Nat)
126126
f (a^4 + b^4) + f (9 - c^4) ≠ 1 := by
127127
grind
128128

129-
example [CommRing α] (a b c : α) (f : α → Nat)
129+
example [CommRing α] [NoNatZeroDivisors α] (a b c : α) (f : α → Nat)
130130
: a + b + c = 3
131131
a^2 + b^2 + c^2 = 5
132132
a^3 + b^3 + c^3 = 7

0 commit comments

Comments
 (0)