Skip to content

Commit a7f67ed

Browse files
committed
feat: sine qua non premise selection
1 parent e2f5938 commit a7f67ed

File tree

10 files changed

+418
-111
lines changed

10 files changed

+418
-111
lines changed

src/Lean/Data/SMap.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ namespace Lean
3232
-/
3333
structure SMap (α : Type u) (β : Type v) [BEq α] [Hashable α] where
3434
stage₁ : Bool := true
35+
/-- Imported constants. -/
3536
map₁ : Std.HashMap α β := {}
37+
/-- Local constants defined in the current module. -/
3638
map₂ : PHashMap α β := {}
3739

3840
namespace SMap

src/Lean/PremiseSelection.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ prelude
99
import Lean.PremiseSelection.Basic
1010
import Lean.PremiseSelection.SymbolFrequency
1111
import Lean.PremiseSelection.MePo
12+
import Lean.PremiseSelection.SineQuaNon

src/Lean/PremiseSelection/Basic.lean

Lines changed: 103 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,90 @@ Lean does not provide a default premise selector, so this module is intended to
2727
with a downstream package which registers a premise selector.
2828
-/
2929

30+
namespace Lean.Expr.FoldRelevantConstantsImpl
31+
32+
open Lean Meta
33+
34+
unsafe structure State where
35+
visited : PtrSet Expr := mkPtrSet
36+
visitedConsts : NameHashSet := {}
37+
38+
unsafe abbrev FoldM := StateT State MetaM
39+
40+
unsafe def fold {α : Type} (f : Name → α → MetaM α) (e : Expr) (acc : α) : FoldM α :=
41+
let rec visit (e : Expr) (acc : α) : FoldM α := do
42+
if (← get).visited.contains e then
43+
return acc
44+
modify fun s => { s with visited := s.visited.insert e }
45+
if ← isProof e then
46+
-- Don't visit proofs.
47+
return acc
48+
match e with
49+
| .forallE n d b bi =>
50+
let r ← visit d acc
51+
withLocalDecl n bi d fun x =>
52+
visit (b.instantiate1 x) r
53+
| .lam n d b bi =>
54+
let r ← visit d acc
55+
withLocalDecl n bi d fun x =>
56+
visit (b.instantiate1 x) r
57+
| .mdata _ b => visit b acc
58+
| .letE n t v b nondep =>
59+
let r₁ ← visit t acc
60+
let r₂ ← visit v r₁
61+
withLetDecl n t v (nondep := nondep) fun x =>
62+
visit (b.instantiate1 x) r₂
63+
| .app f a =>
64+
let fi ← getFunInfo f (some 1)
65+
if fi.paramInfo[0]!.isInstImplicit then
66+
-- Don't visit implicit arguments.
67+
visit f acc
68+
else
69+
visit a (← visit f acc)
70+
| .proj _ _ b => visit b acc
71+
| .const c _ =>
72+
if (← get).visitedConsts.contains c then
73+
return acc
74+
else
75+
modify fun s => { s with visitedConsts := s.visitedConsts.insert c }
76+
if ← isInstance c then
77+
return acc
78+
else
79+
f c acc
80+
| _ => return acc
81+
visit e acc
82+
83+
@[inline] unsafe def foldUnsafe {α : Type} (e : Expr) (init : α) (f : Name → α → MetaM α) : MetaM α :=
84+
(fold f e init).run' {}
85+
86+
end FoldRelevantConstantsImpl
87+
88+
/-- Apply `f` to every constant occurring in `e` once, skipping instance arguments and proofs. -/
89+
@[implemented_by FoldRelevantConstantsImpl.foldUnsafe]
90+
public opaque foldRelevantConstants {α : Type} (e : Expr) (init : α) (f : Name → α → MetaM α) : MetaM α := pure init
91+
92+
/-- Collect the constants occuring in `e` (once each), skipping instance arguments and proofs. -/
93+
public def relevantConstants (e : Expr) : MetaM (Array Name) := foldRelevantConstants e #[] (fun n ns => return ns.push n)
94+
95+
/-- Collect the constants occuring in `e` (once each), skipping instance arguments and proofs. -/
96+
public def relevantConstantsAsSet (e : Expr) : MetaM NameSet := foldRelevantConstants e ∅ (fun n ns => return ns.insert n)
97+
98+
end Lean.Expr
99+
100+
open Lean Meta MVarId in
101+
public def Lean.MVarId.getConstants (g : MVarId) : MetaM NameSet := withContext g do
102+
let mut c := (← g.getType).getUsedConstantsAsSet
103+
for t in (← getLocalHyps) do
104+
c := c ∪ (← inferType t).getUsedConstantsAsSet
105+
return c
106+
107+
open Lean Meta MVarId in
108+
public def Lean.MVarId.getRelevantConstants (g : MVarId) : MetaM NameSet := withContext g do
109+
let mut c ← (← g.getType).relevantConstantsAsSet
110+
for t in (← getLocalHyps) do
111+
c := c ∪ (← (← inferType t).relevantConstantsAsSet)
112+
return c
113+
30114
@[expose] public section
31115

32116
namespace Lean.PremiseSelection
@@ -130,25 +214,37 @@ end Selector
130214

131215
section DenyList
132216

