Skip to content

Commit d7cc85d

Browse files
committed
perf: do not import non-meta IR for interpreter
1 parent eea8df7 commit d7cc85d

File tree

4 files changed

+125
-24
lines changed

4 files changed

+125
-24
lines changed

src/Lean/Compiler/IR/CompilerM.lean

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ prelude
77
import Lean.Environment
88
import Lean.Compiler.IR.Basic
99
import Lean.Compiler.IR.Format
10+
import Lean.Compiler.MetaAttr
1011

1112
namespace Lean.IR
1213

@@ -82,13 +83,62 @@ private abbrev findAtSorted? (decls : Array Decl) (declName : Name) : Option Dec
8283
let tmpDecl := Decl.extern declName #[] default default
8384
decls.binSearch tmpDecl declLt
8485

86+
namespace CollectUsedDecls'
87+
88+
abbrev M := StateM NameSet
89+
90+
@[inline] def collect (f : FunId) : M Unit :=
91+
modify fun s => s.insert f
92+
93+
partial def collectFnBody : FnBody → M Unit
94+
| .vdecl _ _ v b =>
95+
match v with
96+
| .fap f _ => collect f *> collectFnBody b
97+
| .pap f _ => collect f *> collectFnBody b
98+
| _ => collectFnBody b
99+
| .jdecl _ _ v b => collectFnBody v *> collectFnBody b
100+
| .case _ _ _ alts => alts.forM fun alt => collectFnBody alt.body
101+
| e => unless e.isTerminal do collectFnBody e.body
102+
103+
def collectDecl : Decl → M NameSet
104+
| .fdecl (body := b) .. => collectFnBody b *> get
105+
| .extern .. => get
106+
107+
end CollectUsedDecls'
108+
109+
def collectUsedDecls' (decl : Decl) (used : NameSet := {}) : NameSet :=
110+
(CollectUsedDecls'.collectDecl decl).run' used
111+
112+
def getMetaClosure (m : DeclMap) (decls : Array Decl) : NameSet := Id.run do
113+
let mut toVisit := decls.map (·.name) |>.toList
114+
let mut res : NameSet := .ofList toVisit
115+
while !toVisit.isEmpty do
116+
let n :: toVisit' := toVisit | continue
117+
toVisit := toVisit'
118+
let some d := m.find? n | continue
119+
for d' in collectUsedDecls' d do
120+
if !res.contains d' then
121+
res := res.insert d'
122+
toVisit := d' :: toVisit
123+
return res
124+
85125
builtin_initialize declMapExt : SimplePersistentEnvExtension Decl DeclMap ←
86126
registerSimplePersistentEnvExtension {
87127
addImportedFn := fun _ => {}
88128
addEntryFn := fun s d => s.insert d.name d
89-
toArrayFn := fun s =>
90-
let decls := s.foldl (init := #[]) fun decls decl => decls.push decl
91-
sortDecls decls
129+
exportEntriesFnEx? := some fun env s entries level => Id.run do
130+
let decls := entries.foldl (init := #[]) fun decls decl => decls.push decl
131+
let mut entries := sortDecls decls
132+
if level < .ir then
133+
let closure := getMetaClosure s (decls.filter (isMeta env ·.name))
134+
entries := entries.filter fun
135+
| d@(.fdecl f xs ty b info) =>
136+
let n := match f with
137+
| .str n "_boxed" => n
138+
| n => n
139+
closure.contains n
140+
| _ => true
141+
return entries
92142
-- Written to on codegen environment branch but accessed from other elaboration branches when
93143
-- calling into the interpreter. We cannot use `async` as the IR declarations added may not
94144
-- share a name prefix with the top-level Lean declaration being compiled, e.g. from
@@ -97,12 +147,19 @@ builtin_initialize declMapExt : SimplePersistentEnvExtension Decl DeclMap ←
97147
replay? := some <| SimplePersistentEnvExtension.replayOfFilter (!·.contains ·.name) (fun s d => s.insert d.name d)
98148
}
99149

100-
@[export lean_ir_find_env_decl]
101-
def findEnvDecl (env : Environment) (declName : Name) : Option Decl :=
150+
@[export lean_ir_decl_map_ext_idx]
151+
private def getDeclMapExtIdx : Unit → Nat :=
152+
fun _ => declMapExt.toEnvExtension.idx
153+
154+
def findEnvDecl (env : Environment) (declName : Name) (level := OLeanLevel.ir) : Option Decl :=
102155
match env.getModuleIdxFor? declName with
103-
| some modIdx => findAtSorted? (declMapExt.getModuleEntries env modIdx) declName
156+
| some modIdx => findAtSorted? (declMapExt.getModuleEntries (level := level) env modIdx) declName
104157
| none => declMapExt.getState env |>.find? declName
105158

159+
@[export lean_ir_find_env_decl]
160+
private def findInterpreterDecl (env : Environment) (declName : Name) : Option Decl :=
161+
findEnvDecl env declName .exported
162+
106163
def findDecl (n : Name) : CompilerM (Option Decl) :=
107164
return findEnvDecl (← get).env n
108165

src/Lean/DocString/Extension.lean

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ structure ModuleDoc where
5858
private builtin_initialize moduleDocExt : SimplePersistentEnvExtension ModuleDoc (PersistentArray ModuleDoc) ← registerSimplePersistentEnvExtension {
5959
addImportedFn := fun _ => {}
6060
addEntryFn := fun s e => s.push e
61-
exportEntriesFnEx? := some fun _ _ es => fun
62-
| .exported => #[]
63-
| _ => es.toArray
61+
exportEntriesFnEx? := some fun _ _ es level =>
62+
if level < .server then
63+
#[]
64+
else
65+
es.toArray
6466
}
6567

6668
def addMainModuleDoc (env : Environment) (doc : ModuleDoc) : Environment :=

src/Lean/Environment.lean

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,11 @@ structure Environment where
549549
-/
550550
private base : VisibilityMap Kernel.Environment
551551
/--
552+
Additional imported environment extension state for use in codegen. Access via
553+
`getModuleEntries (level := .ir)`.
554+
-/
555+
private irBaseExts : Array EnvExtensionState := base.private.extensions
556+
/--
552557
Additional imported environment extension state for use in the language server. This field is
553558
identical to `base.extensions` in other contexts. Access via
554559
`getModuleEntries (level := .server)`.
@@ -1484,6 +1489,8 @@ abbrev ImportM := ReaderT Lean.ImportM.Context IO
14841489
inductive OLeanLevel where
14851490
/-- Information from exported contexts. -/
14861491
| exported
1492+
/-- Environment extension state for codegen. -/
1493+
| ir
14871494
/-- Environment extension state for the language server. -/
14881495
| server
14891496
/-- Private module data. -/
@@ -1571,12 +1578,15 @@ namespace PersistentEnvExtension
15711578
/--
15721579
Returns the data saved by `ext.exportEntriesFn` when `m` was elaborated. See docs on the function for
15731580
details. When using the module system, `level` can be used to select the level of data to retrieve,
1574-
but is limited to the maximum level actually imported: `exported` on the cmdline and `server` in the
1581+
but is limited to the maximum level actually imported: `ir` on the cmdline and `server` in the
15751582
language server. Higher levels will return the data of the maximum imported level.
15761583
-/
15771584
def getModuleEntries {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ)
15781585
(env : Environment) (m : ModuleIdx) (level := OLeanLevel.exported) : Array α :=
1579-
let exts := if level = .exported then env.base.private.extensions else env.serverBaseExts
1586+
let exts := match level with
1587+
| .exported => env.base.private.extensions
1588+
| .ir => env.irBaseExts
1589+
| _ => env.serverBaseExts
15801590
-- safety: as in `getStateUnsafe`
15811591
unsafe (ext.toEnvExtension.getStateImpl exts).importedEntries[m]!
15821592

@@ -1715,6 +1725,7 @@ unsafe def Environment.freeRegions (env : Environment) : IO Unit :=
17151725

17161726
def OLeanLevel.adjustFileName (base : System.FilePath) : OLeanLevel → System.FilePath
17171727
| .exported => base
1728+
| ir => base.addExtension "ir"
17181729
| .server => base.addExtension "server"
17191730
| .private => base.addExtension "private"
17201731

@@ -1754,6 +1765,7 @@ def writeModule (env : Environment) (fname : System.FilePath) : IO Unit := do
17541765
return (level.adjustFileName fname, (← mkModuleData env level))
17551766
saveModuleDataParts env.mainModule #[
17561767
(← mkPart .exported),
1768+
(← mkPart .ir),
17571769
(← mkPart .server),
17581770
(← mkPart .private)]
17591771
else
@@ -1839,21 +1851,33 @@ private structure ImportedModule extends EffectiveImport where
18391851
/-- All loaded incremental compacted regions. -/
18401852
parts : Array (ModuleData × CompactedRegion)
18411853

1842-
/-- The main module data that will eventually be used to construct the kernel environment. -/
1843-
private def ImportedModule.mainModule? (self : ImportedModule) : Option ModuleData := do
1844-
let (baseMod, _) ← self.parts[0]?
1845-
self.parts[if baseMod.isModule && self.importAll then 2 else 0]?.map (·.1)
1846-
18471854
/-- The main module data that will eventually be used to construct the publicly accessible constants. -/
18481855
private def ImportedModule.publicModule? (self : ImportedModule) : Option ModuleData := do
18491856
let (baseMod, _) ← self.parts[0]?
18501857
return baseMod
18511858

1859+
private def ImportedModule.getData? (self : ImportedModule) (level : OLeanLevel) : Option ModuleData := do
1860+
-- Without the module system, we only have the exported level.
1861+
let level := if (← self.publicModule?).isModule then level else .exported
1862+
self.parts[level.toCtorIdx]?.map (·.1)
1863+
1864+
/-- The main module data that will eventually be used to construct the kernel environment. -/
1865+
private def ImportedModule.mainModule? (self : ImportedModule) : Option ModuleData :=
1866+
self.getData? (if self.importAll then OLeanLevel.private else .exported)
1867+
18521868
/-- The module data that should be used for server purposes. -/
18531869
private def ImportedModule.serverData? (self : ImportedModule) (level : OLeanLevel) :
1870+
Option ModuleData :=
1871+
-- fall back to .exported outside the server
1872+
self.getData? (if level ≥ .server then level else .exported)
1873+
1874+
/-- The module data that should be used for codegen purposes. -/
1875+
private def ImportedModule.irData? (self : ImportedModule) (level : OLeanLevel) :
18541876
Option ModuleData := do
1855-
let (baseMod, _) ← self.parts[0]?
1856-
self.parts[if baseMod.isModule && level != .exported then 1 else 0]?.map (·.1)
1877+
let level :=
1878+
if level ≥ .server then level
1879+
else .ir
1880+
self.getData? level
18571881

18581882
structure ImportState where
18591883
private moduleNameMap : Std.HashMap Name ImportedModule := {}
@@ -1888,12 +1912,15 @@ private def findOLeanParts (mod : Name) : IO (Array System.FilePath) := do
18881912
let mut fnames := #[mFile]
18891913
-- Opportunistically load all available parts.
18901914
-- Necessary because the import level may be upgraded a later import.
1891-
let sFile := OLeanLevel.server.adjustFileName mFile
1915+
let sFile := OLeanLevel.ir.adjustFileName mFile
18921916
if (← sFile.pathExists) then
18931917
fnames := fnames.push sFile
1894-
let pFile := OLeanLevel.private.adjustFileName mFile
1895-
if (← pFile.pathExists) then
1896-
fnames := fnames.push pFile
1918+
let sFile := OLeanLevel.server.adjustFileName mFile
1919+
if (← sFile.pathExists) then
1920+
fnames := fnames.push sFile
1921+
let pFile := OLeanLevel.private.adjustFileName mFile
1922+
if (← pFile.pathExists) then
1923+
fnames := fnames.push pFile
18971924
return fnames
18981925

18991926
partial def importModulesCore
@@ -2084,9 +2111,23 @@ def finalizeImport (s : ImportState) (imports : Array Import) (opts : Options) (
20842111
base.public := publicBase
20852112
realizedImportedConsts? := none
20862113
}
2087-
env := env.setCheckedSync { env.base.private with extensions := (← setImportedEntries env.base.private.extensions moduleData) }
2114+
let mut extensions ← setImportedEntries env.base.private.extensions moduleData
2115+
let extDescrs ← persistentEnvExtensionsRef.get
2116+
if let some declMapExt := extDescrs.find? (·.name == `Lean.IR.declMapExt) then
2117+
for h : modIdx in [:modules.size] do
2118+
let mod := modules[modIdx]
2119+
if mod.irPhases != .runtime then
2120+
if let some irData := mod.irData? level then
2121+
if let some (_, entries) := irData.entries.find? (·.1 == declMapExt.name) then
2122+
extensions := unsafe declMapExt.toEnvExtension.modifyStateImpl extensions fun s =>
2123+
{ s with importedEntries := s.importedEntries.setIfInBounds modIdx entries }
2124+
env := env.setCheckedSync { env.base.private with extensions }
2125+
let irData := modules.filterMap (·.irData? level)
20882126
let serverData := modules.filterMap (·.serverData? level)
2089-
env := { env with serverBaseExts := (← setImportedEntries env.base.private.extensions serverData) }
2127+
env := { env with
2128+
irBaseExts := (← setImportedEntries env.base.private.extensions irData)
2129+
serverBaseExts := (← setImportedEntries env.base.private.extensions serverData)
2130+
}
20902131
if leakEnv then
20912132
/- Mark persistent a first time before `finalizePersistenExtensions`, which
20922133
avoids costly MT markings when e.g. an interpreter closure (which

src/library/compiler/ir_interpreter.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,7 @@ class interpreter {
885885
throw exception(sstream() << "cannot evaluate `[init]` declaration '" << fn << "' in the same module");
886886
}
887887
push_frame(e.m_decl, m_arg_stack.size());
888+
lean_always_assert(decl_tag(e.m_decl) == decl_kind::Fun);
888889
value r = eval_body(decl_fun_body(e.m_decl));
889890
pop_frame(r, decl_type(e.m_decl));
890891
if (!type_is_scalar(t)) {

0 commit comments

Comments
 (0)