Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Lean/Data/SMap.lean
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ namespace Lean
-/
structure SMap (α : Type u) (β : Type v) [BEq α] [Hashable α] where
stage₁ : Bool := true
/-- Imported constants. -/
map₁ : Std.HashMap α β := {}
/-- Local constants defined in the current module. -/
map₂ : PHashMap α β := {}

namespace SMap
Expand Down
1 change: 1 addition & 0 deletions src/Lean/PremiseSelection.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ prelude
import Lean.PremiseSelection.Basic
import Lean.PremiseSelection.SymbolFrequency
import Lean.PremiseSelection.MePo
import Lean.PremiseSelection.SineQuaNon
109 changes: 103 additions & 6 deletions src/Lean/PremiseSelection/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,90 @@ Lean does not provide a default premise selector, so this module is intended to
with a downstream package which registers a premise selector.
-/

namespace Lean.Expr.FoldRelevantConstantsImpl

open Lean Meta

unsafe structure State where
visited : PtrSet Expr := mkPtrSet
visitedConsts : NameHashSet := {}

unsafe abbrev FoldM := StateT State MetaM

unsafe def fold {α : Type} (f : Name → α → MetaM α) (e : Expr) (acc : α) : FoldM α :=
let rec visit (e : Expr) (acc : α) : FoldM α := do
if (← get).visited.contains e then
return acc
modify fun s => { s with visited := s.visited.insert e }
if ← isProof e then
-- Don't visit proofs.
return acc
match e with
| .forallE n d b bi =>
let r ← visit d acc
withLocalDecl n bi d fun x =>
visit (b.instantiate1 x) r
| .lam n d b bi =>
let r ← visit d acc
withLocalDecl n bi d fun x =>
visit (b.instantiate1 x) r
| .mdata _ b => visit b acc
| .letE n t v b nondep =>
let r₁ ← visit t acc
let r₂ ← visit v r₁
withLetDecl n t v (nondep := nondep) fun x =>
visit (b.instantiate1 x) r₂
| .app f a =>
let fi ← getFunInfo f (some 1)
if fi.paramInfo[0]!.isInstImplicit then
-- Don't visit implicit arguments.
visit f acc
else
visit a (← visit f acc)
| .proj _ _ b => visit b acc
| .const c _ =>
if (← get).visitedConsts.contains c then
return acc
else
modify fun s => { s with visitedConsts := s.visitedConsts.insert c }
if ← isInstance c then
return acc
else
f c acc
| _ => return acc
visit e acc

@[inline] unsafe def foldUnsafe {α : Type} (e : Expr) (init : α) (f : Name → α → MetaM α) : MetaM α :=
(fold f e init).run' {}

end FoldRelevantConstantsImpl

/-- Apply `f` to every constant occurring in `e` once, skipping instance arguments and proofs. -/
@[implemented_by FoldRelevantConstantsImpl.foldUnsafe]
public opaque foldRelevantConstants {α : Type} (e : Expr) (init : α) (f : Name → α → MetaM α) : MetaM α := pure init

/-- Collect the constants occuring in `e` (once each), skipping instance arguments and proofs. -/
public def relevantConstants (e : Expr) : MetaM (Array Name) := foldRelevantConstants e #[] (fun n ns => return ns.push n)

/-- Collect the constants occuring in `e` (once each), skipping instance arguments and proofs. -/
public def relevantConstantsAsSet (e : Expr) : MetaM NameSet := foldRelevantConstants e ∅ (fun n ns => return ns.insert n)

end Lean.Expr

open Lean Meta MVarId in
public def Lean.MVarId.getConstants (g : MVarId) : MetaM NameSet := withContext g do
let mut c := (← g.getType).getUsedConstantsAsSet
for t in (← getLocalHyps) do
c := c ∪ (← inferType t).getUsedConstantsAsSet
return c

open Lean Meta MVarId in
public def Lean.MVarId.getRelevantConstants (g : MVarId) : MetaM NameSet := withContext g do
let mut c ← (← g.getType).relevantConstantsAsSet
for t in (← getLocalHyps) do
c := c ∪ (← (← inferType t).relevantConstantsAsSet)
return c

@[expose] public section

namespace Lean.PremiseSelection
Expand Down Expand Up @@ -130,25 +214,37 @@ end Selector

section DenyList

/-- Premises from a module whose name has one of the following components are not retrieved. -/
/--
Premises from a module whose name has one of the following components are not retrieved.

