Skip to content

Commit 8dc7ac4

Browse files
committed
more
1 parent 9c5361d commit 8dc7ac4

File tree

6 files changed

+207
-108
lines changed

6 files changed

+207
-108
lines changed

src/Lean/Compiler/IR/ToIR.lean

Lines changed: 88 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,25 @@ public section
2020
namespace Lean.IR
2121

2222
open Lean.Compiler (LCNF.Alt LCNF.Arg LCNF.Code LCNF.Decl LCNF.DeclValue LCNF.LCtx LCNF.LetDecl
23-
LCNF.LetValue LCNF.LitValue LCNF.Param LCNF.getMonoDecl?)
23+
LCNF.LetValue LCNF.LitValue LCNF.Param LCNF.getMonoDecl? LCNF.FVarSubst LCNF.MonadFVarSubst
24+
LCNF.MonadFVarSubstState LCNF.addSubst LCNF.normLetValue)
2425

2526
namespace ToIR
2627

2728
structure BuilderState where
2829
vars : Std.HashMap FVarId Arg := {}
2930
joinPoints : Std.HashMap FVarId JoinPointId := {}
3031
nextId : Nat := 1
32+
subst : LCNF.FVarSubst := {}
3133

3234
abbrev M := StateRefT BuilderState CoreM
3335

36+
instance : LCNF.MonadFVarSubst M false where
37+
getSubst := return (← get).subst
38+
39+
instance : LCNF.MonadFVarSubstState M where
40+
modifySubst f := modify fun s => { s with subst := f s.subst }
41+
3442
def M.run (x : M α) : CoreM α := do
3543
x.run' {}
3644

@@ -119,41 +127,57 @@ partial def lowerCode (c : LCNF.Code) : M FnBody := do
119127
let joinPointId ← getJoinPointValue fvarId
120128
return .jmp joinPointId (← args.mapM lowerArg)
121129
| .cases cases =>
122-
-- `casesOn` for inductive predicates should have already been expanded.
123-
let .var varId := (← getFVarValue cases.discr) | unreachable!
124-
return .case cases.typeName
125-
varId
126-
(← nameToIRType cases.typeName)
127-
(← cases.alts.mapM (lowerAlt varId))
130+
if let some info ← hasTrivialStructure? cases.typeName then
131+
assert! cases.alts.size == 1
132+
let .alt ctorName ps k := cases.alts[0]! | unreachable!
133+
assert! ctorName == info.ctorName
134+
assert! info.fieldIdx < ps.size
135+
let p := ps[info.fieldIdx]!
136+
LCNF.addSubst p.fvarId (.fvar cases.discr)
137+
lowerCode k
138+
else
139+
-- `casesOn` for inductive predicates should have already been expanded.
140+
let .var varId := (← getFVarValue cases.discr) | unreachable!
141+
return .case cases.typeName
142+
varId
143+
(← nameToIRType cases.typeName)
144+
(← cases.alts.mapM (lowerAlt varId))
128145
| .return fvarId =>
129146
return .ret (← getFVarValue fvarId)
130147
| .unreach .. => return .unreachable
131148
| .fun .. => panic! "all local functions should be λ-lifted"
132149

