Skip to content

Commit 5a4117d

Browse files
committed
feat: don't count symbols in instances and proofs
1 parent a50a6e2 commit 5a4117d

File tree

3 files changed

+111
-11
lines changed

3 files changed

+111
-11
lines changed

src/Lean/Meta/InferType.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ partial def isProofQuick : Expr → MetaM LBool
415415

416416
end
417417

418+
/-- Check if `e` is a proof, i.e. the type of `e` is a proposition. -/
418419
def isProof (e : Expr) : MetaM Bool := do
419420
match (← isProofQuick e) with
420421
| .true => return true

src/Lean/PremiseSelection/SymbolFrequency.lean

Lines changed: 83 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ module
77

88
prelude
99
public import Lean.CoreM
10+
import Lean.Meta.InferType
11+
import Lean.Meta.FunInfo
12+
import Lean.AddDecl
1013

1114
/-!
1215
# Symbol frequency
@@ -16,6 +19,73 @@ This module provides a persistent environment extension for computing the freque
1619

1720
namespace Lean.PremiseSelection
1821

22+
namespace FoldRelevantConstsImpl
23+
24+
open Lean Meta
25+
26+
unsafe structure State where
27+
visited : PtrSet Expr := mkPtrSet
28+
visitedConsts : NameHashSet := {}
29+
30+
unsafe abbrev FoldM := StateT State MetaM
31+
32+
unsafe def fold {α : Type} (f : Name → α → MetaM α) (e : Expr) (acc : α) : FoldM α :=
33+
let rec visit (e : Expr) (acc : α) : FoldM α := do
34+
if (← get).visited.contains e then
35+
return acc
36+
modify fun s => { s with visited := s.visited.insert e }
37+
if ← isProof e then
38+
-- Don't visit proofs.
39+
return acc
40+
match e with
41+
| .forallE n d b bi =>
42+
let r ← visit d acc
43+
withLocalDecl n bi d fun x =>
44+
visit (b.instantiate1 x) r
45+
| .lam n d b bi =>
46+
let r ← visit d acc
47+
withLocalDecl n bi d fun x =>
48+
visit (b.instantiate1 x) r
49+
| .mdata _ b => visit b acc
50+
| .letE n t v b nondep =>
51+
let r₁ ← visit t acc
52+
let r₂ ← visit v r₁
53+
withLetDecl n t v (nondep := nondep) fun x =>
54+
visit (b.instantiate1 x) r₂
55+
| .app f a =>
56+
let fi ← getFunInfo f (some 1)
57+
if fi.paramInfo[0]!.isInstImplicit then
58+
-- Don't visit implicit arguments.
59+
visit f acc
60+
else
61+
visit a (← visit f acc)
62+
| .proj _ _ b => visit b acc
63+
| .const c _ =>
64+
if (← get).visitedConsts.contains c then
65+
return acc
66+
else
67+
modify fun s => { s with visitedConsts := s.visitedConsts.insert c };
68+
f c acc
69+
| _ => return acc
70+
visit e acc
71+
72+
@[inline] unsafe def foldUnsafe {α : Type} (e : Expr) (init : α) (f : Name → α → MetaM α) : MetaM α :=
73+
(fold f e init).run' {}
74+
75+
end FoldRelevantConstsImpl
76+
77+
/-- Apply `f` to every constant occurring in `e` once, skipping instance arguments and proofs. -/
78+
@[implemented_by FoldRelevantConstsImpl.foldUnsafe]
79+
opaque foldRelevantConsts {α : Type} (e : Expr) (init : α) (f : Name → α → MetaM α) : MetaM α := pure init
80+
81+
/-- Helper function for running `MetaM` code during module export. We have nothing but an `Environment` available. -/
82+
private def runMetaM [Inhabited α] (env : Environment) (x : MetaM α) : α :=
83+
match unsafe unsafeEIO ((((withoutExporting x).run' {} {}).run' { fileName := "symbolFrequency", fileMap := default } { env })) with
84+
| Except.ok a => a
85+
| Except.error ex => panic! match unsafe unsafeIO ex.toMessageData.toString with
86+
| Except.ok s => s
87+
| Except.error ex => ex.toString
88+
1989
/--
2090
The state is just an array of array of maps.
2191
We don't assemble these on import for efficiency reasons: most modules will not query this extension.
@@ -30,26 +100,24 @@ builtin_initialize symbolFrequencyExt : PersistentEnvExtension (NameMap Nat) Emp
30100
mkInitial := pure ∅
31101
addImportedFn := fun mapss _ => pure mapss
32102
addEntryFn := nofun
33-
exportEntriesFnEx := fun env _ _ =>
34-
let r := env.constants.map₂.foldl (init := (∅ : NameMap Nat)) (fun acc n ci =>
35-
if n.isInternalDetail then
36-
acc
103+
exportEntriesFnEx := fun env _ _ => runMetaM env do
104+
let r env.constants.map₂.foldlM (init := (∅ : NameMap Nat)) (fun acc n ci => do
105+
if n.isInternalDetail || !Lean.wasOriginallyTheorem env n then
106+
pure acc
37107
else
38-
-- It would be nice if we could discard all proof sub-terms here,
39-
-- but there doesn't seem to be a good way to do that.
40-
ci.type.foldConsts (init := acc) fun n' acc => acc.alter n' fun i? => some (i?.getD 0 + 1))
41-
#[r]
108+
foldRelevantConsts ci.type (init := acc) fun n' acc => pure (acc.alter n' fun i? => some (i?.getD 0 + 1)))
109+
return #[r]
42110
statsFn := fun _ => "symbol frequency extension"
43111
}
44112

