Skip to content

Commit 5cc0a10

Browse files
authored
refactor: use Match.AltParamInfo also for splitters (#11261)
This PR continues the homogenization between matchers and splitters, following up on #11256. In particular it removes the ambiguity whether `numParams` includes the `discrEqns` or not.
1 parent 1b6fba4 commit 5cc0a10

File tree

5 files changed

+60
-49
lines changed

5 files changed

+60
-49
lines changed

src/Lean/Meta/Match/MatchEqs.lean

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -85,19 +85,23 @@ def unfoldNamedPattern (e : Expr) : MetaM Expr := do
8585
8686
This can be used to use the alternative of a match expression in its splitter.
8787
-/
88-
partial def forallAltVarsTelescope (altType : Expr) (altNumParams numDiscrEqs : Nat)
88+
partial def forallAltVarsTelescope (altType : Expr) (altInfo : AltParamInfo)
8989
(k : (patVars : Array Expr) → (args : Array Expr) → (mask : Array Bool) → (type : Expr) → MetaM α) : MetaM α := do
90-
go #[] #[] #[] 0 altType
90+
assert! altInfo.numOverlaps = 0
91+
if altInfo.hasUnitThunk then
92+
let type ← whnfForall altType
93+
let type ← Match.unfoldNamedPattern type
94+
let type ← instantiateForall type #[mkConst ``Unit.unit]
95+
k #[] #[mkConst ``Unit.unit] #[false] type
96+
else
97+
go #[] #[] #[] 0 altType
9198
where
9299
go (ys : Array Expr) (args : Array Expr) (mask : Array Bool) (i : Nat) (type : Expr) : MetaM α := do
93100
let type ← whnfForall type
94-
if i < altNumParams - numDiscrEqs then
95-
let Expr.forallE n d b .. := type
96-
| throwError "expecting {altNumParams} parameters, excluding {numDiscrEqs} equalities, but found type{indentExpr altType}"
97101

98-
-- Handle the special case of `Unit` parameters.
99-
if i = 0 && altNumParams - numDiscrEqs = 1 && d.isConstOf ``Unit && !b.hasLooseBVars then
100-
return ← k #[] #[mkConst ``Unit.unit] #[false] b
102+
if i < altInfo.numFields then
103+
let Expr.forallE n d b .. := type
104+
| throwError "expecting {altInfo.numFields} parameters, but found type{indentExpr altType}"
101105

102106
let d ← Match.unfoldNamedPattern d
103107
withLocalDeclD n d fun y => do
@@ -145,17 +149,17 @@ where
145149
146150
This can be used to use the alternative of a match expression in its splitter.
147151
-/
148-
partial def forallAltTelescope (altType : Expr) (altNumParams numDiscrEqs : Nat)
152+
partial def forallAltTelescope (altType : Expr) (altInfo : AltParamInfo) (numDiscrEqs : Nat)
149153
(k : (ys : Array Expr) → (eqs : Array Expr) → (args : Array Expr) → (mask : Array Bool) → (type : Expr) → MetaM α)
150154
: MetaM α := do
151-
forallAltVarsTelescope altType altNumParams numDiscrEqs fun ys args mask altType => do
155+
forallAltVarsTelescope altType altInfo fun ys args mask altType => do
152156
go ys #[] args mask 0 altType
153157
where
154158
go (ys : Array Expr) (eqs : Array Expr) (args : Array Expr) (mask : Array Bool) (i : Nat) (type : Expr) : MetaM α := do
155159
let type ← whnfForall type
156160
if i < numDiscrEqs then
157161
let Expr.forallE n d b .. := type
158-
| throwError "expecting {altNumParams} parameters, including {numDiscrEqs} equalities, but found type{indentExpr altType}"
162+
| throwError "expecting {numDiscrEqs} equalities, but found type{indentExpr altType}"
159163
let arg ← if let some (_, _, rhs) ← matchEq? d then
160164
mkEqRefl rhs
161165
else if let some (_, _, _, rhs) ← matchHEq? d then
@@ -270,16 +274,17 @@ private partial def withSplitterAlts (altTypes : Array Expr) (f : Array Expr →
270274
f xs
271275
go 0 #[]
272276

273-
private abbrev ConvertM := ReaderT (FVarIdMap (Expr × Nat × Array Bool)) $ StateRefT (Array MVarId) MetaM
277+
private abbrev ConvertM := ReaderT (FVarIdMap (Expr × AltParamInfo × Array Bool)) $ StateRefT (Array MVarId) MetaM
274278

275279
/--
276280
Construct a proof for the splitter generated by `mkEquationsFor`.
277281
The proof uses the definition of the `match`-declaration as a template (argument `template`).
278282
- `alts` are free variables corresponding to alternatives of the `match` auxiliary declaration being processed.
279283
- `altNews` are the new free variables which contains additional hypotheses that ensure they are only used
280-
when the previous overlapping alternatives are not applicable. -/
284+
when the previous overlapping alternatives are not applicable.
285+
- `altInfos` refers to the splitter -/
281286
private partial def mkSplitterProof (matchDeclName : Name) (template : Expr) (alts altsNew : Array Expr)
282-
(altsNewNumParams : Array Nat) (altArgMasks : Array (Array Bool)) (numDiscrEqs : Nat) : MetaM Expr := do
287+
(altInfos : Array AltParamInfo) (altArgMasks : Array (Array Bool)) : MetaM Expr := do
283288
trace[Meta.Match.matchEqs] "proof template: {template}"
284289
let map := mkMap
285290
let (proof, mvarIds) ← convertTemplate template |>.run map |>.run #[]
@@ -289,10 +294,10 @@ private partial def mkSplitterProof (matchDeclName : Name) (template : Expr) (al
289294
solveOverlap mvarId
290295
instantiateMVars proof
291296
where
292-
mkMap : FVarIdMap (Expr × Nat × Array Bool) := Id.run do
297+
mkMap : FVarIdMap (Expr × AltParamInfo × Array Bool) := Id.run do
293298
let mut m := {}
294-
for alt in alts, altNew in altsNew, numParams in altsNewNumParams, argMask in altArgMasks do
295-
m := m.insert alt.fvarId! (altNew, numParams, argMask)
299+
for alt in alts, altNew in altsNew, altInfo in altInfos, argMask in altArgMasks do
300+
m := m.insert alt.fvarId! (altNew, altInfo, argMask)
296301
return m
297302

298303
trimFalseTrail (argMask : Array Bool) : Array Bool :=
@@ -452,9 +457,9 @@ where
452457
return .done (← convertCastEqRec e)
453458
else
454459
let Expr.fvar fvarId .. := e.getAppFn | return .continue
455-
let some (altNew, numParams, argMask) := (← read).get? fvarId | return .continue
456-
trace[Meta.Match.matchEqs] ">> argMask: {argMask}, numParams: {numParams}, e: {e}, alsNew: {altNew}, "
457-
if numParams + numDiscrEqs = 0 then
460+
let some (altNew, altParamInfo, argMask) := (← read).get? fvarId | return .continue
461+
trace[Meta.Match.matchEqs] ">> argMask: {argMask}, altParamInfo: {repr altParamInfo}, e: {e}, alsNew: {altNew}, "
462+
if altParamInfo.hasUnitThunk then
458463
let eNew := mkApp altNew (mkConst ``Unit.unit)
459464
return TransformStep.done eNew
460465
let mut newArgs := #[]
@@ -465,8 +470,7 @@ where
465470
if includeArg then
466471
newArgs := newArgs.push arg
467472
let eNew := mkAppN altNew newArgs
468-
/- Recall that `numParams` does not include the `numDiscrEqs` equalities associated with discriminants of the form `h : discr`. -/
469-
let (mvars, _, _) ← forallMetaBoundedTelescope (← inferType eNew) (numParams - newArgs.size) (kind := MetavarKind.syntheticOpaque)
473+
let (mvars, _, _) ← forallMetaBoundedTelescope (← inferType eNew) altParamInfo.numOverlaps (kind := MetavarKind.syntheticOpaque)
470474
modify fun s => s ++ (mvars.map (·.mvarId!))
471475
let eNew := mkAppN eNew mvars
472476
return TransformStep.done eNew
@@ -528,14 +532,14 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
528532
let mut notAlts := #[]
529533
let mut idx := 1
530534
let mut splitterAltTypes := #[]
531-
let mut splitterAltNumParams := #[]
535+
let mut splitterAltInfos := #[]
532536
let mut altArgMasks := #[] -- masks produced by `forallAltTelescope`
533537
for i in *...alts.size do
534-
let altNumParams := matchInfo.altNumParams[i]!
538+
let altInfo := matchInfo.altInfos[i]!
535539
let thmName := Name.str baseName eqnThmSuffixBase |>.appendIndexAfter idx
536540
eqnNames := eqnNames.push thmName
537-
let (notAlt, splitterAltType, splitterAltNumParam, argMask) ←
538-
forallAltTelescope (← inferType alts[i]!) altNumParams numDiscrEqs
541+
let (notAlt, splitterAltType, splitterAltInfo, argMask) ←
542+
forallAltTelescope (← inferType alts[i]!) altInfo numDiscrEqs
539543
fun ys eqs rhsArgs argMask altResultType => do
540544
let patterns := altResultType.getAppArgs
541545
let mut hs := #[]
@@ -548,12 +552,13 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
548552
let splitterAltType ← mkForallFVars eqs altResultType
549553
let splitterAltType ← mkArrowN hs splitterAltType
550554
let splitterAltType ← mkForallFVars ys splitterAltType
551-
let splitterAltType ← if splitterAltType == altResultType then
555+
let hasUnitThunk := splitterAltType == altResultType
556+
let splitterAltType ← if hasUnitThunk then
552557
mkArrow (mkConst ``Unit) splitterAltType
553558
else
554559
pure splitterAltType
555560
let splitterAltType ← unfoldNamedPattern splitterAltType
556-
let splitterAltNumParam := hs.size + ys.size
561+
let splitterAltInfo := { numFields := ys.size, numOverlaps := hs.size, hasUnitThunk }
557562
-- Create a proposition for representing terms that do not match `patterns`
558563
let mut notAlt := mkConst ``False
559564
for discr in discrs.toArray.reverse, pattern in patterns.reverse do
@@ -576,10 +581,10 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
576581
type := thmType
577582
value := thmVal
578583
}
579-
return (notAlt, splitterAltType, splitterAltNumParam, argMask)
584+
return (notAlt, splitterAltType, splitterAltInfo, argMask)
580585
notAlts := notAlts.push notAlt
581586
splitterAltTypes := splitterAltTypes.push splitterAltType
582-
splitterAltNumParams := splitterAltNumParams.push splitterAltNumParam
587+
splitterAltInfos := splitterAltInfos.push splitterAltInfo
583588
altArgMasks := altArgMasks.push argMask
584589
trace[Meta.Match.matchEqs] "splitterAltType: {splitterAltType}"
585590
idx := idx + 1
@@ -595,7 +600,7 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
595600
let template := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ discrs ++ alts)
596601
let template ← deltaExpand template (· == constInfo.name)
597602
let template := template.headBeta
598-
mkLambdaFVars splitterParams (← mkSplitterProof matchDeclName template alts altsNew splitterAltNumParams altArgMasks numDiscrEqs)
603+
mkLambdaFVars splitterParams (← mkSplitterProof matchDeclName template alts altsNew splitterAltInfos altArgMasks)
599604
addAndCompile <| Declaration.defnDecl {
600605
name := splitterName
601606
levelParams := constInfo.levelParams
@@ -605,7 +610,8 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
605610
safety := .safe
606611
}
607612
setInlineAttribute splitterName
608-
let result := { eqnNames, splitterName, splitterAltNumParams }
613+
let splitterMatchInfo := { matchInfo with altInfos := splitterAltInfos }
614+
let result := { eqnNames, splitterName, splitterMatchInfo }
609615
registerMatchEqns matchDeclName result
610616

611617
/- We generate the equations and splitter on demand, and do not save them on .olean files. -/
@@ -651,12 +657,12 @@ where go baseName := withConfig (fun c => { c with etaStruct := .none }) do
651657
let mut notAlts := #[]
652658
let mut idx := 1
653659
for i in *...alts.size do
654-
let altNumParams := matchInfo.altNumParams[i]!
660+
let altInfo := matchInfo.altInfos[i]!
655661
let thmName := (Name.str baseName congrEqnThmSuffixBase).appendIndexAfter idx
656662
eqnNames := eqnNames.push thmName
657663
let notAlt ← do
658664
let alt := alts[i]!
659-
Match.forallAltVarsTelescope (← inferType alt) altNumParams numDiscrEqs fun altVars args _mask altResultType => do
665+
Match.forallAltVarsTelescope (← inferType alt) altInfo fun altVars args _mask altResultType => do
660666
let patterns ← forallTelescope altResultType fun _ t => pure t.getAppArgs
661667
let mut heqsTypes := #[]
662668
assert! patterns.size == discrs.size

src/Lean/Meta/Match/MatchEqsExt.lean

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ module
77

88
prelude
99
public import Lean.Meta.Basic
10+
public import Lean.Meta.Match.Basic
1011
import Lean.Meta.Eqns
1112

1213
public section
@@ -16,8 +17,8 @@ namespace Lean.Meta.Match
1617
structure MatchEqns where
1718
eqnNames : Array Name
1819
splitterName : Name
19-
splitterAltNumParams : Array Nat
20-
deriving Inhabited, Repr
20+
splitterMatchInfo : MatcherInfo
21+
deriving Inhabited, Repr
2122

2223
def MatchEqns.size (e : MatchEqns) : Nat :=
2324
e.eqnNames.size

src/Lean/Meta/Match/MatcherApp/Transform.lean

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,10 @@ def withUserNames {n} [MonadControlT MetaM n] [Monad n]
190190
-/
191191
private def forallAltTelescope'
192192
{n} [Monad n] [MonadControlT MetaM n]
193-
{α} (origAltType : Expr) (numParams numDiscrEqs : Nat)
193+
{α} (origAltType : Expr) (altInfo : Match.AltParamInfo)
194194
(k : Array Expr → Array Expr → n α) : n α := do
195195
map2MetaM (fun k =>
196-
Match.forallAltVarsTelescope origAltType numParams numDiscrEqs
196+
Match.forallAltVarsTelescope origAltType altInfo
197197
fun ys args _mask _bodyType => k ys args
198198
) k
199199

@@ -308,28 +308,30 @@ def transform
308308
let mut alts' := #[]
309309
for altIdx in *...matcherApp.alts.size,
310310
alt in matcherApp.alts,
311-
numParams in matcherApp.altNumParams,
312-
splitterNumParams in matchEqns.splitterAltNumParams,
311+
altInfo in matcherApp.altInfos,
312+
splitterAltInfo in matchEqns.splitterMatchInfo.altInfos,
313313
origAltType in origAltTypes,
314314
altType in altTypes do
315-
let alt' ← forallAltTelescope' origAltType (numParams - numDiscrEqs) 0 fun ys args => do
315+
assert! altInfo.numOverlaps = 0
316+
let alt' ← forallAltTelescope' origAltType altInfo fun ys args => do
317+
assert! ys.size == splitterAltInfo.numFields
316318
let altType ← instantiateForall altType ys
317319
-- Look past the thunking unit parameter, if present
318-
let altType ← if splitterNumParams + numDiscrEqs = 0 then
320+
let altType ← if altInfo.hasUnitThunk then
319321
instantiateForall altType #[mkConst ``Unit.unit]
320322
else
321323
pure altType
322324
-- The splitter inserts its extra parameters after the first ys.size parameters, before
323325
-- the parameters for the numDiscrEqs
324-
let alt' ← forallBoundedTelescope altType (splitterNumParams - ys.size) fun ys2 altType => do
326+
let alt' ← forallBoundedTelescope altType splitterAltInfo.numOverlaps fun ys2 altType => do
325327
forallBoundedTelescope altType numDiscrEqs fun ys3 altType => do
326328
forallBoundedTelescope altType extraEqualities fun ys4 altType => do
327329
let altParams := args ++ ys3
328330
let alt ← try instantiateLambda alt altParams
329331
catch _ => throwError "unexpected matcher application, insufficient number of parameters in alternative"
330332
let alt' ← onAlt altIdx altType altParams alt
331333
mkLambdaFVars (ys ++ ys2 ++ ys3 ++ ys4) alt'
332-
let alt' ← if splitterNumParams + numDiscrEqs = 0 then
334+
let alt' ← if splitterAltInfo.hasUnitThunk then
333335
-- The splitter expects a thunked alternative, but we don't want the `x : Unit` to be in
334336
-- the context (e.g. in functional induction), so use Function.const rather than a lambda
335337
mkAppM ``Function.const #[mkConst ``Unit, alt']

src/Lean/Meta/Match/MatcherInfo.lean

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ namespace Match
1616
structure DiscrInfo where
1717
/-- `some h` if the discriminant is annotated with `h:` -/
1818
hName? : Option Name := none
19-
deriving Inhabited
19+
deriving Inhabited, Repr
2020

2121

2222
structure Overlaps where
2323
map : Std.HashMap Nat (Std.TreeSet Nat) := {}
24+
deriving Inhabited, Repr
2425

2526
def Overlaps.insert (o : Overlaps) (overlapping overlapped : Nat) : Overlaps where
2627
map := o.map.alter overlapped fun s? => some ((s?.getD {}).insert overlapping)
@@ -40,7 +41,7 @@ structure AltParamInfo where
4041
numOverlaps : Nat
4142
/-- Whether this alternatie has an artifcial `Unit` parameter -/
4243
hasUnitThunk : Bool
43-
deriving Inhabited
44+
deriving Inhabited, Repr
4445

4546
/--
4647
A "matcher" auxiliary declaration has the following structure:
@@ -63,6 +64,7 @@ structure MatcherInfo where
6364
-/
6465
discrInfos : Array DiscrInfo
6566
overlaps : Overlaps := {}
67+
deriving Inhabited, Repr
6668

6769
@[expose] def MatcherInfo.numAlts (info : MatcherInfo) : Nat :=
6870
info.altInfos.size

src/Lean/Meta/Tactic/Split.lean

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,14 +284,14 @@ def applyMatchSplitter (mvarId : MVarId) (matcherDeclName : Name) (us : Array Le
284284
trace[split.debug] "after check splitter"
285285
let mvarIds ← mvarId.applyN splitter matchEqns.size
286286
let (_, mvarIds) ← mvarIds.foldlM (init := (0, [])) fun (i, mvarIds) mvarId => do
287-
let numParams := matchEqns.splitterAltNumParams[i]!
287+
let altInfo := matchEqns.splitterMatchInfo.altInfos[i]!
288288
let mvarId ←
289-
if numParams + info.getNumDiscrEqs = 0 then
289+
if altInfo.hasUnitThunk then
290290
trace[split.debug] "introducing unit param for alt {(i : Nat)}"
291291
let (unitFvarId, mvarId) ← mvarId.intro1
292292
mvarId.tryClear unitFvarId
293293
else
294-
let (_, mvarId) ← mvarId.introN numParams
294+
let (_, mvarId) ← mvarId.introN (altInfo.numFields + altInfo.numOverlaps)
295295
pure mvarId
296296
trace[split.debug] "before unifyEqs\n{mvarId}"
297297
match (← Cases.unifyEqs? (info.getNumDiscrEqs + numEqs) mvarId {}) with

0 commit comments

Comments
 (0)