133150
partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do
134-
match decl.value with
151+
let value ← LCNF.normLetValue decl.value
152+
match value with
135153
| .lit litValue =>
136154
let var ← bindVar decl.fvarId
137155
let ⟨litValue, type⟩ := lowerLitValue litValue
138156
return .vdecl var type (.lit litValue) (← lowerCode k)
139157
| .proj typeName i fvarId =>
140-
match (← getFVarValue fvarId) with
141-
| .var varId =>
142-
let some (.inductInfo { ctors := [ctorName], .. }) := (← Lean.getEnv).find? typeName
143-
| panic! "projection of non-structure type"
144-
let ⟨ctorInfo, fields⟩ ← getCtorLayout ctorName
145-
let ⟨result, type⟩ := lowerProj varId ctorInfo fields[i]!
146-
match result with
147-
| .expr e =>
148-
let var ← bindVar decl.fvarId
149-
return .vdecl var type e (← lowerCode k)
158+
if let some info ← hasTrivialStructure? typeName then
159+
if info.fieldIdx == i then
160+
LCNF.addSubst decl.fvarId (.fvar fvarId)
161+
else
162+
bindErased decl.fvarId
163+
lowerCode k
164+
else
165+
match (← getFVarValue fvarId) with
166+
| .var varId =>
167+
let some (.inductInfo { ctors := [ctorName], .. }) := (← Lean.getEnv).find? typeName
168+
| panic! "projection of non-structure type"
169+
let ⟨ctorInfo, fields⟩ ← getCtorLayout ctorName
170+
let ⟨result, type⟩ := lowerProj varId ctorInfo fields[i]!
171+
match result with
172+
| .expr e =>
173+
let var ← bindVar decl.fvarId
174+
return .vdecl var type e (← lowerCode k)
175+
| .erased | .void =>
176+
bindErased decl.fvarId
177+
lowerCode k
150178
| .erased =>
151179
bindErased decl.fvarId
152180
lowerCode k
153-
| .void => unreachable!
154-
| .erased =>
155-
bindErased decl.fvarId
156-
lowerCode k
157181
| .const name _ args =>
158182
let irArgs ← args.mapM lowerArg
159183
if let some decl ← findDecl name then
@@ -163,43 +187,48 @@ partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do
163187
let env ← Lean.getEnv
164188
match env.find? name with
165189
| some (.ctorInfo ctorVal) =>
166-
let type ← nameToIRType ctorVal.induct
167-
if type.isScalar then
168-
let var ← bindVar decl.fvarId
169-
return .vdecl var type (.lit (.num ctorVal.cidx)) (← lowerCode k)
170-
171-
let ⟨ctorInfo, fields⟩ ← getCtorLayout name
172-
let irArgs := irArgs.extract (start := ctorVal.numParams)
173-
if irArgs.size != fields.size then
174-
-- An overapplied constructor arises from compiler
175-
-- transformations on unreachable code
176-
return .unreachable
177-
178-
let objArgs : Array Arg ← do
179-
let mut result : Array Arg := #[]
180-
for h : i in *...fields.size do
181-
match fields[i] with
182-
| .object .. =>
183-
result := result.push irArgs[i]!
184-
| .usize .. | .scalar .. | .erased | .void => pure ()
185-
pure result
186-
let objVar ← bindVar decl.fvarId
187-
let rec lowerNonObjectFields (_ : Unit) : M FnBody :=
188-
let rec loop (i : Nat) : M FnBody := do
189-
match irArgs[i]? with
190-
| some (.var varId) =>
191-
match fields[i]! with
192-
| .usize usizeIdx =>
193-
let k ← loop (i + 1)
194-
return .uset objVar usizeIdx varId k
195-
| .scalar _ offset argType =>
196-
let k ← loop (i + 1)
197-
return .sset objVar (ctorInfo.size + ctorInfo.usize) offset varId argType k
198-
| .object .. | .erased | .void => loop (i + 1)
199-
| some .erased => loop (i + 1)
200-
| none => lowerCode k
201-
loop 0
202-
return .vdecl objVar ctorInfo.type (.ctor ctorInfo objArgs) (← lowerNonObjectFields ())
190+
if let some info ← hasTrivialStructure? ctorVal.induct then
191+
let arg := args[info.numParams + info.fieldIdx]!
192+
LCNF.addSubst decl.fvarId arg
193+
lowerCode k
194+
else
195+
let type ← nameToIRType ctorVal.induct
196+
if type.isScalar then
197+
let var ← bindVar decl.fvarId
198+
return .vdecl var type (.lit (.num ctorVal.cidx)) (← lowerCode k)
199+
200+
let ⟨ctorInfo, fields⟩ ← getCtorLayout name
201+
let irArgs := irArgs.extract (start := ctorVal.numParams)
202+
if irArgs.size != fields.size then
203+
-- An overapplied constructor arises from compiler
204+
-- transformations on unreachable code
205+
return .unreachable
206+
207+
let objArgs : Array Arg ← do
208+
let mut result : Array Arg := #[]
209+
for h : i in *...fields.size do
210+
match fields[i] with
211+
| .object .. =>
212+
result := result.push irArgs[i]!
213+
| .usize .. | .scalar .. | .erased | .void => pure ()
214+
pure result
215+
let objVar ← bindVar decl.fvarId
216+
let rec lowerNonObjectFields (_ : Unit) : M FnBody :=
217+
let rec loop (i : Nat) : M FnBody := do
218+
match irArgs[i]? with
219+
| some (.var varId) =>
220+
match fields[i]! with
221+
| .usize usizeIdx =>
222+
let k ← loop (i + 1)
223+
return .uset objVar usizeIdx varId k
224+
| .scalar _ offset argType =>
225+
let k ← loop (i + 1)
226+
return .sset objVar (ctorInfo.size + ctorInfo.usize) offset varId argType k
227+
| .object .. | .erased | .void => loop (i + 1)
228+
| some .erased => loop (i + 1)
229+
| none => lowerCode k
230+
loop 0
231+
return .vdecl objVar ctorInfo.type (.ctor ctorInfo objArgs) (← lowerNonObjectFields ())
203232
| some (.defnInfo ..) | some (.opaqueInfo ..) =>
204233
mkFap name irArgs
205234
| some (.axiomInfo ..) | .some (.quotInfo ..) | .some (.inductInfo ..) | .some (.thmInfo ..) =>

