Skip to content

Commit acf1cf5

Browse files
committed
Try to treat casesOn as matchers almost everywhere
1 parent 549a97f commit acf1cf5

File tree

11 files changed

+70
-40
lines changed

11 files changed

+70
-40
lines changed

src/Lean/Compiler/LCNF/ToDecl.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Inline auxiliary `matcher` applications.
4040
partial def inlineMatchers (e : Expr) : CoreM Expr :=
4141
Meta.MetaM.run' <| Meta.transform e fun e => do
4242
let .const declName us := e.getAppFn | return .continue
43-
let some info ← Meta.getMatcherInfo? declName | return .continue
43+
let some info ← Meta.getMatcherInfo? declName (alsoCasesOn := false) | return .continue
4444
let numArgs := e.getAppNumArgs
4545
if numArgs > info.arity then
4646
return .continue

src/Lean/Meta/Match/MatchEqs.lean

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ private def _root_.Lean.MVarId.contradictionQuick (mvarId : MVarId) : MetaM Bool
4646
return false
4747

4848
/--
49-
Helper method for `proveCondEqThm`. Given a goal of the form `C.rec ... xMajor = rhs`,
49+
Helper method for `proveCondEqThm`. Given a goal of the form `C.casesOn ... xMajor ... = rhs`,
5050
apply `cases xMajor`. -/
5151
partial def casesOnStuckLHS (mvarId : MVarId) : MetaM (Array MVarId) := do
5252
let target ← mvarId.getType
@@ -71,14 +71,24 @@ where
7171
else
7272
return none
7373
| none =>
74-
matchConstRec f (fun _ => return none) fun recVal _ => do
75-
if recVal.getMajorIdx >= args.size then
74+
if isCasesOnRecursor (← getEnv) declName then
75+
let some matcherInfo := getMatcherInfoForCasesOn (← getEnv) declName | return none
76+
if matcherInfo.getFirstDiscrPos >= args.size then
7677
return none
77-
let major := args[recVal.getMajorIdx]!.consumeMData
78+
let major := args[matcherInfo.getLastDiscrPos]!.consumeMData
7879
if major.isFVar then
7980
return some major.fvarId!
8081
else
8182
return none
83+
else
84+
matchConstRec f (fun _ => return none) fun recVal _ => do
85+
if recVal.getMajorIdx >= args.size then
86+
return none
87+
let major := args[recVal.getMajorIdx]!.consumeMData
88+
if major.isFVar then
89+
return some major.fvarId!
90+
else
91+
return none
8292

8393
def casesOnStuckLHS? (mvarId : MVarId) : MetaM (Option (Array MVarId)) := do
8494
try casesOnStuckLHS mvarId catch _ => return none
@@ -751,7 +761,7 @@ def getEquationsForImpl (matchDeclName : Name) : MetaM MatchEqns := do
751761
where go baseName splitterName := withConfig (fun c => { c with etaStruct := .none }) do
752762
let constInfo ← getConstInfo matchDeclName
753763
let us := constInfo.levelParams.map mkLevelParam
754-
let some matchInfo ← getMatcherInfo? matchDeclName | throwError "`{matchDeclName}` is not a matcher function"
764+
let some matchInfo ← getMatcherInfo? matchDeclName (alsoCasesOn := true) | throwError "`{matchDeclName}` is not a matcher function"
755765
let numDiscrEqs := getNumEqsFromDiscrInfos matchInfo.discrInfos
756766
forallTelescopeReducing constInfo.type fun xs matchResultType => do
757767
let mut eqnNames := #[]
@@ -867,7 +877,7 @@ where go baseName := withConfig (fun c => { c with etaStruct := .none }) do
867877
withConfig (fun c => { c with etaStruct := .none }) do
868878
let constInfo ← getConstInfo matchDeclName
869879
let us := constInfo.levelParams.map mkLevelParam
870-
let some matchInfo ← getMatcherInfo? matchDeclName | throwError "`{matchDeclName}` is not a matcher function"
880+
let some matchInfo ← getMatcherInfo? matchDeclName (alsoCasesOn := true) | throwError "`{matchDeclName}` is not a matcher function"
871881
let numDiscrEqs := matchInfo.getNumDiscrEqs
872882
forallTelescopeReducing constInfo.type fun xs _matchResultType => do
873883
let mut eqnNames := #[]