133-
/-- Premises from a module whose name has one of the following components are not retrieved. -/
217+
/--
218+
Premises from a module whose name has one of the following components are not retrieved.
219+
220+
Use `run_cmd modifyEnv fun env => moduleDenyListExt.addEntry env module` to add a module to the deny list.
221+
-/
134222
builtin_initialize moduleDenyListExt : SimplePersistentEnvExtension String (List String) ←
135223
registerSimplePersistentEnvExtension {
136224
addEntryFn := (·.cons)
137-
addImportedFn := mkStateFromImportedEntries (·.cons) []
225+
addImportedFn := mkStateFromImportedEntries (·.cons) ["Lake", "Lean", "Internal", "Tactic"]
138226
}
139227

140-
/-- A premise whose name has one of the following components is not retrieved. -/
228+
/--
229+
A premise whose name has one of the following components is not retrieved.
230+
231+
Use `run_cmd modifyEnv fun env => nameDenyListExt.addEntry env name` to add a name to the deny list.
232+
-/
141233
builtin_initialize nameDenyListExt : SimplePersistentEnvExtension String (List String) ←
142234
registerSimplePersistentEnvExtension {
143235
addEntryFn := (·.cons)
144-
addImportedFn := mkStateFromImportedEntries (·.cons) []
236+
addImportedFn := mkStateFromImportedEntries (·.cons) ["Lake", "Lean", "Internal", "Tactic"]
145237
}
146238

147-
/-- A premise whose `type.getForallBody.getAppFn` is a constant that has one of these prefixes is not retrieved. -/
239+
/--
240+
A premise whose `type.getForallBody.getAppFn` is a constant that has one of these prefixes is not retrieved.
241+
242+
Use `run_cmd modifyEnv fun env => typePrefixDenyListExt.addEntry env typePrefix` to add a type prefix to the deny list.
243+
-/
148244
builtin_initialize typePrefixDenyListExt : SimplePersistentEnvExtension Name (List Name) ←
149245
registerSimplePersistentEnvExtension {
150246
addEntryFn := (·.cons)
151-
addImportedFn := mkStateFromImportedEntries (·.cons) []
247+
addImportedFn := mkStateFromImportedEntries (·.cons) [`Lake, `Lean]
152248
}
153249

154250
def isDeniedModule (env : Environment) (moduleName : Name) : Bool :=
@@ -157,6 +253,7 @@ def isDeniedModule (env : Environment) (moduleName : Name) : Bool :=
157253
def isDeniedPremise (env : Environment) (name : Name) : Bool := Id.run do
158254
if name == ``sorryAx then return true
159255
if name.isInternalDetail then return true
256+
if Lean.Meta.isInstanceCore env name then return true
160257
if (nameDenyListExt.getState env).any (fun p => name.anyS (· == p)) then return true
161258
if let some moduleIdx := env.getModuleIdxFor? name then
162259
let moduleName := env.header.moduleNames[moduleIdx.toNat]!

src/Lean/PremiseSelection/MePo.lean

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ module
77

88
prelude
99
public import Lean.PremiseSelection.Basic
10+
import Lean.PremiseSelection.SymbolFrequency
1011
import Lean.Meta.Basic
1112

1213
/-!
@@ -24,14 +25,6 @@ namespace Lean.PremiseSelection.MePo
2425

2526
builtin_initialize registerTraceClass `mepo
2627

27-
def symbolFrequency (env : Environment) : NameMap Nat := Id.run do
28-
-- TODO: ideally this would use a precomputed frequency map, as this is too slow.
29-
let mut map := {}
30-
for (_, ci) in env.constants do
31-
for n' in ci.type.getUsedConstantsAsSet do
32-
map := map.alter n' fun i? => some (i?.getD 0 + 1)
33-
return map
34-
3528
def weightedScore (weight : Name → Float) (relevant candidate : NameSet) : Float :=
3629
let S := candidate
3730
let R := relevant ∩ S
@@ -71,26 +64,19 @@ def mepo (initialRelevant : NameSet) (score : NameSet → NameSet → Float) (ac
7164
p := p + (1 - p) / c
7265
return accepted.qsort (fun a b => a.score > b.score)
7366

74-
open Lean Meta MVarId in
75-
def _root_.Lean.MVarId.getConstants (g : MVarId) : MetaM NameSet := withContext g do
76-
let mut c := (← g.getType).getUsedConstantsAsSet
77-
for t in (← getLocalHyps) do
78-
c := c ∪ (← inferType t).getUsedConstantsAsSet
79-
return c
80-
8167
end MePo
8268

8369
open MePo
8470

8571
-- The values of p := 0.6 and c := 2.4 are taken from the MePo paper, and need to be tuned.
8672
public def mepoSelector (useRarity : Bool) (p : Float := 0.6) (c : Float := 2.4) : Selector := fun g config => do
87-
let constants ← g.getConstants
73+
let constants ← g.getRelevantConstants
8874
let env ← getEnv
89-
let score := if useRarity then
90-
let frequency := symbolFrequency env
91-
frequencyScore (frequency.getD · 0)
75+
let score if useRarity then do
76+
let frequency ← symbolFrequencyMap
77+
pure <| frequencyScore (fun n => frequency.getD n 0)
9278
else
93-
unweightedScore
79+
pure <| unweightedScore
9480
let accept := fun ci => return !isDeniedPremise env ci.name
9581
let suggestions ← mepo constants score accept config.maxSuggestions p c
9682
let suggestions := suggestions

0 commit comments

Comments
 (0)