Skip to content

Commit 03e905d

Browse files
authored
feat: hash consing with alpha equivalence in grind (#8479)
This PR implements hash-consing for `grind` that takes alpha equivalence into account.
1 parent 383f68f commit 03e905d

File tree

5 files changed

+132
-11
lines changed

5 files changed

+132
-11
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/-
2+
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Leonardo de Moura
5+
-/
6+
prelude
7+
import Lean.Meta.Tactic.Grind.ENodeKey
8+
9+
namespace Lean.Meta.Grind
10+
11+
private def hashChild (e : Expr) : UInt64 :=
12+
match e with
13+
| .bvar .. | .mvar .. | .const .. | .fvar .. | .sort .. | .lit .. =>
14+
hash e
15+
| .app .. | .letE .. | .forallE .. | .lam .. | .mdata .. | .proj .. =>
16+
(unsafe ptrAddrUnsafe e).toUInt64
17+
18+
private def alphaHash (e : Expr) : UInt64 :=
19+
match e with
20+
| .bvar .. | .mvar .. | .const .. | .fvar .. | .sort .. | .lit .. =>
21+
hash e
22+
| .app f a => mixHash (hashChild f) (hashChild a)
23+
| .letE _ _ v b _ => mixHash (hashChild v) (hashChild b)
24+
| .forallE _ d b _ | .lam _ d b _ => mixHash (hashChild d) (hashChild b)
25+
| .mdata _ b => mixHash 13 (hashChild b)
26+
| .proj n i b => mixHash (mixHash (hash n) (hash i)) (hashChild b)
27+
28+
private def alphaEq (e₁ e₂ : Expr) : Bool := Id.run do
29+
match e₁ with
30+
| .bvar .. | .mvar .. | .const .. | .fvar .. | .sort .. | .lit .. =>
31+
e₁ == e₂
32+
| .app f₁ a₁ =>
33+
let .app f₂ a₂ := e₂ | false
34+
isSameExpr f₁ f₂ && isSameExpr a₁ a₂
35+
| .letE _ _ v₁ b₁ _ =>
36+
let .letE _ _ v₂ b₂ _ := e₂ | false
37+
isSameExpr v₁ v₂ && isSameExpr b₁ b₂
38+
| .forallE _ d₁ b₁ _ =>
39+
let .forallE _ d₂ b₂ _ := e₂ | false
40+
isSameExpr d₁ d₂ && isSameExpr b₁ b₂
41+
| .lam _ d₁ b₁ _ =>
42+
let .lam _ d₂ b₂ _ := e₂ | false
43+
isSameExpr d₁ d₂ && isSameExpr b₁ b₂
44+
| .mdata d₁ b₁ =>
45+
let .mdata d₂ b₂ := e₂ | false
46+
return isSameExpr b₁ b₂ && d₁ == d₂
47+
| .proj n₁ i₁ b₁ =>
48+
let .proj n₂ i₂ b₂ := e₂ | false
49+
n₁ == n₂ && i₁ == i₂ && isSameExpr b₁ b₂
50+
51+
structure AlphaKey where
52+
expr : Expr
53+
54+
instance : Hashable AlphaKey where
55+
hash k := alphaHash k.expr
56+
57+
instance : BEq AlphaKey where
58+
beq k₁ k₂ := alphaEq k₁.expr k₂.expr
59+
60+
structure AlphaShareCommon.State where
61+
map : PHashMap ENodeKey Expr := {}
62+
set : PHashSet AlphaKey := {}
63+
64+
abbrev AlphaShareCommonM := StateM AlphaShareCommon.State
65+
66+
private def save (e : Expr) (r : Expr) : AlphaShareCommonM Expr := do
67+
if let some r := (← get).set.find? { expr := r } then
68+
let r := r.expr
69+
modify fun { set, map } => {
70+
set
71+
map := map.insert { expr := e } r
72+
}
73+
return r
74+
else
75+
modify fun { set, map } => {
76+
set := set.insert { expr := r }
77+
map := map.insert { expr := e } r |>.insert { expr := r } r
78+
}
79+
return r
80+
81+
private abbrev visit (e : Expr) (k : AlphaShareCommonM Expr) : AlphaShareCommonM Expr := do
82+
if let some r := (← get).map.find? { expr := e } then
83+
return r
84+
else
85+
save e (← k)
86+
87+
/-- Similar to `shareCommon`, but handles alpha-equivalence. -/
88+
def shareCommonAlpha (e : Expr) : AlphaShareCommonM Expr :=
89+
go e
90+
where
91+
go (e : Expr) : AlphaShareCommonM Expr := do
92+
match e with
93+
| .bvar .. | .mvar .. | .const .. | .fvar .. | .sort .. | .lit .. =>
94+
if let some r := (← get).set.find? { expr := e } then
95+
return r.expr
96+
else
97+
modify fun { set, map } => { set := set.insert { expr := e }, map }
98+
return e
99+
| .app f a =>
100+
visit e (return mkApp (← go f) (← go a))
101+
| .letE n t v b nd =>
102+
visit e (return mkLet n t (← go v) (← go b) nd)
103+
| .forallE n d b bi =>
104+
visit e (return mkForall n bi (← go d) (← go b))
105+
| .lam n d b bi =>
106+
visit e (return mkLambda n bi (← go d) (← go b))
107+
| .mdata d b =>
108+
visit e (return mkMData d (← go b))
109+
| .proj n i b =>
110+
visit e (return mkProj n i (← go b))
111+
112+
end Lean.Meta.Grind

src/Lean/Meta/Tactic/Grind/Main.lean

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,11 @@ def mkMethods (fallback : Fallback) : CoreM Methods := do
5555
}
5656

