Skip to content

Commit 0504e32

Browse files
authored
feat: add addEdge to grind order (#10596)
This PR implements the function for adding new edges to the graph used by `grind order`. The graph maintains the transitive closure of all asserted constraints.
1 parent fbfc769 commit 0504e32

File tree

6 files changed

+233
-41
lines changed

6 files changed

+233
-41
lines changed

src/Init/Grind/Order.lean

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,20 @@ theorem le_eq_false_of_lt {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPre
196196
have := Preorder.lt_irrefl a
197197
contradiction
198198

199+
theorem lt_eq_false_of_lt {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPreorder α]
200+
(a b : α) : a < b → (b < a) = False := by
201+
simp; intro h₁ h₂
202+
have := lt_trans h₁ h₂
203+
have := Preorder.lt_irrefl a
204+
contradiction
205+
206+
theorem lt_eq_false_of_le {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPreorder α]
207+
(a b : α) : a ≤ b → (b < a) = False := by
208+
simp; intro h₁ h₂
209+
have := le_lt_trans h₁ h₂
210+
have := Preorder.lt_irrefl a
211+
contradiction
212+
199213
theorem le_eq_false_of_le_k {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPreorder α] [Ring α] [OrderedRing α]
200214
(a b : α) (k₁ k₂ : Int) : (k₂ + k₁).blt' 0 → a ≤ b + k₁ → (b ≤ a + k₂) = False := by
201215
intro h₁; simp; intro h₂ h₃

src/Lean/Meta/Tactic/Grind/Order/Assert.lean

Lines changed: 147 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,108 @@ where
2525
let p' ← getProof u p.w
2626
go (← mkTrans p' p v)
2727

28+
/--
29+
Given a new edge edge `u --(kuv)--> v` justified by proof `huv` s.t.
30+
it creates a negative cycle with the existing path `v --{kvu}-->* u`, i.e., `kuv + kvu < 0`,
31+
this function closes the current goal by constructing a proof of `False`.
32+
-/
33+
def setUnsat (u v : NodeId) (kuv : Weight) (huv : Expr) (kvu : Weight) : OrderM Unit := do
34+
let hvu ← mkProofForPath v u
35+
let u ← getExpr u
36+
let v ← getExpr v
37+
let h ← mkUnsatProof u v kuv huv kvu hvu
38+
closeGoal h
39+
40+
/-- Sets the new shortest distance `k` between nodes `u` and `v`. -/
41+
def setDist (u v : NodeId) (k : Weight) : OrderM Unit := do
42+
modifyStruct fun s => { s with
43+
targets := s.targets.modify u fun es => es.insert v k
44+
sources := s.sources.modify v fun es => es.insert u k
45+
}
46+
47+
def setProof (u v : NodeId) (p : ProofInfo) : OrderM Unit := do
48+
modifyStruct fun s => { s with
49+
proofs := s.proofs.modify u fun es => es.insert v p
50+
}
51+
52+
@[inline] def forEachSourceOf (u : NodeId) (f : NodeId → Weight → OrderM Unit) : OrderM Unit := do
53+
(← getStruct).sources[u]!.forM f
54+
55+
@[inline] def forEachTargetOf (u : NodeId) (f : NodeId → Weight → OrderM Unit) : OrderM Unit := do
56+
(← getStruct).targets[u]!.forM f
57+
58+
/-- Returns `true` if `k` is smaller than the shortest distance between `u` and `v` -/
59+
def isShorter (u v : NodeId) (k : Weight) : OrderM Bool := do
60+
if let some k' ← getDist? u v then
61+
return k < k'
62+
else
63+
return true
64+
2865
/-- Adds `p` to the list of things to be propagated. -/
2966
def pushToPropagate (p : ToPropagate) : OrderM Unit :=
3067
modifyStruct fun s => { s with propagate := p :: s.propagate }
3168

32-
/-
33-
def propagateEqTrue (e : Expr) (u v : NodeId) (k k' : Int) : OrderM Unit := do
69+
def propagateEqTrue (e : Expr) (u v : NodeId) (k k' : Weight) : OrderM Unit := do
3470
let kuv ← mkProofForPath u v
3571
let u ← getExpr u
3672
let v ← getExpr v
37-
pushEqTrue e <| mkPropagateEqTrueProof u v k kuv k'
73+
let h ← mkPropagateEqTrueProof u v k kuv k'
74+
pushEqTrue e h
3875

39-
private def propagateEqFalse (e : Expr) (u v : NodeId) (k k' : Int) : OrderM Unit := do
76+
def propagateEqFalse (e : Expr) (u v : NodeId) (k k' : Weight) : OrderM Unit := do
4077
let kuv ← mkProofForPath u v
4178
let u ← getExpr u
4279
let v ← getExpr v
43-
pushEqFalse e <| mkPropagateEqFalseProof u v k kuv k'
80+
let h ← mkPropagateEqFalseProof u v k kuv k'
81+
pushEqFalse e h
82+
83+
/-- Propagates all pending constraints and equalities and resets to "to do" list. -/
84+
private def propagatePending : OrderM Unit := do
85+
let todo := (← getStruct).propagate
86+
modifyStruct fun s => { s with propagate := [] }
87+
for p in todo do
88+
match p with
89+
| .eqTrue e u v k k' => propagateEqTrue e u v k k'
90+
| .eqFalse e u v k k' => propagateEqFalse e u v k k'
91+
| .eq u v =>
92+
let ue ← getExpr u
93+
let ve ← getExpr v
94+
unless (← isEqv ue ve) do
95+
let huv ← mkProofForPath u v
96+
let hvu ← mkProofForPath v u
97+
let h ← mkEqProofOfLeOfLe ue ve huv hvu
98+
pushEq ue ve h
99+
100+
def Cnstr.getWeight? (c : Cnstr α) : Option Weight :=
101+
match c.kind with
102+
| .le => some { k := c.k }
103+
| .lt => some { k := c.k, strict := true }
104+
| .eq => none
105+
106+
/--
107+
Given `e` represented by constraint `c` (from `u` to `v`).
108+
Checks whether `e = True` can be propagated using the path `u --(k)--> v`.
109+
If it can, adds a new entry to propagation list.
110+
-/
111+
def checkEqTrue (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : OrderM Bool := do
112+
let some k' := c.getWeight? | return false
113+
if k ≤ k' then
114+
pushToPropagate <| .eqTrue e u v k k'
115+
return true
116+
else
117+
return false
118+
119+
/--
120+
Given `e` represented by constraint `c` (from `v` to `u`).
121+
Checks whether `e = False` can be propagated using the path `u --(k)--> v`.
122+
If it can, adds a new entry to propagation list.
44123
-/
124+
def checkEqFalse (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : OrderM Bool := do
125+
let some k' := c.getWeight? | return false
126+
if (k + k').isNeg then
127+
pushToPropagate <| .eqFalse e u v k k'
128+
return true
129+
return false
45130

46131
/--
47132
Auxiliary function for implementing theory propagation.
@@ -51,8 +136,7 @@ associated with `(u, v)` IF
51136
- `e` is already assigned, or
52137
- `f c e` returns true
53138
-/
54-
@[inline]
55-
private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → OrderM Bool) : OrderM Unit := do
139+
@[inline] def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → OrderM Bool) : OrderM Unit := do
56140
if let some cs := (← getStruct).cnstrsOf.find? (u, v) then
57141
let cs' ← cs.filterM fun (c, e) => do
58142
if (← isEqTrue e <||> isEqFalse e) then
@@ -61,6 +145,62 @@ private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → OrderM
61145
return !(← f c e)
62146
modifyStruct fun s => { s with cnstrsOf := s.cnstrsOf.insert (u, v) cs' }
63147

148+
/-- Equality propagation. -/
149+
def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do
150+
if !k.isZero then return ()
151+
let some k' ← getDist? v u | return ()
152+
if !k'.isZero then return ()
153+
let ue ← getExpr u
154+
let ve ← getExpr v
155+
if (← alreadyInternalized ue <&&> alreadyInternalized ve) then
156+
if (← isEqv ue ve) then return ()
157+
pushToPropagate <| .eq u v
158+
159+
/-- Finds constrains and equalities to be propagated. -/
160+
def checkToPropagate (u v : NodeId) (k : Weight) : OrderM Unit := do
161+
updateCnstrsOf u v fun c e => return !(← checkEqTrue u v k c e)
162+
updateCnstrsOf v u fun c e => return !(← checkEqFalse u v k c e)
163+
checkEq u v k
164+
165+
/--
166+
If `isShorter u v k`, updates the shortest distance between `u` and `v`.
167+
`w` is a node in the path from `u` to `v` such that `(← getProof? w v)` is `some`
168+
-/
169+
def updateIfShorter (u v : NodeId) (k : Weight) (w : NodeId) : OrderM Unit := do
170+
if (← isShorter u v k) then
171+
setDist u v k
172+
setProof u v (← getProof w v)
173+
checkToPropagate u v k
174+
175+
/--
176+
Adds an edge `u --(k) --> v` justified by the proof term `p`, and then
177+
if no negative cycle was created, updates the shortest distance of affected
178+
node pairs.
179+
-/
180+
def addEdge (u : NodeId) (v : NodeId) (k : Weight) (h : Expr) : OrderM Unit := do
181+
if (← isInconsistent) then return ()
182+
if let some k' ← getDist? v u then
183+
if (k + k').isNeg then
184+
setUnsat u v k h k'
185+
return ()
186+
if (← isShorter u v k) then
187+
setDist u v k
188+
setProof u v { w := u, k, proof := h }
189+
checkToPropagate u v k
190+
update
191+
propagatePending
192+
where
193+
update : OrderM Unit := do
194+
forEachTargetOf v fun j k₂ => do
195+
/- Check whether new path: `u -(k)-> v -(k₂)-> j` is shorter -/
196+
updateIfShorter u j (k+k₂) v
197+
forEachSourceOf u fun i k₁ => do
198+
/- Check whether new path: `i -(k₁)-> u -(k)-> v` is shorter -/
199+
updateIfShorter i v (k₁+k) u
200+
forEachTargetOf v fun j k₂ => do
201+
/- Check whether new path: `i -(k₁)-> u -(k)-> v -(k₂) -> j` is shorter -/
202+
updateIfShorter i j (k₁+k+k₂) v
203+
64204
def assertTrue (c : Cnstr NodeId) (p : Expr) : OrderM Unit := do
65205
trace[grind.order.assert] "{p} = True: {← c.pp}"
66206

src/Lean/Meta/Tactic/Grind/Order/OrderM.lean

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,7 @@ def getCnstr? (e : Expr) : OrderM (Option (Cnstr NodeId)) :=
5757
def isRing : OrderM Bool :=
5858
return (← getStruct).ringId?.isSome
5959

60+
def isPartialOrder : OrderM Bool :=
61+
return (← getStruct).isPartialInst?.isSome
62+
6063
end Lean.Meta.Grind.Order

src/Lean/Meta/Tactic/Grind/Order/Proof.lean

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ def mkLePreorderPrefix (declName : Name) : OrderM Expr := do
1616
let s ← getStruct
1717
return mkApp3 (mkConst declName [s.u]) s.type s.leInst s.isPreorderInst
1818

19+
/--
20+
Returns `declName α leInst isPartialInst`
21+
-/
22+
def mkLePartialPrefix (declName : Name) : OrderM Expr := do
23+
let s ← getStruct
24+
return mkApp3 (mkConst declName [s.u]) s.type s.leInst s.isPartialInst?.get!
25+
1926
/--
2027
Returns `declName α leInst ltInst lawfulOrderLtInst`
2128
-/
@@ -127,8 +134,13 @@ public def mkPropagateEqTrueProof (u v : Expr) (k : Weight) (huv : Expr) (k' : W
127134
/--
128135
`u < v → (v ≤ u) = False
129136
-/
130-
def mkPropagateEqFalseProofCore (u v : Expr) (huv : Expr) : OrderM Expr := do
131-
let h ← mkLeLtPreorderPrefix ``Grind.Order.le_eq_false_of_lt
137+
def mkPropagateEqFalseProofCore (u v : Expr) (k : Weight) (huv : Expr) (k' : Weight) : OrderM Expr := do
138+
let declName := match k'.strict, k.strict with
139+
| false, false => unreachable!
140+
| false, true => ``Grind.Order.le_eq_false_of_lt
141+
| true, false => ``Grind.Order.lt_eq_false_of_le
142+
| true, true => ``Grind.Order.lt_eq_false_of_lt
143+
let h ← mkLeLtPreorderPrefix declName
132144
return mkApp3 h u v huv
133145

134146
def mkPropagateEqFalseProofOffset (u v : Expr) (k : Weight) (huv : Expr) (k' : Weight) : OrderM Expr := do
@@ -148,7 +160,7 @@ public def mkPropagateEqFalseProof (u v : Expr) (k : Weight) (huv : Expr) (k' :
148160
if (← isRing) then
149161
mkPropagateEqFalseProofOffset u v k huv k'
150162
else
151-
mkPropagateEqFalseProofCore u v huv
163+
mkPropagateEqFalseProofCore u v k huv k'
152164

153165
def mkUnsatProofCore (u v : Expr) (k₁ : Weight) (h₁ : Expr) (k₂ : Weight) (h₂ : Expr) : OrderM Expr := do
154166
let h ← mkTransCoreProof u v u k₁.strict k₂.strict h₁ h₂
@@ -170,10 +182,14 @@ Returns a proof of `False` using a negative cycle composed of
170182
- `u --(k₁)--> v` with proof `h₁`
171183
- `v --(k₂)--> u` with proof `h₂`
172184
-/
173-
def mkUnsatProof (u v : Expr) (k₁ : Weight) (h₁ : Expr) (k₂ : Weight) (h₂ : Expr) : OrderM Expr := do
185+
public def mkUnsatProof (u v : Expr) (k₁ : Weight) (h₁ : Expr) (k₂ : Weight) (h₂ : Expr) : OrderM Expr := do
174186
if (← isRing) then
175187
mkUnsatProofOffset u v k₁ h₁ k₂ h₂
176188
else
177189
mkUnsatProofCore u v k₁ h₁ k₂ h₂
178190

191+
public def mkEqProofOfLeOfLe (u v : Expr) (h₁ : Expr) (h₂ : Expr) : OrderM Expr := do
192+
let h ← mkLePartialPrefix ``Grind.Order.eq_of_le_of_le
193+
return mkApp4 h u v h₁ h₂
194+
179195
end Lean.Meta.Grind.Order

src/Lean/Meta/Tactic/Grind/Order/Types.lean

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,33 +39,6 @@ structure Weight where
3939
strict := false
4040
deriving Inhabited
4141

42-
def Weight.compare (a b : Weight) : Ordering :=
43-
if a.k < b.k then
44-
.lt
45-
else if b.k > a.k then
46-
.gt
47-
else if a.strict == b.strict then
48-
.eq
49-
else if !a.strict && b.strict then
50-
.lt
51-
else
52-
.gt
53-
54-
instance : Ord Weight where
55-
compare := Weight.compare
56-
57-
instance : LE Weight where
58-
le a b := compare a b ≠ .gt
59-
60-
instance : LT Weight where
61-
lt a b := compare a b = .lt
62-
63-
instance : DecidableLE Weight :=
64-
fun a b => inferInstanceAs (Decidable (compare a b ≠ .gt))
65-
66-
instance : DecidableLT Weight :=
67-
fun a b => inferInstanceAs (Decidable (compare a b = .lt))
68-
6942
/-- Auxiliary structure used for proof extraction. -/
7043
structure ProofInfo where
7144
w : NodeId
@@ -82,8 +55,8 @@ Thus, we store the information to be propagated into a list.
8255
See field `propagate` in `State`.
8356
-/
8457
inductive ToPropagate where
85-
| eqTrue (e : Expr) (u v : NodeId) (k k' : Int)
86-
| eqFalse (e : Expr) (u v : NodeId) (k k' : Int)
58+
| eqTrue (e : Expr) (u v : NodeId) (k k' : Weight)
59+
| eqFalse (e : Expr) (u v : NodeId) (k k' : Weight)
8760
| eq (u v : NodeId)
8861
deriving Inhabited
8962

src/Lean/Meta/Tactic/Grind/Order/Util.lean

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ module
77
prelude
88
public import Lean.Meta.Tactic.Grind.Order.OrderM
99
import Lean.Meta.Tactic.Grind.Arith.Util
10+
public section
1011
namespace Lean.Meta.Grind.Order
1112

12-
public def Cnstr.pp (c : Cnstr NodeId) : OrderM MessageData := do
13+
def Cnstr.pp (c : Cnstr NodeId) : OrderM MessageData := do
1314
let u ← getExpr c.u
1415
let v ← getExpr c.v
1516
let op := match c.kind with
@@ -21,4 +22,49 @@ public def Cnstr.pp (c : Cnstr NodeId) : OrderM MessageData := do
2122
else
2223
return m!"{Arith.quoteIfArithTerm u} {op} {Arith.quoteIfArithTerm v}"
2324

25+
def Weight.compare (a b : Weight) : Ordering :=
26+
if a.k < b.k then
27+
.lt
28+
else if b.k > a.k then
29+
.gt
30+
else if a.strict == b.strict then
31+
.eq
32+
else if a.strict && !b.strict then
33+
/-
34+
**Note**: Recall that we view a constraint of the
35+
form `x < y + k` as `x ≤ y + (k - ε)` where `ε` is
36+
an "infinitesimal" positive value.
37+
Thus, `k - ε < k`
38+
-/
39+
.lt
40+
else
41+
.gt
42+
43+
instance : Ord Weight where
44+
compare := Weight.compare
45+
46+
instance : LE Weight where
47+
le a b := compare a b ≠ .gt
48+
49+
instance : LT Weight where
50+
lt a b := compare a b = .lt
51+
52+
instance : DecidableLE Weight :=
53+
fun a b => inferInstanceAs (Decidable (compare a b ≠ .gt))
54+
55+
instance : DecidableLT Weight :=
56+
fun a b => inferInstanceAs (Decidable (compare a b = .lt))
57+
58+
def Weight.add (a b : Weight) : Weight :=
59+
{ k := a.k + b.k, strict := a.strict || b.strict }
60+
61+
instance : Add Weight where
62+
add := Weight.add
63+
64+
def Weight.isNeg (a : Weight) : Bool :=
65+
a.k < 0 || (a.k == 0 && a.strict)
66+
67+
def Weight.isZero (a : Weight) : Bool :=
68+
a.k == 0 && !a.strict
69+
2470
end Lean.Meta.Grind.Order

0 commit comments

Comments
 (0)