Skip to content

Commit fa01115

Browse files
committed
fix: enable more optimizations on inductives with computed fields in the new compiler
1 parent 8aa003b commit fa01115

File tree

2 files changed

+44
-37
lines changed

2 files changed

+44
-37
lines changed

src/Lean/Compiler/LCNF/ToLCNF.lean

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -568,33 +568,6 @@ where
568568
let result := .fvar auxDecl.fvarId
569569
mkOverApplication result args casesInfo.arity
570570

571-
visitCasesImplementedBy (casesInfo : CasesInfo) (f : Expr) (args : Array Expr) : M Arg := do
572-
let mut args := args
573-
let discr := args[casesInfo.discrPos]!
574-
if discr matches .fvar _ then
575-
let typeName := casesInfo.declName.getPrefix
576-
let .inductInfo indVal ← getConstInfo typeName | unreachable!
577-
args ← args.mapIdxM fun i arg => do
578-
unless casesInfo.altsRange.start <= i && i < casesInfo.altsRange.stop do return arg
579-
let altIdx := i - casesInfo.altsRange.start
580-
let numParams := casesInfo.altNumParams[altIdx]!
581-
let ctorName := indVal.ctors[altIdx]!
582-
583-
-- We simplify `casesOn` arguments that simply reconstruct the discriminant and replace
584-
-- them with the actual discriminant. This is required for hash consing to work correctly,
585-
-- and should eventually be fixed by changing the elaborated term to use the original
586-
-- variable.
587-
Meta.MetaM.run' <| Meta.lambdaBoundedTelescope arg numParams fun paramExprs body => do
588-
let fn := body.getAppFn
589-
let args := body.getAppArgs
590-
let args := args.map fun arg =>
591-
if arg.getAppFn.constName? == some ctorName && arg.getAppArgs == paramExprs then
592-
discr
593-
else
594-
arg
595-
Meta.mkLambdaFVars paramExprs (mkAppN fn args)
596-
visitAppDefaultConst f args
597-
598571
visitCtor (arity : Nat) (e : Expr) : M Arg :=
599572
etaIfUnderApplied e arity do
600573
visitAppDefaultConst e.getAppFn e.getAppArgs
@@ -715,10 +688,7 @@ where
715688
else if declName == ``False.rec || declName == ``Empty.rec || declName == ``False.casesOn || declName == ``Empty.casesOn then
716689
visitFalseRec e
717690
else if let some casesInfo ← getCasesInfo? declName then
718-
if (getImplementedBy? (← getEnv) declName).isSome then
719-
e.withApp (visitCasesImplementedBy casesInfo)
720-
else
721-
visitCases casesInfo e
691+
visitCases casesInfo e
722692
else if let some arity ← getCtorArity? declName then
723693
visitCtor arity e
724694
else if isNoConfusion (← getEnv) declName then

src/Lean/Compiler/LCNF/ToMono.lean

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Authors: Leonardo de Moura
55
-/
66
prelude
77
import Lean.Compiler.ExternAttr
8+
import Lean.Compiler.ImplementedByAttr
89
import Lean.Compiler.LCNF.MonoTypes
910
import Lean.Compiler.LCNF.InferType
1011
import Lean.Compiler.NoncomputableAttr
@@ -107,6 +108,24 @@ def LetDecl.toMono (decl : LetDecl) : ToMonoM LetDecl := do
107108
let value ← decl.value.toMono decl.fvarId
108109
decl.update type value
109110

111+
def mkFieldParamsForComputedFields (ctorType : Expr) (numParams : Nat) (numNewFields : Nat)
112+
(oldFields : Array Param) : ToMonoM (Array Param) := do
113+
let mut type := ctorType
114+
for _ in [0:numParams] do
115+
match type with
116+
| .forallE _ _ body _ =>
117+
type := body
118+
| _ => unreachable!
119+
let mut newFields := Array.emptyWithCapacity (oldFields.size + numNewFields)
120+
for _ in [0:numNewFields] do
121+
match type with
122+
| .forallE name fieldType body _ =>
123+
let param ← mkParam name (← toMonoType fieldType) false
124+
newFields := newFields.push param
125+
type := body
126+
| _ => unreachable!
127+
return newFields ++ oldFields
128+
110129
mutual
111130

112131
partial def FunDecl.toMono (decl : FunDecl) : ToMonoM FunDecl := do
@@ -278,12 +297,30 @@ partial def Code.toMono (code : Code) : ToMonoM Code := do
278297
else if let some info ← hasTrivialStructure? c.typeName then
279298
trivialStructToMono info c
280299
else
281-
let type ← toMonoType c.resultType
282-
let alts ← c.alts.mapM fun alt =>
283-
match alt with
284-
| .default k => return alt.updateCode (← k.toMono)
285-
| .alt _ ps k => return alt.updateAlt! (← ps.mapM (·.toMono)) (← k.toMono)
286-
return code.updateCases! type c.discr alts
300+
let resultType ← toMonoType c.resultType
301+
let env ← getEnv
302+
let some (.inductInfo inductInfo) := env.find? c.typeName | panic! "expected inductive type"
303+
let casesOnName := mkCasesOnName inductInfo.name
304+
if (getImplementedBy? env casesOnName).isSome then
305+
-- TODO: Enforce that this is only used for computed fields.
306+
let typeName := c.typeName ++ `_impl
307+
let alts ← c.alts.mapM fun alt => do
308+
match alt with
309+
| .default k => return alt.updateCode (← k.toMono)
310+
| .alt ctorName ps k =>
311+
let implCtorName := ctorName ++ `_impl
312+
let some (.ctorInfo ctorInfo) := env.find? implCtorName | panic! "expected constructor"
313+
let numNewFields := ctorInfo.numFields - ps.size
314+
let ps ← mkFieldParamsForComputedFields ctorInfo.type ctorInfo.numParams numNewFields ps
315+
let k ← k.toMono
316+
return .alt implCtorName ps k
317+
return .cases { discr := c.discr, resultType, typeName, alts }
318+
else
319+
let alts ← c.alts.mapM fun alt =>
320+
match alt with
321+
| .default k => return alt.updateCode (← k.toMono)
322+
| .alt _ ps k => return alt.updateAlt! (← ps.mapM (·.toMono)) (← k.toMono)
323+
return code.updateCases! resultType c.discr alts
287324

288325
end
289326

0 commit comments

Comments
 (0)