113+
/-- A global `IO.Ref` containing the symbol frequency map. This is initialized on first use. -/
45114
builtin_initialize symbolFrequencyMapRef : IO.Ref (Option (NameMap Nat)) ← IO.mkRef none
46115

47-
open Lean Core
48-
49116
private local instance : Zero (NameMap Nat) := ⟨∅⟩
50117
private local instance : Add (NameMap Nat) where
51118
add x y := y.foldl (init := x) fun x' n c => x'.insert n (x'.getD n 0 + c)
52119

120+
/-- The symbol frequency map for imported constants. This is initialized on first use. -/
53121
def symbolFrequencyMap : CoreM (NameMap Nat) := do
54122
match ← symbolFrequencyMapRef.get with
55123
| some map => return map
@@ -59,6 +127,10 @@ def symbolFrequencyMap : CoreM (NameMap Nat) := do
59127
symbolFrequencyMapRef.set (some map)
60128
return map
61129

62-
/-- Return the number of times a `Name` appears in the signatures of (non-internal) declarations in the environment. -/
130+
/--
131+
Return the number of times a `Name` appears
132+
in the signatures of (non-internal) theorems in the imported environment,
133+
skipping instance arguments and proofs.
134+
-/
63135
public def symbolFrequency (n : Name) : CoreM Nat :=
64136
return (← symbolFrequencyMap) |>.getD n 0
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
module
2+
3+
import all Lean.PremiseSelection.SymbolFrequency
4+
import all Init.Data.Array.Basic
5+
6+
open Lean PremiseSelection
7+
8+
/-- info: [List, Eq, HAppend.hAppend] -/
9+
#guard_msgs in
10+
run_meta do
11+
let ci ← getConstInfo `List.append_assoc
12+
let consts ← foldRelevantConsts ci.type (init := #[]) (fun n ns => return ns.push n)
13+
logInfo m!"{consts}"
14+
15+
/-- info: [List, Ne, HAppend.hAppend, List.nil, Eq, List.head] -/
16+
#guard_msgs in
17+
run_meta do
18+
let ci ← getConstInfo `List.head_append_right
19+
let consts ← foldRelevantConsts ci.type (init := #[]) (fun n ns => return ns.push n)
20+
logInfo m!"{consts}"
21+
22+
/-- info: [Array, Nat, LT.lt, Array.size, HAdd.hAdd, OfNat.ofNat, Array.swap, Not] -/
23+
#guard_msgs in
24+
run_meta do
25+
let ci ← getConstInfo `Array.eraseIdx.induct
26+
let consts ← foldRelevantConsts ci.type (init := #[]) (fun n ns => return ns.push n)
27+
logInfo m!"{consts}"

0 commit comments

Comments
 (0)