@@ -20,17 +20,25 @@ public section
2020namespace Lean.IR
2121
2222open Lean.Compiler (LCNF.Alt LCNF.Arg LCNF.Code LCNF.Decl LCNF.DeclValue LCNF.LCtx LCNF.LetDecl
23- LCNF.LetValue LCNF.LitValue LCNF.Param LCNF.getMonoDecl?)
23+ LCNF.LetValue LCNF.LitValue LCNF.Param LCNF.getMonoDecl? LCNF.FVarSubst LCNF.MonadFVarSubst
24+ LCNF.MonadFVarSubstState LCNF.addSubst LCNF.normLetValue)
2425
2526namespace ToIR
2627
2728structure BuilderState where
2829 vars : Std.HashMap FVarId Arg := {}
2930 joinPoints : Std.HashMap FVarId JoinPointId := {}
3031 nextId : Nat := 1
32+ subst : LCNF.FVarSubst := {}
3133
3234abbrev M := StateRefT BuilderState CoreM
3335
36+ instance : LCNF.MonadFVarSubst M false where
37+ getSubst := return (← get).subst
38+
39+ instance : LCNF.MonadFVarSubstState M where
40+ modifySubst f := modify fun s => { s with subst := f s.subst }
41+
3442def M.run (x : M α) : CoreM α := do
3543 x.run' {}
3644
@@ -119,41 +127,57 @@ partial def lowerCode (c : LCNF.Code) : M FnBody := do
119127 let joinPointId ← getJoinPointValue fvarId
120128 return .jmp joinPointId (← args.mapM lowerArg)
121129 | .cases cases =>
122- -- `casesOn` for inductive predicates should have already been expanded.
123- let .var varId := (← getFVarValue cases.discr) | unreachable!
124- return .case cases.typeName
125- varId
126- (← nameToIRType cases.typeName)
127- (← cases.alts.mapM (lowerAlt varId))
130+ if let some info ← hasTrivialStructure? cases.typeName then
131+ assert! cases.alts.size == 1
132+ let .alt ctorName ps k := cases.alts[0 ]! | unreachable!
133+ assert! ctorName == info.ctorName
134+ assert! info.fieldIdx < ps.size
135+ let p := ps[info.fieldIdx]!
136+ LCNF.addSubst p.fvarId (.fvar cases.discr)
137+ lowerCode k
138+ else
139+ -- `casesOn` for inductive predicates should have already been expanded.
140+ let .var varId := (← getFVarValue cases.discr) | unreachable!
141+ return .case cases.typeName
142+ varId
143+ (← nameToIRType cases.typeName)
144+ (← cases.alts.mapM (lowerAlt varId))
128145 | .return fvarId =>
129146 return .ret (← getFVarValue fvarId)
130147 | .unreach .. => return .unreachable
131148 | .fun .. => panic! "all local functions should be λ-lifted"
132149
133150partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do
134- match decl.value with
151+ let value ← LCNF.normLetValue decl.value
152+ match value with
135153 | .lit litValue =>
136154 let var ← bindVar decl.fvarId
137155 let ⟨litValue, type⟩ := lowerLitValue litValue
138156 return .vdecl var type (.lit litValue) (← lowerCode k)
139157 | .proj typeName i fvarId =>
140- match (← getFVarValue fvarId) with
141- | .var varId =>
142- let some (.inductInfo { ctors := [ctorName], .. }) := (← Lean.getEnv).find? typeName
143- | panic! "projection of non-structure type"
144- let ⟨ctorInfo, fields⟩ ← getCtorLayout ctorName
145- let ⟨result, type⟩ := lowerProj varId ctorInfo fields[i]!
146- match result with
147- | .expr e =>
148- let var ← bindVar decl.fvarId
149- return .vdecl var type e (← lowerCode k)
158+ if let some info ← hasTrivialStructure? typeName then
159+ if info.fieldIdx == i then
160+ LCNF.addSubst decl.fvarId (.fvar fvarId)
161+ else
162+ bindErased decl.fvarId
163+ lowerCode k
164+ else
165+ match (← getFVarValue fvarId) with
166+ | .var varId =>
167+ let some (.inductInfo { ctors := [ctorName], .. }) := (← Lean.getEnv).find? typeName
168+ | panic! "projection of non-structure type"
169+ let ⟨ctorInfo, fields⟩ ← getCtorLayout ctorName
170+ let ⟨result, type⟩ := lowerProj varId ctorInfo fields[i]!
171+ match result with
172+ | .expr e =>
173+ let var ← bindVar decl.fvarId
174+ return .vdecl var type e (← lowerCode k)
175+ | .erased | .void =>
176+ bindErased decl.fvarId
177+ lowerCode k
150178 | .erased =>
151179 bindErased decl.fvarId
152180 lowerCode k
153- | .void => unreachable!
154- | .erased =>
155- bindErased decl.fvarId
156- lowerCode k
157181 | .const name _ args =>
158182 let irArgs ← args.mapM lowerArg
159183 if let some decl ← findDecl name then
@@ -163,43 +187,48 @@ partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do
163187 let env ← Lean.getEnv
164188 match env.find? name with
165189 | some (.ctorInfo ctorVal) =>
166- let type ← nameToIRType ctorVal.induct
167- if type.isScalar then
168- let var ← bindVar decl.fvarId
169- return .vdecl var type (.lit (.num ctorVal.cidx)) (← lowerCode k)
170-
171- let ⟨ctorInfo, fields⟩ ← getCtorLayout name
172- let irArgs := irArgs.extract (start := ctorVal.numParams)
173- if irArgs.size != fields.size then
174- -- An overapplied constructor arises from compiler
175- -- transformations on unreachable code
176- return .unreachable
177-
178- let objArgs : Array Arg ← do
179- let mut result : Array Arg := #[]
180- for h : i in *...fields.size do
181- match fields[i] with
182- | .object .. =>
183- result := result.push irArgs[i]!
184- | .usize .. | .scalar .. | .erased | .void => pure ()
185- pure result
186- let objVar ← bindVar decl.fvarId
187- let rec lowerNonObjectFields (_ : Unit) : M FnBody :=
188- let rec loop (i : Nat) : M FnBody := do
189- match irArgs[i]? with
190- | some (.var varId) =>
191- match fields[i]! with
192- | .usize usizeIdx =>
193- let k ← loop (i + 1 )
194- return .uset objVar usizeIdx varId k
195- | .scalar _ offset argType =>
196- let k ← loop (i + 1 )
197- return .sset objVar (ctorInfo.size + ctorInfo.usize) offset varId argType k
198- | .object .. | .erased | .void => loop (i + 1 )
199- | some .erased => loop (i + 1 )
200- | none => lowerCode k
201- loop 0
202- return .vdecl objVar ctorInfo.type (.ctor ctorInfo objArgs) (← lowerNonObjectFields ())
190+ if let some info ← hasTrivialStructure? ctorVal.induct then
191+ let arg := args[info.numParams + info.fieldIdx]!
192+ LCNF.addSubst decl.fvarId arg
193+ lowerCode k
194+ else
195+ let type ← nameToIRType ctorVal.induct
196+ if type.isScalar then
197+ let var ← bindVar decl.fvarId
198+ return .vdecl var type (.lit (.num ctorVal.cidx)) (← lowerCode k)
199+
200+ let ⟨ctorInfo, fields⟩ ← getCtorLayout name
201+ let irArgs := irArgs.extract (start := ctorVal.numParams)
202+ if irArgs.size != fields.size then
203+ -- An overapplied constructor arises from compiler
204+ -- transformations on unreachable code
205+ return .unreachable
206+
207+ let objArgs : Array Arg ← do
208+ let mut result : Array Arg := #[]
209+ for h : i in *...fields.size do
210+ match fields[i] with
211+ | .object .. =>
212+ result := result.push irArgs[i]!
213+ | .usize .. | .scalar .. | .erased | .void => pure ()
214+ pure result
215+ let objVar ← bindVar decl.fvarId
216+ let rec lowerNonObjectFields (_ : Unit) : M FnBody :=
217+ let rec loop (i : Nat) : M FnBody := do
218+ match irArgs[i]? with
219+ | some (.var varId) =>
220+ match fields[i]! with
221+ | .usize usizeIdx =>
222+ let k ← loop (i + 1 )
223+ return .uset objVar usizeIdx varId k
224+ | .scalar _ offset argType =>
225+ let k ← loop (i + 1 )
226+ return .sset objVar (ctorInfo.size + ctorInfo.usize) offset varId argType k
227+ | .object .. | .erased | .void => loop (i + 1 )
228+ | some .erased => loop (i + 1 )
229+ | none => lowerCode k
230+ loop 0
231+ return .vdecl objVar ctorInfo.type (.ctor ctorInfo objArgs) (← lowerNonObjectFields ())
203232 | some (.defnInfo ..) | some (.opaqueInfo ..) =>
204233 mkFap name irArgs
205234 | some (.axiomInfo ..) | .some (.quotInfo ..) | .some (.inductInfo ..) | .some (.thmInfo ..) =>
0 commit comments