Skip to content

Commit f7031c7

Browse files
authored
perf: in match splitters, thunk alts if needed (#11239)
This PR adds a `Unit` assumption to alternatives of the splitter that would otherwise not have arguments. This fixes #11211. In practice these argument-less alternatives did not cause wrong behavior, as the motive when used with `split` is always a function type. But it is better to be safe here (maybe someone uses splitters in other ways), it may increase the effectiveness of #10184 and simplifies #11220. The perf impact is insignificant in the grand scheme of things on stdlib, but the change is effective: ``` ~/lean4 $ build/release/stage1/bin/lean tests/lean/run/matchSplitStats.lean 969 splitters found 455 splitters are const defs ~/lean4 $ build/release/stage2/bin/lean tests/lean/run/matchSplitStats.lean 969 splitters found 829 splitters are const defs ```
1 parent 9fc9048 commit f7031c7

File tree

6 files changed

+146
-18
lines changed

6 files changed

+146
-18
lines changed

src/Lean/Meta/Match/MatchEqs.lean

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,7 @@ private abbrev ConvertM := ReaderT (FVarIdMap (Expr × Nat × Array Bool)) $ Sta
279279
- `altNews` are the new free variables which contains additional hypotheses that ensure they are only used
280280
when the previous overlapping alternatives are not applicable. -/
281281
private partial def mkSplitterProof (matchDeclName : Name) (template : Expr) (alts altsNew : Array Expr)
282-
(altsNewNumParams : Array Nat)
283-
(altArgMasks : Array (Array Bool)) : MetaM Expr := do
282+
(altsNewNumParams : Array Nat) (altArgMasks : Array (Array Bool)) (numDiscrEqs : Nat) : MetaM Expr := do
284283
trace[Meta.Match.matchEqs] "proof template: {template}"
285284
let map := mkMap
286285
let (proof, mvarIds) ← convertTemplate template |>.run map |>.run #[]
@@ -454,7 +453,10 @@ where
454453
else
455454
let Expr.fvar fvarId .. := e.getAppFn | return .continue
456455
let some (altNew, numParams, argMask) := (← read).get? fvarId | return .continue
457-
trace[Meta.Match.matchEqs] ">> argMask: {argMask}, e: {e}, {altNew}"
456+
trace[Meta.Match.matchEqs] ">> argMask: {argMask}, numParams: {numParams}, e: {e}, alsNew: {altNew}, "
457+
if numParams + numDiscrEqs = 0 then
458+
let eNew := mkApp altNew (mkConst ``Unit.unit)
459+
return TransformStep.done eNew
458460
let mut newArgs := #[]
459461
let argMask := trimFalseTrail argMask
460462
unless e.getAppNumArgs ≥ argMask.size do
@@ -463,7 +465,7 @@ where
463465
if includeArg then
464466
newArgs := newArgs.push arg
465467
let eNew := mkAppN altNew newArgs
466-
/- Recall that `numParams` does not include the equalities associated with discriminants of the form `h : discr`. -/
468+
/- Recall that `numParams` does not include the `numDiscrEqs` equalities associated with discriminants of the form `h : discr`. -/
467469
let (mvars, _, _) ← forallMetaBoundedTelescope (← inferType eNew) (numParams - newArgs.size) (kind := MetavarKind.syntheticOpaque)
468470
modify fun s => s ++ (mvars.map (·.mvarId!))
469471
let eNew := mkAppN eNew mvars
@@ -543,7 +545,13 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
543545
if let some h ← simpH? h patterns.size then
544546
hs := hs.push h
545547
trace[Meta.Match.matchEqs] "hs: {hs}"
546-
let splitterAltType ← mkForallFVars ys (← hs.foldrM (init := (← mkForallFVars eqs altResultType)) (mkArrow · ·))
548+
let splitterAltType ← mkForallFVars eqs altResultType
549+
let splitterAltType ← mkArrowN hs splitterAltType
550+
let splitterAltType ← mkForallFVars ys splitterAltType
551+
let splitterAltType ← if splitterAltType == altResultType then
552+
mkArrow (mkConst ``Unit) splitterAltType
553+
else
554+
pure splitterAltType
547555
let splitterAltType ← unfoldNamedPattern splitterAltType
548556
let splitterAltNumParam := hs.size + ys.size
549557
-- Create a proposition for representing terms that do not match `patterns`
@@ -580,14 +588,14 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
580588
let splitterParams := params.toArray ++ #[motive] ++ discrs.toArray ++ altsNew
581589
let splitterType ← mkForallFVars splitterParams matchResultType
582590
trace[Meta.Match.matchEqs] "splitterType: {splitterType}"
583-
let template := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ discrs ++ alts)
584-
let template ← deltaExpand template (· == constInfo.name)
585-
let template := template.headBeta
586591
let splitterVal ←
587592
if (← isDefEq splitterType constInfo.type) then
588593
pure <| mkConst constInfo.name us
589594
else
590-
mkLambdaFVars splitterParams (← mkSplitterProof matchDeclName template alts altsNew splitterAltNumParams altArgMasks)
595+
let template := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ discrs ++ alts)
596+
let template ← deltaExpand template (· == constInfo.name)
597+
let template := template.headBeta
598+
mkLambdaFVars splitterParams (← mkSplitterProof matchDeclName template alts altsNew splitterAltNumParams altArgMasks numDiscrEqs)
591599
addAndCompile <| Declaration.defnDecl {
592600
name := splitterName
593601
levelParams := constInfo.levelParams

src/Lean/Meta/Match/MatcherApp/Transform.lean

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,16 +311,27 @@ def transform
311311
altType in altTypes do
312312
let alt' ← forallAltTelescope' origAltType (numParams - numDiscrEqs) 0 fun ys args => do
313313
let altType ← instantiateForall altType ys
314+
-- Look past the thunking unit parameter, if present
315+
let altType ← if splitterNumParams + numDiscrEqs = 0 then
316+
instantiateForall altType #[mkConst ``Unit.unit]
317+
else
318+
pure altType
314319
-- The splitter inserts its extra parameters after the first ys.size parameters, before
315320
-- the parameters for the numDiscrEqs
316-
forallBoundedTelescope altType (splitterNumParams - ys.size) fun ys2 altType => do
321+
let alt' ← forallBoundedTelescope altType (splitterNumParams - ys.size) fun ys2 altType => do
317322
forallBoundedTelescope altType numDiscrEqs fun ys3 altType => do
318323
forallBoundedTelescope altType extraEqualities fun ys4 altType => do
319324
let altParams := args ++ ys3
320325
let alt ← try instantiateLambda alt altParams
321326
catch _ => throwError "unexpected matcher application, insufficient number of parameters in alternative"
322327
let alt' ← onAlt altIdx altType altParams alt
323328
mkLambdaFVars (ys ++ ys2 ++ ys3 ++ ys4) alt'
329+
let alt' ← if splitterNumParams + numDiscrEqs = 0 then
330+
-- The splitter expects a thunked alternative, but we don't want the `x : Unit` to be in
331+
-- the context (e.g. in functional induction), so use Function.const rather than a lambda
332+
mkAppM ``Function.const #[mkConst ``Unit, alt']
333+
else
334+
pure alt'
324335
alts' := alts'.push alt'
325336

326337
remaining' := remaining' ++ (← onRemaining matcherApp.remaining)

src/Lean/Meta/Tactic/Split.lean

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,16 @@ def applyMatchSplitter (mvarId : MVarId) (matcherDeclName : Name) (us : Array Le
285285
let mvarIds ← mvarId.applyN splitter matchEqns.size
286286
let (_, mvarIds) ← mvarIds.foldlM (init := (0, [])) fun (i, mvarIds) mvarId => do
287287
let numParams := matchEqns.splitterAltNumParams[i]!
288-
let (_, mvarId) ← mvarId.introN numParams
288+
let mvarId ←
289+
if numParams + info.getNumDiscrEqs = 0 then
290+
trace[split.debug] "introducing unit param for alt {(i : Nat)}"
291+
let (unitFvarId, mvarId) ← mvarId.intro1
292+
mvarId.tryClear unitFvarId
293+
else
294+
let (_, mvarId) ← mvarId.introN numParams
295+
pure mvarId
289296
trace[split.debug] "before unifyEqs\n{mvarId}"
290-
match (← Cases.unifyEqs? (numEqs + info.getNumDiscrEqs) mvarId {}) with
297+
match (← Cases.unifyEqs? (info.getNumDiscrEqs + numEqs) mvarId {}) with
291298
| none => return (i+1, mvarIds) -- case was solved
292299
| some (mvarId, fvarSubst) =>
293300
trace[split.debug] "after unifyEqs\n{mvarId}"

tests/lean/run/casesOnSameCtor.lean

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,18 @@ info: Vec.match_on_same_ctor.{u_1, u} {α : Type u}
3838
/--
3939
info: Vec.match_on_same_ctor.splitter.{u_1, u} {α : Type u}
4040
{motive : {a : Nat} → (t t_1 : Vec α a) → t.ctorIdx = t_1.ctorIdx → Sort u_1} {a✝ : Nat} (t t✝ : Vec α a✝)
41-
(h : t.ctorIdx = t✝.ctorIdx) (h_1 : motive nil nil ⋯)
41+
(h : t.ctorIdx = t✝.ctorIdx) (h_1 : Unit → motive nil nil ⋯)
4242
(h_2 : (a : α) → (n : Nat) → (a_1 : Vec α n) → (a' : α) → (a'_1 : Vec α n) → motive (cons a a_1) (cons a' a'_1) ⋯) :
4343
motive t t✝ h
4444
-/
4545
#guard_msgs in
4646
#check Vec.match_on_same_ctor.splitter
4747

48-
-- Since there is no overlap, the splitter is equal to the matcher
49-
-- (I wonder if we should use this in general in MatchEq)
50-
example : @Vec.match_on_same_ctor = @Vec.match_on_same_ctor.splitter := by rfl
48+
-- After #11211 this is no longer true. Should we thunk the same-ctor-construction?
49+
50+
-- -- Since there is no overlap, the splitter is equal to the matcher
51+
-- -- (I wonder if we should use this in general in MatchEq)
52+
-- example : @Vec.match_on_same_ctor = @Vec.match_on_same_ctor.splitter := by rfl
5153

5254
/--
5355
info: Vec.match_on_same_ctor.eq_2.{u_1, u} {α : Type u}

tests/lean/run/issue11211.lean

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/-!
2+
Checks that splitters have `Unit →` thunks and that nothing is confused because of that.
3+
-/
4+
5+
set_option linter.unusedVariables false
6+
7+
8+
-- set_option trace.Meta.Match.matchEqs true
9+
10+
def f (xs : List Nat) : Nat :=
11+
match xs with
12+
| [] => 1
13+
| _ => 2
14+
15+
/--
16+
info: def f.match_1.{u_1} : (motive : List Nat → Sort u_1) →
17+
(xs : List Nat) → (Unit → motive []) → ((x : List Nat) → motive x) → motive xs
18+
-/
19+
#guard_msgs in
20+
#print sig f.match_1
21+
22+
/--
23+
info: private def f.match_1.splitter.{u_1} : (motive : List Nat → Sort u_1) →
24+
(xs : List Nat) → (Unit → motive []) → ((x : List Nat) → (x = [] → False) → motive x) → motive xs
25+
-/
26+
#guard_msgs(pass trace, all) in
27+
#print sig f.match_1.splitter
28+
29+
30+
/--
31+
info: private theorem f.match_1.congr_eq_1.{u_1} : ∀ (motive : List Nat → Sort u_1) (xs : List Nat) (h_1 : Unit → motive [])
32+
(h_2 : (x : List Nat) → motive x),
33+
xs = [] →
34+
(match xs with
35+
| [] => h_1 ()
36+
| x => h_2 x) ≍
37+
h_1 ()
38+
-/
39+
#guard_msgs(pass trace, all) in
40+
#print sig f.match_1.congr_eq_1
41+
42+
-- set_option trace.split.debug true
43+
44+
theorem test1: f n ≤ 2 := by
45+
unfold f
46+
split <;> grind
47+
48+
49+
theorem test2 : f n ≤ 2 := by
50+
unfold f
51+
grind
52+
53+
/--
54+
info: theorem f.fun_cases : ∀ (motive : List Nat → Prop),
55+
motive [] → (∀ (xs : List Nat), (xs = [] → False) → motive xs) → ∀ (xs : List Nat), motive xs
56+
-/
57+
#guard_msgs(pass trace, all) in
58+
#print sig f.fun_cases
59+
60+
def Option_map (f : α → β) : Option α → Option β
61+
| some x => some (f x)
62+
| none => none
63+
64+
/--
65+
info: def Option_map.match_1.{u_1, u_2} : {α : Type u_1} →
66+
(motive : Option α → Sort u_2) → (x : Option α) → ((x : α) → motive (some x)) → (Unit → motive none) → motive x
67+
-/
68+
#guard_msgs in
69+
#print sig Option_map.match_1
70+
71+
/--
72+
info: private def Option_map.match_1.splitter.{u_1, u_2} : {α : Type u_1} →
73+
(motive : Option α → Sort u_2) → (x : Option α) → ((x : α) → motive (some x)) → (Unit → motive none) → motive x :=
74+
@Option_map.match_1
75+
-/
76+
#guard_msgs in
77+
#print Option_map.match_1.splitter
78+
79+
/--
80+
info: theorem Option_map.fun_cases.{u_1} : ∀ {α : Type u_1} (motive : Option α → Prop),
81+
(∀ (x : α), motive (some x)) → motive none → ∀ (x : Option α), motive x
82+
-/
83+
#guard_msgs(pass trace, all) in
84+
#print sig Option_map.fun_cases
85+
86+
def List_map (f : α → β) (l : List α) : List β := match _ : l with
87+
| x::xs => f x :: List_map f xs
88+
| [] => []
89+
termination_by l
90+
91+
def foo₁ (a : Nat) (ha : a = 37) :=
92+
(match h : a with | 42 => 23 | n => n) = 37
93+
94+
/--
95+
info: private def foo₁.match_1.splitter.{u_1} : (motive : Nat → Sort u_1) →
96+
(a : Nat) → (a = 42 → motive 42) → ((n : Nat) → (n = 42 → False) → a = n → motive n) → motive a
97+
-/
98+
#guard_msgs in
99+
#print sig foo₁.match_1.splitter

tests/lean/run/matchSparse.lean

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ info: expensive.match_1.{u_1} (motive : Expr → Expr → Sort u_1) (x✝ x✝¹
3737
/--
3838
info: expensive.match_1.splitter.{u_1} (motive : Expr → Expr → Sort u_1) (x✝ x✝¹ : Expr)
3939
(h_1 :
40-
motive (((sort zero.succ).app (sort zero.succ)).app (sort zero.succ))
41-
(((sort zero.succ).app (sort zero.succ)).app (sort zero.succ)))
40+
Unit →
41+
motive (((sort zero.succ).app (sort zero.succ)).app (sort zero.succ))
42+
(((sort zero.succ).app (sort zero.succ)).app (sort zero.succ)))
4243
(h_2 :
4344
(x x_1 : Expr) →
4445
(x = ((sort zero.succ).app (sort zero.succ)).app (sort zero.succ) →

0 commit comments

Comments
 (0)