Use `run_cmd modifyEnv fun env => moduleDenyListExt.addEntry env module` to add a module to the deny list.
-/
builtin_initialize moduleDenyListExt : SimplePersistentEnvExtension String (List String) ←
registerSimplePersistentEnvExtension {
addEntryFn := (·.cons)
addImportedFn := mkStateFromImportedEntries (·.cons) []
addImportedFn := mkStateFromImportedEntries (·.cons) ["Lake", "Lean", "Internal", "Tactic"]
}

/-- A premise whose name has one of the following components is not retrieved. -/
/--
A premise whose name has one of the following components is not retrieved.

Use `run_cmd modifyEnv fun env => nameDenyListExt.addEntry env name` to add a name to the deny list.
-/
builtin_initialize nameDenyListExt : SimplePersistentEnvExtension String (List String) ←
registerSimplePersistentEnvExtension {
addEntryFn := (·.cons)
addImportedFn := mkStateFromImportedEntries (·.cons) []
addImportedFn := mkStateFromImportedEntries (·.cons) ["Lake", "Lean", "Internal", "Tactic"]
}

/-- A premise whose `type.getForallBody.getAppFn` is a constant that has one of these prefixes is not retrieved. -/
/--
A premise whose `type.getForallBody.getAppFn` is a constant that has one of these prefixes is not retrieved.

Use `run_cmd modifyEnv fun env => typePrefixDenyListExt.addEntry env typePrefix` to add a type prefix to the deny list.
-/
builtin_initialize typePrefixDenyListExt : SimplePersistentEnvExtension Name (List Name) ←
registerSimplePersistentEnvExtension {
addEntryFn := (·.cons)
addImportedFn := mkStateFromImportedEntries (·.cons) []
addImportedFn := mkStateFromImportedEntries (·.cons) [`Lake, `Lean]
}

def isDeniedModule (env : Environment) (moduleName : Name) : Bool :=
Expand All @@ -157,6 +253,7 @@ def isDeniedModule (env : Environment) (moduleName : Name) : Bool :=
def isDeniedPremise (env : Environment) (name : Name) : Bool := Id.run do
if name == ``sorryAx then return true
if name.isInternalDetail then return true
if Lean.Meta.isInstanceCore env name then return true
if (nameDenyListExt.getState env).any (fun p => name.anyS (· == p)) then return true
if let some moduleIdx := env.getModuleIdxFor? name then
let moduleName := env.header.moduleNames[moduleIdx.toNat]!
Expand Down
26 changes: 6 additions & 20 deletions src/Lean/PremiseSelection/MePo.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module

prelude
public import Lean.PremiseSelection.Basic
import Lean.PremiseSelection.SymbolFrequency
import Lean.Meta.Basic

/-!
Expand All @@ -24,14 +25,6 @@ namespace Lean.PremiseSelection.MePo

builtin_initialize registerTraceClass `mepo

def symbolFrequency (env : Environment) : NameMap Nat := Id.run do
-- TODO: ideally this would use a precomputed frequency map, as this is too slow.
let mut map := {}
for (_, ci) in env.constants do
for n' in ci.type.getUsedConstantsAsSet do
map := map.alter n' fun i? => some (i?.getD 0 + 1)
return map

def weightedScore (weight : Name → Float) (relevant candidate : NameSet) : Float :=
let S := candidate
let R := relevant ∩ S
Expand Down Expand Up @@ -71,26 +64,19 @@ def mepo (initialRelevant : NameSet) (score : NameSet → NameSet → Float) (ac
p := p + (1 - p) / c
return accepted.qsort (fun a b => a.score > b.score)

open Lean Meta MVarId in
def _root_.Lean.MVarId.getConstants (g : MVarId) : MetaM NameSet := withContext g do
let mut c := (← g.getType).getUsedConstantsAsSet
for t in (← getLocalHyps) do
c := c ∪ (← inferType t).getUsedConstantsAsSet
return c

end MePo

open MePo

-- The values of p := 0.6 and c := 2.4 are taken from the MePo paper, and need to be tuned.
public def mepoSelector (useRarity : Bool) (p : Float := 0.6) (c : Float := 2.4) : Selector := fun g config => do
let constants ← g.getConstants
let constants ← g.getRelevantConstants
let env ← getEnv
let score := if useRarity then
let frequency := symbolFrequency env
frequencyScore (frequency.getD · 0)
let score if useRarity then do
let frequency ← symbolFrequencyMap
pure <| frequencyScore (fun n => frequency.getD n 0)
else
unweightedScore
pure <| unweightedScore
let accept := fun ci => return !isDeniedPremise env ci.name
let suggestions ← mepo constants score accept config.maxSuggestions p c
let suggestions := suggestions
Expand Down
Loading
Loading