src/Lean/Compiler/IR/ToIRType.lean

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ public section
1717
namespace Lean
1818
namespace IR
1919

20-
open Lean.Compiler (LCNF.CacheExtension LCNF.isTypeFormerType LCNF.toLCNFType LCNF.toMonoType)
20+
open Lean.Compiler (LCNF.CacheExtension LCNF.isTypeFormerType LCNF.toLCNFType LCNF.toMonoType
21+
LCNF.TrivialStructureInfo LCNF.getOtherDeclBaseType LCNF.getParamTypes LCNF.instantiateForall
22+
LCNF.Irrelevant.hasTrivialStructure?)
2123

2224
def irTypeForEnum (numCtors : Nat) : IRType :=
2325
if numCtors == 1 then
@@ -34,6 +36,31 @@ def irTypeForEnum (numCtors : Nat) : IRType :=
3436
builtin_initialize irTypeExt : LCNF.CacheExtension Name IRType ←
3537
LCNF.CacheExtension.register
3638

39+
builtin_initialize trivialStructureInfoExt :
40+
LCNF.CacheExtension Name (Option LCNF.TrivialStructureInfo) ←
41+
LCNF.CacheExtension.register
42+
43+
def getRelevantCtorFields (ctorName : Name) : CoreM (Array Bool) := do
44+
let .ctorInfo info ← getConstInfo ctorName | unreachable!
45+
Meta.MetaM.run' do
46+
Meta.forallTelescopeReducing info.type fun xs _ => do
47+
let mut result := #[]
48+
for x in xs[info.numParams...*] do
49+
let type ← Meta.inferType x
50+
let cond := !(← Meta.isProp type <||> Meta.isTypeFormerType type <||> pure type.isRealWorld)
51+
result := result.push cond
52+
return result
53+
54+
/--
55+
The idea of this function is the same as in `ToMono`, however the notion of "irrelevancy" has
56+
changed because we now have the `void` type which can only be erased in impure context and thus at
57+
earliest at the conversion from mono to IR.
58+
-/
59+
def hasTrivialStructure? (declName : Name) : CoreM (Option LCNF.TrivialStructureInfo) := do
60+
let irrelevantType type :=
61+
Meta.isProp type <||> Meta.isTypeFormerType type <||> pure type.isRealWorld
62+
LCNF.Irrelevant.hasTrivialStructure? trivialStructureInfoExt irrelevantType declName
63+
3764
def nameToIRType (name : Name) : CoreM IRType := do
3865
match (← irTypeExt.find? name) with
3966
| some type => return type
@@ -86,13 +113,13 @@ private def isAnyProducingType (type : Lean.Expr) : Bool :=
86113
| .forallE _ _ b _ => isAnyProducingType b
87114
| _ => false
88115

