Skip to content

Commit 3b2705d

Browse files
kim-emhanwenzhu
andauthored
feat: helper functions for premise selection API (#10512)
This PR adds some helper functions for the premise selection API, to assist implementers. --------- Co-authored-by: Thomas Zhu <[email protected]>
1 parent 44a2b08 commit 3b2705d

File tree

2 files changed

+65
-3
lines changed

2 files changed

+65
-3
lines changed

src/Lean/PremiseSelection/Basic.lean

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,20 @@ structure Suggestion where
4343
The score of the suggestion, as a probability that this suggestion should be used.
4444
-/
4545
score : Float
46+
/--
47+
Optional flag associated with the suggestion, e.g. "←" or "=",
48+
if the premise selection algorithm is aware of the tactic consuming the results,
49+
and wants to suggest modifiers for this suggestion.
50+
E.g. this supports calling `simp` in the reverse direction,
51+
or telling `grind` or `aesop` to use the theorem in a particular way.
52+
-/
53+
flag : Option String := none
4654

4755
structure Config where
4856
/--
4957
The maximum number of suggestions to return.
5058
-/
51-
maxSuggestions : Option Nat := none
59+
maxSuggestions? : Option Nat := none
5260
/--
5361
The tactic that is calling the premise selection, e.g. `simp`, `grind`, or `aesop`.
5462
This may be used to adjust the score of the suggestions
@@ -71,8 +79,62 @@ structure Config where
7179
-/
7280
hint : Option String := none
7381

82+
def Config.maxSuggestions (c : Config) : Nat :=
83+
c.maxSuggestions?.getD 100
84+
7485
abbrev Selector : Type := MVarId → Config → MetaM (Array Suggestion)
7586

87+
/--
88+
Construct a `Selector` (which acts on an `MVarId`)
89+
from a function which takes the pretty printed goal.
90+
-/
91+
def ppSelector (selector : String → Config → MetaM (Array Suggestion)) (g : MVarId) (c : Config) :
92+
MetaM (Array Suggestion) := do
93+
selector (toString (← Meta.ppGoal g)) c
94+
95+
namespace Selector
96+
97+
/--
98+
Respect the `Config.filter` option by implementing it as a post-filter.
99+
If a premise selection implementation does not natively handle the filter,
100+
it should be wrapped with this function.
101+
-/
102+
def postFilter (selector : Selector) : Selector := fun g c => do
103+
let suggestions ← selector g { c with filter := fun _ => pure true }
104+
suggestions.filterM (fun s => c.filter s.name)
105+
106+
/--
107+
Wrapper for `Selector` that ensures
108+
the `maxSuggestions` field in `Config` is respected, post-hoc.
109+
-/
110+
def maxSuggestions (selector : Selector) : Selector := fun g c => do
111+
let suggestions ← selector g c
112+
match c.maxSuggestions? with
113+
| none => return suggestions
114+
| some max => return suggestions.take max
115+
116+
/-- Combine two premise selectors, returning the best suggestions. -/
117+
def combine (selector₁ : Selector) (selector₂ : Selector) : Selector := fun g c => do
118+
let suggestions₁ ← selector₁ g c
119+
let suggestions₂ ← selector₂ g c
120+
121+
let mut dedupMap : Std.HashMap (Name × Option String) Suggestion := {}
122+
123+
for s in suggestions₁ ++ suggestions₂ do
124+
let key := (s.name, s.flag)
125+
dedupMap := dedupMap.alter key fun
126+
| none => some s
127+
| some s' => if s.score > s'.score then some s else some s'
128+
129+
let deduped := dedupMap.valuesArray
130+
let sorted := deduped.qsort (fun s₁ s₂ => s₁.score > s₂.score)
131+
132+
match c.maxSuggestions? with
133+
| none => return sorted
134+
| some max => return sorted.take max
135+
136+
end Selector
137+
76138
section DenyList
77139

78140
/-- Premises from a module whose name has one of the following components are not retrieved. -/
@@ -123,7 +185,7 @@ def empty : Selector := fun _ _ => pure #[]
123185
def random (gen : StdGen := ⟨37, 59⟩) : Selector := fun _ cfg => do
124186
IO.stdGenRef.set gen
125187
let env ← getEnv
126-
let max := cfg.maxSuggestions.getD 10
188+
let max := cfg.maxSuggestions
127189
let consts := env.const2ModIdx.keysArray
128190
let mut suggestions := #[]
129191
while suggestions.size < max do

src/Lean/PremiseSelection/MePo.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,6 @@ public def mepoSelector (useRarity : Bool) (p : Float := 0.6) (c : Float := 2.4)
8989
let suggestions := suggestions
9090
|>.map (fun (n, s) => { name := n, score := s })
9191
|>.reverse -- we favor constants that appear at the end of `env.constants`
92-
match config.maxSuggestions with
92+
match config.maxSuggestions? with
9393
| none => return suggestions
9494
| some k => return suggestions.take k

0 commit comments

Comments
 (0)