Skip to content

Commit c7f57d6

Browse files
authored
fix: avoid unnecessary branching in match compilation (#10763)
This PR improves match compilation: Branch on variables in the order suggested by the first remaining alternative, and do not branch when the first remaining alternative does not require it. This fixes #10749. With `set_option backwards.match.rowMajor false` the old behavior can be turned on. (For now this is an experiment to get familiar with the code and the whole problem domain. It is likely overly naive.)
1 parent 275f907 commit c7f57d6

File tree

6 files changed

+110
-39
lines changed

6 files changed

+110
-39
lines changed

src/Lean/Meta/Match/Match.lean

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ public section
1717

1818
namespace Lean.Meta.Match
1919

20+
register_builtin_option backwards.match.rowMajor : Bool := {
21+
defValue := true
22+
group := "bootstrap"
23+
descr := "If true (the default), match compilation will split the discrimnants based \
24+
on position of the first constructor pattern in the first alternative. If false, \
25+
it splits them from left to right, which can lead to unnecessary code bloat."
26+
}
27+
28+
2029
private def mkIncorrectNumberOfPatternsMsg [ToMessageData α]
2130
(discrepancyKind : String) (expected actual : Nat) (pats : List α) :=
2231
let patternsMsg := MessageData.joinSep (pats.map toMessageData) ", "
@@ -747,6 +756,46 @@ private def checkNextPatternTypes (p : Problem) : MetaM Unit := do
747756
unless (← isDefEq xType eType) do
748757
throwError "Type mismatch in pattern: Pattern{indentExpr e}\n{← mkHasTypeButIsExpectedMsg eType xType}"
749758

759+
private def List.moveToFront [Inhabited α] (as : List α) (i : Nat) : List α :=
760+
let rec loop : (as : List α) → (i : Nat) → α × List α
761+
| [], _ => unreachable!
762+
| a::as, 0 => (a, as)
763+
| a::as, i+1 =>
764+
let (b, bs) := loop as i
765+
(b, a::bs)
766+
let (b, bs) := loop as i
767+
b :: bs
768+
769+
/-- Move variable `#i` to the beginning of the to-do list `p.vars`. -/
770+
private def moveToFront (p : Problem) (i : Nat) : Problem :=
771+
if i == 0 then
772+
p
773+
else if i < p.vars.length then
774+
{ p with
775+
vars := List.moveToFront p.vars i
776+
alts := p.alts.map fun alt => { alt with patterns := List.moveToFront alt.patterns i }
777+
}
778+
else
779+
p
780+
781+
def Pattern.isRefutable : Pattern → Bool
782+
| .var _ => false
783+
| .inaccessible _ => false
784+
| .as _ p _ => p.isRefutable
785+
| .arrayLit .. => true
786+
| .ctor .. => true
787+
| .val .. => true
788+
789+
/--
790+
Returns the index of the first pattern in the first alternative that is refutable
791+
(i.e. not a variable or inaccessible pattern). We want to handle these first
792+
so that the generated code branches in the order suggested by the user's code.
793+
-/
794+
private def firstRefutablePattern (p : Problem) : Option Nat :=
795+
match p.alts with
796+
| alt:: _ => alt.patterns.findIdx? (·.isRefutable)
797+
| _ => none
798+
750799
def isExFalsoTransition (p : Problem) : MetaM Bool := do
751800
if p.alts.isEmpty then
752801
withGoalOf p do
@@ -778,6 +827,21 @@ private partial def process (p : Problem) : StateRefT State MetaM Unit := do
778827
process p
779828
return
780829

830+
if backwards.match.rowMajor.get (← getOptions) then
831+
match firstRefutablePattern p with
832+
| some i =>
833+
if i > 0 then
834+
traceStep ("move var to front")
835+
process (moveToFront p i)
836+
return
837+
| none =>
838+
if 1 < p.alts.length then
839+
traceStep ("drop all but first alt")
840+
-- all patterns are irrefutable, we can drop all other alts
841+
let p := { p with alts := p.alts.take 1 }
842+
process p
843+
return
844+
781845
if (← isNatValueTransition p) then
782846
traceStep ("nat value to constructor")
783847
process (← expandNatValuePattern p)

src/Lean/Meta/Tactic/Cases.lean

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,11 @@ partial def unifyEqs? (numEqs : Nat) (mvarId : MVarId) (subst : FVarSubst) (case
237237
return none
238238

239239
private def unifyCasesEqs (numEqs : Nat) (subgoals : Array CasesSubgoal) : MetaM (Array CasesSubgoal) :=
240-
subgoals.foldlM (init := #[]) fun subgoals s => do
240+
subgoals.filterMapM fun s => do
241241
match (← unifyEqs? numEqs s.mvarId s.subst s.ctorName) with
242-
| none => pure subgoals
242+
| none => pure none
243243
| some (mvarId, subst) =>
244-
return subgoals.push { s with
244+
return some { s with
245245
mvarId := mvarId,
246246
subst := subst,
247247
fields := s.fields.map (subst.apply ·)

tests/lean/run/issue10749.lean

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def test (a : List Nat) : Nat :=
1212
/--
1313
info: def test.match_1.{u_1} : (motive : List Nat → Sort u_1) →
1414
(a : List Nat) → ((x : List Nat) → motive x) → (Unit → motive []) → motive a :=
15-
fun motive a h_1 h_2 => List.casesOn a (h_1 []) fun head tail => h_1 (head :: tail)
15+
fun motive a h_1 h_2 => h_1 a
1616
-/
1717
#guard_msgs in #print test.match_1
1818

@@ -31,7 +31,7 @@ info: def test2.match_1.{u_1} : (motive : List Nat → List Nat → Sort u_1)
3131
(tail : List Nat) → (head_1 : Nat) → (tail_1 : List Nat) → motive (head :: tail) (head_1 :: tail_1)) →
3232
motive a b :=
3333
fun motive a b h_1 h_2 h_3 =>
34-
List.casesOn a (List.casesOn b (h_1 []) fun head tail => h_1 (head :: tail)) fun head tail =>
34+
List.casesOn a (h_1 b) fun head tail =>
3535
List.casesOn b (h_2 (head :: tail)) fun head_1 tail_1 => h_3 head tail head_1 tail_1
3636
-/
3737
#guard_msgs in #print test2.match_1
@@ -51,8 +51,7 @@ info: def test3.match_1.{u_1} : (motive : List Nat → Bool → Sort u_1) →
5151
((x : List Nat) → motive x true) →
5252
((x : Bool) → motive [] x) → ((x : List Nat) → (x_1 : Bool) → motive x x_1) → motive a b :=
5353
fun motive a b h_1 h_2 h_3 =>
54-
List.casesOn a (Bool.casesOn b (h_2 false) (h_1 [])) fun head tail =>
55-
Bool.casesOn b (h_3 (head :: tail) false) (h_1 (head :: tail))
54+
Bool.casesOn b (List.casesOn a (h_2 false) fun head tail => h_3 (head :: tail) false) (h_1 a)
5655
-/
5756
#guard_msgs in #print test3.match_1
5857

@@ -79,29 +78,33 @@ info: def test4.match_1.{u_1} : (motive : Bool → Bool → Bool → Bool → Bo
7978
((x x_5 x_6 x_7 : Bool) → motive true x x_5 x_6 x_7) →
8079
((x x_5 x_6 x_7 x_8 : Bool) → motive x x_5 x_6 x_7 x_8) → motive x x_1 x_2 x_3 x_4 :=
8180
fun motive x x_1 x_2 x_3 x_4 h_1 h_2 h_3 h_4 h_5 h_6 =>
82-
Bool.casesOn x
83-
(Bool.casesOn x_1
84-
(Bool.casesOn x_2
85-
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_6 false false false false false) (h_1 false false false false))
86-
(Bool.casesOn x_4 (h_2 false false false false) (h_1 false false false true)))
87-
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_3 false false false false) (h_1 false false true false))
88-
(Bool.casesOn x_4 (h_2 false false true false) (h_1 false false true true))))
89-
(Bool.casesOn x_2
90-
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_4 false false false false) (h_1 false true false false))
91-
(Bool.casesOn x_4 (h_2 false true false false) (h_1 false true false true)))
92-
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_3 false true false false) (h_1 false true true false))
93-
(Bool.casesOn x_4 (h_2 false true true false) (h_1 false true true true)))))
94-
(Bool.casesOn x_1
81+
Bool.casesOn x_4
82+
(Bool.casesOn x_3
9583
(Bool.casesOn x_2
96-
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_5 false false false false) (h_1 true false false false))
97-
(Bool.casesOn x_4 (h_2 true false false false) (h_1 true false false true)))
98-
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_3 true false false false) (h_1 true false true false))
99-
(Bool.casesOn x_4 (h_2 true false true false) (h_1 true false true true))))
100-
(Bool.casesOn x_2
101-
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_4 true false false false) (h_1 true true false false))
102-
(Bool.casesOn x_4 (h_2 true true false false) (h_1 true true false true)))
103-
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_3 true true false false) (h_1 true true true false))
104-
(Bool.casesOn x_4 (h_2 true true true false) (h_1 true true true true)))))
84+
(Bool.casesOn x_1 (Bool.casesOn x (h_6 false false false false false) (h_5 false false false false))
85+
(h_4 x false false false))
86+
(h_3 x x_1 false false))
87+
(h_2 x x_1 x_2 false))
88+
(h_1 x x_1 x_2 x_3)
10589
-/
10690
#guard_msgs in
10791
#print test4.match_1
92+
93+
-- Just testing the backwards compatibility option
94+
95+
set_option match.ignoreUnusedAlts true in
96+
set_option backwards.match.rowMajor false in
97+
def testOld (a : List Nat) : Nat :=
98+
match a with
99+
| _ => 3
100+
| [] => 4
101+
102+
-- Has unnecessary `casesOn`
103+
104+
/--
105+
info: def testOld.match_1.{u_1} : (motive : List Nat → Sort u_1) →
106+
(a : List Nat) → ((x : List Nat) → motive x) → (Unit → motive []) → motive a :=
107+
fun motive a h_1 h_2 => List.casesOn a (h_1 []) fun head tail => h_1 (head :: tail)
108+
-/
109+
#guard_msgs in
110+
#print testOld.match_1

