Skip to content

Commit d41f39f

Browse files
authored
perf: sparse case splitting in match compilation (#10823)
This PR lets the match compilation procedure use sparse case analysis when the patterns only match on some but not all constructors of an inductive type. This way, less code is produce. Before, code handling each of the other cases was then optimized and commoned-up by later compilation pipeline, but that is wasteful to do. In some cases this will prevent Lean from noticing that a match statement is complete because it performs less case-splitting for the unreachable case. In this case, give explicit patterns to perform the deeper split with `by contradiction` as the right-hand side. At least temporarily, there is also the option to disable this behaviour with ``` set_option backwards.match.sparseCases false ```
1 parent 7459304 commit d41f39f

File tree

13 files changed

+359
-140
lines changed

13 files changed

+359
-140
lines changed

src/Init/Data/Int/Lemmas.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ protected theorem mul_eq_zero {a b : Int} : a * b = 0 ↔ a = 0 ∨ b = 0 := by
552552
exact match a, b, h with
553553
| .ofNat 0, _, _ => by simp
554554
| _, .ofNat 0, _ => by simp
555-
| .ofNat (a+1), .negSucc b, h => by cases h
555+
| .ofNat (_+1), .negSucc _, h => by cases h
556556

557557
protected theorem mul_ne_zero {a b : Int} (a0 : a ≠ 0) (b0 : b ≠ 0) : a * b ≠ 0 :=
558558
Or.rec a0 b0 ∘ Int.mul_eq_zero.mp

src/Lean/Compiler/LCNF/Util.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ prelude
99
public import Init.Data.FloatArray.Basic
1010
public import Lean.CoreM
1111
public import Lean.Util.Recognizers
12+
import Lean.Meta.Basic
1213

1314
public section
1415

src/Lean/Elab/Tactic/RCases.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ partial def rcasesCore (g : MVarId) (fs : FVarSubst) (clears : Array FVarId) (e
323323
let (v, g) ← g.intro x
324324
let (varsOut, g) ← g.introNP vars.size
325325
let fs' := (vars.zip varsOut).foldl (init := fs) fun fs (v, w) => fs.insert v (mkFVar w)
326-
pure ([(n, ps)], #[⟨⟨g, #[mkFVar v], fs'⟩, n⟩])
326+
pure ([(n, ps)], #[{mvarId := g, fields := #[mkFVar v], subst := fs', ctorName := n }])
327327
| ConstantInfo.inductInfo info, _ => do
328328
let (altVarNames, r) ← processConstructors pat.ref info.numParams #[] info.ctors pat.asAlts
329329
(r, ·) <$> g.cases e.fvarId! altVarNames (useNatCasesAuxOn := true)

src/Lean/Meta/Match/Match.lean

Lines changed: 112 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,22 @@ public section
1717

1818
namespace Lean.Meta.Match
1919

20+
register_builtin_option backwards.match.sparseCases : Bool := {
21+
defValue := true
22+
descr := "if true (the default), generate and use sparse case constructs when splitting inductive
23+
types. In some cases this will prevent Lean from noticing that a match statement is complete
24+
because it performs less case-splitting for the unreachable case. In this case, give explicit
25+
patterns to perform the deeper split with `by contradiction` as the right-hand side.
26+
,"
27+
}
28+
2029
register_builtin_option backwards.match.rowMajor : Bool := {
2130
defValue := true
22-
group := "bootstrap"
2331
descr := "If true (the default), match compilation will split the discrimnants based \
2432
on position of the first constructor pattern in the first alternative. If false, \
2533
it splits them from left to right, which can lead to unnecessary code bloat."
2634
}
2735

28-
2936
private def mkIncorrectNumberOfPatternsMsg [ToMessageData α]
3037
(discrepancyKind : String) (expected actual : Nat) (pats : List α) :=
3138
let patternsMsg := MessageData.joinSep (pats.map toMessageData) ", "
@@ -162,6 +169,12 @@ private def hasArrayLitPattern (p : Problem) : Bool :=
162169
| .arrayLit .. :: _ => true
163170
| _ => false
164171

172+
private def hasVarOrInaccessiblePattern (p : Problem) : Bool :=
173+
p.alts.any fun alt => match alt.patterns with
174+
| .inaccessible _ :: _ => true
175+
| .var _ :: _ => true
176+
| _ => false
177+
165178
private def isVariableTransition (p : Problem) : Bool :=
166179
p.alts.all fun alt => match alt.patterns with
167180
| .inaccessible _ :: _ => true
@@ -341,6 +354,35 @@ where
341354
msg := msg ++ m!"\n {lhs} ≋ {rhs}"
342355
throwErrorAt alt.ref msg
343356

357+
private abbrev isCtorIdxIneq? (e : Expr) : Option FVarId := do
358+
if let some (_, lhs, _rhs) := e.ne? then
359+
if
360+
lhs.isApp &&
361+
lhs.getAppFn.isConst &&
362+
(`ctorIdx).isSuffixOf lhs.getAppFn.constName! && -- This should be an env extension maybe
363+
lhs.appArg!.isFVar
364+
then
365+
return lhs.appArg!.fvarId!
366+
none
367+
368+
private partial def contradiction (mvarId : MVarId) : MetaM Bool := do
369+
mvarId.withContext do
370+
withTraceNode `Meta.Match.match (msg := (return m!"{exceptBoolEmoji ·} Match.contradiction")) do
371+
trace[Meta.Match.match] m!"Match.contradiction:\n{mvarId}"
372+
if (← mvarId.contradictionCore {}) then
373+
trace[Meta.Match.match] "Contradiction found!"
374+
return true
375+
else
376+
-- Try harder by splitting `ctorIdx x ≠ 23` assumptions
377+
for localDecl in (← getLCtx) do
378+
if let some fvarId := isCtorIdxIneq? localDecl.type then
379+
trace[Meta.Match.match] "splitting ctorIdx assumption {localDecl.type}"
380+
let subgoals ← mvarId.cases fvarId
381+
return ← subgoals.allM (contradiction ·.mvarId)
382+
383+
mvarId.admit
384+
return false
385+
344386
/--
345387
Try to solve the problem by using the first alternative whose pending constraints can be resolved.
346388
-/
@@ -352,10 +394,10 @@ where
352394
go (alts : List Alt) : StateRefT State MetaM Unit := do
353395
match alts with
354396
| [] =>
397+
let mvarId ← p.mvarId.exfalso
355398
/- TODO: allow users to configure which tactic is used to close leaves. -/
356-
unless (← p.mvarId.contradictionCore {}) do
357-
trace[Meta.Match.match] "missing alternative"
358-
p.mvarId.admit
399+
unless (← contradiction mvarId) do
400+
trace[Meta.Match.match] "contradiction failed, missing alternative"
359401
modify fun s => { s with counterExamples := p.examples :: s.counterExamples }
360402
| alt :: _ =>
361403
solveCnstrs p.mvarId alt
@@ -516,13 +558,29 @@ private def throwCasesException (p : Problem) (ex : Exception) : MetaM α := do
516558
"- <constructor> = <constructor>, examples: List.cons x xs = List.cons y ys, and List.cons x xs = List.nil"
517559
| _ => throw ex
518560

561+
private def collectCtors (p : Problem) : Array Name :=
562+
p.alts.foldl (init := #[]) fun ctors alt =>
563+
match alt.patterns with
564+
| .ctor n _ _ _ :: _ => if ctors.contains n then ctors else ctors.push n
565+
| _ => ctors
566+
519567
private def processConstructor (p : Problem) : MetaM (Array Problem) := do
520568
trace[Meta.Match.match] "constructor step"
521569
let x :: xs := p.vars | unreachable!
570+
let interestingCtors? ←
571+
-- We use a sparse case analysis only if there is at least one non-constructor pattern,
572+
-- but not just because there are constructors missing (in that case we benefit from
573+
-- the eager split in ruling out constructors by type or by a more explicit error message)
574+
if backwards.match.sparseCases.get (← getOptions) && hasVarOrInaccessiblePattern p then
575+
let ctors := collectCtors p
576+
trace[Meta.Match.match] "using sparse cases: {ctors}"
577+
pure (some ctors)
578+
else
579+
pure none
522580
let subgoals? ← commitWhenSome? do
523581
let subgoals ←
524582
try
525-
p.mvarId.cases x.fvarId!
583+
p.mvarId.cases x.fvarId! (interestingCtors? := interestingCtors?)
526584
catch ex =>
527585
if p.alts.isEmpty then
528586
/- If we have no alternatives and dependent pattern matching fails, then a "missing cases" error is better than a "stuck" error message. -/
@@ -545,28 +603,42 @@ private def processConstructor (p : Problem) : MetaM (Array Problem) := do
545603
return some subgoals
546604
let some subgoals := subgoals? | return #[{ p with vars := xs }]
547605
subgoals.mapM fun subgoal => subgoal.mvarId.withContext do
548-
let subst := subgoal.subst
549-
let fields := subgoal.fields.toList
550-
let newVars := fields ++ xs
551-
let newVars := newVars.map fun x => x.applyFVarSubst subst
552-
let subex := Example.ctor subgoal.ctorName <| fields.map fun field => match field with
553-
| .fvar fvarId => Example.var fvarId
554-
| _ => Example.underscore -- This case can happen due to dependent elimination
555-
let examples := p.examples.map <| Example.replaceFVarId x.fvarId! subex
556-
let examples := examples.map <| Example.applyFVarSubst subst
557-
let newAlts := p.alts.filter fun alt => match alt.patterns with
558-
| .ctor n .. :: _ => n == subgoal.ctorName
559-
| .var _ :: _ => true
560-
| .inaccessible _ :: _ => true
561-
| _ => false
562-
let newAlts := newAlts.map fun alt => alt.applyFVarSubst subst
563-
let newAlts ← newAlts.mapM fun alt => do
564-
match alt.patterns with
565-
| .ctor _ _ _ fields :: ps => return { alt with patterns := fields ++ ps }
566-
| .var _ :: _ => expandVarIntoCtor alt subgoal.ctorName
567-
| .inaccessible _ :: _ => processInaccessibleAsCtor alt subgoal.ctorName
568-
| _ => unreachable!
569-
return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
606+
-- withTraceNode `Meta.Match.match (msg := (return m!"{exceptEmoji ·} case {subgoal.ctorName}")) do
607+
if let some ctorName := subgoal.ctorName then
608+
-- A normal constructor case
609+
let subst := subgoal.subst
610+
let fields := subgoal.fields.toList
611+
let newVars := fields ++ xs
612+
let newVars := newVars.map fun x => x.applyFVarSubst subst
613+
let subex := Example.ctor ctorName <| fields.map fun field => match field with
614+
| .fvar fvarId => Example.var fvarId
615+
| _ => Example.underscore -- This case can happen due to dependent elimination
616+
let examples := p.examples.map <| Example.replaceFVarId x.fvarId! subex
617+
let examples := examples.map <| Example.applyFVarSubst subst
618+
let newAlts := p.alts.filter fun alt => match alt.patterns with
619+
| .ctor n .. :: _ => n == ctorName
620+
| .var _ :: _ => true
621+
| .inaccessible _ :: _ => true
622+
| _ => false
623+
let newAlts := newAlts.map fun alt => alt.applyFVarSubst subst
624+
let newAlts ← newAlts.mapM fun alt => do
625+
match alt.patterns with
626+
| .ctor _ _ _ fields :: ps => return { alt with patterns := fields ++ ps }
627+
| .var _ :: _ => expandVarIntoCtor alt ctorName
628+
| .inaccessible _ :: _ => processInaccessibleAsCtor alt ctorName
629+
| _ => unreachable!
630+
return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
631+
else
632+
-- A catch-all case
633+
let subst := subgoal.subst
634+
trace[Meta.Match.match] "constructor catch-all case"
635+
let examples := p.examples.map <| Example.applyFVarSubst subst
636+
let newVars := p.vars.map fun x => x.applyFVarSubst subst
637+
let newAlts := p.alts.filter fun alt => match alt.patterns with
638+
| .ctor .. :: _ => false
639+
| _ => true
640+
let newAlts := newAlts.map fun alt => alt.applyFVarSubst subst
641+
return { mvarId := subgoal.mvarId, alts := newAlts, vars := newVars, examples := examples }
570642

571643
private def processNonVariable (p : Problem) : MetaM Problem := withGoalOf p do
572644
let x :: xs := p.vars | unreachable!
@@ -808,6 +880,15 @@ def processExFalso (p : Problem) : MetaM Problem := do
808880
let mvarId' ← p.mvarId.exfalso
809881
return { p with mvarId := mvarId' }
810882

883+
private def tracedForM (xs : Array α) (process : α → StateRefT State MetaM Unit) : StateRefT State MetaM Unit :=
884+
if xs.size > 1 then
885+
for x in xs, i in [:xs.size] do
886+
withTraceNode `Meta.Match.match (msg := (return m!"{exceptEmoji ·} subgoal {i+1}/{xs.size}")) do
887+
process x
888+
else
889+
for x in xs do
890+
process x
891+
811892
private partial def process (p : Problem) : StateRefT State MetaM Unit := do
812893
traceState p
813894
if isDone p then
@@ -870,7 +951,7 @@ private partial def process (p : Problem) : StateRefT State MetaM Unit := do
870951

871952
if (← isConstructorTransition p) then
872953
let ps ← processConstructor p
873-
ps.forM process
954+
tracedForM ps process
874955
return
875956

876957
if isVariableTransition p then
@@ -881,12 +962,12 @@ private partial def process (p : Problem) : StateRefT State MetaM Unit := do
881962

882963
if isValueTransition p then
883964
let ps ← processValue p
884-
ps.forM process
965+
tracedForM ps process
885966
return
886967

887968
if isArrayLitTransition p then
888969
let ps ← processArrayLit p
889-
ps.forM process
970+
tracedForM ps process
890971
return
891972

892973
if (← hasNatValPattern p) then

src/Lean/Meta/Tactic/Cases.lean

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ prelude
99
public import Lean.Meta.Tactic.Induction
1010
public import Lean.Meta.Tactic.Acyclic
1111
public import Lean.Meta.Tactic.UnifyEq
12+
import Lean.Meta.Constructions.SparseCasesOn
13+
import Lean.Meta.Constructions.CtorIdx
1214

1315
public section
1416

@@ -149,7 +151,8 @@ def generalizeIndices (mvarId : MVarId) (fvarId : FVarId) : MetaM GeneralizeIndi
149151
generalizeIndices' mvarId fvarDecl.toExpr fvarDecl.userName
150152

151153
structure CasesSubgoal extends InductionSubgoal where
152-
ctorName : Name
154+
/-- The constructor of this subgoal. Is `none` in the catch-all of a sparse case match -/
155+
ctorName : Option Name
153156

154157
namespace Cases
155158

@@ -216,11 +219,13 @@ private def elimAuxIndices (s₁ : GeneralizeIndicesSubgoal) (s₂ : Array Cases
216219
private def toCasesSubgoals (s : Array InductionSubgoal) (ctorNames : Array Name) (majorFVarId : FVarId) (us : List Level) (params : Array Expr)
217220
: Array CasesSubgoal :=
218221
s.mapIdx fun i s =>
219-
let ctorName := ctorNames[i]!
220-
let ctorApp := mkAppN (mkAppN (mkConst ctorName us) params) s.fields
221-
let s := { s with subst := s.subst.insert majorFVarId ctorApp }
222-
{ ctorName := ctorName,
223-
toInductionSubgoal := s }
222+
if _ : i < ctorNames.size then
223+
let ctorName := ctorNames[i]
224+
let ctorApp := mkAppN (mkAppN (mkConst ctorName us) params) s.fields
225+
let subst := s.subst.erase majorFVarId |>.insert majorFVarId ctorApp
226+
{ s with ctorName := ctorName, subst}
227+
else
228+
{ s with ctorName := none }
224229

225230
partial def unifyEqs? (numEqs : Nat) (mvarId : MVarId) (subst : FVarSubst) (caseName? : Option Name := none): MetaM (Option (MVarId × FVarSubst)) := withIncRecDepth do
226231
if numEqs == 0 then
@@ -244,17 +249,33 @@ private def unifyCasesEqs (numEqs : Nat) (subgoals : Array CasesSubgoal) : MetaM
244249
}
245250

246251
private def inductionCasesOn (mvarId : MVarId) (majorFVarId : FVarId) (givenNames : Array AltVarNames) (ctx : Context)
247-
(useNatCasesAuxOn : Bool := false) : MetaM (Array CasesSubgoal) := mvarId.withContext do
252+
(useNatCasesAuxOn : Bool := false) (interestingCtors? : Option (Array Name) := none) :
253+
MetaM (Array CasesSubgoal) := mvarId.withContext do
248254
let majorType ← inferType (mkFVar majorFVarId)
249255
let (us, params) ← getInductiveUniverseAndParams majorType
250-
let mut casesOn := mkCasesOnName ctx.inductiveVal.name
251-
if useNatCasesAuxOn && ctx.inductiveVal.name == ``Nat && (← getEnv).contains ``Nat.casesAuxOn then
252-
casesOn := ``Nat.casesAuxOn
256+
257+
if let some interestingCtors := interestingCtors? then
258+
-- Avoid Init.Prelude complications
259+
let hasNE := (← getEnv).contains ``Ne
260+
-- We can only create a sparse casesOn if we have `ctorIdx` (in particular, if it is a type)
261+
let hasCtorIdx := (← getEnv).contains (mkCtorIdxName ctx.inductiveVal.name)
262+
if hasNE && hasCtorIdx && !interestingCtors.isEmpty &&
263+
interestingCtors.size < ctx.inductiveVal.ctors.length then
264+
let casesOn ← Lean.Meta.mkSparseCasesOn ctx.inductiveVal.name interestingCtors
265+
let s ← mvarId.induction majorFVarId casesOn givenNames
266+
return toCasesSubgoals s interestingCtors majorFVarId us params
267+
268+
let casesOn :=
269+
if useNatCasesAuxOn && ctx.inductiveVal.name == ``Nat && (← getEnv).contains ``Nat.casesAuxOn then
270+
``Nat.casesAuxOn
271+
else
272+
mkCasesOnName ctx.inductiveVal.name
253273
let ctors := ctx.inductiveVal.ctors.toArray
254274
let s ← mvarId.induction majorFVarId casesOn givenNames
255275
return toCasesSubgoals s ctors majorFVarId us params
256276

257-
def cases (mvarId : MVarId) (majorFVarId : FVarId) (givenNames : Array AltVarNames := #[]) (useNatCasesAuxOn : Bool := false) : MetaM (Array CasesSubgoal) := do
277+
def cases (mvarId : MVarId) (majorFVarId : FVarId) (givenNames : Array AltVarNames := #[])
278+
(useNatCasesAuxOn : Bool := false) (interestingCtors? : Option (Array Name) := none) : MetaM (Array CasesSubgoal) := do
258279
try
259280
mvarId.withContext do
260281
mvarId.checkNotAssigned `cases
@@ -268,10 +289,11 @@ def cases (mvarId : MVarId) (majorFVarId : FVarId) (givenNames : Array AltVarNam
268289
if (← hasIndepIndices ctx) then
269290
-- Simple case
270291
inductionCasesOn mvarId majorFVarId givenNames ctx (useNatCasesAuxOn := useNatCasesAuxOn)
292+
(interestingCtors? := interestingCtors?)
271293
else
272294
let s₁ ← generalizeIndices mvarId majorFVarId
273295
trace[Meta.Tactic.cases] "after generalizeIndices\n{MessageData.ofGoal s₁.mvarId}"
274-
let s₂ ← inductionCasesOn s₁.mvarId s₁.fvarId givenNames ctx
296+
let s₂ ← inductionCasesOn s₁.mvarId s₁.fvarId givenNames ctx (interestingCtors? := interestingCtors?)
275297
let s₂ ← elimAuxIndices s₁ s₂
276298
unifyCasesEqs s₁.numEqs s₂
277299
catch ex =>
@@ -288,8 +310,9 @@ Apply `casesOn` using the free variable `majorFVarId` as the major premise (aka
288310
It enables using `Nat.casesAuxOn` instead of `Nat.casesOn`,
289311
which causes case splits on `n : Nat` to be represented as `0` and `n' + 1` rather than as `Nat.zero` and `Nat.succ n'`.
290312
-/
291-
def _root_.Lean.MVarId.cases (mvarId : MVarId) (majorFVarId : FVarId) (givenNames : Array AltVarNames := #[]) (useNatCasesAuxOn : Bool := false) : MetaM (Array CasesSubgoal) :=
292-
Cases.cases mvarId majorFVarId givenNames (useNatCasesAuxOn := useNatCasesAuxOn)
313+
def _root_.Lean.MVarId.cases (mvarId : MVarId) (majorFVarId : FVarId) (givenNames : Array AltVarNames := #[]) (useNatCasesAuxOn : Bool := false)
314+
(interestingCtors? : Option (Array Name) := none) : MetaM (Array CasesSubgoal) :=
315+
Cases.cases mvarId majorFVarId givenNames (useNatCasesAuxOn := useNatCasesAuxOn) (interestingCtors? := interestingCtors?)
293316

294317
/--
295318
Keep applying `cases` on any hypothesis that satisfies `p`.

0 commit comments

Comments
 (0)