Skip to content

Commit 7b80cd2

Browse files
authored
feat: closed term extraction in the new compiler (#8458)
This PR adds closed term extraction to the new compiler, closely following the approach in the old compiler. In the future, we will explore some ideas to improve upon this approach.
1 parent 21846eb commit 7b80cd2

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/-
2+
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Cameron Zwarich
5+
-/
6+
prelude
7+
import Lean.Compiler.ClosedTermCache
8+
import Lean.Compiler.NeverExtractAttr
9+
import Lean.Compiler.LCNF.Basic
10+
import Lean.Compiler.LCNF.InferType
11+
import Lean.Compiler.LCNF.Internalize
12+
import Lean.Compiler.LCNF.MonoTypes
13+
import Lean.Compiler.LCNF.PassManager
14+
import Lean.Compiler.LCNF.ToExpr
15+
16+
namespace Lean.Compiler.LCNF
17+
namespace ExtractClosed
18+
19+
abbrev ExtractM := StateRefT (Array CodeDecl) CompilerM
20+
21+
mutual
22+
23+
partial def extractLetValue (v : LetValue) : ExtractM Unit := do
24+
match v with
25+
| .const _ _ args => args.forM extractArg
26+
| .fvar fnVar args =>
27+
extractFVar fnVar
28+
args.forM extractArg
29+
| .proj _ _ baseVar => extractFVar baseVar
30+
| .lit _ | .erased => return ()
31+
32+
partial def extractArg (arg : Arg) : ExtractM Unit := do
33+
match arg with
34+
| .fvar fvarId => extractFVar fvarId
35+
| .type _ | .erased => return ()
36+
37+
partial def extractFVar (fvarId : FVarId) : ExtractM Unit := do
38+
if let some letDecl ← findLetDecl? fvarId then
39+
modify fun decls => decls.push (.let letDecl)
40+
extractLetValue letDecl.value
41+
42+
end
43+
44+
def isIrrelevantArg (arg : Arg) : Bool :=
45+
match arg with
46+
| .erased | .type _ => true
47+
| .fvar _ => false
48+
49+
structure Context where
50+
baseName : Name
51+
sccDecls : Array Decl
52+
53+
structure State where
54+
decls : Array Decl := {}
55+
56+
abbrev M := ReaderT Context $ StateRefT State CompilerM
57+
58+
mutual
59+
60+
partial def shouldExtractLetValue (isRoot : Bool) (v : LetValue) : M Bool := do
61+
match v with
62+
| .lit (.str _) => return true
63+
| .lit (.nat v) =>
64+
-- The old compiler's implementation used the runtime's `is_scalar` function, which
65+
-- introduces a dependency on the architecture used by the compiler.
66+
return v >= Nat.pow 2 63
67+
| .lit _ | .erased => return !isRoot
68+
| .const name _ args =>
69+
if (← read).sccDecls.any (·.name == name) then
70+
return false
71+
if hasNeverExtractAttribute (← getEnv) name then
72+
return false
73+
if isRoot then
74+
if let some constInfo := (← getEnv).find? name then
75+
let shouldExtract := match constInfo with
76+
| .defnInfo val => val.type.isForall
77+
| .ctorInfo _ => !(args.all isIrrelevantArg)
78+
| _ => true
79+
if !shouldExtract then
80+
return false
81+
args.allM shouldExtractArg
82+
| .fvar fnVar args => return (← shouldExtractFVar fnVar) && (← args.allM shouldExtractArg)
83+
| .proj _ _ baseVar => shouldExtractFVar baseVar
84+
85+
partial def shouldExtractArg (arg : Arg) : M Bool := do
86+
match arg with
87+
| .fvar fvarId => shouldExtractFVar fvarId
88+
| .type _ | .erased => return true
89+
90+
partial def shouldExtractFVar (fvarId : FVarId) : M Bool := do
91+
if let some letDecl ← findLetDecl? fvarId then
92+
shouldExtractLetValue false letDecl.value
93+
else
94+
return false
95+
96+
end
97+
98+
mutual
99+
100+
partial def visitCode (code : Code) : M Code := do
101+
match code with
102+
| .let decl k =>
103+
if (← shouldExtractLetValue true decl.value) then
104+
let ⟨_, decls⟩ ← extractLetValue decl.value |>.run {}
105+
let decls := decls.reverse.push (.let decl)
106+
let decls ← decls.mapM Internalize.internalizeCodeDecl |>.run' {}
107+
let closedCode := attachCodeDecls decls (.return decls.back!.fvarId)
108+
let closedExpr := closedCode.toExpr
109+
let env ← getEnv
110+
let name ← if let some closedTermName := getClosedTermName? env closedExpr then
111+
eraseCode closedCode
112+
pure closedTermName
113+
else
114+
let name := (← read).baseName ++ (`_closedTerm).appendIndexAfter (← get).decls.size
115+
cacheClosedTermName env closedExpr name |> setEnv
116+
let decl := { name, levelParams := [], type := decl.type, params := #[],
117+
value := .code closedCode, inlineAttr? := some .noinline }
118+
decl.saveMono
119+
modify fun s => { s with decls := s.decls.push decl }
120+
pure name
121+
let decl ← decl.updateValue (.const name [] #[])
122+
return code.updateLet! decl (← visitCode k)
123+
else
124+
return code.updateLet! decl (← visitCode k)
125+
| .fun decl k =>
126+
let decl ← decl.updateValue (← visitCode decl.value)
127+
return code.updateFun! decl (← visitCode k)
128+
| .jp decl k =>
129+
let decl ← decl.updateValue (← visitCode decl.value)
130+
return code.updateFun! decl (← visitCode k)
131+
| .cases cases =>
132+
let alts ← cases.alts.mapM (fun alt => do return alt.updateCode (← visitCode alt.getCode))
133+
return code.updateAlts! alts
134+
| .jmp .. | .return _ | .unreach .. => return code
135+
136+
end
137+
138+
def visitDecl (decl : Decl) : M Decl := do
139+
let value ← decl.value.mapCodeM visitCode
140+
return { decl with value }
141+
142+
end ExtractClosed
143+
144+
partial def Decl.extractClosed (decl : Decl) (sccDecls : Array Decl) : CompilerM (Array Decl) := do
145+
let ⟨decl, s⟩ ← ExtractClosed.visitDecl decl |>.run { baseName := decl.name, sccDecls } |>.run {}
146+
return s.decls.push decl
147+
148+
def extractClosed : Pass where
149+
phase := .mono
150+
name := `extractClosed
151+
run := fun decls =>
152+
decls.foldlM (init := #[]) fun newDecls decl => return newDecls ++ (← decl.extractClosed decls)
153+
154+
builtin_initialize registerTraceClass `Compiler.extractClosed (inherited := true)
155+
156+
end Lean.Compiler.LCNF

src/Lean/Compiler/LCNF/Passes.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import Lean.Compiler.LCNF.FloatLetIn
1919
import Lean.Compiler.LCNF.ReduceArity
2020
import Lean.Compiler.LCNF.ElimDeadBranches
2121
import Lean.Compiler.LCNF.StructProjCases
22+
import Lean.Compiler.LCNF.ExtractClosed
2223

2324
namespace Lean.Compiler.LCNF
2425

0 commit comments

Comments
 (0)