@@ -5,6 +5,7 @@ Authors: Leonardo de Moura
55-/
66prelude
77import Lean.Compiler.ExternAttr
8+ import Lean.Compiler.ImplementedByAttr
89import Lean.Compiler.LCNF.MonoTypes
910import Lean.Compiler.LCNF.InferType
1011import Lean.Compiler.NoncomputableAttr
@@ -107,6 +108,24 @@ def LetDecl.toMono (decl : LetDecl) : ToMonoM LetDecl := do
107108 let value ← decl.value.toMono decl.fvarId
108109 decl.update type value
109110
111+ def mkFieldParamsForComputedFields (ctorType : Expr) (numParams : Nat) (numNewFields : Nat)
112+ (oldFields : Array Param) : ToMonoM (Array Param) := do
113+ let mut type := ctorType
114+ for _ in [0 :numParams] do
115+ match type with
116+ | .forallE _ _ body _ =>
117+ type := body
118+ | _ => unreachable!
119+ let mut newFields := Array.emptyWithCapacity (oldFields.size + numNewFields)
120+ for _ in [0 :numNewFields] do
121+ match type with
122+ | .forallE name fieldType body _ =>
123+ let param ← mkParam name (← toMonoType fieldType) false
124+ newFields := newFields.push param
125+ type := body
126+ | _ => unreachable!
127+ return newFields ++ oldFields
128+
110129mutual
111130
112131partial def FunDecl.toMono (decl : FunDecl) : ToMonoM FunDecl := do
@@ -278,12 +297,30 @@ partial def Code.toMono (code : Code) : ToMonoM Code := do
278297 else if let some info ← hasTrivialStructure? c.typeName then
279298 trivialStructToMono info c
280299 else
281- let type ← toMonoType c.resultType
282- let alts ← c.alts.mapM fun alt =>
283- match alt with
284- | .default k => return alt.updateCode (← k.toMono)
285- | .alt _ ps k => return alt.updateAlt! (← ps.mapM (·.toMono)) (← k.toMono)
286- return code.updateCases! type c.discr alts
300+ let resultType ← toMonoType c.resultType
301+ let env ← getEnv
302+ let some (.inductInfo inductInfo) := env.find? c.typeName | panic! "expected inductive type"
303+ let casesOnName := mkCasesOnName inductInfo.name
304+ if (getImplementedBy? env casesOnName).isSome then
305+ -- TODO: Enforce that this is only used for computed fields.
306+ let typeName := c.typeName ++ `_impl
307+ let alts ← c.alts.mapM fun alt => do
308+ match alt with
309+ | .default k => return alt.updateCode (← k.toMono)
310+ | .alt ctorName ps k =>
311+ let implCtorName := ctorName ++ `_impl
312+ let some (.ctorInfo ctorInfo) := env.find? implCtorName | panic! "expected constructor"
313+ let numNewFields := ctorInfo.numFields - ps.size
314+ let ps ← mkFieldParamsForComputedFields ctorInfo.type ctorInfo.numParams numNewFields ps
315+ let k ← k.toMono
316+ return .alt implCtorName ps k
317+ return .cases { discr := c.discr, resultType, typeName, alts }
318+ else
319+ let alts ← c.alts.mapM fun alt =>
320+ match alt with
321+ | .default k => return alt.updateCode (← k.toMono)
322+ | .alt _ ps k => return alt.updateAlt! (← ps.mapM (·.toMono)) (← k.toMono)
323+ return code.updateCases! resultType c.discr alts
287324
288325end
289326
0 commit comments