@@ -10,6 +10,7 @@ public import Lean.Meta.Match.Match
1010public import Lean.Meta.Match.MatchEqsExt
1111public import Lean.Meta.Tactic.Refl
1212public import Lean.Meta.Tactic.Delta
13+ import Lean.Meta.Tactic.CasesOnStuckLHS
1314import Lean.Meta.Tactic.SplitIf
1415import Lean.Meta.Match.SimpH
1516import Lean.Meta.Match.SolveOverlap
@@ -18,53 +19,8 @@ public section
1819
1920namespace Lean.Meta
2021
21- /--
22- Helper method for `proveCondEqThm`. Given a goal of the form `C.rec ... xMajor = rhs`,
23- apply `cases xMajor`. -/
24- partial def casesOnStuckLHS (mvarId : MVarId) : MetaM (Array MVarId) := do
25- let target ← mvarId.getType
26- if let some (_, lhs, _) ← matchEq? target then
27- if let some fvarId ← findFVar? lhs then
28- return (← mvarId.cases fvarId).map fun s => s.mvarId
29- throwError "'casesOnStuckLHS' failed"
30- where
31- findFVar? (e : Expr) : MetaM (Option FVarId) := do
32- match e.getAppFn with
33- | Expr.proj _ _ e => findFVar? e
34- | f =>
35- if !f.isConst then
36- return none
37- else
38- let declName := f.constName!
39- let args := e.getAppArgs
40- match (← getProjectionFnInfo? declName) with
41- | some projInfo =>
42- if projInfo.numParams < args.size then
43- findFVar? args[projInfo.numParams]!
44- else
45- return none
46- | none =>
47- matchConstRec f (fun _ => return none) fun recVal _ => do
48- if recVal.getMajorIdx >= args.size then
49- return none
50- let major := args[recVal.getMajorIdx]!.consumeMData
51- if major.isFVar then
52- return some major.fvarId!
53- else
54- return none
55-
56- def casesOnStuckLHS? (mvarId : MVarId) : MetaM (Option (Array MVarId)) := do
57- try casesOnStuckLHS mvarId catch _ => return none
58-
5922namespace Match
6023
61- def unfoldNamedPattern (e : Expr) : MetaM Expr := do
62- let visit (e : Expr) : MetaM TransformStep := do
63- if let some e := isNamedPattern? e then
64- if let some eNew ← unfoldDefinition? e then
65- return TransformStep.visit eNew
66- return .continue
67- Meta.transform e (pre := visit)
6824
6925/--
7026 Similar to `forallTelescopeReducing`, but
13288
13389
13490/--
135- Extension of `forallAltTelescope ` that continues further:
91+ Extension of `forallAltVarsTelescope ` that continues further:
13692
13793 Equality parameters associated with the `h : discr` notation are replaced with `rfl` proofs.
13894 Recall that this kind of parameter always occurs after the parameters corresponding to pattern
@@ -262,18 +218,6 @@ where
262218 (throwError "failed to generate equality theorems for `match` expression `{matchDeclName}`\n {MessageData.ofGoal mvarId}" )
263219 subgoals.forM (go · (depth+1 ))
264220
265-
266- /-- Construct new local declarations `xs` with types `altTypes`, and then execute `f xs` -/
267- private partial def withSplitterAlts (altTypes : Array Expr) (f : Array Expr → MetaM α) : MetaM α := do
268- let rec go (i : Nat) (xs : Array Expr) : MetaM α := do
269- if h : i < altTypes.size then
270- let hName := (`h).appendIndexAfter (i+1 )
271- withLocalDeclD hName altTypes[i] fun x =>
272- go (i+1 ) (xs.push x)
273- else
274- f xs
275- go 0 #[]
276-
277221/--
278222 Create new alternatives (aka minor premises) by replacing `discrs` with `patterns` at `alts`.
279223 Recall that `alts` depends on `discrs` when `numDiscrEqs > 0`, where `numDiscrEqs` is the number of discriminants
@@ -322,7 +266,7 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
322266 let us := constInfo.levelParams.map mkLevelParam
323267 let some matchInfo ← getMatcherInfo? matchDeclName | throwError "`{matchDeclName}` is not a matcher function"
324268 let numDiscrEqs := getNumEqsFromDiscrInfos matchInfo.discrInfos
325- forallTelescopeReducing constInfo.type fun xs matchResultType => do
269+ forallTelescopeReducing constInfo.type fun xs _matchResultType => do
326270 let mut eqnNames := #[]
327271 let params := xs[*...matchInfo.numParams]
328272 let motive := xs[matchInfo.getMotivePos]!
@@ -331,16 +275,15 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
331275 let discrs := xs[firstDiscrIdx...(firstDiscrIdx + matchInfo.numDiscrs)]
332276 let mut notAlts := #[]
333277 let mut idx := 1
334- let mut splitterAltTypes := #[]
335278 let mut splitterAltInfos := #[]
336279 let mut altArgMasks := #[] -- masks produced by `forallAltTelescope`
337280 for i in *...alts.size do
338281 let altInfo := matchInfo.altInfos[i]!
339282 let thmName := Name.str baseName eqnThmSuffixBase |>.appendIndexAfter idx
340283 eqnNames := eqnNames.push thmName
341- let (notAlt, splitterAltType, splitterAltInfo, argMask) ←
284+ let (notAlt, splitterAltInfo, argMask) ←
342285 forallAltTelescope (← inferType alts[i]!) altInfo numDiscrEqs
343- fun ys eqs rhsArgs argMask altResultType => do
286+ fun ys _eqs rhsArgs argMask altResultType => do
344287 let patterns := altResultType.getAppArgs
345288 let mut hs := #[]
346289 for overlappedBy in matchInfo.overlaps.overlapping i do
@@ -349,15 +292,7 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
349292 if let some h ← simpH? h patterns.size then
350293 hs := hs.push h
351294 trace[Meta.Match.matchEqs] "hs: {hs}"
352- let splitterAltType ← mkForallFVars eqs altResultType
353- let splitterAltType ← mkArrowN hs splitterAltType
354- let splitterAltType ← mkForallFVars ys splitterAltType
355- let hasUnitThunk := splitterAltType == altResultType
356- let splitterAltType ← if hasUnitThunk then
357- mkArrow (mkConst ``Unit) splitterAltType
358- else
359- pure splitterAltType
360- let splitterAltType ← unfoldNamedPattern splitterAltType
295+ let hasUnitThunk := ys.isEmpty && hs.isEmpty && numDiscrEqs = 0
361296 let splitterAltInfo := { numFields := ys.size, numOverlaps := hs.size, hasUnitThunk }
362297 -- Create a proposition for representing terms that do not match `patterns`
363298 let mut notAlt := mkConst ``False
@@ -381,31 +316,23 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
381316 type := thmType
382317 value := thmVal
383318 }
384- return (notAlt, splitterAltType, splitterAltInfo, argMask)
319+ return (notAlt, splitterAltInfo, argMask)
385320 notAlts := notAlts.push notAlt
386- splitterAltTypes := splitterAltTypes.push splitterAltType
387321 splitterAltInfos := splitterAltInfos.push splitterAltInfo
388322 altArgMasks := altArgMasks.push argMask
389- trace[Meta.Match.matchEqs] "splitterAltType: {splitterAltType}"
390323 idx := idx + 1
391324 let splitterMatchInfo : MatcherInfo := { matchInfo with altInfos := splitterAltInfos }
392325
393326 let needsSplitter := !matchInfo.overlaps.isEmpty || (constInfo.type.find? (isNamedPattern )).isSome
394327
395328 if needsSplitter then
396- -- Define splitter with conditional/refined alternatives
397- withSplitterAlts splitterAltTypes fun altsNew => do
398- let splitterParams := params.toArray ++ #[motive] ++ discrs.toArray ++ altsNew
399- let splitterType ← mkForallFVars splitterParams matchResultType
400- trace[Meta.Match.matchEqs] "splitterType: {splitterType}"
401-
402- withMkMatcherInput matchDeclName (unfoldNamed := true ) fun matcherInput => do
403- let matcherInput := { matcherInput with
404- matcherName := splitterName
405- isSplitter := some matchInfo.overlaps
406- }
407- let res ← Match.mkMatcher matcherInput
408- res.addMatcher -- TODO: Do not set matcherinfo for the splitter!
329+ withMkMatcherInput matchDeclName (unfoldNamed := true ) fun matcherInput => do
330+ let matcherInput := { matcherInput with
331+ matcherName := splitterName
332+ isSplitter := some matchInfo.overlaps
333+ }
334+ let res ← Match.mkMatcher matcherInput
335+ res.addMatcher -- TODO: Do not set matcherinfo for the splitter!
409336 else
410337 assert! matchInfo.altInfos == splitterAltInfos
411338 -- This match statement does not need a splitter, we can use itself for that.
0 commit comments