89-
def toIRType (type : Lean.Expr) : CoreM IRType := do
116+
partial def toIRType (type : Lean.Expr) : CoreM IRType := do
90117
match type with
91-
| .const name _ => nameToIRType name
118+
| .const name _ => visitApp name #[]
92119
| .app .. =>
93120
-- All mono types are in headBeta form.
94121
let .const name _ := type.getAppFn | unreachable!
95-
nameToIRType name
122+
visitApp name type.getAppArgs
96123
| .forallE _ _ b _ =>
97124
-- Type formers are erased, but can be used polymorphically as
98125
-- an arrow type producing `lcAny`. The runtime representation of
@@ -104,6 +131,14 @@ def toIRType (type : Lean.Expr) : CoreM IRType := do
104131
return .object
105132
| .mdata _ b => toIRType b
106133
| _ => unreachable!
134+
where
135+
visitApp (declName : Name) (args : Array Lean.Expr) : CoreM IRType := do
136+
if let some info ← hasTrivialStructure? declName then
137+
let ctorType ← LCNF.getOtherDeclBaseType info.ctorName []
138+
let monoType ← LCNF.toMonoType (LCNF.getParamTypes (← LCNF.instantiateForall ctorType args[*...info.numParams]))[info.fieldIdx]!
139+
toIRType monoType
140+
else
141+
nameToIRType declName
107142

108143
inductive CtorFieldInfo where
109144
| erased

src/Lean/Compiler/LCNF.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,6 @@ public import Lean.Compiler.LCNF.Closure
4444
public import Lean.Compiler.LCNF.LambdaLifting
4545
public import Lean.Compiler.LCNF.ReduceArity
4646
public import Lean.Compiler.LCNF.Probing
47+
public import Lean.Compiler.LCNF.Irrelevant
4748

4849
public section
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/-
2+
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Henrik Böving
5+
-/
6+
module
7+
8+
prelude
9+
public import Lean.Compiler.LCNF.CompilerM
10+
import Lean.Compiler.LCNF.BaseTypes
11+
import Lean.Compiler.LCNF.Util
12+
13+
namespace Lean.Compiler.LCNF
14+
15+
/--
16+
Given a constructor, return a bitmask `m` s.t. `m[i]` is true if field `i` is
17+
computationally relevant.
18+
-/
19+
def getRelevantCtorFields (ctorName : Name) (trivialType : Expr → MetaM Bool) :
20+
CoreM (Array Bool) := do
21+
let .ctorInfo info ← getConstInfo ctorName | unreachable!
22+
Meta.MetaM.run' do
23+
Meta.forallTelescopeReducing info.type fun xs _ => do
24+
let mut result := #[]
25+
for x in xs[info.numParams...*] do
26+
let type ← Meta.inferType x
27+
result := result.push !(← trivialType type)
28+
return result
29+
30+
/--
31+
We say a structure has a trivial structure if it has not builtin support in the runtime,
32+
it has only one constructor, and this constructor has only one relevant field.
33+
-/
34+
public structure TrivialStructureInfo where
35+
ctorName : Name
36+
numParams : Nat
37+
fieldIdx : Nat
38+
deriving Inhabited, Repr
39+
40+
/--
41+
Return `some fieldIdx` if `declName` is the name of an inductive datatype s.t.
42+
- It does not have builtin support in the runtime.
43+
- It has only one constructor.
44+
- This constructor has only one computationally relevant field.
45+
-/
46+
public def Irrelevant.hasTrivialStructure?
47+
(cacheExt : CacheExtension Name (Option TrivialStructureInfo))
48+
(trivialType : Expr → MetaM Bool) (declName : Name) : CoreM (Option TrivialStructureInfo) := do
49+
match (← cacheExt.find? declName) with
50+
| some info? => return info?
51+
| none =>
52+
let info? ← fillCache
53+
cacheExt.insert declName info?
54+
return info?
55+
where fillCache : CoreM (Option TrivialStructureInfo) := do
56+
if isRuntimeBuiltinType declName then return none
57+
let .inductInfo info ← getConstInfo declName | return none
58+
if info.isUnsafe || info.isRec then return none
59+
let [ctorName] := info.ctors | return none
60+
let ctorType ← getOtherDeclBaseType ctorName []
61+
if ctorType.isErased then return none
62+
let mask ← getRelevantCtorFields ctorName trivialType
63+
let mut result := none
64+
for h : i in *...mask.size do
65+
if mask[i] then
66+
if result.isSome then return none
67+
result := some { ctorName, fieldIdx := i, numParams := info.numParams }
68+
return result
69+
70+
71+
end Lean.Compiler.LCNF

0 commit comments

Comments
 (0)