Skip to content

Commit b8c941d

Browse files
authored
chore: use mutual inductives for data structures in the new compiler (#8332)
This PR changes the types `AltCore`, `FunDeclCore` and `CasesCore` used in the IRs of the new compiler into the mutual inductives `Alt`, `FunDecl` and `Cases`.
1 parent 995fa47 commit b8c941d

File tree

10 files changed

+56
-55
lines changed

10 files changed

+56
-55
lines changed

src/Lean/Compiler/IR/Basic.lean

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,11 @@ structure Param where
245245
@[export lean_ir_mk_param]
246246
def mkParam (x : VarId) (borrow : Bool) (ty : IRType) : Param := ⟨x, borrow, ty⟩
247247

248-
inductive AltCore (FnBody : Type) : Type where
249-
| ctor (info : CtorInfo) (b : FnBody) : AltCore FnBody
250-
| default (b : FnBody) : AltCore FnBody
248+
mutual
249+
250+
inductive Alt where
251+
| ctor (info : CtorInfo) (b : FnBody) : Alt
252+
| default (b : FnBody) : Alt
251253

252254
inductive FnBody where
253255
/-- `let x : ty := e; b` -/
@@ -271,12 +273,14 @@ inductive FnBody where
271273
| dec (x : VarId) (n : Nat) (c : Bool) (persistent : Bool) (b : FnBody)
272274
| del (x : VarId) (b : FnBody)
273275
| mdata (d : MData) (b : FnBody)
274-
| case (tid : Name) (x : VarId) (xType : IRType) (cs : Array (AltCore FnBody))
276+
| case (tid : Name) (x : VarId) (xType : IRType) (cs : Array Alt)
275277
| ret (x : Arg)
276278
/-- Jump to join point `j` -/
277279
| jmp (j : JoinPointId) (ys : Array Arg)
278280
| unreachable
279281

282+
end
283+
280284
instance : Inhabited FnBody := ⟨FnBody.unreachable⟩
281285

282286
abbrev FnBody.nil := FnBody.unreachable
@@ -285,17 +289,13 @@ abbrev FnBody.nil := FnBody.unreachable
285289
@[export lean_ir_mk_jdecl] def mkJDecl (j : JoinPointId) (xs : Array Param) (v : FnBody) (b : FnBody) : FnBody := FnBody.jdecl j xs v b
286290
@[export lean_ir_mk_uset] def mkUSet (x : VarId) (i : Nat) (y : VarId) (b : FnBody) : FnBody := FnBody.uset x i y b
287291
@[export lean_ir_mk_sset] def mkSSet (x : VarId) (i : Nat) (offset : Nat) (y : VarId) (ty : IRType) (b : FnBody) : FnBody := FnBody.sset x i offset y ty b
288-
@[export lean_ir_mk_case] def mkCase (tid : Name) (x : VarId) (cs : Array (AltCore FnBody)) : FnBody :=
292+
@[export lean_ir_mk_case] def mkCase (tid : Name) (x : VarId) (cs : Array Alt) : FnBody :=
289293
-- Type field `xType` is set by `explicitBoxing` compiler pass.
290294
FnBody.case tid x IRType.object cs
291295
@[export lean_ir_mk_ret] def mkRet (x : Arg) : FnBody := FnBody.ret x
292296
@[export lean_ir_mk_jmp] def mkJmp (j : JoinPointId) (ys : Array Arg) : FnBody := FnBody.jmp j ys
293297
@[export lean_ir_mk_unreachable] def mkUnreachable : Unit → FnBody := fun _ => FnBody.unreachable
294298

295-
abbrev Alt := AltCore FnBody
296-
@[match_pattern] abbrev Alt.ctor := @AltCore.ctor FnBody
297-
@[match_pattern] abbrev Alt.default := @AltCore.default FnBody
298-
299299
instance : Inhabited Alt := ⟨Alt.default default⟩
300300

301301
def FnBody.isTerminal : FnBody → Bool
@@ -341,19 +341,19 @@ def FnBody.setBody : FnBody → FnBody → FnBody
341341
let c := b.resetBody
342342
(c, b')
343343

344-
def AltCore.body : Alt → FnBody
344+
def Alt.body : Alt → FnBody
345345
| Alt.ctor _ b => b
346346
| Alt.default b => b
347347

348-
def AltCore.setBody : Alt → FnBody → Alt
348+
def Alt.setBody : Alt → FnBody → Alt
349349
| Alt.ctor c _, b => Alt.ctor c b
350350
| Alt.default _, b => Alt.default b
351351

352-
@[inline] def AltCore.modifyBody (f : FnBody → FnBody) : AltCore FnBody → Alt
352+
@[inline] def Alt.modifyBody (f : FnBody → FnBody) : Alt → Alt
353353
| Alt.ctor c b => Alt.ctor c (f b)
354354
| Alt.default b => Alt.default (f b)
355355

356-
@[inline] def AltCore.mmodifyBody {m : TypeType} [Monad m] (f : FnBody → m FnBody) : AltCore FnBody → m Alt
356+
@[inline] def Alt.mmodifyBody {m : TypeType} [Monad m] (f : FnBody → m FnBody) : Alt → m Alt
357357
| Alt.ctor c b => Alt.ctor c <$> f b
358358
| Alt.default b => Alt.default <$> f b
359359

src/Lean/Compiler/IR/ToIR.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import Lean.Environment
1515

1616
namespace Lean.IR
1717

18-
open Lean.Compiler (LCNF.AltCore LCNF.Arg LCNF.Code LCNF.Decl LCNF.DeclValue LCNF.LCtx LCNF.LetDecl
18+
open Lean.Compiler (LCNF.Alt LCNF.Arg LCNF.Code LCNF.Decl LCNF.DeclValue LCNF.LCtx LCNF.LetDecl
1919
LCNF.LetValue LCNF.LitValue LCNF.Param LCNF.getMonoDecl?)
2020

2121
namespace ToIR
@@ -346,7 +346,7 @@ partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do
346346
| some (.joinPoint ..) | none => panic! "unexpected value"
347347
| .erased => mkErased ()
348348

349-
partial def lowerAlt (discr : VarId) (a : LCNF.AltCore LCNF.Code) : M (AltCore FnBody) := do
349+
partial def lowerAlt (discr : VarId) (a : LCNF.Alt) : M Alt := do
350350
match a with
351351
| .alt ctorName params code =>
352352
let ⟨ctorInfo, fields⟩ ← getCtorInfo ctorName

src/Lean/Compiler/LCNF/Basic.lean

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@ structure Param where
3333
def Param.toExpr (p : Param) : Expr :=
3434
.fvar p.fvarId
3535

36-
inductive AltCore (Code : Type) where
37-
| alt (ctorName : Name) (params : Array Param) (code : Code)
38-
| default (code : Code)
39-
deriving Inhabited
40-
4136
inductive LitValue where
4237
| natVal (val : Nat)
4338
| strVal (val : String)
@@ -136,42 +131,48 @@ structure LetDecl where
136131
value : LetValue
137132
deriving Inhabited, BEq
138133

139-
structure FunDeclCore (Code : Type) where
134+
mutual
135+
136+
inductive Alt where
137+
| alt (ctorName : Name) (params : Array Param) (code : Code)
138+
| default (code : Code)
139+
140+
structure FunDecl where
140141
fvarId : FVarId
141142
binderName : Name
142143
params : Array Param
143144
type : Expr
144145
value : Code
145-
deriving Inhabited
146-
147-
def FunDeclCore.getArity (decl : FunDeclCore Code) : Nat :=
148-
decl.params.size
149146

150-
structure CasesCore (Code : Type) where
147+
structure Cases where
151148
typeName : Name
152149
resultType : Expr
153150
discr : FVarId
154-
alts : Array (AltCore Code)
151+
alts : Array Alt
155152
deriving Inhabited
156153

157154
inductive Code where
158155
| let (decl : LetDecl) (k : Code)
159-
| fun (decl : FunDeclCore Code) (k : Code)
160-
| jp (decl : FunDeclCore Code) (k : Code)
156+
| fun (decl : FunDecl) (k : Code)
157+
| jp (decl : FunDecl) (k : Code)
161158
| jmp (fvarId : FVarId) (args : Array Arg)
162-
| cases (cases : CasesCore Code)
159+
| cases (cases : Cases)
163160
| return (fvarId : FVarId)
164161
| unreach (type : Expr)
165162
deriving Inhabited
166163

167-
abbrev Alt := AltCore Code
168-
abbrev FunDecl := FunDeclCore Code
169-
abbrev Cases := CasesCore Code
164+
end
165+
166+
deriving instance Inhabited for Alt
167+
deriving instance Inhabited for FunDecl
168+
169+
def FunDecl.getArity (decl : FunDecl) : Nat :=
170+
decl.params.size
170171

171172
/--
172173
Return the constructor names that have an explicit (non-default) alternative.
173174
-/
174-
def CasesCore.getCtorNames (c : Cases) : NameSet :=
175+
def Cases.getCtorNames (c : Cases) : NameSet :=
175176
c.alts.foldl (init := {}) fun ctorNames alt =>
176177
match alt with
177178
| .default _ => ctorNames
@@ -241,15 +242,15 @@ instance : BEq Code where
241242
instance : BEq FunDecl where
242243
beq := FunDecl.beq
243244

244-
def AltCore.getCode : Alt → Code
245+
def Alt.getCode : Alt → Code
245246
| .default k => k
246247
| .alt _ _ k => k
247248

248-
def AltCore.getParams : Alt → Array Param
249+
def Alt.getParams : Alt → Array Param
249250
| .default _ => #[]
250251
| .alt _ ps _ => ps
251252

252-
def AltCore.forCodeM [Monad m] (alt : Alt) (f : Code → m Unit) : m Unit := do
253+
def Alt.forCodeM [Monad m] (alt : Alt) (f : Code → m Unit) : m Unit := do
253254
match alt with
254255
| .default k => f k
255256
| .alt _ _ k => f k
@@ -259,14 +260,14 @@ private unsafe def updateAltCodeImp (alt : Alt) (k' : Code) : Alt :=
259260
| .default k => if ptrEq k k' then alt else .default k'
260261
| .alt ctorName ps k => if ptrEq k k' then alt else .alt ctorName ps k'
261262

262-
@[implemented_by updateAltCodeImp] opaque AltCore.updateCode (alt : Alt) (c : Code) : Alt
263+
@[implemented_by updateAltCodeImp] opaque Alt.updateCode (alt : Alt) (c : Code) : Alt
263264

264265
private unsafe def updateAltImp (alt : Alt) (ps' : Array Param) (k' : Code) : Alt :=
265266
match alt with
266267
| .alt ctorName ps k => if ptrEq k k' && ptrEq ps ps' then alt else .alt ctorName ps' k'
267268
| _ => unreachable!
268269

269-
@[implemented_by updateAltImp] opaque AltCore.updateAlt! (alt : Alt) (ps' : Array Param) (k' : Code) : Alt
270+
@[implemented_by updateAltImp] opaque Alt.updateAlt! (alt : Alt) (ps' : Array Param) (k' : Code) : Alt
270271

271272
@[inline] private unsafe def updateAltsImp (c : Code) (alts : Array Alt) : Code :=
272273
match c with
@@ -364,9 +365,9 @@ Low-level update `FunDecl` function. It does not update the local context.
364365
Consider using `FunDecl.update : LetDecl → Expr → Array Param → Code → CompilerM FunDecl` if you want the local context
365366
to be updated.
366367
-/
367-
@[implemented_by updateFunDeclCoreImp] opaque FunDeclCore.updateCore (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl
368+
@[implemented_by updateFunDeclCoreImp] opaque FunDecl.updateCore (decl : FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl
368369

369-
def CasesCore.extractAlt! (cases : Cases) (ctorName : Name) : Alt × Cases :=
370+
def Cases.extractAlt! (cases : Cases) (ctorName : Name) : Alt × Cases :=
370371
let found i := (cases.alts[i], { cases with alts := cases.alts.eraseIdx i })
371372
if let some i := cases.alts.findFinIdx? fun | .alt ctorName' .. => ctorName == ctorName' | _ => false then
372373
found i
@@ -375,7 +376,7 @@ def CasesCore.extractAlt! (cases : Cases) (ctorName : Name) : Alt × Cases :=
375376
else
376377
unreachable!
377378

378-
def AltCore.mapCodeM [Monad m] (alt : Alt) (f : Code → m Code) : m Alt := do
379+
def Alt.mapCodeM [Monad m] (alt : Alt) (f : Code → m Code) : m Alt := do
379380
return alt.updateCode (← f alt.getCode)
380381

381382
def Code.isDecl : Code → Bool
@@ -678,7 +679,7 @@ private partial def collectParams (ps : Array Param) (s : FVarIdSet) : FVarIdSet
678679
ps.foldl (init := s) fun s p => collectType p.type s
679680

680681
mutual
681-
partial def FunDeclCore.collectUsed (decl : FunDecl) (s : FVarIdSet := {}) : FVarIdSet :=
682+
partial def FunDecl.collectUsed (decl : FunDecl) (s : FVarIdSet := {}) : FVarIdSet :=
682683
decl.value.collectUsed <| collectParams decl.params <| collectType decl.type s
683684

684685
partial def Code.collectUsed (code : Code) (s : FVarIdSet := {}) : FVarIdSet :=

src/Lean/Compiler/LCNF/Bind.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def isEtaExpandCandidateCore (type : Expr) (params : Array Param) : Bool :=
9999
let valueArity := params.size
100100
typeArity > valueArity
101101

102-
abbrev FunDeclCore.isEtaExpandCandidate (decl : FunDecl) : Bool :=
102+
abbrev FunDecl.isEtaExpandCandidate (decl : FunDecl) : Bool :=
103103
isEtaExpandCandidateCore decl.type decl.params
104104

105105
def etaExpandCore (type : Expr) (params : Array Param) (value : Code) : CompilerM (Array Param × Code) := do
@@ -118,7 +118,7 @@ def etaExpandCore? (type : Expr) (params : Array Param) (value : Code) : Compile
118118
else
119119
return none
120120

121-
def FunDeclCore.etaExpand (decl : FunDecl) : CompilerM FunDecl := do
121+
def FunDecl.etaExpand (decl : FunDecl) : CompilerM FunDecl := do
122122
let some (params, value) ← etaExpandCore? decl.type decl.params decl.value | return decl
123123
decl.update decl.type params value
124124

src/Lean/Compiler/LCNF/CompilerM.lean

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -395,20 +395,20 @@ private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : LetV
395395
def LetDecl.updateValue (decl : LetDecl) (value : LetValue) : CompilerM LetDecl :=
396396
decl.update decl.type value
397397

398-
private unsafe def updateFunDeclImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
398+
private unsafe def updateFunDeclImp (decl : FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
399399
if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then
400400
return decl
401401
else
402402
let decl := { decl with type, params, value }
403403
modifyLCtx fun lctx => lctx.addFunDecl decl
404404
return decl
405405

406-
@[implemented_by updateFunDeclImp] opaque FunDeclCore.update (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl
406+
@[implemented_by updateFunDeclImp] opaque FunDecl.update (decl : FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl
407407

408-
abbrev FunDeclCore.update' (decl : FunDecl) (type : Expr) (value : Code) : CompilerM FunDecl :=
408+
abbrev FunDecl.update' (decl : FunDecl) (type : Expr) (value : Code) : CompilerM FunDecl :=
409409
decl.update type decl.params value
410410

411-
abbrev FunDeclCore.updateValue (decl : FunDecl) (value : Code) : CompilerM FunDecl :=
411+
abbrev FunDecl.updateValue (decl : FunDecl) (value : Code) : CompilerM FunDecl :=
412412
decl.update decl.type decl.params value
413413

414414
@[inline] def normParam [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (p : Param) : m Param := do

src/Lean/Compiler/LCNF/InferType.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def Code.inferParamType (params : Array Param) (code : Code) : CompilerM Expr :=
262262
let xs := params.map fun p => .fvar p.fvarId
263263
InferType.mkForallFVars xs type |>.run {}
264264

265-
def AltCore.inferType (alt : Alt) : CompilerM Expr :=
265+
def Alt.inferType (alt : Alt) : CompilerM Expr :=
266266
alt.getCode.inferType
267267

268268
def mkAuxLetDecl (e : LetValue) (prefixName := `_x) : CompilerM LetDecl := do

src/Lean/Compiler/LCNF/Probing.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ where
5454
| .fun decl k | .jp decl k =>
5555
go decl.value
5656
go k
57-
| .cases (cases : CasesCore Code) => cases.alts.forM (go ·.getCode)
57+
| .cases cs => cs.alts.forM (go ·.getCode)
5858
| .jmp .. | .return .. | .unreach .. => return ()
5959
start (decls : Array Decl) : StateRefT (Array LetValue) CompilerM Unit :=
6060
decls.forM (·.value.forCodeM go)

src/Lean/Compiler/LCNF/Renaming.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def LetDecl.applyRenaming (decl : LetDecl) (r : Renaming) : CompilerM LetDecl :=
2929
return decl
3030

3131
mutual
32-
partial def FunDeclCore.applyRenaming (decl : FunDecl) (r : Renaming) : CompilerM FunDecl := do
32+
partial def FunDecl.applyRenaming (decl : FunDecl) (r : Renaming) : CompilerM FunDecl := do
3333
if let some binderName := r.find? decl.fvarId then
3434
let decl := { decl with binderName }
3535
modifyLCtx fun lctx => lctx.addFunDecl decl

src/Lean/Compiler/LCNF/ToExpr.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ private def Arg.toExprM (arg : Arg) : ToExprM Expr :=
7979
return arg.toExpr.abstract' (← read) (← get)
8080

8181
mutual
82-
partial def FunDeclCore.toExprM (decl : FunDecl) : ToExprM Expr :=
82+
partial def FunDecl.toExprM (decl : FunDecl) : ToExprM Expr :=
8383
withParams decl.params do mkLambdaM decl.params (← decl.value.toExprM)
8484

8585
partial def Code.toExprM (code : Code) : ToExprM Expr := do
@@ -107,7 +107,7 @@ end
107107
def Code.toExpr (code : Code) (xs : Array FVarId := #[]) : Expr :=
108108
run' code.toExprM xs
109109

110-
def FunDeclCore.toExpr (decl : FunDecl) (xs : Array FVarId := #[]) : Expr :=
110+
def FunDecl.toExpr (decl : FunDecl) (xs : Array FVarId := #[]) : Expr :=
111111
run' decl.toExprM xs
112112

113113
end Lean.Compiler.LCNF

src/Lean/Compiler/LCNF/ToMono.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def LetDecl.toMono (decl : LetDecl) : ToMonoM LetDecl := do
8383

8484
mutual
8585

86-
partial def FunDeclCore.toMono (decl : FunDecl) : ToMonoM FunDecl := do
86+
partial def FunDecl.toMono (decl : FunDecl) : ToMonoM FunDecl := do
8787
let type ← toMonoType decl.type
8888
let params ← decl.params.mapM (·.toMono)
8989
let value ← decl.value.toMono

0 commit comments

Comments
 (0)