@@ -5,6 +5,7 @@ Authors: Leonardo de Moura
55-/
66prelude
77import Lean.Compiler.ExternAttr
8+ import Lean.Compiler.ImplementedByAttr
89import Lean.Compiler.LCNF.MonoTypes
910import Lean.Compiler.LCNF.InferType
1011import Lean.Compiler.NoncomputableAttr
@@ -107,6 +108,23 @@ def LetDecl.toMono (decl : LetDecl) : ToMonoM LetDecl := do
107108 let value ← decl.value.toMono decl.fvarId
108109 decl.update type value
109110
111+ def mkFieldParamsForComputedFields (ctorType : Expr) (numParams : Nat) (numNewFields : Nat)
112+ (oldFields : Array Param) : ToMonoM (Array Param) := do
113+ let mut type := ctorType
114+ for _ in [0 :numParams] do
115+ match type with
116+ | .forallE _ _ body _ =>
117+ type := body
118+ | _ => unreachable!
119+ let mut newFields := Array.emptyWithCapacity (oldFields.size + numNewFields)
120+ for _ in [0 :numNewFields] do
121+ match type with
122+ | .forallE name fieldType body _ =>
123+ newFields := newFields.push (← mkParam name fieldType false )
124+ type := body
125+ | _ => unreachable!
126+ return newFields ++ oldFields
127+
110128mutual
111129
112130partial def FunDecl.toMono (decl : FunDecl) : ToMonoM FunDecl := do
@@ -278,12 +296,30 @@ partial def Code.toMono (code : Code) : ToMonoM Code := do
278296 else if let some info ← hasTrivialStructure? c.typeName then
279297 trivialStructToMono info c
280298 else
281- let type ← toMonoType c.resultType
282- let alts ← c.alts.mapM fun alt =>
283- match alt with
284- | .default k => return alt.updateCode (← k.toMono)
285- | .alt _ ps k => return alt.updateAlt! (← ps.mapM (·.toMono)) (← k.toMono)
286- return code.updateCases! type c.discr alts
299+ let resultType ← toMonoType c.resultType
300+ let env ← getEnv
301+ let some (.inductInfo inductInfo) := env.find? c.typeName | panic! "expected inductive type"
302+ let casesOnName := mkCasesOnName inductInfo.name
303+ if (getImplementedBy? env casesOnName).isSome then
304+ -- TODO: Enforce that this is only used for computed fields.
305+ let typeName := c.typeName ++ `_impl
306+ let alts ← c.alts.mapM fun alt => do
307+ match alt with
308+ | .default k => return alt.updateCode (← k.toMono)
309+ | .alt ctorName ps k =>
310+ let implCtorName := ctorName ++ `_impl
311+ let some (.ctorInfo ctorInfo) := env.find? implCtorName | panic! "expected constructor"
312+ let numNewFields := ctorInfo.numFields - ps.size
313+ let ps ← mkFieldParamsForComputedFields ctorInfo.type ctorInfo.numParams numNewFields ps
314+ let k ← k.toMono
315+ return .alt implCtorName ps k
316+ return .cases { discr := c.discr, resultType, typeName, alts }
317+ else
318+ let alts ← c.alts.mapM fun alt =>
319+ match alt with
320+ | .default k => return alt.updateCode (← k.toMono)
321+ | .alt _ ps k => return alt.updateAlt! (← ps.mapM (·.toMono)) (← k.toMono)
322+ return code.updateCases! resultType c.discr alts
287323
288324end
289325
0 commit comments