diff --git a/src/Lean/Elab/PreDefinition/Structural/Eqns.lean b/src/Lean/Elab/PreDefinition/Structural/Eqns.lean index 26400a74dda9..cb8ef3c12d43 100644 --- a/src/Lean/Elab/PreDefinition/Structural/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/Structural/Eqns.lean @@ -8,10 +8,12 @@ module prelude public import Lean.Elab.PreDefinition.FixedParams import Lean.Elab.PreDefinition.EqnsUtils -import Lean.Meta.Tactic.Split +import Lean.Meta.Tactic.CasesOnStuckLHS +import Lean.Meta.Tactic.Delta import Lean.Meta.Tactic.Simp.Main import Lean.Meta.Tactic.Delta import Lean.Meta.Tactic.CasesOnStuckLHS +import Lean.Meta.Tactic.Split namespace Lean.Elab open Meta diff --git a/src/Lean/Meta/Constructions/CasesOnSameCtor.lean b/src/Lean/Meta/Constructions/CasesOnSameCtor.lean index 07ecbfd4de64..8e0d41e32237 100644 --- a/src/Lean/Meta/Constructions/CasesOnSameCtor.lean +++ b/src/Lean/Meta/Constructions/CasesOnSameCtor.lean @@ -213,7 +213,9 @@ public def mkCasesOnSameCtor (declName : Name) (indName : Name) : MetaM Unit := numDiscrs := info.numIndices + 3 altInfos uElimPos? := some 0 - discrInfos := #[{}, {}, {}]} + discrInfos := #[{}, {}, {}] + overlaps := {} + } -- Compare attributes with `mkMatcherAuxDefinition` withExporting (isExporting := !isPrivateName declName) do diff --git a/src/Lean/Meta/IndPredBelow.lean b/src/Lean/Meta/IndPredBelow.lean index 94bda8ef3754..30a9f9794d62 100644 --- a/src/Lean/Meta/IndPredBelow.lean +++ b/src/Lean/Meta/IndPredBelow.lean @@ -319,7 +319,7 @@ public partial def mkBelowMatcher (matcherApp : MatcherApp) (belowParams : Array (ctx : RecursionContext) (transformAlt : RecursionContext → Expr → MetaM Expr) : MetaM (Option (Expr × MetaM Unit)) := withTraceNode `Meta.IndPredBelow.match (return m!"{exceptEmoji ·} {matcherApp.toExpr} and {belowParams}") do - let mut input ← getMkMatcherInputInContext matcherApp + let mut input ← getMkMatcherInputInContext matcherApp (unfoldNamed := false) let mut discrs := matcherApp.discrs let mut matchTypeAdd := #[] -- #[(discrIdx, ), ...] let mut i := discrs.size diff --git a/src/Lean/Meta/Match/Basic.lean b/src/Lean/Meta/Match/Basic.lean index c1a45fb66ab8..ade3e580163a 100644 --- a/src/Lean/Meta/Match/Basic.lean +++ b/src/Lean/Meta/Match/Basic.lean @@ -150,6 +150,11 @@ structure Alt where After we perform additional case analysis, their types become definitionally equal. -/ cnstrs : List (Expr × Expr) + /-- + Indices of previous alternatives that this alternative expects a not-that-proofs. + (When producing a splitter, and in the future also for source-level overlap hypotheses.) + -/ + notAltIdxs : Array Nat deriving Inhabited namespace Alt diff --git a/src/Lean/Meta/Match/Match.lean b/src/Lean/Meta/Match/Match.lean index f7157f11d094..456a6b0b8f9e 100644 --- a/src/Lean/Meta/Match/Match.lean +++ b/src/Lean/Meta/Match/Match.lean @@ -12,7 +12,11 @@ public import Lean.Meta.GeneralizeTelescope public import Lean.Meta.Match.Basic public import Lean.Meta.Match.MatcherApp.Basic public import Lean.Meta.Match.MVarRenaming +public import Lean.Meta.Match.MVarRenaming +import Lean.Meta.Match.SimpH +import Lean.Meta.Match.SolveOverlap import Lean.Meta.HasNotBit +import Lean.Meta.Match.NamedPatterns public section @@ -92,34 +96,62 @@ where /-- Given a list of `AltLHS`, create a minor premise for each one, convert them into `Alt`, and then execute `k` -/ private def withAlts {α} (motive : Expr) (discrs : Array Expr) (discrInfos : Array DiscrInfo) - (lhss : List AltLHS) (k : List Alt → Array Expr → Array AltParamInfo → MetaM α) : MetaM α := - loop lhss [] #[] #[] + (lhss : List AltLHS) (isSplitter : Option Overlaps) + (k : List Alt → Array Expr → Array AltParamInfo → MetaM α) : MetaM α := + loop lhss [] #[] #[] #[] where - mkMinorType (xs : Array Expr) (lhs : AltLHS) : MetaM Expr := + mkSplitterHyps (idx : Nat) (lhs : AltLHS) (notAlts : Array Expr) : MetaM (Array Expr × Array Nat) := do + withExistingLocalDecls lhs.fvarDecls do + let patterns ← lhs.patterns.toArray.mapM (Pattern.toExpr · (annotate := true)) + let mut hs := #[] + let mut notAltIdxs := #[] + for overlappingIdx in isSplitter.get!.overlapping idx do + let notAlt := notAlts[overlappingIdx]! + let h ← instantiateForall notAlt patterns + if let some h ← simpH? h patterns.size then + notAltIdxs := notAltIdxs.push overlappingIdx + hs := hs.push h + trace[Meta.Match.debug] "hs for {lhs.ref}: {hs}" + return (hs, notAltIdxs) + + mkMinorType (xs : Array Expr) (lhs : AltLHS) (notAltHs : Array Expr): MetaM Expr := withExistingLocalDecls lhs.fvarDecls do let args ← lhs.patterns.toArray.mapM (Pattern.toExpr · (annotate := true)) let minorType := mkAppN motive args withEqs discrs args discrInfos fun eqs => do - mkForallFVars (xs ++ eqs) minorType + let minorType ← mkForallFVars eqs minorType + let minorType ← mkArrowN notAltHs minorType + mkForallFVars xs minorType + + mkNotAlt (xs : Array Expr) (lhs : AltLHS) : MetaM Expr := do + withExistingLocalDecls lhs.fvarDecls do + let mut notAlt := mkConst ``False + for discr in discrs.reverse, pattern in lhs.patterns.reverse do + notAlt ← mkArrow (← mkEqHEq discr (← pattern.toExpr)) notAlt + notAlt ← mkForallFVars (discrs ++ xs) notAlt + return notAlt - loop (lhss : List AltLHS) (alts : List Alt) (minors : Array Expr) (altInfos : Array AltParamInfo) : MetaM α := do + loop (lhss : List AltLHS) (alts : List Alt) (minors : Array Expr) (altInfos : Array AltParamInfo) (notAlts : Array Expr) : MetaM α := do match lhss with | [] => k alts.reverse minors altInfos | lhs::lhss => + let idx := alts.length let xs := lhs.fvarDecls.toArray.map LocalDecl.toExpr - let minorType ← mkMinorType xs lhs - let hasParams := !xs.isEmpty || discrInfos.any fun info => info.hName?.isSome + let (notAltHs, notAltIdxs) ← if isSplitter.isSome then mkSplitterHyps idx lhs notAlts else pure (#[], #[]) + let minorType ← mkMinorType xs lhs notAltHs + let notAlt ← mkNotAlt xs lhs + let hasParams := !xs.isEmpty || !notAltHs.isEmpty || discrInfos.any fun info => info.hName?.isSome let minorType := if hasParams then minorType else mkSimpleThunkType minorType - let idx := alts.length let minorName := (`h).appendIndexAfter (idx+1) trace[Meta.Match.debug] "minor premise {minorName} : {minorType}" withLocalDeclD minorName minorType fun minor => do let rhs := if hasParams then mkAppN minor xs else mkApp minor (mkConst `Unit.unit) let minors := minors.push minor - let altInfos := altInfos.push { numFields := xs.size, numOverlaps := 0, hasUnitThunk := !hasParams } + let altInfos := altInfos.push { numFields := xs.size, numOverlaps := notAltHs.size, hasUnitThunk := !hasParams } let fvarDecls ← lhs.fvarDecls.mapM instantiateLocalDeclMVars - let alts := { ref := lhs.ref, idx := idx, rhs := rhs, fvarDecls := fvarDecls, patterns := lhs.patterns, cnstrs := [] } :: alts - loop lhss alts minors altInfos + let alt := { ref := lhs.ref, idx := idx, rhs := rhs, fvarDecls := fvarDecls, patterns := lhs.patterns, cnstrs := [], notAltIdxs := notAltIdxs } + let alts := alt :: alts + loop lhss alts minors altInfos (notAlts.push notAlt) structure State where /-- Used alternatives -/ @@ -338,7 +370,7 @@ where return (p, (lhs, rhs) :: cnstrs) /-- -Solve pending alternative constraints. +Solve pending alternative constraints and overlap assumptions. If all constraints can be solved perform assignment `mvarId := alt.rhs`, else throw error. -/ private partial def solveCnstrs (mvarId : MVarId) (alt : Alt) : StateRefT State MetaM Unit := do @@ -350,13 +382,19 @@ where | none => let alt ← filterTrivialCnstrs alt if alt.cnstrs.isEmpty then - let eType ← inferType alt.rhs - let targetType ← mvarId.getType - unless (← isDefEqGuarded targetType eType) do - trace[Meta.Match.match] "assignGoalOf failed {eType} =?= {targetType}" - throwErrorAt alt.ref "Dependent elimination failed: Type mismatch when solving this alternative: it {← mkHasTypeButIsExpectedMsg eType targetType}" - mvarId.assign alt.rhs - modify fun s => { s with used := s.used.insert alt.idx } + mvarId.withContext do + let eType ← inferType alt.rhs + let (notAltsMVarIds, _, eType) ← forallMetaBoundedTelescope eType alt.notAltIdxs.size + unless notAltsMVarIds.size = alt.notAltIdxs.size do + throwErrorAt alt.ref "Incorrect number of overlap hypotheses in the right-hand-side, expected {alt.notAltIdxs.size}:{indentExpr eType}" + let targetType ← mvarId.getType + unless (← isDefEqGuarded targetType eType) do + trace[Meta.Match.match] "assignGoalOf failed {eType} =?= {targetType}" + throwErrorAt alt.ref "Dependent elimination failed: Type mismatch when solving this alternative: it {← mkHasTypeButIsExpectedMsg eType targetType}" + for notAltMVarId in notAltsMVarIds do + solveOverlap notAltMVarId.mvarId! + mvarId.assign (mkAppN alt.rhs notAltsMVarIds) + modify fun s => { s with used := s.used.insert alt.idx } else trace[Meta.Match.match] "alt has unsolved cnstrs:\n{← alt.toMessageData}" let mut msg := m!"Dependent match elimination failed: Could not solve constraints" @@ -636,7 +674,7 @@ private def processConstructor (p : Problem) : MetaM (Array Problem) := do | .var _ :: _ => expandVarIntoCtor alt ctorName | .inaccessible _ :: _ => processInaccessibleAsCtor alt ctorName | _ => unreachable! - return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } + return { p with mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } else -- A catch-all case let subst := subgoal.subst @@ -647,7 +685,7 @@ private def processConstructor (p : Problem) : MetaM (Array Problem) := do | .ctor .. :: _ => false | _ => true let newAlts := newAlts.map fun alt => alt.applyFVarSubst subst - return { mvarId := subgoal.mvarId, alts := newAlts, vars := newVars, examples := examples } + return { p with mvarId := subgoal.mvarId, alts := newAlts, vars := newVars, examples := examples } private def processNonVariable (p : Problem) : MetaM Problem := withGoalOf p do let x :: xs := p.vars | unreachable! @@ -708,7 +746,7 @@ private def processValue (p : Problem) : MetaM (Array Problem) := do alt.replaceFVarId fvarId value | _ => unreachable! let newVars := xs.map fun x => x.applyFVarSubst subst - return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } + return { p with mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } else -- else branch for value let newAlts := p.alts.filter isFirstPatternVar @@ -764,7 +802,7 @@ private def processArrayLit (p : Problem) : MetaM (Array Problem) := do let α ← getArrayArgType <| subst.apply x expandVarIntoArrayLit { alt with patterns := ps } fvarId α size | _ => unreachable! - return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } + return { p with mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } else -- else branch let newAlts := p.alts.filter isFirstPatternVar @@ -1018,7 +1056,7 @@ private builtin_initialize matcherExt : EnvExtension (PHashMap MatcherKey Name) /-- Similar to `mkAuxDefinition`, but uses the cache `matcherExt`. It also returns an Boolean that indicates whether a new matcher function was added to the environment or not. -/ -def mkMatcherAuxDefinition (name : Name) (type : Expr) (value : Expr) : MetaM (Expr × Option (MatcherInfo → MetaM Unit)) := do +def mkMatcherAuxDefinition (name : Name) (type : Expr) (value : Expr) (isSplitter : Bool) : MetaM (Expr × Option (MatcherInfo → MetaM Unit)) := do trace[Meta.Match.debug] "{name} : {type} := {value}" let compile := bootstrap.genMatcherCode.get (← getOptions) let result ← Closure.mkValueTypeClosure type value (zetaDelta := false) @@ -1026,10 +1064,12 @@ def mkMatcherAuxDefinition (name : Name) (type : Expr) (value : Expr) : MetaM (E let mkMatcherConst name := mkAppN (mkConst name result.levelArgs.toList) result.exprArgs let key := { value := result.value, compile, isPrivate := env.header.isModule && isPrivateName name } - let mut nameNew? := (matcherExt.getState env).find? key - if nameNew?.isNone && key.isPrivate then - -- private contexts may reuse public matchers - nameNew? := (matcherExt.getState env).find? { key with isPrivate := false } + let mut nameNew? := none + unless isSplitter do + nameNew? := (matcherExt.getState env).find? key + if nameNew?.isNone && key.isPrivate then + -- private contexts may reuse public matchers + nameNew? := (matcherExt.getState env).find? { key with isPrivate := false } match nameNew? with | some nameNew => return (mkMatcherConst nameNew, none) | none => @@ -1040,8 +1080,9 @@ def mkMatcherAuxDefinition (name : Name) (type : Expr) (value : Expr) : MetaM (E -- matcher bodies should always be exported, if not private anyway withExporting do addDecl decl - modifyEnv fun env => matcherExt.modifyState env fun s => s.insert key name - addMatcherInfo name mi + unless isSplitter do + modifyEnv fun env => matcherExt.modifyState env fun s => s.insert key name + addMatcherInfo name mi setInlineAttribute name enableRealizationsForConst name if compile then @@ -1053,6 +1094,7 @@ structure MkMatcherInput where matchType : Expr discrInfos : Array DiscrInfo lhss : List AltLHS + isSplitter : Option Overlaps := none def MkMatcherInput.numDiscrs (m : MkMatcherInput) := m.discrInfos.size @@ -1093,7 +1135,7 @@ The generated matcher has the structure described at `MatcherInfo`. The motive a where `v` is a universe parameter or 0 if `B[a_1, ..., a_n]` is a proposition. -/ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor input do - let ⟨matcherName, matchType, discrInfos, lhss⟩ := input + let {matcherName, matchType, discrInfos, lhss, isSplitter} := input let numDiscrs := discrInfos.size checkNumPatterns numDiscrs lhss forallBoundedTelescope matchType numDiscrs fun discrs matchTypeBody => do @@ -1116,7 +1158,7 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor | negSucc n => succ n ``` which is defined **before** `Int.decLt` -/ - let (matcher, addMatcher) ← mkMatcherAuxDefinition matcherName type val + let (matcher, addMatcher) ← mkMatcherAuxDefinition matcherName type val (isSplitter := input.isSplitter.isSome) trace[Meta.Match.debug] "matcher levels: {matcher.getAppFn.constLevels!}, uElim: {uElimGen}" let uElimPos? ← getUElimPos? matcher.getAppFn.constLevels! uElimGen discard <| isLevelDefEq uElimGen uElim @@ -1152,7 +1194,7 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor let isEqMask ← eqs.mapM fun eq => return (← inferType eq).isEq return (mvarType, isEqMask) trace[Meta.Match.debug] "target: {mvarType}" - withAlts motive discrs discrInfos lhss fun alts minors altInfos => do + withAlts motive discrs discrInfos lhss isSplitter fun alts minors altInfos => do let mvar ← mkFreshExprMVar mvarType trace[Meta.Match.debug] "goal\n{mvar.mvarId!}" let examples := discrs'.toList.map fun discr => Example.var discr.fvarId! @@ -1176,7 +1218,7 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor else let mvarType := mkAppN motive discrs trace[Meta.Match.debug] "target: {mvarType}" - withAlts motive discrs discrInfos lhss fun alts minors altInfos => do + withAlts motive discrs discrInfos lhss isSplitter fun alts minors altInfos => do let mvar ← mkFreshExprMVar mvarType let examples := discrs.toList.map fun discr => Example.var discr.fvarId! let (_, s) ← (process { mvarId := mvar.mvarId!, vars := discrs.toList, alts := alts, examples := examples }).run {} @@ -1185,7 +1227,7 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor let val ← mkLambdaFVars args mvar mkMatcher type val altInfos s -def getMkMatcherInputInContext (matcherApp : MatcherApp) : MetaM MkMatcherInput := do +def getMkMatcherInputInContext (matcherApp : MatcherApp) (unfoldNamed : Bool) : MetaM MkMatcherInput := do let matcherName := matcherApp.matcherName let some matcherInfo ← getMatcherInfo? matcherName | throwError "Internal error during match expression elaboration: Could not find a matcher named `{matcherName}`" @@ -1204,6 +1246,7 @@ def getMkMatcherInputInContext (matcherApp : MatcherApp) : MetaM MkMatcherInput let lhss ← forallBoundedTelescope matcherType (some matcherApp.alts.size) fun alts _ => alts.mapM fun alt => do let ty ← inferType alt + let ty ← if unfoldNamed then unfoldNamedPattern ty else pure ty forallTelescope ty fun xs body => do let xs ← xs.filterM fun x => dependsOn body x.fvarId! body.withApp fun _ args => do @@ -1217,18 +1260,17 @@ def getMkMatcherInputInContext (matcherApp : MatcherApp) : MetaM MkMatcherInput return { matcherName, matchType, discrInfos := matcherInfo.discrInfos, lhss := lhss.toList } -/-- This function is only used for testing purposes -/ -def withMkMatcherInput (matcherName : Name) (k : MkMatcherInput → MetaM α) : MetaM α := do +def withMkMatcherInput (matcherName : Name) (unfoldNamed : Bool) (k : MkMatcherInput → MetaM α) : MetaM α := do let some matcherInfo ← getMatcherInfo? matcherName - | throwError "Internal error during match expression elaboration: Could not find a matcher named `{matcherName}`" + | throwError "withMkMatcherInput: {.ofConstName matcherName} is not a matcher" let matcherConst ← getConstInfo matcherName - forallBoundedTelescope matcherConst.type (some matcherInfo.arity) fun xs _ => do - let matcherApp ← mkConstWithLevelParams matcherConst.name - let matcherApp := mkAppN matcherApp xs - let some matcherApp ← matchMatcherApp? matcherApp - | throwError "Internal error during match expression elaboration: Could not find a matcher app named `{matcherApp}`" - let mkMatcherInput ← getMkMatcherInputInContext matcherApp - k mkMatcherInput + forallBoundedTelescope matcherConst.type matcherInfo.arity fun xs _ => do + let matcherApp ← mkConstWithLevelParams matcherConst.name + let matcherApp := mkAppN matcherApp xs + let some matcherApp ← matchMatcherApp? matcherApp + | throwError "withMkMatcherInput: {.ofConstName matcherName} does not produce a matcher application" + let mkMatcherInput ← getMkMatcherInputInContext matcherApp unfoldNamed + k mkMatcherInput end Match diff --git a/src/Lean/Meta/Match/MatchEqs.lean b/src/Lean/Meta/Match/MatchEqs.lean index 85607998ef23..bd1666f338a8 100644 --- a/src/Lean/Meta/Match/MatchEqs.lean +++ b/src/Lean/Meta/Match/MatchEqs.lean @@ -110,220 +110,6 @@ where (throwError "failed to generate equality theorems for `match` expression `{matchDeclName}`\n{MessageData.ofGoal mvarId}") subgoals.forM (go · (depth+1)) - -/-- Construct new local declarations `xs` with types `altTypes`, and then execute `f xs` -/ -private partial def withSplitterAlts (altTypes : Array Expr) (f : Array Expr → MetaM α) : MetaM α := do - let rec go (i : Nat) (xs : Array Expr) : MetaM α := do - if h : i < altTypes.size then - let hName := (`h).appendIndexAfter (i+1) - withLocalDeclD hName altTypes[i] fun x => - go (i+1) (xs.push x) - else - f xs - go 0 #[] - -private abbrev ConvertM := ReaderT (FVarIdMap (Expr × AltParamInfo × Array Bool)) $ StateRefT (Array MVarId) MetaM - -/-- - Construct a proof for the splitter generated by `mkEquationsFor`. - The proof uses the definition of the `match`-declaration as a template (argument `template`). - - `alts` are free variables corresponding to alternatives of the `match` auxiliary declaration being processed. - - `altNews` are the new free variables which contains additional hypotheses that ensure they are only used - when the previous overlapping alternatives are not applicable. - - `altInfos` refers to the splitter -/ -private partial def mkSplitterProof (matchDeclName : Name) (template : Expr) (alts altsNew : Array Expr) - (altInfos : Array AltParamInfo) (altArgMasks : Array (Array Bool)) : MetaM Expr := do - trace[Meta.Match.matchEqs] "proof template: {template}" - let map := mkMap - let (proof, mvarIds) ← convertTemplate template |>.run map |>.run #[] - trace[Meta.Match.matchEqs] "splitter proof: {proof}" - for mvarId in mvarIds do - let mvarId ← mvarId.tryClearMany (alts.map (·.fvarId!)) - solveOverlap mvarId - instantiateMVars proof -where - mkMap : FVarIdMap (Expr × AltParamInfo × Array Bool) := Id.run do - let mut m := {} - for alt in alts, altNew in altsNew, altInfo in altInfos, argMask in altArgMasks do - m := m.insert alt.fvarId! (altNew, altInfo, argMask) - return m - - trimFalseTrail (argMask : Array Bool) : Array Bool := - if argMask.isEmpty then - argMask - else if !argMask.back! then - trimFalseTrail argMask.pop - else - argMask - - /-- - Auxiliary function used at `convertTemplate` to decide whether to use `convertCastEqRec`. - See `convertCastEqRec`. -/ - isCastEqRec (e : Expr) : ConvertM Bool := do - -- TODO: we do not handle `Eq.rec` since we never found an example that needed it. - -- If we find one we must extend `convertCastEqRec`. - unless e.isAppOf ``Eq.ndrec do return false - unless e.getAppNumArgs > 6 do return false - for arg in e.getAppArgs[6...*] do - if arg.isFVar && (← read).contains arg.fvarId! then - return true - return true - - /-- - Auxiliary function used at `convertTemplate`. It is needed when the auxiliary `match` declaration had to refine the type of its - minor premises during dependent pattern match. For an example, consider - ``` - inductive Foo : Nat → Type _ - | nil : Foo 0 - | cons (t: Foo l): Foo l - - def Foo.bar (t₁: Foo l₁): Foo l₂ → Bool - | cons s₁ => t₁.bar s₁ - | _ => false - attribute [simp] Foo.bar - ``` - The auxiliary `Foo.bar.match_1` is of the form - ``` - def Foo.bar.match_1.{u_1} : {l₂ : Nat} → - (t₂ : Foo l₂) → - (motive : Foo l₂ → Sort u_1) → - (t₂ : Foo l₂) → ((s₁ : Foo l₂) → motive (Foo.cons s₁)) → ((x : Foo l₂) → motive x) → motive t₂ := - fun {l₂} t₂ motive t₂_1 h_1 h_2 => - (fun t₂_2 => - Foo.casesOn (motive := fun a x => l₂ = a → t₂_1 ≍ x → motive t₂_1) t₂_2 - (fun h => - Eq.ndrec (motive := fun {l₂} => - (t₂ t₂ : Foo l₂) → - (motive : Foo l₂ → Sort u_1) → - ((s₁ : Foo l₂) → motive (Foo.cons s₁)) → ((x : Foo l₂) → motive x) → t₂ ≍ Foo.nil → motive t₂) - (fun t₂ t₂ motive h_1 h_2 h => Eq.symm (eq_of_heq h) ▸ h_2 Foo.nil) (Eq.symm h) t₂ t₂_1 motive h_1 h_2) --- HERE - fun {l} t h => - Eq.ndrec (motive := fun {l} => (t : Foo l) → t₂_1 ≍ Foo.cons t → motive t₂_1) - (fun t h => Eq.symm (eq_of_heq h) ▸ h_1 t) h t) - t₂_1 (Eq.refl l₂) (HEq.refl t₂_1) - ``` - The `HERE` comment marks the place where the type of `Foo.bar.match_1` minor premises `h_1` and `h_2` is being "refined" - using `Eq.ndrec`. - - This function will adjust the motive and minor premise of the `Eq.ndrec` to reflect the new minor premises used in the - corresponding splitter theorem. - - We may have to extend this function to handle `Eq.rec` too. - - This function was added to address issue #1179 - -/ - convertCastEqRec (e : Expr) : ConvertM Expr := do - assert! (← isCastEqRec e) - e.withApp fun f args => do - let mut argsNew := args - let mut isAlt := #[] - for i in 6...args.size do - let arg := argsNew[i]! - if arg.isFVar then - match (← read).get? arg.fvarId! with - | some (altNew, _, _) => - argsNew := argsNew.set! i altNew - trace[Meta.Match.matchEqs] "arg: {arg} : {← inferType arg}, altNew: {altNew} : {← inferType altNew}" - isAlt := isAlt.push true - | none => - argsNew := argsNew.set! i (← convertTemplate arg) - isAlt := isAlt.push false - else - argsNew := argsNew.set! i (← convertTemplate arg) - isAlt := isAlt.push false - assert! isAlt.size == args.size - 6 - let rhs := args[4]! - let motive := args[2]! - -- Construct new motive using the splitter theorem minor premise types. - let motiveNew ← lambdaTelescope motive fun motiveArgs body => do - unless motiveArgs.size == 1 do - throwError "unexpected `Eq.ndrec` motive while creating splitter/eliminator theorem for `{matchDeclName}`, expected lambda with 1 binder{indentExpr motive}" - let x := motiveArgs[0]! - forallTelescopeReducing body fun motiveTypeArgs resultType => do - unless motiveTypeArgs.size >= isAlt.size do - throwError "unexpected `Eq.ndrec` motive while creating splitter/eliminator theorem for `{matchDeclName}`, expected arrow with at least #{isAlt.size} binders{indentExpr body}" - let rec go (i : Nat) (motiveTypeArgsNew : Array Expr) : ConvertM Expr := do - assert! motiveTypeArgsNew.size == i - if h : i < motiveTypeArgs.size then - let motiveTypeArg := motiveTypeArgs[i] - if i < isAlt.size && isAlt[i]! then - let altNew := argsNew[6+i]! -- Recall that `Eq.ndrec` has 6 arguments - let altTypeNew ← inferType altNew - trace[Meta.Match.matchEqs] "altNew: {altNew} : {altTypeNew}" - -- Replace `rhs` with `x` (the lambda binder in the motive) - let mut altTypeNewAbst := (← kabstract altTypeNew rhs).instantiate1 x - -- Replace args[6...(6+i)] with `motiveTypeArgsNew` - for j in *...i do - altTypeNewAbst := (← kabstract altTypeNewAbst argsNew[6+j]!).instantiate1 motiveTypeArgsNew[j]! - let localDecl ← motiveTypeArg.fvarId!.getDecl - withLocalDecl localDecl.userName localDecl.binderInfo altTypeNewAbst fun motiveTypeArgNew => - go (i+1) (motiveTypeArgsNew.push motiveTypeArgNew) - else - go (i+1) (motiveTypeArgsNew.push motiveTypeArg) - else - mkLambdaFVars motiveArgs (← mkForallFVars motiveTypeArgsNew resultType) - go 0 #[] - trace[Meta.Match.matchEqs] "new motive: {motiveNew}" - unless (← isTypeCorrect motiveNew) do - throwError "failed to construct new type correct motive for `Eq.ndrec` while creating splitter/eliminator theorem for `{matchDeclName}`{indentExpr motiveNew}" - argsNew := argsNew.set! 2 motiveNew - -- Construct the new minor premise for the `Eq.ndrec` application. - -- First, we use `eqRecNewPrefix` to infer the new minor premise binders for `Eq.ndrec` - let eqRecNewPrefix := mkAppN f argsNew[*...3] -- `Eq.ndrec` minor premise is the fourth argument. - let .forallE _ minorTypeNew .. ← whnf (← inferType eqRecNewPrefix) | unreachable! - trace[Meta.Match.matchEqs] "new minor type: {minorTypeNew}" - let minor := args[3]! - let minorNew ← forallBoundedTelescope minorTypeNew isAlt.size fun minorArgsNew _ => do - let mut minorBodyNew := minor - -- We have to extend the mapping to make sure `convertTemplate` can "fix" occurrences of the refined minor premises - let mut m ← read - for h : i in *...isAlt.size do - if isAlt[i] then - -- `convertTemplate` will correct occurrences of the alternative - let alt := args[6+i]! -- Recall that `Eq.ndrec` has 6 arguments - let some (_, numParams, argMask) := m.get? alt.fvarId! | unreachable! - -- We add a new entry to `m` to make sure `convertTemplate` will correct the occurrences of the alternative - m := m.insert minorArgsNew[i]!.fvarId! (minorArgsNew[i]!, numParams, argMask) - unless minorBodyNew.isLambda do - throwError "unexpected `Eq.ndrec` minor premise while creating splitter/eliminator theorem for `{matchDeclName}`, expected lambda with at least #{isAlt.size} binders{indentExpr minor}" - minorBodyNew := minorBodyNew.bindingBody! - minorBodyNew := minorBodyNew.instantiateRev minorArgsNew - trace[Meta.Match.matchEqs] "minor premise new body before convertTemplate:{indentExpr minorBodyNew}" - minorBodyNew ← withReader (fun _ => m) <| convertTemplate minorBodyNew - trace[Meta.Match.matchEqs] "minor premise new body after convertTemplate:{indentExpr minorBodyNew}" - mkLambdaFVars minorArgsNew minorBodyNew - unless (← isTypeCorrect minorNew) do - throwError "failed to construct new type correct minor premise for `Eq.ndrec` while creating splitter/eliminator theorem for `{matchDeclName}`{indentExpr minorNew}" - argsNew := argsNew.set! 3 minorNew - -- trace[Meta.Match.matchEqs] "argsNew: {argsNew}" - trace[Meta.Match.matchEqs] "found cast target {e}" - return mkAppN f argsNew - - convertTemplate (e : Expr) : ConvertM Expr := - transform e fun e => do - if (← isCastEqRec e) then - return .done (← convertCastEqRec e) - else - let Expr.fvar fvarId .. := e.getAppFn | return .continue - let some (altNew, altParamInfo, argMask) := (← read).get? fvarId | return .continue - trace[Meta.Match.matchEqs] ">> argMask: {argMask}, altParamInfo: {repr altParamInfo}, e: {e}, alsNew: {altNew}, " - if altParamInfo.hasUnitThunk then - let eNew := mkApp altNew (mkConst ``Unit.unit) - return TransformStep.done eNew - let mut newArgs := #[] - let argMask := trimFalseTrail argMask - unless e.getAppNumArgs ≥ argMask.size do - throwError "unexpected occurrence of `match`-expression alternative (aka minor premise) while creating splitter/eliminator theorem for `{matchDeclName}`, minor premise is partially applied{indentExpr e}\npossible solution if you are matching on inductive families: add its indices as additional discriminants" - for arg in e.getAppArgs, includeArg in argMask do - if includeArg then - newArgs := newArgs.push arg - let eNew := mkAppN altNew newArgs - let (mvars, _, _) ← forallMetaBoundedTelescope (← inferType eNew) altParamInfo.numOverlaps (kind := MetavarKind.syntheticOpaque) - modify fun s => s ++ (mvars.map (·.mvarId!)) - let eNew := mkAppN eNew mvars - return TransformStep.done eNew - - /-- Create new alternatives (aka minor premises) by replacing `discrs` with `patterns` at `alts`. Recall that `alts` depends on `discrs` when `numDiscrEqs > 0`, where `numDiscrEqs` is the number of discriminants @@ -364,13 +150,15 @@ def getEquationsForImpl (matchDeclName : Name) : MetaM MatchEqns := do -- `realizeConst` as well as for looking up the resultant environment extension state via -- `getState`. realizeConst matchDeclName splitterName (go baseName splitterName) - return matchEqnsExt.getState (asyncMode := .async .asyncEnv) (asyncDecl := splitterName) (← getEnv) |>.map.find! matchDeclName + match matchEqnsExt.getState (asyncMode := .async .asyncEnv) (asyncDecl := splitterName) (← getEnv) |>.map.find? matchDeclName with + | some eqns => return eqns + | none => throwError "failed to retrieve match equations for `{matchDeclName}` after realization" where go baseName splitterName := withConfig (fun c => { c with etaStruct := .none }) do let constInfo ← getConstInfo matchDeclName let us := constInfo.levelParams.map mkLevelParam let some matchInfo ← getMatcherInfo? matchDeclName | throwError "`{matchDeclName}` is not a matcher function" let numDiscrEqs := getNumEqsFromDiscrInfos matchInfo.discrInfos - forallTelescopeReducing constInfo.type fun xs matchResultType => do + forallTelescopeReducing constInfo.type fun xs _matchResultType => do let mut eqnNames := #[] let params := xs[*...matchInfo.numParams] let motive := xs[matchInfo.getMotivePos]! @@ -379,16 +167,15 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no let discrs := xs[firstDiscrIdx...(firstDiscrIdx + matchInfo.numDiscrs)] let mut notAlts := #[] let mut idx := 1 - let mut splitterAltTypes := #[] let mut splitterAltInfos := #[] let mut altArgMasks := #[] -- masks produced by `forallAltTelescope` for i in *...alts.size do let altInfo := matchInfo.altInfos[i]! let thmName := Name.str baseName eqnThmSuffixBase |>.appendIndexAfter idx eqnNames := eqnNames.push thmName - let (notAlt, splitterAltType, splitterAltInfo, argMask) ← + let (notAlt, splitterAltInfo, argMask) ← forallAltTelescope (← inferType alts[i]!) altInfo numDiscrEqs - fun ys eqs rhsArgs argMask altResultType => do + fun ys _eqs rhsArgs argMask altResultType => do let patterns := altResultType.getAppArgs let mut hs := #[] for overlappedBy in matchInfo.overlaps.overlapping i do @@ -397,15 +184,7 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no if let some h ← simpH? h patterns.size then hs := hs.push h trace[Meta.Match.matchEqs] "hs: {hs}" - let splitterAltType ← mkForallFVars eqs altResultType - let splitterAltType ← mkArrowN hs splitterAltType - let splitterAltType ← mkForallFVars ys splitterAltType - let hasUnitThunk := splitterAltType == altResultType - let splitterAltType ← if hasUnitThunk then - mkArrow (mkConst ``Unit) splitterAltType - else - pure splitterAltType - let splitterAltType ← unfoldNamedPattern splitterAltType + let hasUnitThunk := ys.isEmpty && hs.isEmpty && numDiscrEqs = 0 let splitterAltInfo := { numFields := ys.size, numOverlaps := hs.size, hasUnitThunk } -- Create a proposition for representing terms that do not match `patterns` let mut notAlt := mkConst ``False @@ -429,38 +208,38 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no type := thmType value := thmVal } - return (notAlt, splitterAltType, splitterAltInfo, argMask) + return (notAlt, splitterAltInfo, argMask) notAlts := notAlts.push notAlt - splitterAltTypes := splitterAltTypes.push splitterAltType splitterAltInfos := splitterAltInfos.push splitterAltInfo altArgMasks := altArgMasks.push argMask - trace[Meta.Match.matchEqs] "splitterAltType: {splitterAltType}" idx := idx + 1 - -- Define splitter with conditional/refined alternatives - withSplitterAlts splitterAltTypes fun altsNew => do - let splitterParams := params.toArray ++ #[motive] ++ discrs.toArray ++ altsNew - let splitterType ← mkForallFVars splitterParams matchResultType - trace[Meta.Match.matchEqs] "splitterType: {splitterType}" - let splitterVal ← - if (← isDefEq splitterType constInfo.type) then - pure <| mkConst constInfo.name us - else - let template := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ discrs ++ alts) - let template ← deltaExpand template (· == constInfo.name) - let template := template.headBeta - mkLambdaFVars splitterParams (← mkSplitterProof matchDeclName template alts altsNew splitterAltInfos altArgMasks) + let splitterMatchInfo : MatcherInfo := { matchInfo with altInfos := splitterAltInfos } + + let needsSplitter := !matchInfo.overlaps.isEmpty || (constInfo.type.find? (isNamedPattern )).isSome + + if needsSplitter then + withMkMatcherInput matchDeclName (unfoldNamed := true) fun matcherInput => do + let matcherInput := { matcherInput with + matcherName := splitterName + isSplitter := some matchInfo.overlaps + } + let res ← Match.mkMatcher matcherInput + res.addMatcher -- TODO: Do not set matcherinfo for the splitter! + else + assert! matchInfo.altInfos == splitterAltInfos + -- This match statement does not need a splitter, we can use itself for that. + -- (We still have to generate a declaration to satisfy the realizable constant) addAndCompile <| Declaration.defnDecl { name := splitterName levelParams := constInfo.levelParams - type := splitterType - value := splitterVal + type := constInfo.type + value := mkConst matchDeclName us hints := .abbrev safety := .safe } setInlineAttribute splitterName - let splitterMatchInfo := { matchInfo with altInfos := splitterAltInfos } - let result := { eqnNames, splitterName, splitterMatchInfo } - registerMatchEqns matchDeclName result + let result := { eqnNames, splitterName, splitterMatchInfo } + registerMatchEqns matchDeclName result /- We generate the equations and splitter on demand, and do not save them on .olean files. -/ builtin_initialize matchCongrEqnsExt : EnvExtension (PHashMap Name (Array Name)) ← diff --git a/src/Lean/Meta/Match/MatcherApp/Basic.lean b/src/Lean/Meta/Match/MatcherApp/Basic.lean index 0f35e5ab7970..119ed082898c 100644 --- a/src/Lean/Meta/Match/MatcherApp/Basic.lean +++ b/src/Lean/Meta/Match/MatcherApp/Basic.lean @@ -67,6 +67,7 @@ def matchMatcherApp? [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (alsoCases matcherName := declName matcherLevels := declLevels.toArray uElimPos?, discrInfos, params, motive, discrs, alts, remaining, altInfos + overlaps := {} -- CasesOn constructor have no overlaps } return none diff --git a/src/Lean/Meta/Match/MatcherInfo.lean b/src/Lean/Meta/Match/MatcherInfo.lean index a1e70946cb1d..5ae8b9026edd 100644 --- a/src/Lean/Meta/Match/MatcherInfo.lean +++ b/src/Lean/Meta/Match/MatcherInfo.lean @@ -23,6 +23,9 @@ structure Overlaps where map : Std.HashMap Nat (Std.TreeSet Nat) := {} deriving Inhabited, Repr +def Overlaps.isEmpty (o : Overlaps) : Bool := + o.map.isEmpty + def Overlaps.insert (o : Overlaps) (overlapping overlapped : Nat) : Overlaps where map := o.map.alter overlapped fun s? => some ((s?.getD {}).insert overlapping) @@ -41,29 +44,32 @@ structure AltParamInfo where numOverlaps : Nat /-- Whether this alternatie has an artifcial `Unit` parameter -/ hasUnitThunk : Bool -deriving Inhabited, Repr +deriving Inhabited, Repr, BEq /-- -A "matcher" auxiliary declaration has the following structure: -- `numParams` parameters -- motive -- `numDiscrs` discriminators (aka major premises) -- `altInfos.size` alternatives (aka minor premises) with parameter structure information -- `uElimPos?` is `some pos` when the matcher can eliminate in different universe levels, and - `pos` is the position of the universe level parameter that specifies the elimination universe. - It is `none` if the matcher only eliminates into `Prop`. -- `overlaps` indicates which alternatives may overlap another +Information about the structure of a matcher declaration -/ structure MatcherInfo where + /-- Number of parameters -/ numParams : Nat + /-- Number of discriminants -/ numDiscrs : Nat + /-- Parameter structure information for each alternative -/ altInfos : Array AltParamInfo + /-- + `uElimPos?` is `some pos` when the matcher can eliminate in different universe levels, and + `pos` is the position of the universe level parameter that specifies the elimination universe. + It is `none` if the matcher only eliminates into `Prop`. + -/ uElimPos? : Option Nat /-- - `discrInfos[i] = { hName? := some h }` if the i-th discriminant was annotated with `h :`. + `discrInfos[i] = { hName? := some h }` if the i-th discriminant was annotated with `h :`. -/ discrInfos : Array DiscrInfo - overlaps : Overlaps := {} + /-- + (Conservative approximation of) which alternatives may overlap another. + -/ + overlaps : Overlaps deriving Inhabited, Repr @[expose] def MatcherInfo.numAlts (info : MatcherInfo) : Nat := diff --git a/stage0/src/stdlib_flags.h b/stage0/src/stdlib_flags.h index 79a0e58eddae..f4d5f7ddaece 100644 --- a/stage0/src/stdlib_flags.h +++ b/stage0/src/stdlib_flags.h @@ -1,5 +1,7 @@ #include "util/options.h" +// please update stage0 + namespace lean { options get_default_options() { options opts; diff --git a/tests/lean/run/casesOnSameCtor.lean b/tests/lean/run/casesOnSameCtor.lean index 2be962b6ba93..8c9aba0fac24 100644 --- a/tests/lean/run/casesOnSameCtor.lean +++ b/tests/lean/run/casesOnSameCtor.lean @@ -38,8 +38,8 @@ info: Vec.match_on_same_ctor.{u_1, u} {α : Type u} /-- info: Vec.match_on_same_ctor.splitter.{u_1, u} {α : Type u} {motive : {a : Nat} → (t t_1 : Vec α a) → t.ctorIdx = t_1.ctorIdx → Sort u_1} {a✝ : Nat} (t t✝ : Vec α a✝) - (h : t.ctorIdx = t✝.ctorIdx) (h_1 : Unit → motive nil nil ⋯) - (h_2 : (a : α) → (n : Nat) → (a_1 : Vec α n) → (a' : α) → (a'_1 : Vec α n) → motive (cons a a_1) (cons a' a'_1) ⋯) : + (h : t.ctorIdx = t✝.ctorIdx) (nil : Unit → motive nil nil ⋯) + (cons : (a : α) → {n : Nat} → (a_1 : Vec α n) → (a' : α) → (a'_1 : Vec α n) → motive (cons a a_1) (cons a' a'_1) ⋯) : motive t t✝ h -/ #guard_msgs in diff --git a/tests/lean/run/issue8274.lean b/tests/lean/run/issue8274.lean index 663cf3f308ed..62854f6d65b6 100644 --- a/tests/lean/run/issue8274.lean +++ b/tests/lean/run/issue8274.lean @@ -11,7 +11,7 @@ info: private def myTest.match_1.splitter.{u_1} : (motive : List Bool → Sort u (x : List Bool) → ((x_1 : Bool) → (xs : List Bool) → x = x_1 :: xs → motive (x_1 :: xs)) → (x = [] → motive []) → motive x := fun motive x h_1 h_2 => - List.casesOn (motive := fun x_1 => x = x_1 → motive x_1) x h_2 (fun head tail => h_1 head tail) ⋯ + (fun x_1 => List.casesOn (motive := fun x_2 => x = x_2 → motive x_2) x_1 h_2 fun head tail => h_1 head tail) x ⋯ -/ #guard_msgs in #print myTest.match_1.splitter diff --git a/tests/lean/run/match2.lean b/tests/lean/run/match2.lean index 4e90d65de1f1..65ce1481558b 100644 --- a/tests/lean/run/match2.lean +++ b/tests/lean/run/match2.lean @@ -1,7 +1,9 @@ import Lean +set_option linter.unusedVariables false + def checkWithMkMatcherInput (matcher : Lean.Name) : Lean.MetaM Unit := - Lean.Meta.Match.withMkMatcherInput matcher fun input => do + Lean.Meta.Match.withMkMatcherInput matcher (unfoldNamed := false) fun input => do let res ← Lean.Meta.Match.mkMatcher input let origMatcher ← Lean.getConstInfo matcher if not <| input.matcherName == matcher then diff --git a/tests/lean/run/matchSparse.lean b/tests/lean/run/matchSparse.lean index 729d0b97c714..ff513d0f5618 100644 --- a/tests/lean/run/matchSparse.lean +++ b/tests/lean/run/matchSparse.lean @@ -9,6 +9,25 @@ def simple : Lean.Expr → Bool | .sort _ => true | _ => false +/-- +info: def simple.match_1.{u_1} : (motive : Expr → Sort u_1) → + (x : Expr) → ((u : Level) → motive (sort u)) → ((x : Expr) → motive x) → motive x := +fun motive x h_1 h_2 => simple._sparseCasesOn_1 x (fun u => h_1 u) fun h => h_2 x +-/ +#guard_msgs in +#print simple.match_1 + +-- Check that the splitter re-uses the sparseCasesOn generated for the matcher: + +/-- +info: private def simple.match_1.splitter.{u_1} : (motive : Expr → Sort u_1) → + (x : Expr) → + ((u : Level) → motive (sort u)) → ((x : Expr) → (∀ (u : Level), x = sort u → False) → motive x) → motive x := +fun motive x h_1 h_2 => simple._sparseCasesOn_1 x (fun u => h_1 u) fun h => h_2 x ⋯ +-/ +#guard_msgs in +#print simple.match_1.splitter + def expensive : Lean.Expr → Lean.Expr → Bool | .app (.app (.sort 1) (.sort 1)) (.sort 1), .app (.app (.sort 1) (.sort 1)) (.sort 1) => false | _, _ => true @@ -49,6 +68,7 @@ info: expensive.match_1.splitter.{u_1} (motive : Expr → Expr → Sort u_1) (x -/ #guard_msgs in #check expensive.match_1.splitter + /-- info: expensive.match_1.eq_1.{u_1} (motive : Expr → Expr → Sort u_1) (h_1 : diff --git a/tests/lean/run/split1.lean b/tests/lean/run/split1.lean index d3112573453c..a9f81019f56d 100644 --- a/tests/lean/run/split1.lean +++ b/tests/lean/run/split1.lean @@ -1,3 +1,6 @@ +-- set_option trace.Meta.Match.match true +-- set_option trace.Meta.Match.matchEqs true + def f (xs : List Nat) : Nat := match xs with | [] => 1