Skip to content

Commit 56a00e6

Browse files
nomeataalgebraic-dev
authored andcommitted
perf: when matching on values, avoid generating hyps when not needed (#11508)
This PR avoids generating hyps when not needed (i.e. if there is a catch-all so no completeness checking needed) during matching on values. This tweak was made possible by #11220.
1 parent e84a2af commit 56a00e6

File tree

4 files changed

+99
-42
lines changed

4 files changed

+99
-42
lines changed

src/Lean/Meta/Match/CaseValues.lean

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,14 @@ import Lean.Meta.Tactic.Subst
1212

1313
namespace Lean.Meta
1414

15-
structure CaseValueSubgoal where
16-
mvarId : MVarId
17-
newH : FVarId
18-
deriving Inhabited
19-
2015
/--
21-
Split goal `... |- C x` into two subgoals
22-
`..., (h : x = value) |- C x`
23-
`..., (h : x != value) |- C x`
24-
where `fvarId` is `x`s id.
16+
Split goal `... |- C x`,, where `fvarId` is `x`s id, into two subgoals
17+
`..., |- (h : x = value) → C x`
18+
`..., |- (h : x != value) → C x`
2519
The type of `x` must have decidable equality.
2620
-/
2721
def caseValue (mvarId : MVarId) (fvarId : FVarId) (value : Expr) (hName : Name := `h)
28-
: MetaM (CaseValueSubgoal × CaseValueSubgoal) :=
22+
: MetaM (MVarId × MVarId) :=
2923
mvarId.withContext do
3024
let tag ← mvarId.getTag
3125
mvarId.checkNotAssigned `caseValue
@@ -38,15 +32,7 @@ def caseValue (mvarId : MVarId) (fvarId : FVarId) (value : Expr) (hName : Name :
3832
let elseMVar ← mkFreshExprSyntheticOpaqueMVar elseTarget tag
3933
let val ← mkAppOptM `dite #[none, xEqValue, none, thenMVar, elseMVar]
4034
mvarId.assign val
41-
let (elseH, elseMVarId) ← elseMVar.mvarId!.intro1P
42-
let elseSubgoal := { mvarId := elseMVarId, newH := elseH }
43-
let (thenH, thenMVarId) ← thenMVar.mvarId!.intro1P
44-
thenMVarId.withContext do
45-
trace[Meta] "searching for decl"
46-
let _ ← thenH.getDecl
47-
trace[Meta] "found decl"
48-
let thenSubgoal := { mvarId := thenMVarId, newH := thenH }
49-
pure (thenSubgoal, elseSubgoal)
35+
return (thenMVar.mvarId!, elseMVar.mvarId!)
5036

5137
public structure CaseValuesSubgoal where
5238
mvarId : MVarId
@@ -55,34 +41,44 @@ public structure CaseValuesSubgoal where
5541
deriving Inhabited
5642

5743
/--
58-
Split goal `... |- C x` into values.size + 1 subgoals
59-
1) `..., (h_1 : x = value[0]) |- C value[0]`
44+
Split goal `... |- C x`, where `fvarId` is `x`s id, into `values.size + 1` subgoals
45+
1) `..., (h_1 : x = value[0]) |- C value[0]`
6046
...
61-
n) `..., (h_n : x = value[n - 1]) |- C value[n - 1]`
47+
n) `..., (h_n : x = value[n - 1]) |- C value[n - 1]`
6248
n+1) `..., (h_1 : x != value[0]) ... (h_n : x != value[n-1]) |- C x`
6349
where `n = values.size`
64-
where `fvarId` is `x`s id.
6550
The type of `x` must have decidable equality.
6651
6752
Remark: the last subgoal is for the "else" catchall case, and its `subst` is `{}`.
6853
Remark: the field `newHs` has size 1 forall but the last subgoal.
6954
70-
If `substNewEqs = true`, then the new `h_i` equality hypotheses are substituted in the first `n` cases.
55+
If `needsHyps = false` then the else case comes without hypotheses.
7156
-/
72-
public def caseValues (mvarId : MVarId) (fvarId : FVarId) (values : Array Expr) (hNamePrefix := `h) : MetaM (Array CaseValuesSubgoal) :=
57+
public def caseValues (mvarId : MVarId) (fvarId : FVarId) (values : Array Expr) (hNamePrefix := `h)
58+
(needHyps := true) : MetaM (Array CaseValuesSubgoal) :=
7359
let rec loop : Nat → MVarId → List Expr → Array FVarId → Array CaseValuesSubgoal → MetaM (Array CaseValuesSubgoal)
7460
| _, mvarId, [], _, _ => throwTacticEx `caseValues mvarId "list of values must not be empty"
7561
| i, mvarId, v::vs, hs, subgoals => do
76-
let (thenSubgoal, elseSubgoal) ← caseValue mvarId fvarId v (hNamePrefix.appendIndexAfter i)
77-
appendTagSuffix thenSubgoal.mvarId ((`case).appendIndexAfter i)
78-
let thenMVarId ← thenSubgoal.mvarId.tryClearMany hs
79-
let (subst, mvarId) ← substCore thenMVarId thenSubgoal.newH (symm := false) {} (clearH := true)
80-
let subgoals := subgoals.push { mvarId := mvarId, newHs := #[], subst := subst }
62+
let (thenMVarId, elseMVarId) ← caseValue mvarId fvarId v (hNamePrefix.appendIndexAfter i)
63+
appendTagSuffix thenMVarId ((`case).appendIndexAfter i)
64+
let thenMVarId ← thenMVarId.tryClearMany hs
65+
let (thenH, thenMVarId) ← thenMVarId.intro1P
66+
let (subst, thenMVarId) ← substCore thenMVarId thenH (symm := false) {} (clearH := true)
67+
let subgoals := subgoals.push { mvarId := thenMVarId, newHs := #[], subst := subst }
68+
let (hs', elseMVarId) ←
69+
if needHyps then
70+
let (elseH, elseMVarId) ← elseMVarId.intro1P
71+
pure (hs.push elseH, elseMVarId)
72+
else
73+
let elseMVarId ← elseMVarId.intro1_
74+
pure (hs, elseMVarId)
8175
match vs with
8276
| [] => do
83-
appendTagSuffix elseSubgoal.mvarId ((`case).appendIndexAfter (i+1))
84-
pure $ subgoals.push { mvarId := elseSubgoal.mvarId, newHs := hs.push elseSubgoal.newH, subst := {} }
85-
| vs => loop (i+1) elseSubgoal.mvarId vs (hs.push elseSubgoal.newH) subgoals
77+
appendTagSuffix elseMVarId ((`case).appendIndexAfter (i+1))
78+
pure $ subgoals.push { mvarId := elseMVarId, newHs := hs', subst := {} }
79+
| vs =>
80+
loop (i+1) elseMVarId vs hs' subgoals
81+
8682
loop 1 mvarId values.toList #[] #[]
8783

8884
end Lean.Meta

src/Lean/Meta/Match/Match.lean

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -722,11 +722,23 @@ private def isFirstPatternVar (alt : Alt) : Bool :=
722722
| .var _ :: _ => true
723723
| _ => false
724724

725+
private def Pattern.isRefutable : Pattern → Bool
726+
| .var _ => false
727+
| .inaccessible _ => false
728+
| .as _ p _ => p.isRefutable
729+
| .arrayLit .. => true
730+
| .ctor .. => true
731+
| .val .. => true
732+
733+
private def triviallyComplete (p : Problem) : Bool :=
734+
!p.alts.isEmpty && p.alts.getLast!.patterns.all (!·.isRefutable)
735+
725736
private def processValue (p : Problem) : MetaM (Array Problem) := do
726737
trace[Meta.Match.match] "value step"
727738
let x :: xs := p.vars | unreachable!
728739
let values := collectValues p
729-
let subgoals ← caseValues p.mvarId x.fvarId! values
740+
let needHyps := !triviallyComplete p || p.alts.any (!·.notAltIdxs.isEmpty)
741+
let subgoals ← caseValues p.mvarId x.fvarId! values (needHyps := needHyps)
730742
subgoals.mapIdxM fun i subgoal => do
731743
trace[Meta.Match.match] "processValue subgoal\n{MessageData.ofGoal subgoal.mvarId}"
732744
if h : i < values.size then
@@ -900,14 +912,6 @@ private def moveToFront (p : Problem) (i : Nat) : Problem :=
900912
else
901913
p
902914

903-
def Pattern.isRefutable : Pattern → Bool
904-
| .var _ => false
905-
| .inaccessible _ => false
906-
| .as _ p _ => p.isRefutable
907-
| .arrayLit .. => true
908-
| .ctor .. => true
909-
| .val .. => true
910-
911915
/--
912916
Returns the index of the first pattern in the first alternative that is refutable
913917
(i.e. not a variable or inaccessible pattern). We want to handle these first

src/Lean/Meta/Tactic/Intro.lean

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,23 @@ does not start with a forall, lambda or let. -/
161161
abbrev _root_.Lean.MVarId.intro1P (mvarId : MVarId) : MetaM (FVarId × MVarId) :=
162162
intro1Core mvarId true
163163

164+
/--
165+
Given a goal `... |- β → α`, returns a goal `... ⊢ α`.
166+
Like `intro h; clear h`, but without ever appending to the local context.
167+
-/
168+
def _root_.Lean.MVarId.intro1_ (mvarId : MVarId) : MetaM MVarId := do
169+
mvarId.withContext do
170+
let target ← mvarId.getType'
171+
match target with
172+
| .forallE n β α bi =>
173+
if α.hasLooseBVars then
174+
throwError "intro1_: expected arrow type\n{mvarId}"
175+
let tag ← mvarId.getTag
176+
let newMVar ← mkFreshExprSyntheticOpaqueMVar α tag
177+
mvarId.assign (.lam n β newMVar bi)
178+
return newMVar.mvarId!
179+
| _ => throwError "intro1_: expected arrow type\n{mvarId}"
180+
164181
/--
165182
Calculate the number of new hypotheses that would be created by `intros`,
166183
i.e. the number of binders which can be introduced without unfolding definitions.

tests/lean/run/match_nat.lean

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
def f : Nat → Nat
2+
| 0 => 0
3+
| 10 => 1
4+
| 100 => 2
5+
| _ => 3
6+
7+
8+
/--
9+
info: def f.match_1.{u_1} : (motive : Nat → Sort u_1) →
10+
(x : Nat) → (Unit → motive 0) → (Unit → motive 10) → (Unit → motive 100) → ((x : Nat) → motive x) → motive x
11+
-/
12+
#guard_msgs in
13+
#print sig f.match_1
14+
15+
16+
/--
17+
info: private def f.match_1.splitter.{u_1} : (motive : Nat → Sort u_1) →
18+
(x : Nat) →
19+
(Unit → motive 0) →
20+
(Unit → motive 10) →
21+
(Unit → motive 100) → ((x : Nat) → (x = 0 → False) → (x = 10 → False) → (x = 100 → False) → motive x) → motive x
22+
-/
23+
#guard_msgs in
24+
#print sig f.match_1.splitter
25+
26+
/--
27+
info: private theorem f.match_1.eq_4.{u_1} : ∀ (motive : Nat → Sort u_1) (x : Nat) (h_1 : Unit → motive 0)
28+
(h_2 : Unit → motive 10) (h_3 : Unit → motive 100) (h_4 : (x : Nat) → motive x),
29+
(x = 0 → False) →
30+
(x = 10 → False) →
31+
(x = 100 → False) →
32+
(match x with
33+
| 0 => h_1 ()
34+
| 10 => h_2 ()
35+
| 100 => h_3 ()
36+
| x => h_4 x) =
37+
h_4 x
38+
-/
39+
#guard_msgs in
40+
#print sig f.match_1.eq_4

0 commit comments

Comments
 (0)