src/Lean/Meta/Match/MatcherInfo.lean

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ def MatcherInfo.arity (info : MatcherInfo) : Nat :=
4646
def MatcherInfo.getFirstDiscrPos (info : MatcherInfo) : Nat :=
4747
info.numParams + 1
4848

49+
def MatcherInfo.getLastDiscrPos (info : MatcherInfo) : Nat :=
50+
info.numParams + 1 + info.numDiscrs - 1
51+
4952
def MatcherInfo.getDiscrRange (info : MatcherInfo) : Std.Rco Nat :=
5053
info.getFirstDiscrPos...(info.getFirstDiscrPos + info.numDiscrs)
5154

@@ -89,7 +92,7 @@ builtin_initialize extension : SimplePersistentEnvExtension Entry State ←
8992
asyncMode := .async .mainEnv
9093
exportEntriesFnEx? := some fun env _ entries _ =>
9194
-- Do not export info for private defs
92-
entries.filter (env.contains (skipRealize := false) ·.name) |>.toArray
95+
entries.filter (env.contains (skipRealize := true) ·.name) |>.toArray
9396
}
9497

9598
def addMatcherInfo (env : Environment) (matcherName : Name) (info : MatcherInfo) : Environment :=
@@ -128,36 +131,36 @@ def getMatcherInfoForCasesOn (env : Environment) (declName : Name) : Option Matc
128131
altNumParams
129132
}
130133

131-
def getMatcherInfoCore? (env : Environment) (declName : Name) (alsoCasesOn : Bool := false) : Option MatcherInfo :=
134+
def getMatcherInfoCore? (env : Environment) (declName : Name) (alsoCasesOn : Bool := true) : Option MatcherInfo :=
132135
if alsoCasesOn && isCasesOnRecursor env declName then
133136
getMatcherInfoForCasesOn env declName
134137
else
135138
Match.Extension.getMatcherInfo? env declName
136139

137-
def getMatcherInfo? [Monad m] [MonadEnv m] (declName : Name) (alsoCasesOn := false) : m (Option MatcherInfo) :=
140+
def getMatcherInfo? [Monad m] [MonadEnv m] (declName : Name) (alsoCasesOn := true) : m (Option MatcherInfo) :=
138141
return getMatcherInfoCore? (← getEnv) declName (alsoCasesOn := alsoCasesOn)
139142

140143
@[export lean_is_matcher]
141-
def isMatcherCore (env : Environment) (declName : Name) : Bool :=
142-
getMatcherInfoCore? env declName |>.isSome
144+
def isMatcherCore (env : Environment) (declName : Name) (alsoCasesOn := true) : Bool :=
145+
getMatcherInfoCore? env declName (alsoCasesOn := alsoCasesOn) |>.isSome
143146

144-
def isMatcher [Monad m] [MonadEnv m] (declName : Name) : m Bool :=
145-
return isMatcherCore (← getEnv) declName
147+
def isMatcher [Monad m] [MonadEnv m] (declName : Name) (alsoCasesOn := true) : m Bool :=
148+
return isMatcherCore (← getEnv) declName (alsoCasesOn := alsoCasesOn)
146149

147-
def isMatcherAppCore? (env : Environment) (e : Expr) : Option MatcherInfo :=
150+
def isMatcherAppCore? (env : Environment) (e : Expr) (alsoCasesOn := true) : Option MatcherInfo :=
148151
let fn := e.getAppFn
149152
if fn.isConst then
150-
if let some matcherInfo := getMatcherInfoCore? env fn.constName! then
153+
if let some matcherInfo := getMatcherInfoCore? env fn.constName! (alsoCasesOn := alsoCasesOn) then
151154
if e.getAppNumArgs ≥ matcherInfo.arity then some matcherInfo else none
152155
else
153156
none
154157
else
155158
none
156159

157-
def isMatcherAppCore (env : Environment) (e : Expr) : Bool :=
158-
isMatcherAppCore? env e |>.isSome
160+
def isMatcherAppCore (env : Environment) (e : Expr) (alsoCasesOn := true) : Bool :=
161+
isMatcherAppCore? env e (alsoCasesOn := alsoCasesOn) |>.isSome
159162

160-
def isMatcherApp [Monad m] [MonadEnv m] (e : Expr) : m Bool :=
161-
return isMatcherAppCore (← getEnv) e
163+
def isMatcherApp [Monad m] [MonadEnv m] (e : Expr) (alsoCasesOn := true) : m Bool :=
164+
return isMatcherAppCore (← getEnv) e (alsoCasesOn := alsoCasesOn)
162165

