Skip to content

Commit fd50f05

Browse files
committed
fix: propagator for a^(n+m) in grind
This PR adds a propagator for `a^(n+m)` and removes its normalizer. This change was motivated by issue #10661 Closes #10661
1 parent 2f32110 commit fd50f05

File tree

5 files changed

+68
-12
lines changed

5 files changed

+68
-12
lines changed

src/Init/Grind/Ring/Basic.lean

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ theorem pow_add (a : α) (k₁ k₂ : Nat) : a ^ (k₁ + k₂) = a^k₁ * a^k₂
201201
next => simp [pow_zero, mul_one]
202202
next k₂ ih => rw [Nat.add_succ, pow_succ, pow_succ, ih, mul_assoc]
203203

204+
theorem pow_add_congr (a r : α) (k k₁ k₂ : Nat) : k = k₁ + k₂ → a^k₁ * a^k₂ = r → a ^ k = r := by
205+
intros; subst k r; rw [pow_add]
206+
204207
theorem natCast_pow (x : Nat) (k : Nat) : ((x ^ k : Nat) : α) = (x : α) ^ k := by
205208
induction k
206209
next => simp [pow_zero, Nat.pow_zero, natCast_one]

src/Lean/Meta/Tactic/Grind/Arith/CommRing.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadCanon
2626
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadRing
2727
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadSemiring
2828
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Action
29+
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Power
2930
public section
3031
namespace Lean.Meta.Grind.Arith.CommRing
3132
builtin_initialize registerTraceClass `grind.ring
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/-
2+
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Leonardo de Moura
5+
-/
6+
module
7+
prelude
8+
public import Lean.Meta.Tactic.Grind.Types
9+
import Init.Grind
10+
import Lean.Meta.Tactic.Grind.PropagatorAttr
11+
import Lean.Meta.Tactic.Grind.Simp
12+
import Lean.Meta.Tactic.Grind.Arith.Simproc
13+
import Lean.Meta.NatInstTesters
14+
public section
15+
namespace Lean.Meta.Grind.Arith.CommRing
16+
17+
builtin_grind_propagator propagatePower ↑HPow.hPow := fun e => do
18+
-- **Note**: We are not checking whether the `^` instance is the expected ones.
19+
let_expr HPow.hPow α n α' _ a b := e | return ()
20+
let_expr Nat := n | return ()
21+
unless isSameExpr α α' do return ()
22+
traverseEqc b fun bENode => do
23+
let b' := bENode.self
24+
match_expr b' with
25+
| HAdd.hAdd n₁ n₂ n₃ inst b₁ b₂ =>
26+
unless isSameExpr n n₁ && isSameExpr n n₂ && isSameExpr n n₃ do return ()
27+
unless (← isInstHAddNat inst) do return ()
28+
let pwFn := e.appFn!.appFn!
29+
let r ← mkMul (mkApp2 pwFn a b₁) (mkApp2 pwFn a b₂)
30+
let r ← preprocess r
31+
internalize r.expr (← getGeneration e)
32+
let some h ← mkSemiringThm ``Grind.Semiring.pow_add_congr α | return ()
33+
let h := mkApp7 h a r.expr b b₁ b₂ (← mkEqProof b b') (← r.getProof)
34+
pushEq e r.expr h
35+
| _ => return ()
36+
37+
end Lean.Meta.Grind.Arith.CommRing

src/Lean/Meta/Tactic/Grind/Arith/Simproc.lean

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,30 @@ import Init.Grind.Ring.Field
1313
public section
1414
namespace Lean.Meta.Grind.Arith
1515

16-
private def mkSemiringThm (declName : Name) (α : Expr) : MetaM (Option Expr) := do
16+
def mkSemiringThm (declName : Name) (α : Expr) : MetaM (Option Expr) := do
1717
let some u ← getDecLevel? α | return none
1818
let semiring := mkApp (mkConst ``Grind.Semiring [u]) α
1919
let some semiringInst ← synthInstanceMeta? semiring | return none
2020
return mkApp2 (mkConst declName [u]) α semiringInst
2121

2222
/--
23-
Applies `a^(m+n) = a^m * a^n`, `a^0 = 1`, `a^1 = a`.
23+
Applies `a^0 = 1`, `a^1 = a`.
2424
2525
We do normalize `a^0` and `a^1` when converting expressions into polynomials,
2626
but we need to normalize them here when for other preprocessing steps such as
2727
`a / b = a*b⁻¹`. If `b` is of the form `c^1`, it will be treated as an
28-
atom in the comm ring module.
28+
atom in the ring module.
29+
30+
**Note**: We used to expand `a^(n+m)` here, but it prevented `grind` from solving
31+
simple problems such as
32+
```
33+
example {k : Nat} (h : k - 1 + 1 = k) :
34+
2 ^ (k - 1 + 1) = 2 ^ k := by
35+
grind
36+
```
37+
We now use a propagator for `a^(n+m)` which adds the `a^n*a^m` to the equivalence class.
2938
-/
30-
builtin_simproc_decl expandPowAdd (_ ^ _) := fun e => do
39+
builtin_simproc_decl expandPow01 (_ ^ _) := fun e => do
3140
let_expr HPow.hPow α nat α' _ a k := e | return .continue
3241
let_expr Nat ← nat | return .continue
3342
if let some k ← getNatValue? k then
@@ -42,13 +51,7 @@ builtin_simproc_decl expandPowAdd (_ ^ _) := fun e => do
4251
return .done { expr := a, proof? := some (mkApp h a) }
4352
else
4453
return .continue
45-
else
46-
let_expr HAdd.hAdd _ _ _ _ m n := k | return .continue
47-
unless (← isDefEq α α') do return .continue
48-
let some h ← mkSemiringThm ``Grind.Semiring.pow_add α | return .continue
49-
let pwFn := e.appFn!.appFn!
50-
let r ← mkMul (mkApp2 pwFn a m) (mkApp2 pwFn a n)
51-
return .visit { expr := r, proof? := some (mkApp3 h a m n) }
54+
return .continue
5255

5356
private def notField : Std.HashSet Name :=
5457
[``Nat, ``Int, ``BitVec, ``UInt8, ``UInt16, ``UInt32, ``Int64, ``Int8, ``Int16, ``Int32, ``Int64].foldl (init := {}) (·.insert ·)
@@ -185,7 +188,7 @@ Add additional arithmetic simprocs
185188
-/
186189

187190
def addSimproc (s : Simprocs) : CoreM Simprocs := do
188-
let s ← s.add ``expandPowAdd (post := true)
191+
let s ← s.add ``expandPow01 (post := true)
189192
let s ← s.add ``expandDiv (post := true)
190193
let s ← s.add ``normNatAddInst (post := false)
191194
let s ← s.add ``normNatMulInst (post := false)

tests/lean/run/grind_10661.lean

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
example {k : Nat} (h : k - 1 + 1 = k) :
2+
2 ^ (k - 1 + 1) = 2 ^ k := by
3+
grind
4+
5+
example (h : a = b + c) : 2 ^ a = 2^b * 2^c := by
6+
grind
7+
8+
example (h : a = c + b) : 2 ^ a = 2^b * 2^c := by
9+
grind
10+
11+
example (h : a = 1 + b) : 2 ^ a = 2^b * 2 := by
12+
grind

0 commit comments

Comments
 (0)