Skip to content

Commit 4cbd1a4

Browse files
authored
feat: non-commutative semiring normalizer in grind (#10421)
This PR adds a normalizer for non-commutative semirings to `grind`. Examples: ```lean open Lean.Grind variable (R : Type u) [Semiring R] example (a b c : R) : a * (b + c) = a * c + a * b := by grind example (a b : R) : (a + 2 * b)^2 = a^2 + 2 * a * b + 2 * b * a + 4 * b^2 := by grind example (a b : R) : b^2 + (a + 2 * b)^2 = a^2 + 2 * a * b + b * (1+1) * a * 1 + 5 * b^2 := by grind example (a b : R) : a^3 + a^2*b + a*b*a + b*a^2 + a*b^2 + b*a*b + b^2*a + b^3 = (a+b)^3 := by grind ```
1 parent 20873d5 commit 4cbd1a4

File tree

11 files changed

+220
-78
lines changed

11 files changed

+220
-78
lines changed

src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ public import Lean.Meta.Tactic.Grind.Arith.CommRing.Internalize
1313
public import Lean.Meta.Tactic.Grind.Arith.CommRing.ToExpr
1414
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingM
1515
public import Lean.Meta.Tactic.Grind.Arith.CommRing.SemiringM
16+
public import Lean.Meta.Tactic.Grind.Arith.CommRing.NonCommRingM
17+
public import Lean.Meta.Tactic.Grind.Arith.CommRing.NonCommSemiringM
1618
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Functions
1719
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Reify
1820
public import Lean.Meta.Tactic.Grind.Arith.CommRing.EqCnstr

src/Lean/Meta/Tactic/Grind/Arith/CommRing/EqCnstr.lean

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ private def inSameNonCommRing? (a b : Expr) : GoalM (Option Nat) := do
3838
unless ringId == ringId' do return none -- This can happen when we have heterogeneous equalities
3939
return ringId
4040

41+
/-- Returns `some semiringId` if `a` and `b` are elements of the same (non-commutative) semiring. -/
42+
private def inSameNonCommSemiring? (a b : Expr) : GoalM (Option Nat) := do
43+
let some semiringId ← getTermNonCommSemiringId? a | return none
44+
let some semiringId' ← getTermNonCommSemiringId? b | return none
45+
unless semiringId == semiringId' do return none -- This can happen when we have heterogeneous equalities
46+
return semiringId
47+
4148
def mkEqCnstr (p : Poly) (h : EqCnstrProof) : RingM EqCnstr := do
4249
let id := (← getCommRing).nextId
4350
let sugar := p.degree
@@ -62,7 +69,7 @@ private def toRingExpr? [Monad m] [MonadLiftT GrindM m] [MonadRing m] (e : Expr)
6269
Returns the semiring expression denoting the given Lean expression.
6370
Recall that we compute the semiring expressions during internalization.
6471
-/
65-
private def toSemiringExpr? (e : Expr) : SemiringM (Option SemiringExpr) := do
72+
private def toSemiringExpr? [Monad m] [MonadLiftT GrindM m] [MonadSemiring m] (e : Expr) : m (Option SemiringExpr) := do
6673
let semiring ← getSemiring
6774
if let some re := semiring.denote.find? { expr := e } then
6875
return some re
@@ -353,7 +360,7 @@ def processNewEq (a b : Expr) : GoalM Unit := do
353360
let some sb ← toSemiringExpr? b | return ()
354361
let lhs ← sa.denoteAsRingExpr
355362
let rhs ← sb.denoteAsRingExpr
356-
RingM.run (← getSemiring).ringId do
363+
RingM.run (← getCommSemiring).ringId do
357364
let some ra ← reify? lhs (skipVar := false) (gen := (← getGeneration a)) | return ()
358365
let some rb ← reify? rhs (skipVar := false) (gen := (← getGeneration b)) | return ()
359366
let p ← (ra.sub rb).toPolyM
@@ -415,7 +422,7 @@ private def processNewDiseqCommSemiring (a b : Expr) : SemiringM Unit := do
415422
let some sb ← toSemiringExpr? b | return ()
416423
let lhs ← sa.denoteAsRingExpr
417424
let rhs ← sb.denoteAsRingExpr
418-
RingM.run (← getSemiring).ringId do
425+
RingM.run (← getCommSemiring).ringId do
419426
let some ra ← reify? lhs (skipVar := false) (gen := (← getGeneration a)) | return ()
420427
let some rb ← reify? rhs (skipVar := false) (gen := (← getGeneration b)) | return ()
421428
let p ← (ra.sub rb).toPolyM
@@ -443,13 +450,21 @@ private def processNewDiseqNonCommRing (a b : Expr) : NonCommRingM Unit := do
443450
if ra.toPoly_nc == rb.toPoly_nc then
444451
setNonCommRingDiseqUnsat a b ra rb
445452

453+
private def processNewDiseqNonCommSemiring (a b : Expr) : NonCommSemiringM Unit := do
454+
let some sa ← toSemiringExpr? a | return ()
455+
let some sb ← toSemiringExpr? b | return ()
456+
if sa.toPolyS_nc == sb.toPolyS_nc then
457+
setNonCommSemiringDiseqUnsat a b sa sb
458+
446459
def processNewDiseq (a b : Expr) : GoalM Unit := do
447460
if let some ringId ← inSameRing? a b then RingM.run ringId do
448461
processNewDiseqCommRing a b
449462
else if let some semiringId ← inSameSemiring? a b then SemiringM.run semiringId do
450463
processNewDiseqCommSemiring a b
451464
else if let some ncRingId ← inSameNonCommRing? a b then NonCommRingM.run ncRingId do
452465
processNewDiseqNonCommRing a b
466+
else if let some ncSemiringId ← inSameNonCommSemiring? a b then NonCommSemiringM.run ncSemiringId do
467+
processNewDiseqNonCommSemiring a b
453468

454469
/--
455470
Returns `true` if the todo queue is not empty or the `recheck` flag is set to `true`

src/Lean/Meta/Tactic/Grind/Arith/CommRing/Internalize.lean

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,15 @@ def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
147147
modifySemiring fun s => { s with denote := s.denote.insert { expr := e } re }
148148
else if let some ncRingId ← getNonCommRingId? type then NonCommRingM.run ncRingId do
149149
let some re ← ncreify? e | return ()
150-
trace_goal[grind.ring.internalize] "(non-comm) ring [ncRingId}]: {e}"
150+
trace_goal[grind.ring.internalize] "(non-comm) ring [{ncRingId}]: {e}"
151151
setTermNonCommRingId e
152152
ringExt.markTerm e
153153
modifyRing fun s => { s with denote := s.denote.insert { expr := e } re }
154+
else if let some ncSemiringId ← getNonCommSemiringId? type then NonCommSemiringM.run ncSemiringId do
155+
let some re ← ncsreify? e | return ()
156+
trace_goal[grind.ring.internalize] "(non-comm) semiring [{ncSemiringId}]: {e}"
157+
setTermNonCommSemiringId e
158+
ringExt.markTerm e
159+
modifySemiring fun s => { s with denote := s.denote.insert { expr := e } re }
154160

155161
end Lean.Meta.Grind.Arith.CommRing
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/-
2+
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Leonardo de Moura
5+
-/
6+
module
7+
prelude
8+
public import Lean.Meta.Tactic.Grind.Arith.CommRing.SemiringM
9+
public section
10+
namespace Lean.Meta.Grind.Arith.CommRing
11+
12+
structure NonCommSemiringM.Context where
13+
semiringId : Nat
14+
15+
abbrev NonCommSemiringM := ReaderT NonCommSemiringM.Context GoalM
16+
17+
abbrev NonCommSemiringM.run (semiringId : Nat) (x : NonCommSemiringM α) : GoalM α :=
18+
x { semiringId }
19+
20+
instance : MonadCanon NonCommSemiringM where
21+
canonExpr e := do shareCommon (← canon e)
22+
synthInstance? e := Grind.synthInstance? e
23+
24+
protected def NonCommSemiringM.getSemiring : NonCommSemiringM Semiring := do
25+
let s ← get'
26+
let semiringId := (← read).semiringId
27+
if h : semiringId < s.ncSemirings.size then
28+
return s.ncSemirings[semiringId]
29+
else
30+
throwError "`grind` internal error, invalid semiringId"
31+
32+
protected def NonCommSemiringM.modifySemiring (f : Semiring → Semiring) : NonCommSemiringM Unit := do
33+
let semiringId := (← read).semiringId
34+
modify' fun s => { s with ncSemirings := s.ncSemirings.modify semiringId f }
35+
36+
instance : MonadSemiring NonCommSemiringM where
37+
getSemiring := NonCommSemiringM.getSemiring
38+
modifySemiring := NonCommSemiringM.modifySemiring
39+
40+
def getTermNonCommSemiringId? (e : Expr) : GoalM (Option Nat) := do
41+
return (← get').exprToNCSemiringId.find? { expr := e }
42+
43+
def setTermNonCommSemiringId (e : Expr) : NonCommSemiringM Unit := do
44+
let semiringId := (← read).semiringId
45+
if let some semiringId' ← getTermNonCommSemiringId? e then
46+
unless semiringId' == semiringId do
47+
reportIssue! "expression in two different semirings{indentExpr e}"
48+
return ()
49+
modify' fun s => { s with exprToNCSemiringId := s.exprToNCSemiringId.insert { expr := e } semiringId }
50+
51+
instance : MonadSetTermId NonCommSemiringM where
52+
setTermId e := setTermNonCommSemiringId e
53+
54+
end Lean.Meta.Grind.Arith.CommRing

src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ prelude
88
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId
99
public import Lean.Meta.Tactic.Grind.Arith.CommRing.SemiringM
1010
public import Lean.Meta.Tactic.Grind.Arith.CommRing.NonCommRingM
11+
public import Lean.Meta.Tactic.Grind.Arith.CommRing.NonCommSemiringM
1112
import Init.Grind.Ring.CommSemiringAdapter
1213
import Lean.Data.RArray
1314
import Lean.Meta.Tactic.Grind.Diseq
@@ -31,7 +32,7 @@ private def toContextExpr [Monad m] [MonadLiftT MetaM m] [MonadCanon m] [MonadRi
3132
else
3233
RArray.toExpr ring.type id (RArray.leaf (mkApp (← getNatCastFn) (toExpr 0)))
3334

34-
private def toSContextExpr' (vars : Array Expr) : SemiringM Expr := do
35+
private def toSContextExpr' [Monad m] [MonadLiftT MetaM m] [MonadCanon m] [MonadSemiring m] (vars : Array Expr) : m Expr := do
3536
let semiring ← getSemiring
3637
if h : 0 < vars.size then
3738
RArray.toExpr semiring.type id (RArray.ofFn (vars[·]) h)
@@ -133,7 +134,7 @@ private def getSemiringIdOf : RingM Nat := do
133134
return semiringId
134135

135136
private def getSemiringOf : RingM CommSemiring := do
136-
SemiringM.run (← getSemiringIdOf) do getSemiring
137+
SemiringM.run (← getSemiringIdOf) do getCommSemiring
137138

138139
private def mkSemiringPrefix (declName : Name) : ProofM Expr := do
139140
let sctx ← getSContext
@@ -322,7 +323,7 @@ Given `a` and `b`, such that `a ≠ b` in the core and `sa` and `sb` their reifi
322323
terms s.t. `sa.toPoly == sb.toPoly`, close the goal.
323324
-/
324325
def setSemiringDiseqUnsat (a b : Expr) (sa sb : SemiringExpr) : SemiringM Unit := do
325-
let semiring ← getSemiring
326+
let semiring ← getCommSemiring
326327
let hne ← mkDiseqProof a b
327328
let usedVars := sa.collectVars >> sb.collectVars <| {}
328329
let vars' := usedVars.toArray
@@ -358,4 +359,23 @@ def setNonCommRingDiseqUnsat (a b : Expr) (ra rb : RingExpr) : NonCommRingM Unit
358359
let h := mkApp3 h (toExpr ra) (toExpr rb) eagerReflBoolTrue
359360
closeGoal (mkApp hne h)
360361

362+
/--
363+
Given `a` and `b`, such that `a ≠ b` in the core and `sa` and `sb` their reified semiring
364+
terms s.t. `sa.toPolyS_nc == sb.toPolyS_nc`, close the goal.
365+
-/
366+
def setNonCommSemiringDiseqUnsat (a b : Expr) (sa sb : SemiringExpr) : NonCommSemiringM Unit := do
367+
let semiring ← getSemiring
368+
let hne ← mkDiseqProof a b
369+
let usedVars := sa.collectVars >> sb.collectVars <| {}
370+
let vars' := usedVars.toArray
371+
let varRename := mkVarRename vars'
372+
let vars := semiring.vars
373+
let vars := vars'.map fun x => vars[x]!
374+
let sa := sa.renameVars varRename
375+
let sb := sb.renameVars varRename
376+
let ctx ← toSContextExpr' vars
377+
let h := mkApp3 (mkConst ``Grind.CommRing.eq_normS_nc [semiring.u]) semiring.type semiring.semiringInst ctx
378+
let h := mkApp3 h (toExpr sa) (toExpr sb) eagerReflBoolTrue
379+
closeGoal (mkApp hne h)
380+
361381
end Lean.Meta.Grind.Arith.CommRing

src/Lean/Meta/Tactic/Grind/Arith/CommRing/Reify.lean

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ prelude
88
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingM
99
public import Lean.Meta.Tactic.Grind.Arith.CommRing.SemiringM
1010
public import Lean.Meta.Tactic.Grind.Arith.CommRing.NonCommRingM
11+
public import Lean.Meta.Tactic.Grind.Arith.CommRing.NonCommSemiringM
1112
import Lean.Meta.Tactic.Grind.Simp
1213
import Lean.Meta.Tactic.Grind.Arith.CommRing.Functions
1314
public section
@@ -31,7 +32,7 @@ def isNatCastInst (inst : Expr) : m Bool :=
3132
return isSameExpr (← getNatCastFn).appArg! inst
3233

3334
private def reportAppIssue (e : Expr) : GoalM Unit := do
34-
reportIssue! "comm ring term with unexpected instance{indentExpr e}"
35+
reportIssue! "ring term with unexpected instance{indentExpr e}"
3536

3637
variable [MonadLiftT GoalM m] [MonadSetTermId m]
3738

@@ -119,25 +120,30 @@ partial def reifyCore? (e : Expr) (skipVar : Bool) (gen : Nat) : m (Option RingE
119120
return some (.num k)
120121
| _ => toTopVar e
121122

122-
partial def reify? (e : Expr) (skipVar := true) (gen : Nat := 0) : RingM (Option RingExpr) := do
123+
/-- Reify ring expression. -/
124+
def reify? (e : Expr) (skipVar := true) (gen : Nat := 0) : RingM (Option RingExpr) := do
123125
reifyCore? e skipVar gen
124126

125-
partial def ncreify? (e : Expr) (skipVar := true) (gen : Nat := 0) : NonCommRingM (Option RingExpr) := do
127+
/-- Reify non-commutative ring expression. -/
128+
def ncreify? (e : Expr) (skipVar := true) (gen : Nat := 0) : NonCommRingM (Option RingExpr) := do
126129
reifyCore? e skipVar gen
127130

128131
private def reportSAppIssue (e : Expr) : GoalM Unit := do
129-
reportIssue! "comm semiring term with unexpected instance{indentExpr e}"
132+
reportIssue! "semiring term with unexpected instance{indentExpr e}"
133+
134+
section
135+
variable [MonadLiftT GoalM m] [MonadError m] [Monad m] [MonadCanon m] [MonadSemiring m] [MonadSetTermId m]
130136

131137
/--
132138
Similar to `reify?` but for `CommSemiring`
133139
-/
134-
partial def sreify? (e : Expr) : SemiringM (Option SemiringExpr) := do
135-
let toVar (e : Expr) : SemiringM SemiringExpr := do
136-
return .var (← mkSVar e)
137-
let asVar (e : Expr) : SemiringM SemiringExpr := do
140+
partial def sreifyCore? (e : Expr) : m (Option SemiringExpr) := do
141+
let toVar (e : Expr) : m SemiringExpr := do
142+
return .var (← mkSVarCore e)
143+
let asVar (e : Expr) : m SemiringExpr := do
138144
reportSAppIssue e
139-
return .var (← mkSVar e)
140-
let rec go (e : Expr) : SemiringM SemiringExpr := do
145+
return .var (← mkSVarCore e)
146+
let rec go (e : Expr) : m SemiringExpr := do
141147
match_expr e with
142148
| HAdd.hAdd _ _ _ i a b =>
143149
if isSameExpr (← getAddFn').appArg! i then return .add (← go a) (← go b) else asVar e
@@ -156,9 +162,9 @@ partial def sreify? (e : Expr) : SemiringM (Option SemiringExpr) := do
156162
let some k ← getNatValue? n | toVar e
157163
return .num k
158164
| _ => toVar e
159-
let toTopVar (e : Expr) : SemiringM (Option SemiringExpr) := do
165+
let toTopVar (e : Expr) : m (Option SemiringExpr) := do
160166
return some (← toVar e)
161-
let asTopVar (e : Expr) : SemiringM (Option SemiringExpr) := do
167+
let asTopVar (e : Expr) : m (Option SemiringExpr) := do
162168
reportSAppIssue e
163169
toTopVar e
164170
match_expr e with
@@ -180,4 +186,14 @@ partial def sreify? (e : Expr) : SemiringM (Option SemiringExpr) := do
180186
return some (.num k)
181187
| _ => toTopVar e
182188

189+
end
190+
191+
/-- Reify semiring expression. -/
192+
def sreify? (e : Expr) : SemiringM (Option SemiringExpr) := do
193+
sreifyCore? e
194+
195+
/-- Reify non-commutative semiring expression. -/
196+
def ncsreify? (e : Expr) : NonCommSemiringM (Option SemiringExpr) := do
197+
sreifyCore? e
198+
183199
end Lean.Meta.Grind.Arith.CommRing

src/Lean/Meta/Tactic/Grind/Arith/CommRing/RingId.lean

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,21 @@ where
104104
setCommSemiringId ringId id
105105
return some id
106106

107+
def getNonCommSemiringId? (type : Expr) : GoalM (Option Nat) := do
108+
if let some id? := (← get').ncstypeIdOf.find? { expr := type } then
109+
return id?
110+
else
111+
let id? ← go?
112+
modify' fun s => { s with ncstypeIdOf := s.ncstypeIdOf.insert { expr := type } id? }
113+
return id?
114+
where
115+
go? : GoalM (Option Nat) := do
116+
let u ← getDecLevel type
117+
let semiring := mkApp (mkConst ``Grind.Semiring [u]) type
118+
let some semiringInst ← synthInstance? semiring | return none
119+
let id := (← get').ncSemirings.size
120+
let semiring : Semiring := { id, type, u, semiringInst }
121+
modify' fun s => { s with ncSemirings := s.ncSemirings.push semiring }
122+
return some id
123+
107124
end Lean.Meta.Grind.Arith.CommRing

0 commit comments

Comments
 (0)