163166
end Lean.Meta

src/Lean/Meta/Tactic/Split.lean

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ private partial def generalizeMatchDiscrs (mvarId : MVarId) (matcherDeclName : N
150150
if discrs.all (·.isFVar) then
151151
trace[split.debug] "no need to generalize discriminants, all are fvars"
152152
return (discrs.map (·.fvarId!), #[], mvarId)
153-
let some matcherInfo ← getMatcherInfo? matcherDeclName | unreachable!
153+
let some matcherInfo ← getMatcherInfo? matcherDeclName (alsoCasesOn := true) | unreachable!
154154
let numDiscrEqs := matcherInfo.getNumDiscrEqs -- Number of `h : discr = pattern` equations
155155
let (targetNew, rfls) ←
156156
forallTelescope motiveType fun discrVars _ =>
@@ -249,7 +249,7 @@ private def substDiscrEqs (mvarId : MVarId) (fvarSubst : FVarSubst) (discrEqs :
249249
return mvarId
250250

251251
def applyMatchSplitter (mvarId : MVarId) (matcherDeclName : Name) (us : Array Level) (params : Array Expr) (discrs : Array Expr) : MetaM (List MVarId) := do
252-
let some info ← getMatcherInfo? matcherDeclName | throwInternalMisuseError m!"Internal error in `split` tactic: `{matcherDeclName}` is not an auxiliary declaration used to encode `match`-expressions"
252+
let some info ← getMatcherInfo? matcherDeclName (alsoCasesOn := true) | throwInternalMisuseError m!"Internal error in `split` tactic: `{matcherDeclName}` is not an auxiliary declaration used to encode `match`-expressions"
253253
let matchEqns ← Match.getEquationsFor matcherDeclName
254254
-- splitterPre does not have the correct universe elimination level, but this is fine, we only use it to compute the `motiveType`,
255255
-- and we only care about the `motiveType` arguments, and not the resulting `Sort u`.
@@ -302,7 +302,7 @@ def throwDiscrGenError (e : Expr) : MetaM α :=
302302
throwError (mkDiscrGenErrorMsg e)
303303

304304
def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := mvarId.withContext do
305-
let some app ← matchMatcherApp? e | throwInternalMisuseError m!"Internal error in `split` tactic: Match application expected{indentExpr e}"
305+
let some app ← matchMatcherApp? e (alsoCasesOn := true) | throwInternalMisuseError m!"Internal error in `split` tactic: Match application expected{indentExpr e}"
306306
let matchEqns ← Match.getEquationsFor app.matcherName
307307
let mvarIds ← applyMatchSplitter mvarId app.matcherName app.matcherLevels app.params app.discrs
308308
let (_, mvarIds) ← mvarIds.foldlM (init := (0, [])) fun (i, mvarIds) mvarId => do
@@ -315,7 +315,7 @@ end Split
315315
open Split
316316

317317
/--
318-
Splits an `if-then-else` of `match`-expression in the goal target.
318+
Splits an `if-then-else` or `match`-expression in the goal target.
319319
If `useNewSemantics` is `true`, the flag `backward.split` is ignored. Recall this flag only affects the split of `if-then-else` expressions.
320320
-/
321321
partial def splitTarget? (mvarId : MVarId) (splitIte := true) (useNewSemantics := false) : MetaM (Option (List MVarId)) := commitWhenSome? do mvarId.withContext do

src/Lean/Meta/Tactic/SplitIf.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ private def isCandidate? (env : Environment) (ctx : Context) (e : Expr) : Option
4848
if numArgs >= 5 && !(e.getArg! 1 5).hasLooseBVars then
4949
return ret (e.getBoundedAppFn (numArgs - 5))
5050
if ctx.kind.considerMatch then
51-
if let some info := isMatcherAppCore? env e then
51+
if let some info := isMatcherAppCore? env e (alsoCasesOn := true) then
5252
let args := e.getAppArgs
5353
for i in info.getFirstDiscrPos...(info.getFirstDiscrPos + info.numDiscrs) do
5454
if args[i]!.hasLooseBVars then

src/Lean/Meta/WHNF.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ private def whnfMatcher (e : Expr) : MetaM Expr := do
528528
def reduceMatcher? (e : Expr) : MetaM ReduceMatcherResult := do
529529
let .const declName declLevels := e.getAppFn
530530
| return .notMatcher
531-
let some info ← getMatcherInfo? declName
531+
let some info ← getMatcherInfo? declName (alsoCasesOn := true)
532532
| return .notMatcher
533533
let args := e.getAppArgs
534534
let prefixSz := info.numParams + 1 + info.numDiscrs
@@ -870,7 +870,7 @@ mutual
870870
recordUnfold fInfo.name
871871
return some r
872872
| _ =>
873-
if (← getMatcherInfo? fInfo.name).isSome then
873+
if (← getMatcherInfo? fInfo.name (alsoCasesOn := false)).isSome then
874874
-- Recall that `whnfCore` tries to reduce "matcher" applications.
875875
return none
876876
else

tests/lean/run/diagRec.lean

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ termination_by structural n
88
info: 573147844013817084101
99
---
1010
trace: [diag] Diagnostics
11-
[reduction] unfolded declarations (max: 596, num: 2):
12-
[reduction] Nat.rec ↦ 596
11+
[reduction] unfolded declarations (max: 995, num: 2):
12+
[reduction] Nat.rec ↦ 995
1313
[reduction] HAdd.hAdd ↦ 196
14-
[reduction] unfolded reducible declarations (max: 397, num: 1):
15-
[reduction] Nat.casesOn397
14+
[reduction] unfolded reducible declarations (max: 399, num: 1):
15+
[reduction] Nat.below399
1616
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
1717
-/
1818
#guard_msgs in

tests/lean/run/diagnostics.lean

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ trace: [diag] Diagnostics
1919
[reduction] OfNat.ofNat ↦ 5
2020
[reduction] unfolded instances (max: 5, num: 1):
2121
[reduction] instOfNatNat ↦ 5
22-
[reduction] unfolded reducible declarations (max: 15, num: 1):
23-
[reduction] Nat.casesOn ↦ 15
2422
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
2523
-/
2624
#guard_msgs in
@@ -40,8 +38,6 @@ trace: [diag] Diagnostics
4038
[reduction] OfNat.ofNat ↦ 5
4139
[reduction] unfolded instances (max: 5, num: 1):
4240
[reduction] instOfNatNat ↦ 5
43-
[reduction] unfolded reducible declarations (max: 15, num: 1):
44-
[reduction] Nat.casesOn ↦ 15
4541
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
4642
-/
4743
#guard_msgs in

tests/lean/run/grind_indexmap_trace.lean

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,7 @@ example (m : IndexMap α β) (a : α) (h : a ∈ m) :
147147
info: Try this:
148148
[apply]
149149
instantiate only [= mem_indices_of_mem, insert]
150-
instantiate only [= getElem?_neg, = getElem?_pos, =_ HashMap.contains_iff_mem]
151-
instantiate only [=_ HashMap.contains_iff_mem]
150+
instantiate only [=_ HashMap.contains_iff_mem, = getElem?_neg, = getElem?_pos]
152151
cases #4ed2
153152
next =>
154153
cases #ffdf
@@ -181,7 +180,8 @@ example (m : IndexMap α β) (a a' : α) (b : β) :
181180
info: Try this:
182181
[apply]
183182
instantiate only [= mem_indices_of_mem, insert]
184-
instantiate only [=_ HashMap.contains_iff_mem, = getElem?_neg, = getElem?_pos]
183+
instantiate only [= getElem?_neg, = getElem?_pos, =_ HashMap.contains_iff_mem]
184+
instantiate only [=_ HashMap.contains_iff_mem]
185185
cases #4ed2
186186
next =>
187187
cases #ffdf

tests/lean/run/issue10876.lean

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,25 @@ fun x =>
2929
-/
3030
#guard_msgs in
3131
#print bar
32+
33+
-- Does `split` work?
34+
35+
-- set_option trace.split.debug true
36+
37+
/--
38+
trace: case h_1
39+
t✝ : Bool
40+
⊢ false = false
41+
42+
case h_2
43+
t✝ : Bool
44+
⊢ true = true
45+
-/
46+
#guard_msgs in
47+
theorem splitTest (b : Bool) : b = b.casesOn false true := by
48+
split
49+
trace_state
50+
· rfl
51+
· rfl
52+
53+
-- #print splitTest

0 commit comments

Comments
 (0)