tests/lean/run/issue10794.lean

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
/--
2-
error: Dependent match elimination failed: Could not solve constraints
3-
true ≋ false
2+
error: Dependent elimination failed: Type mismatch when solving this alternative: it has type
3+
motive false
4+
but is expected to have type
5+
motive b✝
46
-/
57
#guard_msgs in
68
def test1 b := match b with
79
| .(false) => 1
810
| true => 2
911

1012
/--
11-
error: Dependent match elimination failed: Could not solve constraints
12-
true ≋ false
13+
error: Dependent elimination failed: Type mismatch when solving this alternative: it has type
14+
motive false ⋯
15+
but is expected to have type
16+
motive x✝¹ x✝
1317
-/
1418
#guard_msgs in
1519
def test2 : (b : Bool) → (h : b = false) → Nat

tests/lean/run/match1.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,9 @@ partial def natToBin' : (n : Nat) → List Bool
137137
/--
138138
error: Tactic `cases` failed with a nested error:
139139
Dependent elimination failed: Failed to solve equation
140-
Nat.zero = n✝.add n✝
140+
n✝¹.succ = n✝.add n✝
141141
at case `Parity.even` after processing
142-
Nat.zero, _
142+
(Nat.succ _), _
143143
the dependent pattern matcher can solve the following kinds of equations
144144
- <var> = <term> and <term> = <var>
145145
- <term> = <term> where the terms are definitionally equal

tests/lean/run/matchOverlapInaccesible.lean

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ else
1616
/--
1717
error: Tactic `cases` failed with a nested error:
1818
Dependent elimination failed: Failed to solve equation
19-
Nat.zero = n✝.add n✝
19+
n✝¹.succ = n✝.add n✝
2020
at case `Parity.even` after processing
21-
Nat.zero, _
21+
(Nat.succ _), _
2222
the dependent pattern matcher can solve the following kinds of equations
2323
- <var> = <term> and <term> = <var>
2424
- <term> = <term> where the terms are definitionally equal
@@ -56,9 +56,9 @@ def parity (n : MyNat) : Parity n := sorry
5656
/--
5757
error: Tactic `cases` failed with a nested error:
5858
Dependent elimination failed: Failed to solve equation
59-
zero = n✝.add n✝
59+
a✝.succ = n✝.add n✝
6060
at case `Parity.even` after processing
61-
zero, _
61+
(succ _), _
6262
the dependent pattern matcher can solve the following kinds of equations
6363
- <var> = <term> and <term> = <var>
6464
- <term> = <term> where the terms are definitionally equal

0 commit comments

Comments
 (0)