Skip to content

Commit 2295274

Browse files
committed
feat: multiple grind propagators per declaration
This PR allows users to declare additional `grind` constraint propagators for declarations that already include propagators in core.
1 parent 7ee3079 commit 2295274

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

src/Lean/Meta/Tactic/Grind/Main.lean

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,18 @@ def mkMethods (fallback : Fallback) : CoreM Methods := do
5353
return {
5454
fallback
5555
propagateUp := fun e => do
56-
propagateForallPropUp e
57-
propagateReflCmp e
58-
let .const declName _ := e.getAppFn | return ()
59-
propagateProjEq e
60-
if let some prop := builtinPropagators.up[declName]? then
61-
prop e
56+
propagateForallPropUp e
57+
propagateReflCmp e
58+
let .const declName _ := e.getAppFn | return ()
59+
propagateProjEq e
60+
if let some props := builtinPropagators.up[declName]? then
61+
props.forM fun prop => prop e
6262
propagateDown := fun e => do
63-
propagateForallPropDown e
64-
propagateLawfulEqCmp e
65-
let .const declName _ := e.getAppFn | return ()
66-
if let some prop := builtinPropagators.down[declName]? then
67-
prop e
63+
propagateForallPropDown e
64+
propagateLawfulEqCmp e
65+
let .const declName _ := e.getAppFn | return ()
66+
if let some props := builtinPropagators.down[declName]? then
67+
props.forM fun prop => prop e
6868
}
6969

7070
-- A `simp` discharger that does not use assumptions.

src/Lean/Meta/Tactic/Grind/PropagatorAttr.lean

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,16 @@ import Init.Grind
1111
public section
1212
namespace Lean.Meta.Grind
1313

14+
abbrev PropagatorMap := Std.HashMap Name (List Propagator)
15+
16+
def PropagatorMap.insert (m : PropagatorMap) (declName : Name) (p : Propagator) : PropagatorMap :=
17+
let ps := m[declName]? |>.getD []
18+
Std.HashMap.insert m declName (p :: ps)
19+
1420
/-- Builtin propagators. -/
1521
structure BuiltinPropagators where
16-
up : Std.HashMap Name Propagator := {}
17-
down : Std.HashMap Name Propagator := {}
22+
up : PropagatorMap := {}
23+
down : PropagatorMap := {}
1824
deriving Inhabited
1925

2026
builtin_initialize builtinPropagatorsRef : IO.Ref BuiltinPropagators ← IO.mkRef {}
@@ -23,12 +29,8 @@ private def registerBuiltinPropagatorCore (declName : Name) (up : Bool) (proc :
2329
unless (← initializing) do
2430
throw (IO.userError s!"invalid builtin `grind` propagator declaration, it can only be registered during initialization")
2531
if up then
26-
if (← builtinPropagatorsRef.get).up.contains declName then
27-
throw (IO.userError s!"invalid builtin `grind` upward propagator `{declName}`, it has already been declared")
2832
builtinPropagatorsRef.modify fun { up, down } => { up := up.insert declName proc, down }
2933
else
30-
if (← builtinPropagatorsRef.get).down.contains declName then
31-
throw (IO.userError s!"invalid builtin `grind` downward propagator `{declName}`, it has already been declared")
3234
builtinPropagatorsRef.modify fun { up, down } => { up, down := down.insert declName proc }
3335

3436
def registerBuiltinUpwardPropagator (declName : Name) (proc : Propagator) : IO Unit :=

0 commit comments

Comments
 (0)