@@ -25,23 +25,108 @@ where
2525 let p' ← getProof u p.w
2626 go (← mkTrans p' p v)
2727
28+ /--
29+ Given a new edge edge `u --(kuv)--> v` justified by proof `huv` s.t.
30+ it creates a negative cycle with the existing path `v --{kvu}-->* u`, i.e., `kuv + kvu < 0`,
31+ this function closes the current goal by constructing a proof of `False`.
32+ -/
33+ def setUnsat (u v : NodeId) (kuv : Weight) (huv : Expr) (kvu : Weight) : OrderM Unit := do
34+ let hvu ← mkProofForPath v u
35+ let u ← getExpr u
36+ let v ← getExpr v
37+ let h ← mkUnsatProof u v kuv huv kvu hvu
38+ closeGoal h
39+
40+ /-- Sets the new shortest distance `k` between nodes `u` and `v`. -/
41+ def setDist (u v : NodeId) (k : Weight) : OrderM Unit := do
42+ modifyStruct fun s => { s with
43+ targets := s.targets.modify u fun es => es.insert v k
44+ sources := s.sources.modify v fun es => es.insert u k
45+ }
46+
47+ def setProof (u v : NodeId) (p : ProofInfo) : OrderM Unit := do
48+ modifyStruct fun s => { s with
49+ proofs := s.proofs.modify u fun es => es.insert v p
50+ }
51+
52+ @[inline] def forEachSourceOf (u : NodeId) (f : NodeId → Weight → OrderM Unit) : OrderM Unit := do
53+ (← getStruct).sources[u]!.forM f
54+
55+ @[inline] def forEachTargetOf (u : NodeId) (f : NodeId → Weight → OrderM Unit) : OrderM Unit := do
56+ (← getStruct).targets[u]!.forM f
57+
58+ /-- Returns `true` if `k` is smaller than the shortest distance between `u` and `v` -/
59+ def isShorter (u v : NodeId) (k : Weight) : OrderM Bool := do
60+ if let some k' ← getDist? u v then
61+ return k < k'
62+ else
63+ return true
64+
2865/-- Adds `p` to the list of things to be propagated. -/
2966def pushToPropagate (p : ToPropagate) : OrderM Unit :=
3067 modifyStruct fun s => { s with propagate := p :: s.propagate }
3168
32- /-
33- def propagateEqTrue (e : Expr) (u v : NodeId) (k k' : Int) : OrderM Unit := do
69+ def propagateEqTrue (e : Expr) (u v : NodeId) (k k' : Weight) : OrderM Unit := do
3470 let kuv ← mkProofForPath u v
3571 let u ← getExpr u
3672 let v ← getExpr v
37- pushEqTrue e <| mkPropagateEqTrueProof u v k kuv k'
73+ let h ← mkPropagateEqTrueProof u v k kuv k'
74+ pushEqTrue e h
3875
39- private def propagateEqFalse (e : Expr) (u v : NodeId) (k k' : Int ) : OrderM Unit := do
76+ def propagateEqFalse (e : Expr) (u v : NodeId) (k k' : Weight ) : OrderM Unit := do
4077 let kuv ← mkProofForPath u v
4178 let u ← getExpr u
4279 let v ← getExpr v
43- pushEqFalse e <| mkPropagateEqFalseProof u v k kuv k'
80+ let h ← mkPropagateEqFalseProof u v k kuv k'
81+ pushEqFalse e h
82+
83+ /-- Propagates all pending constraints and equalities and resets to "to do" list. -/
84+ private def propagatePending : OrderM Unit := do
85+ let todo := (← getStruct).propagate
86+ modifyStruct fun s => { s with propagate := [] }
87+ for p in todo do
88+ match p with
89+ | .eqTrue e u v k k' => propagateEqTrue e u v k k'
90+ | .eqFalse e u v k k' => propagateEqFalse e u v k k'
91+ | .eq u v =>
92+ let ue ← getExpr u
93+ let ve ← getExpr v
94+ unless (← isEqv ue ve) do
95+ let huv ← mkProofForPath u v
96+ let hvu ← mkProofForPath v u
97+ let h ← mkEqProofOfLeOfLe ue ve huv hvu
98+ pushEq ue ve h
99+
100+ def Cnstr.getWeight? (c : Cnstr α) : Option Weight :=
101+ match c.kind with
102+ | .le => some { k := c.k }
103+ | .lt => some { k := c.k, strict := true }
104+ | .eq => none
105+
106+ /--
107+ Given `e` represented by constraint `c` (from `u` to `v`).
108+ Checks whether `e = True` can be propagated using the path `u --(k)--> v`.
109+ If it can, adds a new entry to propagation list.
110+ -/
111+ def checkEqTrue (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : OrderM Bool := do
112+ let some k' := c.getWeight? | return false
113+ if k ≤ k' then
114+ pushToPropagate <| .eqTrue e u v k k'
115+ return true
116+ else
117+ return false
118+
119+ /--
120+ Given `e` represented by constraint `c` (from `v` to `u`).
121+ Checks whether `e = False` can be propagated using the path `u --(k)--> v`.
122+ If it can, adds a new entry to propagation list.
44123-/
124+ def checkEqFalse (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : OrderM Bool := do
125+ let some k' := c.getWeight? | return false
126+ if (k + k').isNeg then
127+ pushToPropagate <| .eqFalse e u v k k'
128+ return true
129+ return false
45130
46131/--
47132Auxiliary function for implementing theory propagation.
@@ -51,8 +136,7 @@ associated with `(u, v)` IF
51136- `e` is already assigned, or
52137- `f c e` returns true
53138 -/
54- @[inline]
55- private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → OrderM Bool) : OrderM Unit := do
139+ @[inline] def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → OrderM Bool) : OrderM Unit := do
56140 if let some cs := (← getStruct).cnstrsOf.find? (u, v) then
57141 let cs' ← cs.filterM fun (c, e) => do
58142 if (← isEqTrue e <||> isEqFalse e) then
@@ -61,6 +145,62 @@ private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → OrderM
61145 return !(← f c e)
62146 modifyStruct fun s => { s with cnstrsOf := s.cnstrsOf.insert (u, v) cs' }
63147
148+ /-- Equality propagation. -/
149+ def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do
150+ if !k.isZero then return ()
151+ let some k' ← getDist? v u | return ()
152+ if !k'.isZero then return ()
153+ let ue ← getExpr u
154+ let ve ← getExpr v
155+ if (← alreadyInternalized ue <&&> alreadyInternalized ve) then
156+ if (← isEqv ue ve) then return ()
157+ pushToPropagate <| .eq u v
158+
159+ /-- Finds constrains and equalities to be propagated. -/
160+ def checkToPropagate (u v : NodeId) (k : Weight) : OrderM Unit := do
161+ updateCnstrsOf u v fun c e => return !(← checkEqTrue u v k c e)
162+ updateCnstrsOf v u fun c e => return !(← checkEqFalse u v k c e)
163+ checkEq u v k
164+
165+ /--
166+ If `isShorter u v k`, updates the shortest distance between `u` and `v`.
167+ `w` is a node in the path from `u` to `v` such that `(← getProof? w v)` is `some`
168+ -/
169+ def updateIfShorter (u v : NodeId) (k : Weight) (w : NodeId) : OrderM Unit := do
170+ if (← isShorter u v k) then
171+ setDist u v k
172+ setProof u v (← getProof w v)
173+ checkToPropagate u v k
174+
175+ /--
176+ Adds an edge `u --(k) --> v` justified by the proof term `p`, and then
177+ if no negative cycle was created, updates the shortest distance of affected
178+ node pairs.
179+ -/
180+ def addEdge (u : NodeId) (v : NodeId) (k : Weight) (h : Expr) : OrderM Unit := do
181+ if (← isInconsistent) then return ()
182+ if let some k' ← getDist? v u then
183+ if (k + k').isNeg then
184+ setUnsat u v k h k'
185+ return ()
186+ if (← isShorter u v k) then
187+ setDist u v k
188+ setProof u v { w := u, k, proof := h }
189+ checkToPropagate u v k
190+ update
191+ propagatePending
192+ where
193+ update : OrderM Unit := do
194+ forEachTargetOf v fun j k₂ => do
195+ /- Check whether new path: `u -(k)-> v -(k₂)-> j` is shorter -/
196+ updateIfShorter u j (k+k₂) v
197+ forEachSourceOf u fun i k₁ => do
198+ /- Check whether new path: `i -(k₁)-> u -(k)-> v` is shorter -/
199+ updateIfShorter i v (k₁+k) u
200+ forEachTargetOf v fun j k₂ => do
201+ /- Check whether new path: `i -(k₁)-> u -(k)-> v -(k₂) -> j` is shorter -/
202+ updateIfShorter i j (k₁+k+k₂) v
203+
64204def assertTrue (c : Cnstr NodeId) (p : Expr) : OrderM Unit := do
65205 trace[grind.order.assert] "{p} = True: {← c.pp}"
66206
0 commit comments