Skip to content

Commit dc4d7f6

Browse files
committed
perf: when matching on values, avoid generating hyps when not needed
This PR avoids generating hyps when not needed (i.e. if there is a catch-all so no completeness checking needed) during matching on values. This tweak was made possible by #11220.
1 parent 035b886 commit dc4d7f6

File tree

3 files changed

+65
-11
lines changed

3 files changed

+65
-11
lines changed

src/Lean/Meta/Match/CaseValues.lean

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,11 @@ structure CaseValuesSubgoal where
8282
Remark: the field `newHs` has size 1 forall but the last subgoal.
8383
8484
If `substNewEqs = true`, then the new `h_i` equality hypotheses are substituted in the first `n` cases.
85+
86+
If `needsHyps = false` then the else case has the hypotheses cleared.
8587
-/
86-
def caseValues (mvarId : MVarId) (fvarId : FVarId) (values : Array Expr) (hNamePrefix := `h) (substNewEqs := false) : MetaM (Array CaseValuesSubgoal) :=
88+
def caseValues (mvarId : MVarId) (fvarId : FVarId) (values : Array Expr) (hNamePrefix := `h)
89+
(substNewEqs := false) (needHyps := true) : MetaM (Array CaseValuesSubgoal) :=
8790
let rec loop : Nat → MVarId → List Expr → Array FVarId → Array CaseValuesSubgoal → MetaM (Array CaseValuesSubgoal)
8891
| _, mvarId, [], _, _ => throwTacticEx `caseValues mvarId "list of values must not be empty"
8992
| i, mvarId, v::vs, hs, subgoals => do
@@ -103,7 +106,14 @@ def caseValues (mvarId : MVarId) (fvarId : FVarId) (values : Array Expr) (hNameP
103106
| [] => do
104107
appendTagSuffix elseSubgoal.mvarId ((`case).appendIndexAfter (i+1))
105108
pure $ subgoals.push { mvarId := elseSubgoal.mvarId, newHs := hs.push elseSubgoal.newH, subst := {} }
106-
| vs => loop (i+1) elseSubgoal.mvarId vs (hs.push elseSubgoal.newH) subgoals
109+
| vs =>
110+
let (mvarId', hs') ←
111+
if needHyps then
112+
pure (elseSubgoal.mvarId, hs.push elseSubgoal.newH)
113+
else
114+
pure (← elseSubgoal.mvarId.tryClear elseSubgoal.newH, hs)
115+
loop (i+1) mvarId' vs hs' subgoals
116+
107117
loop 1 mvarId values.toList #[] #[]
108118

109119
end Lean.Meta

src/Lean/Meta/Match/Match.lean

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -720,11 +720,23 @@ private def isFirstPatternVar (alt : Alt) : Bool :=
720720
| .var _ :: _ => true
721721
| _ => false
722722

723+
private def Pattern.isRefutable : Pattern → Bool
724+
| .var _ => false
725+
| .inaccessible _ => false
726+
| .as _ p _ => p.isRefutable
727+
| .arrayLit .. => true
728+
| .ctor .. => true
729+
| .val .. => true
730+
731+
private def triviallyComplete (p : Problem) : Bool :=
732+
!p.alts.isEmpty && p.alts.getLast!.patterns.all (!·.isRefutable)
733+
723734
private def processValue (p : Problem) : MetaM (Array Problem) := do
724735
trace[Meta.Match.match] "value step"
725736
let x :: xs := p.vars | unreachable!
726737
let values := collectValues p
727-
let subgoals ← caseValues p.mvarId x.fvarId! values (substNewEqs := true)
738+
let needHyps := !triviallyComplete p || p.alts.any (!·.notAltIdxs.isEmpty)
739+
let subgoals ← caseValues p.mvarId x.fvarId! values (substNewEqs := true) (needHyps := needHyps)
728740
subgoals.mapIdxM fun i subgoal => do
729741
trace[Meta.Match.match] "processValue subgoal\n{MessageData.ofGoal subgoal.mvarId}"
730742
if h : i < values.size then
@@ -898,14 +910,6 @@ private def moveToFront (p : Problem) (i : Nat) : Problem :=
898910
else
899911
p
900912

901-
def Pattern.isRefutable : Pattern → Bool
902-
| .var _ => false
903-
| .inaccessible _ => false
904-
| .as _ p _ => p.isRefutable
905-
| .arrayLit .. => true
906-
| .ctor .. => true
907-
| .val .. => true
908-
909913
/--
910914
Returns the index of the first pattern in the first alternative that is refutable
911915
(i.e. not a variable or inaccessible pattern). We want to handle these first

tests/lean/run/match_nat.lean

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
def f : Nat → Nat
2+
| 0 => 0
3+
| 10 => 1
4+
| 100 => 2
5+
| _ => 3
6+
7+
8+
/--
9+
info: def f.match_1.{u_1} : (motive : Nat → Sort u_1) →
10+
(x : Nat) → (Unit → motive 0) → (Unit → motive 10) → (Unit → motive 100) → ((x : Nat) → motive x) → motive x
11+
-/
12+
#guard_msgs in
13+
#print sig f.match_1
14+
15+
16+
/--
17+
info: private def f.match_1.splitter.{u_1} : (motive : Nat → Sort u_1) →
18+
(x : Nat) →
19+
(Unit → motive 0) →
20+
(Unit → motive 10) →
21+
(Unit → motive 100) → ((x : Nat) → (x = 0 → False) → (x = 10 → False) → (x = 100 → False) → motive x) → motive x
22+
-/
23+
#guard_msgs in
24+
#print sig f.match_1.splitter
25+
26+
/--
27+
info: private theorem f.match_1.eq_4.{u_1} : ∀ (motive : Nat → Sort u_1) (x : Nat) (h_1 : Unit → motive 0)
28+
(h_2 : Unit → motive 10) (h_3 : Unit → motive 100) (h_4 : (x : Nat) → motive x),
29+
(x = 0 → False) →
30+
(x = 10 → False) →
31+
(x = 100 → False) →
32+
(match x with
33+
| 0 => h_1 ()
34+
| 10 => h_2 ()
35+
| 100 => h_3 ()
36+
| x => h_4 x) =
37+
h_4 x
38+
-/
39+
#guard_msgs in
40+
#print sig f.match_1.eq_4

0 commit comments

Comments
 (0)