@@ -43,12 +43,20 @@ structure Suggestion where
4343 The score of the suggestion, as a probability that this suggestion should be used.
4444 -/
4545 score : Float
46+ /--
47+ Optional flag associated with the suggestion, e.g. "←" or "=",
48+ if the premise selection algorithm is aware of the tactic consuming the results,
49+ and wants to suggest modifiers for this suggestion.
50+ E.g. this supports calling `simp` in the reverse direction,
51+ or telling `grind` or `aesop` to use the theorem in a particular way.
52+ -/
53+ flag : Option String := none
4654
4755structure Config where
4856 /--
4957 The maximum number of suggestions to return.
5058 -/
51- maxSuggestions : Option Nat := none
59+ maxSuggestions? : Option Nat := none
5260 /--
5361 The tactic that is calling the premise selection, e.g. `simp`, `grind`, or `aesop`.
5462 This may be used to adjust the score of the suggestions
@@ -71,8 +79,62 @@ structure Config where
7179 -/
7280 hint : Option String := none
7381
82+ def Config.maxSuggestions (c : Config) : Nat :=
83+ c.maxSuggestions?.getD 100
84+
7485abbrev Selector : Type := MVarId → Config → MetaM (Array Suggestion)
7586
87+ /--
88+ Construct a `Selector` (which acts on an `MVarId`)
89+ from a function which takes the pretty printed goal.
90+ -/
91+ def ppSelector (selector : String → Config → MetaM (Array Suggestion)) (g : MVarId) (c : Config) :
92+ MetaM (Array Suggestion) := do
93+ selector (toString (← Meta.ppGoal g)) c
94+
95+ namespace Selector
96+
97+ /--
98+ Respect the `Config.filter` option by implementing it as a post-filter.
99+ If a premise selection implementation does not natively handle the filter,
100+ it should be wrapped with this function.
101+ -/
102+ def postFilter (selector : Selector) : Selector := fun g c => do
103+ let suggestions ← selector g { c with filter := fun _ => pure true }
104+ suggestions.filterM (fun s => c.filter s.name)
105+
106+ /--
107+ Wrapper for `Selector` that ensures
108+ the `maxSuggestions` field in `Config` is respected, post-hoc.
109+ -/
110+ def maxSuggestions (selector : Selector) : Selector := fun g c => do
111+ let suggestions ← selector g c
112+ match c.maxSuggestions? with
113+ | none => return suggestions
114+ | some max => return suggestions.take max
115+
116+ /-- Combine two premise selectors, returning the best suggestions. -/
117+ def combine (selector₁ : Selector) (selector₂ : Selector) : Selector := fun g c => do
118+ let suggestions₁ ← selector₁ g c
119+ let suggestions₂ ← selector₂ g c
120+
121+ let mut dedupMap : Std.HashMap (Name × Option String) Suggestion := {}
122+
123+ for s in suggestions₁ ++ suggestions₂ do
124+ let key := (s.name, s.flag)
125+ dedupMap := dedupMap.alter key fun
126+ | none => some s
127+ | some s' => if s.score > s'.score then some s else some s'
128+
129+ let deduped := dedupMap.valuesArray
130+ let sorted := deduped.qsort (fun s₁ s₂ => s₁.score > s₂.score)
131+
132+ match c.maxSuggestions? with
133+ | none => return sorted
134+ | some max => return sorted.take max
135+
136+ end Selector
137+
76138section DenyList
77139
78140/-- Premises from a module whose name has one of the following components are not retrieved. -/
@@ -123,7 +185,7 @@ def empty : Selector := fun _ _ => pure #[]
123185def random (gen : StdGen := ⟨37 , 59 ⟩) : Selector := fun _ cfg => do
124186 IO.stdGenRef.set gen
125187 let env ← getEnv
126- let max := cfg.maxSuggestions.getD 10
188+ let max := cfg.maxSuggestions
127189 let consts := env.const2ModIdx.keysArray
128190 let mut suggestions := #[]
129191 while suggestions.size < max do
0 commit comments