Skip to content

Commit 632b688

Browse files
authored
feat: add an LCNF pass to convert structure projections to cases expressions (#8367)
This PR adds a new `structProjCases` pass to the new compiler, analogous to the `struct_cases_on` pass in the old compiler, which converts all projections from structs into `cases` expressions. When lowered to IR, this causes all of the projections from a single structure to be grouped together, which is an invariant relied upon by the IR RC passes (at least for linearity, if not general correctness).
1 parent c5335b6 commit 632b688

File tree

2 files changed

+135
-0
lines changed

2 files changed

+135
-0
lines changed

src/Lean/Compiler/LCNF/Passes.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import Lean.Compiler.LCNF.LambdaLifting
1818
import Lean.Compiler.LCNF.FloatLetIn
1919
import Lean.Compiler.LCNF.ReduceArity
2020
import Lean.Compiler.LCNF.ElimDeadBranches
21+
import Lean.Compiler.LCNF.StructProjCases
2122

2223
namespace Lean.Compiler.LCNF
2324

@@ -76,6 +77,7 @@ def builtinPassManager : PassManager := {
7677
lambdaLifting,
7778
extendJoinPointContext (phase := .mono) (occurrence := 1),
7879
simp (occurrence := 5) (phase := .mono),
80+
structProjCases,
7981
cse (occurrence := 2) (phase := .mono),
8082
saveMono -- End of mono phase
8183
]
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/-
2+
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Cameron Zwarich
5+
-/
6+
prelude
7+
import Lean.Compiler.LCNF.Basic
8+
import Lean.Compiler.LCNF.InferType
9+
import Lean.Compiler.LCNF.MonoTypes
10+
import Lean.Compiler.LCNF.PassManager
11+
import Lean.Compiler.LCNF.PrettyPrinter
12+
13+
namespace Lean.Compiler.LCNF
14+
namespace StructProjCases
15+
16+
def findStructCtorInfo? (typeName : Name) : CoreM (Option ConstructorVal) := do
17+
let .inductInfo info ← getConstInfo typeName | return none
18+
let [ctorName] := info.ctors | return none
19+
let some (.ctorInfo ctorInfo) := (← getEnv).find? ctorName | return none
20+
return ctorInfo
21+
22+
def mkFieldParamsForCtorType (e : Expr) (numParams : Nat): CompilerM (Array Param) := do
23+
let rec loop (params : Array Param) (e : Expr) (numParams : Nat): CompilerM (Array Param) := do
24+
match e with
25+
| .forallE name type body _ =>
26+
if numParams == 0 then
27+
let param ← mkParam name (← toMonoType type) false
28+
loop (params.push param) body numParams
29+
else
30+
loop params body (numParams - 1)
31+
| _ => return params
32+
loop #[] e numParams
33+
34+
structure StructProjState where
35+
projMap : Std.HashMap FVarId (Array FVarId) := {}
36+
fvarMap : Std.HashMap FVarId FVarId := {}
37+
38+
abbrev M := StateRefT StructProjState CompilerM
39+
40+
def M.run (x : M α) : CompilerM α := do
41+
x.run' {}
42+
43+
def remapFVar (fvarId : FVarId) : M FVarId := do
44+
match (← get).fvarMap[fvarId]? with
45+
| some newFvarId => return newFvarId
46+
| none => return fvarId
47+
48+
mutual
49+
50+
partial def visitCode (code : Code) : M Code := do
51+
match code with
52+
| .let decl k =>
53+
match decl.value with
54+
| .proj typeName i base =>
55+
eraseLetDecl decl
56+
let base ← remapFVar base
57+
if let some projVars := (← get).projMap[base]? then
58+
modify fun s => { s with fvarMap := s.fvarMap.insert decl.fvarId projVars[i]! }
59+
visitCode k
60+
else
61+
let some ctorInfo ← findStructCtorInfo? typeName | panic! "expected struct constructor"
62+
let params ← mkFieldParamsForCtorType ctorInfo.type ctorInfo.numParams
63+
assert! params.size == ctorInfo.numFields
64+
let fvars := params.map (·.fvarId)
65+
modify fun s => { s with projMap := s.projMap.insert base fvars,
66+
fvarMap := s.fvarMap.insert decl.fvarId fvars[i]! }
67+
let k ← visitCode k
68+
modify fun s => { s with projMap := s.projMap.erase base }
69+
let resultType ← toMonoType (← k.inferType)
70+
let alts := #[.alt ctorInfo.name params k]
71+
return .cases { typeName, resultType, discr := base, alts }
72+
| _ => return code.updateLet! (← decl.updateValue (← visitLetValue decl.value)) (← visitCode k)
73+
| .fun decl k =>
74+
let decl ← decl.updateValue (← visitCode decl.value)
75+
return code.updateFun! decl (← visitCode k)
76+
| .jp decl k =>
77+
let decl ← decl.updateValue (← visitCode decl.value)
78+
return code.updateFun! decl (← visitCode k)
79+
| .jmp fvarId args =>
80+
return code.updateJmp! (← remapFVar fvarId) (← args.mapM visitArg)
81+
| .cases cases =>
82+
let discr ← remapFVar cases.discr
83+
if let #[.alt ctorName params k] := cases.alts then
84+
if let some projVars := (← get).projMap[discr]? then
85+
assert! projVars.size == params.size
86+
for param in params, projVar in projVars do
87+
modify fun s => { s with fvarMap := s.fvarMap.insert param.fvarId projVar }
88+
eraseParam param
89+
visitCode k
90+
else
91+
let fvars := params.map (·.fvarId)
92+
modify fun s => { s with projMap := s.projMap.insert discr fvars }
93+
let k ← visitCode k
94+
modify fun s => { s with projMap := s.projMap.erase discr }
95+
-- TODO: This should preserve the .alt allocation, but binding it to
96+
-- a variable above while also destructuring an array doesn't work.
97+
return code.updateCases! cases.resultType discr #[.alt ctorName params k]
98+
else
99+
let alts ← cases.alts.mapM (visitAlt ·)
100+
return code.updateCases! cases.resultType discr alts
101+
| .return fvarId => return code.updateReturn! (← remapFVar fvarId)
102+
| .unreach .. => return code
103+
104+
partial def visitLetValue (v : LetValue) : M LetValue := do
105+
match v with
106+
| .const _ _ args =>
107+
return v.updateArgs! (← args.mapM visitArg)
108+
| .fvar fvarId args =>
109+
return v.updateFVar! (← remapFVar fvarId) (← args.mapM visitArg)
110+
| .value _ | .erased => return v
111+
-- Projections should be handled directly by `visitCode`.
112+
| .proj .. => unreachable!
113+
114+
partial def visitAlt (alt : Alt) : M Alt := do
115+
return alt.updateCode (← visitCode alt.getCode)
116+
117+
partial def visitArg (arg : Arg) : M Arg :=
118+
match arg with
119+
| .fvar fvarId => return arg.updateFVar! (← remapFVar fvarId)
120+
| .type _ | .erased => return arg
121+
122+
end
123+
124+
def visitDecl (decl : Decl) : M Decl := do
125+
let value ← decl.value.mapCodeM (visitCode ·)
126+
return { decl with value }
127+
128+
end StructProjCases
129+
130+
def structProjCases : Pass :=
131+
.mkPerDeclaration `structProjCases (StructProjCases.visitDecl · |>.run) .mono
132+
133+
end Lean.Compiler.LCNF

0 commit comments

Comments
 (0)