|
| 1 | +/- |
| 2 | +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. |
| 3 | +Released under Apache 2.0 license as described in the file LICENSE. |
| 4 | +Authors: Cameron Zwarich |
| 5 | +-/ |
| 6 | +prelude |
| 7 | +import Lean.Compiler.ClosedTermCache |
| 8 | +import Lean.Compiler.NeverExtractAttr |
| 9 | +import Lean.Compiler.LCNF.Basic |
| 10 | +import Lean.Compiler.LCNF.InferType |
| 11 | +import Lean.Compiler.LCNF.Internalize |
| 12 | +import Lean.Compiler.LCNF.MonoTypes |
| 13 | +import Lean.Compiler.LCNF.PassManager |
| 14 | +import Lean.Compiler.LCNF.ToExpr |
| 15 | + |
| 16 | +namespace Lean.Compiler.LCNF |
| 17 | +namespace ExtractClosed |
| 18 | + |
| 19 | +abbrev ExtractM := StateRefT (Array CodeDecl) CompilerM |
| 20 | + |
| 21 | +mutual |
| 22 | + |
| 23 | +partial def extractLetValue (v : LetValue) : ExtractM Unit := do |
| 24 | + match v with |
| 25 | + | .const _ _ args => args.forM extractArg |
| 26 | + | .fvar fnVar args => |
| 27 | + extractFVar fnVar |
| 28 | + args.forM extractArg |
| 29 | + | .proj _ _ baseVar => extractFVar baseVar |
| 30 | + | .lit _ | .erased => return () |
| 31 | + |
| 32 | +partial def extractArg (arg : Arg) : ExtractM Unit := do |
| 33 | + match arg with |
| 34 | + | .fvar fvarId => extractFVar fvarId |
| 35 | + | .type _ | .erased => return () |
| 36 | + |
| 37 | +partial def extractFVar (fvarId : FVarId) : ExtractM Unit := do |
| 38 | + if let some letDecl ← findLetDecl? fvarId then |
| 39 | + modify fun decls => decls.push (.let letDecl) |
| 40 | + extractLetValue letDecl.value |
| 41 | + |
| 42 | +end |
| 43 | + |
| 44 | +def isIrrelevantArg (arg : Arg) : Bool := |
| 45 | + match arg with |
| 46 | + | .erased | .type _ => true |
| 47 | + | .fvar _ => false |
| 48 | + |
| 49 | +structure Context where |
| 50 | + baseName : Name |
| 51 | + sccDecls : Array Decl |
| 52 | + |
| 53 | +structure State where |
| 54 | + decls : Array Decl := {} |
| 55 | + |
| 56 | +abbrev M := ReaderT Context $ StateRefT State CompilerM |
| 57 | + |
| 58 | +mutual |
| 59 | + |
| 60 | +partial def shouldExtractLetValue (isRoot : Bool) (v : LetValue) : M Bool := do |
| 61 | + match v with |
| 62 | + | .lit (.str _) => return true |
| 63 | + | .lit (.nat v) => |
| 64 | + -- The old compiler's implementation used the runtime's `is_scalar` function, which |
| 65 | + -- introduces a dependency on the architecture used by the compiler. |
| 66 | + return v >= Nat.pow 2 63 |
| 67 | + | .lit _ | .erased => return !isRoot |
| 68 | + | .const name _ args => |
| 69 | + if (← read).sccDecls.any (·.name == name) then |
| 70 | + return false |
| 71 | + if hasNeverExtractAttribute (← getEnv) name then |
| 72 | + return false |
| 73 | + if isRoot then |
| 74 | + if let some constInfo := (← getEnv).find? name then |
| 75 | + let shouldExtract := match constInfo with |
| 76 | + | .defnInfo val => val.type.isForall |
| 77 | + | .ctorInfo _ => !(args.all isIrrelevantArg) |
| 78 | + | _ => true |
| 79 | + if !shouldExtract then |
| 80 | + return false |
| 81 | + args.allM shouldExtractArg |
| 82 | + | .fvar fnVar args => return (← shouldExtractFVar fnVar) && (← args.allM shouldExtractArg) |
| 83 | + | .proj _ _ baseVar => shouldExtractFVar baseVar |
| 84 | + |
| 85 | +partial def shouldExtractArg (arg : Arg) : M Bool := do |
| 86 | + match arg with |
| 87 | + | .fvar fvarId => shouldExtractFVar fvarId |
| 88 | + | .type _ | .erased => return true |
| 89 | + |
| 90 | +partial def shouldExtractFVar (fvarId : FVarId) : M Bool := do |
| 91 | + if let some letDecl ← findLetDecl? fvarId then |
| 92 | + shouldExtractLetValue false letDecl.value |
| 93 | + else |
| 94 | + return false |
| 95 | + |
| 96 | +end |
| 97 | + |
| 98 | +mutual |
| 99 | + |
| 100 | +partial def visitCode (code : Code) : M Code := do |
| 101 | + match code with |
| 102 | + | .let decl k => |
| 103 | + if (← shouldExtractLetValue true decl.value) then |
| 104 | + let ⟨_, decls⟩ ← extractLetValue decl.value |>.run {} |
| 105 | + let decls := decls.reverse.push (.let decl) |
| 106 | + let decls ← decls.mapM Internalize.internalizeCodeDecl |>.run' {} |
| 107 | + let closedCode := attachCodeDecls decls (.return decls.back!.fvarId) |
| 108 | + let closedExpr := closedCode.toExpr |
| 109 | + let env ← getEnv |
| 110 | + let name ← if let some closedTermName := getClosedTermName? env closedExpr then |
| 111 | + eraseCode closedCode |
| 112 | + pure closedTermName |
| 113 | + else |
| 114 | + let name := (← read).baseName ++ (`_closedTerm).appendIndexAfter (← get).decls.size |
| 115 | + cacheClosedTermName env closedExpr name |> setEnv |
| 116 | + let decl := { name, levelParams := [], type := decl.type, params := #[], |
| 117 | + value := .code closedCode, inlineAttr? := some .noinline } |
| 118 | + decl.saveMono |
| 119 | + modify fun s => { s with decls := s.decls.push decl } |
| 120 | + pure name |
| 121 | + let decl ← decl.updateValue (.const name [] #[]) |
| 122 | + return code.updateLet! decl (← visitCode k) |
| 123 | + else |
| 124 | + return code.updateLet! decl (← visitCode k) |
| 125 | + | .fun decl k => |
| 126 | + let decl ← decl.updateValue (← visitCode decl.value) |
| 127 | + return code.updateFun! decl (← visitCode k) |
| 128 | + | .jp decl k => |
| 129 | + let decl ← decl.updateValue (← visitCode decl.value) |
| 130 | + return code.updateFun! decl (← visitCode k) |
| 131 | + | .cases cases => |
| 132 | + let alts ← cases.alts.mapM (fun alt => do return alt.updateCode (← visitCode alt.getCode)) |
| 133 | + return code.updateAlts! alts |
| 134 | + | .jmp .. | .return _ | .unreach .. => return code |
| 135 | + |
| 136 | +end |
| 137 | + |
| 138 | +def visitDecl (decl : Decl) : M Decl := do |
| 139 | + let value ← decl.value.mapCodeM visitCode |
| 140 | + return { decl with value } |
| 141 | + |
| 142 | +end ExtractClosed |
| 143 | + |
| 144 | +partial def Decl.extractClosed (decl : Decl) (sccDecls : Array Decl) : CompilerM (Array Decl) := do |
| 145 | + let ⟨decl, s⟩ ← ExtractClosed.visitDecl decl |>.run { baseName := decl.name, sccDecls } |>.run {} |
| 146 | + return s.decls.push decl |
| 147 | + |
| 148 | +def extractClosed : Pass where |
| 149 | + phase := .mono |
| 150 | + name := `extractClosed |
| 151 | + run := fun decls => |
| 152 | + decls.foldlM (init := #[]) fun newDecls decl => return newDecls ++ (← decl.extractClosed decls) |
| 153 | + |
| 154 | +builtin_initialize registerTraceClass `Compiler.extractClosed (inherited := true) |
| 155 | + |
| 156 | +end Lean.Compiler.LCNF |
0 commit comments