@@ -85,19 +85,23 @@ def unfoldNamedPattern (e : Expr) : MetaM Expr := do
8585
8686 This can be used to use the alternative of a match expression in its splitter.
8787 -/
88- partial def forallAltVarsTelescope (altType : Expr) (altNumParams numDiscrEqs : Nat )
88+ partial def forallAltVarsTelescope (altType : Expr) (altInfo : AltParamInfo )
8989 (k : (patVars : Array Expr) → (args : Array Expr) → (mask : Array Bool) → (type : Expr) → MetaM α) : MetaM α := do
90- go #[] #[] #[] 0 altType
90+ assert! altInfo.numOverlaps = 0
91+ if altInfo.hasUnitThunk then
92+ let type ← whnfForall altType
93+ let type ← Match.unfoldNamedPattern type
94+ let type ← instantiateForall type #[mkConst ``Unit.unit]
95+ k #[] #[mkConst ``Unit.unit] #[false ] type
96+ else
97+ go #[] #[] #[] 0 altType
9198where
9299 go (ys : Array Expr) (args : Array Expr) (mask : Array Bool) (i : Nat) (type : Expr) : MetaM α := do
93100 let type ← whnfForall type
94- if i < altNumParams - numDiscrEqs then
95- let Expr.forallE n d b .. := type
96- | throwError "expecting {altNumParams} parameters, excluding {numDiscrEqs} equalities, but found type{indentExpr altType}"
97101
98- -- Handle the special case of `Unit` parameters.
99- if i = 0 && altNumParams - numDiscrEqs = 1 && d.isConstOf ``Unit && !b.hasLooseBVars then
100- return ← k #[] #[mkConst ``Unit.unit] #[ false ] b
102+ if i < altInfo.numFields then
103+ let Expr.forallE n d b .. := type
104+ | throwError "expecting {altInfo.numFields} parameters, but found type{indentExpr altType}"
101105
102106 let d ← Match.unfoldNamedPattern d
103107 withLocalDeclD n d fun y => do
@@ -145,17 +149,17 @@ where
145149
146150 This can be used to use the alternative of a match expression in its splitter.
147151 -/
148- partial def forallAltTelescope (altType : Expr) (altNumParams numDiscrEqs : Nat)
152+ partial def forallAltTelescope (altType : Expr) (altInfo : AltParamInfo) ( numDiscrEqs : Nat)
149153 (k : (ys : Array Expr) → (eqs : Array Expr) → (args : Array Expr) → (mask : Array Bool) → (type : Expr) → MetaM α)
150154 : MetaM α := do
151- forallAltVarsTelescope altType altNumParams numDiscrEqs fun ys args mask altType => do
155+ forallAltVarsTelescope altType altInfo fun ys args mask altType => do
152156 go ys #[] args mask 0 altType
153157where
154158 go (ys : Array Expr) (eqs : Array Expr) (args : Array Expr) (mask : Array Bool) (i : Nat) (type : Expr) : MetaM α := do
155159 let type ← whnfForall type
156160 if i < numDiscrEqs then
157161 let Expr.forallE n d b .. := type
158- | throwError "expecting {altNumParams} parameters, including { numDiscrEqs} equalities, but found type{indentExpr altType}"
162+ | throwError "expecting {numDiscrEqs} equalities, but found type{indentExpr altType}"
159163 let arg ← if let some (_, _, rhs) ← matchEq? d then
160164 mkEqRefl rhs
161165 else if let some (_, _, _, rhs) ← matchHEq? d then
@@ -270,16 +274,17 @@ private partial def withSplitterAlts (altTypes : Array Expr) (f : Array Expr →
270274 f xs
271275 go 0 #[]
272276
273- private abbrev ConvertM := ReaderT (FVarIdMap (Expr × Nat × Array Bool)) $ StateRefT (Array MVarId) MetaM
277+ private abbrev ConvertM := ReaderT (FVarIdMap (Expr × AltParamInfo × Array Bool)) $ StateRefT (Array MVarId) MetaM
274278
275279/--
276280 Construct a proof for the splitter generated by `mkEquationsFor`.
277281 The proof uses the definition of the `match`-declaration as a template (argument `template`).
278282 - `alts` are free variables corresponding to alternatives of the `match` auxiliary declaration being processed.
279283 - `altNews` are the new free variables which contains additional hypotheses that ensure they are only used
280- when the previous overlapping alternatives are not applicable. -/
284+ when the previous overlapping alternatives are not applicable.
285+ - `altInfos` refers to the splitter -/
281286private partial def mkSplitterProof (matchDeclName : Name) (template : Expr) (alts altsNew : Array Expr)
282- (altsNewNumParams : Array Nat ) (altArgMasks : Array (Array Bool)) (numDiscrEqs : Nat ) : MetaM Expr := do
287+ (altInfos : Array AltParamInfo ) (altArgMasks : Array (Array Bool)) : MetaM Expr := do
283288 trace[Meta.Match.matchEqs] "proof template: {template}"
284289 let map := mkMap
285290 let (proof, mvarIds) ← convertTemplate template |>.run map |>.run #[]
@@ -289,10 +294,10 @@ private partial def mkSplitterProof (matchDeclName : Name) (template : Expr) (al
289294 solveOverlap mvarId
290295 instantiateMVars proof
291296where
292- mkMap : FVarIdMap (Expr × Nat × Array Bool) := Id.run do
297+ mkMap : FVarIdMap (Expr × AltParamInfo × Array Bool) := Id.run do
293298 let mut m := {}
294- for alt in alts, altNew in altsNew, numParams in altsNewNumParams , argMask in altArgMasks do
295- m := m.insert alt.fvarId! (altNew, numParams , argMask)
299+ for alt in alts, altNew in altsNew, altInfo in altInfos , argMask in altArgMasks do
300+ m := m.insert alt.fvarId! (altNew, altInfo , argMask)
296301 return m
297302
298303 trimFalseTrail (argMask : Array Bool) : Array Bool :=
@@ -452,9 +457,9 @@ where
452457 return .done (← convertCastEqRec e)
453458 else
454459 let Expr.fvar fvarId .. := e.getAppFn | return .continue
455- let some (altNew, numParams , argMask) := (← read).get? fvarId | return .continue
456- trace[Meta.Match.matchEqs] ">> argMask: {argMask}, numParams : {numParams }, e: {e}, alsNew: {altNew}, "
457- if numParams + numDiscrEqs = 0 then
460+ let some (altNew, altParamInfo , argMask) := (← read).get? fvarId | return .continue
461+ trace[Meta.Match.matchEqs] ">> argMask: {argMask}, altParamInfo : {repr altParamInfo }, e: {e}, alsNew: {altNew}, "
462+ if altParamInfo.hasUnitThunk then
458463 let eNew := mkApp altNew (mkConst ``Unit.unit)
459464 return TransformStep.done eNew
460465 let mut newArgs := #[]
@@ -465,8 +470,7 @@ where
465470 if includeArg then
466471 newArgs := newArgs.push arg
467472 let eNew := mkAppN altNew newArgs
468- /- Recall that `numParams` does not include the `numDiscrEqs` equalities associated with discriminants of the form `h : discr`. -/
469- let (mvars, _, _) ← forallMetaBoundedTelescope (← inferType eNew) (numParams - newArgs.size) (kind := MetavarKind.syntheticOpaque)
473+ let (mvars, _, _) ← forallMetaBoundedTelescope (← inferType eNew) altParamInfo.numOverlaps (kind := MetavarKind.syntheticOpaque)
470474 modify fun s => s ++ (mvars.map (·.mvarId!))
471475 let eNew := mkAppN eNew mvars
472476 return TransformStep.done eNew
@@ -528,14 +532,14 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
528532 let mut notAlts := #[]
529533 let mut idx := 1
530534 let mut splitterAltTypes := #[]
531- let mut splitterAltNumParams := #[]
535+ let mut splitterAltInfos := #[]
532536 let mut altArgMasks := #[] -- masks produced by `forallAltTelescope`
533537 for i in *...alts.size do
534- let altNumParams := matchInfo.altNumParams [i]!
538+ let altInfo := matchInfo.altInfos [i]!
535539 let thmName := Name.str baseName eqnThmSuffixBase |>.appendIndexAfter idx
536540 eqnNames := eqnNames.push thmName
537- let (notAlt, splitterAltType, splitterAltNumParam , argMask) ←
538- forallAltTelescope (← inferType alts[i]!) altNumParams numDiscrEqs
541+ let (notAlt, splitterAltType, splitterAltInfo , argMask) ←
542+ forallAltTelescope (← inferType alts[i]!) altInfo numDiscrEqs
539543 fun ys eqs rhsArgs argMask altResultType => do
540544 let patterns := altResultType.getAppArgs
541545 let mut hs := #[]
@@ -548,12 +552,13 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
548552 let splitterAltType ← mkForallFVars eqs altResultType
549553 let splitterAltType ← mkArrowN hs splitterAltType
550554 let splitterAltType ← mkForallFVars ys splitterAltType
551- let splitterAltType ← if splitterAltType == altResultType then
555+ let hasUnitThunk := splitterAltType == altResultType
556+ let splitterAltType ← if hasUnitThunk then
552557 mkArrow (mkConst ``Unit) splitterAltType
553558 else
554559 pure splitterAltType
555560 let splitterAltType ← unfoldNamedPattern splitterAltType
556- let splitterAltNumParam := hs .size + ys .size
561+ let splitterAltInfo := { numFields := ys .size, numOverlaps := hs .size, hasUnitThunk }
557562 -- Create a proposition for representing terms that do not match `patterns`
558563 let mut notAlt := mkConst ``False
559564 for discr in discrs.toArray.reverse, pattern in patterns.reverse do
@@ -576,10 +581,10 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
576581 type := thmType
577582 value := thmVal
578583 }
579- return (notAlt, splitterAltType, splitterAltNumParam , argMask)
584+ return (notAlt, splitterAltType, splitterAltInfo , argMask)
580585 notAlts := notAlts.push notAlt
581586 splitterAltTypes := splitterAltTypes.push splitterAltType
582- splitterAltNumParams := splitterAltNumParams .push splitterAltNumParam
587+ splitterAltInfos := splitterAltInfos .push splitterAltInfo
583588 altArgMasks := altArgMasks.push argMask
584589 trace[Meta.Match.matchEqs] "splitterAltType: {splitterAltType}"
585590 idx := idx + 1
@@ -595,7 +600,7 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
595600 let template := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ discrs ++ alts)
596601 let template ← deltaExpand template (· == constInfo.name)
597602 let template := template.headBeta
598- mkLambdaFVars splitterParams (← mkSplitterProof matchDeclName template alts altsNew splitterAltNumParams altArgMasks numDiscrEqs )
603+ mkLambdaFVars splitterParams (← mkSplitterProof matchDeclName template alts altsNew splitterAltInfos altArgMasks)
599604 addAndCompile <| Declaration.defnDecl {
600605 name := splitterName
601606 levelParams := constInfo.levelParams
@@ -605,7 +610,8 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
605610 safety := .safe
606611 }
607612 setInlineAttribute splitterName
608- let result := { eqnNames, splitterName, splitterAltNumParams }
613+ let splitterMatchInfo := { matchInfo with altInfos := splitterAltInfos }
614+ let result := { eqnNames, splitterName, splitterMatchInfo }
609615 registerMatchEqns matchDeclName result
610616
611617/- We generate the equations and splitter on demand, and do not save them on .olean files. -/
@@ -651,12 +657,12 @@ where go baseName := withConfig (fun c => { c with etaStruct := .none }) do
651657 let mut notAlts := #[]
652658 let mut idx := 1
653659 for i in *...alts.size do
654- let altNumParams := matchInfo.altNumParams [i]!
660+ let altInfo := matchInfo.altInfos [i]!
655661 let thmName := (Name.str baseName congrEqnThmSuffixBase).appendIndexAfter idx
656662 eqnNames := eqnNames.push thmName
657663 let notAlt ← do
658664 let alt := alts[i]!
659- Match.forallAltVarsTelescope (← inferType alt) altNumParams numDiscrEqs fun altVars args _mask altResultType => do
665+ Match.forallAltVarsTelescope (← inferType alt) altInfo fun altVars args _mask altResultType => do
660666 let patterns ← forallTelescope altResultType fun _ t => pure t.getAppArgs
661667 let mut heqsTypes := #[]
662668 assert! patterns.size == discrs.size
0 commit comments