Skip to content

Commit e2a947c

Browse files
authored
feat: track occurrences in linarith (leanprover#8801)
This PR implements the infrastructure for variable elimination in the `grind linarith` procedure.
1 parent 26946dd commit e2a947c

File tree

8 files changed

+76
-4
lines changed

8 files changed

+76
-4
lines changed

src/Lean/Meta/Tactic/Grind/Arith/Cutsat/LeCnstr.lean

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,10 @@ def LeCnstr.assertImpl (c : LeCnstr) : GoalM Unit := do
114114
return ()
115115
let c ← refineWithDiseq c
116116
trace[grind.cutsat.assert.store] "{← c.pp}"
117+
c.p.updateOccs
117118
if a < 0 then
118-
c.p.updateOccs
119119
modify' fun s => { s with lowers := s.lowers.modify x (·.push c) }
120120
else
121-
c.p.updateOccs
122121
modify' fun s => { s with uppers := s.uppers.modify x (·.push c) }
123122
if (← c.satisfied) == .false then
124123
resetAssignmentFrom x

src/Lean/Meta/Tactic/Grind/Arith/Linear/IneqCnstr.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def IneqCnstr.assert (c : IneqCnstr) : LinearM Unit := do
3838
trace[grind.linarith.trivial] "{← c.denoteExpr}"
3939
| .add a x _ =>
4040
trace[grind.linarith.assert.store] "{← c.denoteExpr}"
41+
c.p.updateOccs
4142
if a < 0 then
4243
modifyStruct fun s => { s with lowers := s.lowers.modify x (·.push c) }
4344
else

src/Lean/Meta/Tactic/Grind/Arith/Linear/Inv.lean

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,25 @@ def _root_.Lean.Grind.Linarith.Poly.checkCoeffs : Poly → Bool
2525
| .nil => true
2626
| .add k _ p => k != 0 && checkCoeffs p
2727

28+
def _root_.Lean.Grind.Linarith.Poly.checkOccs (p : Poly) : LinearM Unit := do
29+
let .add _ y p := p | return ()
30+
let rec go (p : Poly) : LinearM Unit := do
31+
let .add _ x p := p | return ()
32+
assert! (← getOccursOf x).contains y
33+
go p
34+
go p
35+
36+
def _root_.Lean.Grind.Linarith.Poly.checkNoElimVars (p : Poly) : LinearM Unit := do
37+
let .add _ x p := p | return ()
38+
assert! !(← eliminated x)
39+
checkNoElimVars p
40+
2841
def _root_.Lean.Grind.Linarith.Poly.checkCnstrOf (p : Poly) (x : Var) : LinearM Unit := do
2942
assert! p.isSorted
3043
assert! p.checkCoeffs
3144
unless (← inconsistent) do
32-
-- p.checkNoElimVars
33-
-- p.checkOccs
45+
p.checkNoElimVars
46+
p.checkOccs
3447
pure ()
3548
let .add _ y _ := p | unreachable!
3649
assert! x == y

src/Lean/Meta/Tactic/Grind/Arith/Linear/PropagateEq.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def DiseqCnstr.assert (c : DiseqCnstr) : LinearM Unit := do
7272
setInconsistent (.diseq c)
7373
| .add _ x _ =>
7474
trace[grind.linarith.assert.store] "{← c.denoteExpr}"
75+
c.p.updateOccs
7576
modifyStruct fun s => { s with diseqs := s.diseqs.modify x (·.push c) }
7677
if (← c.satisfied) == .false then
7778
resetAssignmentFrom x

src/Lean/Meta/Tactic/Grind/Arith/Linear/StructId.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ where
179179
-- Create `1` variable, and assert strict lower bound `0 < 1`
180180
let x ← mkVar one (mark := false)
181181
let p := Poly.add (-1) x .nil
182+
p.updateOccs
182183
modifyStruct fun s => { s with
183184
lowers := s.lowers.modify x fun cs => cs.push { p, h := .oneGtZero, strict := true }
184185
}

src/Lean/Meta/Tactic/Grind/Arith/Linear/Types.lean

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ deriving instance Hashable for Poly
1919
deriving instance Hashable for Grind.Linarith.Expr
2020

2121
mutual
22+
/-- An equality constraint and its justification/proof. -/
23+
structure EqCnstr where
24+
p : Poly
25+
h : EqCnstrProof
26+
27+
inductive EqCnstrProof where
28+
| rfl -- TODO
29+
2230
/-- An inequality constraint and its justification/proof. -/
2331
structure IneqCnstr where
2432
p : Poly
@@ -58,6 +66,8 @@ end
5866
instance : Inhabited DiseqCnstr where
5967
default := { p := .nil, h := .core default default .zero .zero }
6068

69+
abbrev VarSet := RBTree Var compare
70+
6171
/--
6272
State for each algebraic structure by this module.
6373
Each type must at least implement the instance `IntModule`.
@@ -142,6 +152,24 @@ structure Struct where
142152
-/
143153
diseqSplits : PHashMap Poly FVarId := {}
144154
/--
155+
Mapping from variable to equation constraint used to eliminate it. `solved` variables should not occur in
156+
`diseqs`, `lowers`, or `uppers`.
157+
-/
158+
elimEqs : PArray (Option EqCnstr) := {}
159+
/--
160+
Elimination stack. For every variable in `elimStack`. If `x` in `elimStack`, then `elimEqs[x]` is not `none`.
161+
-/
162+
elimStack : List Var := []
163+
/--
164+
Mapping from variable to occurrences.
165+
For example, an entry `x ↦ {y, z}` means that `x` may occur in `lowers`, or `uppers`, or `diseqs` of
166+
variables `y` and `z`.
167+
If `x` occurs in `diseqs[y]`, `lowers[y]`, or `uppers[y]`, then `y` is in `occurs[x]`,
168+
but the reverse is not true.
169+
If `x` is in `elimStack`, then `occurs[x]` is the empty set.
170+
-/
171+
occurs : PArray VarSet := {}
172+
/--
145173
Linear constraints that are not supported.
146174
We use this information for diagnostics.
147175
TODO: store constraints instead.

src/Lean/Meta/Tactic/Grind/Arith/Linear/Util.lean

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,4 +194,31 @@ def inconsistent : LinearM Bool := do
194194
if (← isInconsistent) then return true
195195
return (← getStruct).conflict?.isSome
196196

197+
/-- Returns `true` if `x` has been eliminated using an equality constraint. -/
198+
def eliminated (x : Var) : LinearM Bool :=
199+
return (← getStruct).elimEqs[x]!.isSome
200+
201+
/-- Returns occurrences of `x`. -/
202+
def getOccursOf (x : Var) : LinearM VarSet :=
203+
return (← getStruct).occurs[x]!
204+
205+
/--
206+
Adds `y` as an occurrence of `x`.
207+
That is, `x` occurs in `lowers[y]`, `uppers[y]`, or `diseqs[y]`.
208+
-/
209+
def addOcc (x : Var) (y : Var) : LinearM Unit := do
210+
unless (← getOccursOf x).contains y do
211+
modifyStruct fun s => { s with occurs := s.occurs.modify x fun ys => ys.insert y }
212+
213+
/--
214+
Given `p` a polynomial being inserted into `lowers`, `uppers`, or `diseqs`,
215+
get its leading variable `y`, and adds `y` as an occurrence for the remaining variables in `p`.
216+
-/
217+
partial def _root_.Lean.Grind.Linarith.Poly.updateOccs (p : Poly) : LinearM Unit := do
218+
let .add _ y p := p | throwError "`grind linarith` internal error, unexpected constant polynomial"
219+
let rec go (p : Poly) : LinearM Unit := do
220+
let .add _ x p := p | return ()
221+
addOcc x y; go p
222+
go p
223+
197224
end Lean.Meta.Grind.Arith.Linear

src/Lean/Meta/Tactic/Grind/Arith/Linear/Var.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def mkVar (e : Expr) (mark := true) : LinearM Var := do
1919
lowers := s.lowers.push {}
2020
uppers := s.uppers.push {}
2121
diseqs := s.diseqs.push {}
22+
occurs := s.occurs.push {}
23+
elimEqs := s.elimEqs.push none
2224
}
2325
setTermStructId e
2426
if mark then

0 commit comments

Comments
 (0)