Skip to content

Commit a62b616

Browse files
committed
feat: closed term extraction in the new compiler
1 parent 21846eb commit a62b616

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/-
2+
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Cameron Zwarich
5+
-/
6+
prelude
7+
import Lean.Compiler.ClosedTermCache
8+
import Lean.Compiler.NeverExtractAttr
9+
import Lean.Compiler.LCNF.Basic
10+
import Lean.Compiler.LCNF.InferType
11+
import Lean.Compiler.LCNF.Internalize
12+
import Lean.Compiler.LCNF.MonoTypes
13+
import Lean.Compiler.LCNF.PassManager
14+
import Lean.Compiler.LCNF.ToExpr
15+
16+
namespace Lean.Compiler.LCNF
17+
namespace ExtractClosed
18+
19+
abbrev ExtractM := StateRefT (Array CodeDecl) CompilerM
20+
21+
mutual
22+
23+
partial def extractLetValue (v : LetValue) : ExtractM Unit := do
24+
match v with
25+
| .const _ _ args => args.forM extractArg
26+
| .fvar fnVar args =>
27+
extractFVar fnVar
28+
args.forM extractArg
29+
| .proj _ _ baseVar => extractFVar baseVar
30+
| .lit _ | .erased => return ()
31+
32+
partial def extractArg (arg : Arg) : ExtractM Unit := do
33+
match arg with
34+
| .fvar fvarId => extractFVar fvarId
35+
| .type _ | .erased => return ()
36+
37+
partial def extractFVar (fvarId : FVarId) : ExtractM Unit := do
38+
if let some letDecl ← findLetDecl? fvarId then
39+
modify fun decls => decls.push (.let letDecl)
40+
extractLetValue letDecl.value
41+
42+
end
43+
44+
def isIrrelevantArg (arg : Arg) : Bool :=
45+
match arg with
46+
| .erased | .type _ => true
47+
| .fvar _ => false
48+
49+
structure Context where
50+
baseName : Name
51+
sccDecls : Array Decl
52+
53+
structure State where
54+
decls : Array Decl := {}
55+
56+
abbrev M := ReaderT Context $ StateRefT State CompilerM
57+
58+
mutual
59+
60+
partial def shouldExtractLetValue (isRoot : Bool) (v : LetValue) : M Bool := do
61+
match v with
62+
| .lit (.str _) => return true
63+
| .lit (.nat v) =>
64+
-- The old compiler's implementation used the runtime's `is_scalar` function, which
65+
-- introduces a dependency on the architecture used by the compiler.
66+
return v >= Nat.pow 2 63
67+
| .lit _ | .erased => return !isRoot
68+
| .const name _ args =>
69+
if (← read).sccDecls.any (·.name == name) then
70+
return false
71+
if hasNeverExtractAttribute (← getEnv) name then
72+
return false
73+
if isRoot then
74+
if let some constInfo := (← getEnv).find? name then
75+
let shouldExtract := match constInfo with
76+
| .defnInfo val => val.type.isForall
77+
| .ctorInfo _ => !(args.all isIrrelevantArg)
78+
| _ => true
79+
if !shouldExtract then
80+
return false
81+
args.allM shouldExtractArg
82+
| .fvar fnVar args => return (← shouldExtractFVar fnVar) && (← args.allM shouldExtractArg)
83+
| .proj _ _ baseVar => shouldExtractFVar baseVar
84+
85+
partial def shouldExtractArg (arg : Arg) : M Bool := do
86+
match arg with
87+
| .fvar fvarId => shouldExtractFVar fvarId
88+
| .type _ | .erased => return true
89+
90+
partial def shouldExtractFVar (fvarId : FVarId) : M Bool := do
91+
if let some letDecl ← findLetDecl? fvarId then
92+
shouldExtractLetValue false letDecl.value
93+
else
94+
return false
95+
96+
end
97+
98+
mutual
99+
100+
partial def visitCode (code : Code) : M Code := do
101+
match code with
102+
| .let decl k =>
103+
if (← shouldExtractLetValue true decl.value) then
104+
let ⟨_, decls⟩ ← extractLetValue decl.value |>.run {}
105+
let decls := decls.reverse.push (.let decl)
106+
let decls ← decls.mapM Internalize.internalizeCodeDecl |>.run' {}
107+
let closedCode := attachCodeDecls decls (.return decls.back!.fvarId)
108+
let closedExpr := closedCode.toExpr
109+
let env ← getEnv
110+
let name ← if let some closedTermName := getClosedTermName? env closedExpr then
111+
eraseCode closedCode
112+
pure closedTermName
113+
else
114+
let name := (← read).baseName ++ (`_closedTerm).appendIndexAfter (← get).decls.size
115+
cacheClosedTermName env closedExpr name |> setEnv
116+
let decl := { name, levelParams := [], type := decl.type, params := #[],
117+
value := .code closedCode, inlineAttr? := some .noinline }
118+
decl.saveMono
119+
modify fun s => { s with decls := s.decls.push decl }
120+
pure name
121+
let decl ← decl.updateValue (.const name [] #[])
122+
return code.updateLet! decl (← visitCode k)
123+
else
124+
return code.updateLet! decl (← visitCode k)
125+
| .fun decl k =>
126+
let decl ← decl.updateValue (← visitCode decl.value)
127+
return code.updateFun! decl (← visitCode k)
128+
| .jp decl k =>
129+
let decl ← decl.updateValue (← visitCode decl.value)
130+
return code.updateFun! decl (← visitCode k)
131+
| .cases cases =>
132+
let alts ← cases.alts.mapM (fun alt => do return alt.updateCode (← visitCode alt.getCode))
133+
return code.updateAlts! alts
134+
| .jmp .. | .return _ | .unreach .. => return code
135+
136+
end
137+
138+
def visitDecl (decl : Decl) : M Decl := do
139+
let value ← decl.value.mapCodeM visitCode
140+
return { decl with value }
141+
142+
end ExtractClosed
143+
144+
partial def Decl.extractClosed (decl : Decl) (sccDecls : Array Decl) : CompilerM (Array Decl) := do
145+
let ⟨decl, s⟩ ← ExtractClosed.visitDecl decl |>.run { baseName := decl.name, sccDecls } |>.run {}
146+
return s.decls.push decl
147+
148+
def extractClosed : Pass where
149+
phase := .mono
150+
name := `extractClosed
151+
run := fun decls =>
152+
decls.foldlM (init := #[]) fun newDecls decl => return newDecls ++ (← decl.extractClosed decls)
153+
154+
builtin_initialize registerTraceClass `Compiler.extractClosed (inherited := true)
155+
156+
end Lean.Compiler.LCNF

src/Lean/Compiler/LCNF/Passes.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import Lean.Compiler.LCNF.FloatLetIn
1919
import Lean.Compiler.LCNF.ReduceArity
2020
import Lean.Compiler.LCNF.ElimDeadBranches
2121
import Lean.Compiler.LCNF.StructProjCases
22+
import Lean.Compiler.LCNF.ExtractClosed
2223

2324
namespace Lean.Compiler.LCNF
2425

0 commit comments

Comments
 (0)