@@ -19,17 +19,23 @@ def findStructCtorInfo? (typeName : Name) : CoreM (Option ConstructorVal) := do
1919 let some (.ctorInfo ctorInfo) := (← getEnv).find? ctorName | return none
2020 return ctorInfo
2121
22- def mkFieldParamsForCtorType (e : Expr) (numParams : Nat): CompilerM (Array Param) := do
23- let rec loop (params : Array Param) (e : Expr) (numParams : Nat): CompilerM (Array Param) := do
24- match e with
25- | .forallE name type body _ =>
26- if numParams == 0 then
27- let param ← mkParam name (← toMonoType type) false
28- loop (params.push param) body numParams
29- else
30- loop params body (numParams - 1 )
31- | _ => return params
32- loop #[] e numParams
22+ def mkFieldParamsForCtorType (ctorType : Expr) (numParams : Nat) (numFields : Nat)
23+ : CompilerM (Array Param) := do
24+ let mut type := ctorType
25+ for _ in [0 :numParams] do
26+ match type with
27+ | .forallE _ _ body _ =>
28+ type := body
29+ | _ => unreachable!
30+ let mut fields := Array.emptyWithCapacity numFields
31+ for _ in [0 :numFields] do
32+ match type with
33+ | .forallE name fieldType body _ =>
34+ let param ← mkParam name (← toMonoType fieldType) false
35+ fields := fields.push param
36+ type := body
37+ | _ => unreachable!
38+ return fields
3339
3440structure State where
3541 projMap : Std.HashMap FVarId (Array FVarId) := {}
@@ -57,8 +63,7 @@ partial def visitCode (code : Code) : M Code := do
5763 visitCode k
5864 else
5965 let some ctorInfo ← findStructCtorInfo? typeName | panic! "expected struct constructor"
60- let params ← mkFieldParamsForCtorType ctorInfo.type ctorInfo.numParams
61- assert! params.size == ctorInfo.numFields
66+ let params ← mkFieldParamsForCtorType ctorInfo.type ctorInfo.numParams ctorInfo.numFields
6267 let fvars := params.map (·.fvarId)
6368 modify fun s => { s with projMap := s.projMap.insert base fvars,
6469 fvarMap := s.fvarMap.insert decl.fvarId fvars[i]! }
0 commit comments