5757
def GrindM.run (x : GrindM α) (params : Params) (fallback : Fallback) : MetaM α := do
58-
let scState := ShareCommon.State.mk _
59-
let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False)
60-
let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True)
61-
let (bfalseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``Bool.false)
62-
let (btrueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``Bool.true)
63-
let (natZExpr, scState) := ShareCommon.State.shareCommon scState (mkNatLit 0)
58+
let (falseExpr, scState) := shareCommonAlpha (mkConst ``False) {}
59+
let (trueExpr, scState) := shareCommonAlpha (mkConst ``True) scState
60+
let (bfalseExpr, scState) := shareCommonAlpha (mkConst ``Bool.false) scState
61+
let (btrueExpr, scState) := shareCommonAlpha (mkConst ``Bool.true) scState
62+
let (natZExpr, scState) := shareCommonAlpha (mkNatLit 0) scState
6463
let simprocs := params.normProcs
6564
let simp := params.norm
6665
let config := params.config

src/Lean/Meta/Tactic/Grind/Types.lean

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ prelude
77
import Init.Grind.Tactics
88
import Init.Data.Queue
99
import Std.Data.TreeSet
10-
import Lean.Util.ShareCommon
1110
import Lean.HeadIndex
1211
import Lean.Meta.Basic
1312
import Lean.Meta.CongrTheorems
@@ -16,6 +15,7 @@ import Lean.Meta.Tactic.Simp.Types
1615
import Lean.Meta.Tactic.Util
1716
import Lean.Meta.Tactic.Ext
1817
import Lean.Meta.Tactic.Grind.ENodeKey
18+
import Lean.Meta.Tactic.Grind.AlphaShareCommon
1919
import Lean.Meta.Tactic.Grind.Attr
2020
import Lean.Meta.Tactic.Grind.ExtAttr
2121
import Lean.Meta.Tactic.Grind.Cases
@@ -117,7 +117,7 @@ private def emptySC : ShareCommon.State.{0} ShareCommon.objectFactory := ShareCo
117117
/-- State for the `GrindM` monad. -/
118118
structure State where
119119
/-- `ShareCommon` (aka `Hashconsing`) state. -/
120-
scState : ShareCommon.State.{0} ShareCommon.objectFactory := emptySC
120+
scState : AlphaShareCommon.State := {}
121121
/--
122122
Congruence theorems generated so far. Recall that for constant symbols
123123
we rely on the reserved name feature (i.e., `mkHCongrWithArityForConst?`).
@@ -232,8 +232,8 @@ Applies hash-consing to `e`. Recall that all expressions in a `grind` goal have
232232
been hash-consed. We perform this step before we internalize expressions.
233233
-/
234234
def shareCommon (e : Expr) : GrindM Expr := do
235-
let scState ← modifyGet fun s => (s.scState, { s with scState := emptySC })
236-
let (e, scState) := ShareCommon.State.shareCommon scState e
235+
let scState ← modifyGet fun s => (s.scState, { s with scState := {} })
236+
let (e, scState) := shareCommonAlpha e scState
237237
modify fun s => { s with scState }
238238
return e
239239

tests/lean/run/grind_heartbeats.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ macro_rules
1212
| `(gen! $n:num) => `(op (f $n) (gen! $(Lean.quote (n.getNat - 1))))
1313

1414
/--
15-
trace: [grind.issues] (deterministic) timeout at `simp`, maximum number of heartbeats (5000) has been reached
15+
trace: [grind.issues] (deterministic) timeout at `isDefEq`, maximum number of heartbeats (5000) has been reached
1616
Use `set_option maxHeartbeats <num>` to set the limit.
1717
1818
Additional diagnostic information may be available using the `set_option diagnostics true` command.

tests/lean/run/grind_t1.lean

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,3 +457,13 @@ example (h : ∀ i, (¬i > 0) ∨ ∀ h : i ≠ 10, p i h) : p 5 (by decide) :=
457457
-- Similar to previous test.
458458
example (h : ∀ i, (∀ h : i ≠ 10, p i h) ∨ (¬i > 0)) : p 5 (by decide) := by
459459
grind
460+
461+
-- `grind` performs hash-consing modulo alpha-equivalence
462+
/--
463+
trace: [grind.assert] (f fun x => x) = a
464+
[grind.assert] ¬a = f fun x => x
465+
-/
466+
#guard_msgs (trace) in
467+
example (f : (Nat → Nat) → Nat) : f (fun x => x) = a → a = f (fun y => y) := by
468+
set_option trace.grind.assert true in
469+
grind

0 commit comments

Comments
 (0)