Skip to content

Commit 5d50433

Browse files
authored
fix: allow arbitrary sorts in structural recursion over reflexive inductive types (#7639)
This PR changes the generated `below` and `brecOn` implementations for reflexive inductive types to support motives in `Sort u` rather than `Type u`. Closes #7638
1 parent 812bab6 commit 5d50433

File tree

6 files changed

+66
-75
lines changed

6 files changed

+66
-75
lines changed

src/Lean/Elab/PreDefinition/Structural/BRecOn.lean

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -240,14 +240,7 @@ def mkBRecOnConst (recArgInfos : Array RecArgInfo) (positions : Positions)
240240
let indGroup := recArgInfos[0]!.indGroupInst
241241
let motive := motives[0]!
242242
let brecOnUniv ← lambdaTelescope motive fun _ type => getLevel type
243-
let indInfo ← getConstInfoInduct indGroup.all[0]!
244-
let useBInductionOn := indInfo.isReflexive && brecOnUniv == levelZero
245-
let brecOnUniv ←
246-
if indInfo.isReflexive && brecOnUniv != levelZero then
247-
decLevel brecOnUniv
248-
else
249-
pure brecOnUniv
250-
let brecOnCons := fun idx => indGroup.brecOn useBInductionOn brecOnUniv idx
243+
let brecOnCons := fun idx => indGroup.brecOn false brecOnUniv idx
251244
-- Pick one as a prototype
252245
let brecOnAux := brecOnCons 0
253246
-- Infer the type of the packed motive arguments

src/Lean/Elab/PreDefinition/Structural/FindRecArg.lean

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -70,41 +70,38 @@ def getRecArgInfo (fnName : Name) (fixedParamPerm : FixedParamPerm) (xs : Array
7070
throwError "it is a let-binding"
7171
let xType ← whnfD localDecl.type
7272
matchConstInduct xType.getAppFn (fun _ => throwError "its type is not an inductive") fun indInfo us => do
73-
if indInfo.isReflexive && !(← hasConst (mkBInductionOnName indInfo.name)) && !(← isInductivePredicate indInfo.name) then
74-
throwError "its type {indInfo.name} is a reflexive inductive, but {mkBInductionOnName indInfo.name} does not exist and it is not an inductive predicate"
73+
let indArgs : Array Expr := xType.getAppArgs
74+
let indParams : Array Expr := indArgs[0:indInfo.numParams]
75+
let indIndices : Array Expr := indArgs[indInfo.numParams:]
76+
if !indIndices.all Expr.isFVar then
77+
throwError "its type {indInfo.name} is an inductive family and indices are not variables{indentExpr xType}"
78+
else if !indIndices.allDiff then
79+
throwError "its type {indInfo.name} is an inductive family and indices are not pairwise distinct{indentExpr xType}"
7580
else
76-
let indArgs : Array Expr := xType.getAppArgs
77-
let indParams : Array Expr := indArgs[0:indInfo.numParams]
78-
let indIndices : Array Expr := indArgs[indInfo.numParams:]
79-
if !indIndices.all Expr.isFVar then
80-
throwError "its type {indInfo.name} is an inductive family and indices are not variables{indentExpr xType}"
81-
else if !indIndices.allDiff then
82-
throwError "its type {indInfo.name} is an inductive family and indices are not pairwise distinct{indentExpr xType}"
83-
else
84-
let ys := fixedParamPerm.pickVarying xs
85-
match (← hasBadIndexDep? ys indIndices) with
86-
| some (index, y) =>
87-
throwError "its type {indInfo.name} is an inductive family{indentExpr xType}\nand index{indentExpr index}\ndepends on the non index{indentExpr y}"
81+
let ys := fixedParamPerm.pickVarying xs
82+
match (← hasBadIndexDep? ys indIndices) with
83+
| some (index, y) =>
84+
throwError "its type {indInfo.name} is an inductive family{indentExpr xType}\nand index{indentExpr index}\ndepends on the non index{indentExpr y}"
85+
| none =>
86+
match (← hasBadParamDep? ys indParams) with
87+
| some (indParam, y) =>
88+
throwError "its type is an inductive datatype{indentExpr xType}\nand the datatype parameter{indentExpr indParam}\ndepends on the function parameter{indentExpr y}\nwhich is not fixed."
8889
| none =>
89-
match (← hasBadParamDep? ys indParams) with
90-
| some (indParam, y) =>
91-
throwError "its type is an inductive datatype{indentExpr xType}\nand the datatype parameter{indentExpr indParam}\ndepends on the function parameter{indentExpr y}\nwhich is not fixed."
92-
| none =>
93-
let indAll := indInfo.all.toArray
94-
let .some indIdx := indAll.idxOf? indInfo.name | panic! "{indInfo.name} not in {indInfo.all}"
95-
let indicesPos := indIndices.map fun index => match xs.idxOf? index with | some i => i | none => unreachable!
96-
let indGroupInst := {
97-
IndGroupInfo.ofInductiveVal indInfo with
98-
levels := us
99-
params := indParams }
100-
return { fnName := fnName
101-
fixedParamPerm := fixedParamPerm
102-
recArgPos := i
103-
indicesPos := indicesPos
104-
indGroupInst := indGroupInst
105-
indIdx := indIdx }
106-
else
107-
throwError "the index #{i+1} exceeds {xs.size}, the number of parameters"
90+
let indAll := indInfo.all.toArray
91+
let .some indIdx := indAll.idxOf? indInfo.name | panic! "{indInfo.name} not in {indInfo.all}"
92+
let indicesPos := indIndices.map fun index => match xs.idxOf? index with | some i => i | none => unreachable!
93+
let indGroupInst := {
94+
IndGroupInfo.ofInductiveVal indInfo with
95+
levels := us
96+
params := indParams }
97+
return { fnName := fnName
98+
fixedParamPerm := fixedParamPerm
99+
recArgPos := i
100+
indicesPos := indicesPos
101+
indGroupInst := indGroupInst
102+
indIdx := indIdx }
103+
else
104+
throwError "the index #{i+1} exceeds {xs.size}, the number of parameters"
108105

109106
/--
110107
Collects the `RecArgInfos` for one function, and returns a report for why the others were not

src/Lean/Meta/Constructions/BRecOn.lean

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,6 @@ import Lean.Meta.PProdN
1313
namespace Lean
1414
open Meta
1515

16-
/-- Transforms `e : xᵢ → (t₁ ×' t₂)` into `(xᵢ → t₁) ×' (xᵢ → t₂) -/
17-
private def etaPProd (xs : Array Expr) (e : Expr) : MetaM Expr := do
18-
if xs.isEmpty then return e
19-
let r := mkAppN e xs
20-
let r₁ ← mkLambdaFVars xs (← mkPProdFstM r)
21-
let r₂ ← mkLambdaFVars xs (← mkPProdSndM r)
22-
mkPProdMk r₁ r₂
23-
2416
/--
2517
If `minorType` is the type of a minor premies of a recursor, such as
2618
```
@@ -40,7 +32,6 @@ of type
4032
private def buildBelowMinorPremise (rlvl : Level) (motives : Array Expr) (minorType : Expr) : MetaM Expr :=
4133
forallTelescope minorType fun minor_args _ => do go #[] minor_args.toList
4234
where
43-
ibelow := rlvl matches .zero
4435
go (prods : Array Expr) : List Expr → MetaM Expr
4536
| [] => PProdN.pack rlvl prods
4637
| arg::args => do
@@ -50,8 +41,7 @@ where
5041
let name ← arg.fvarId!.getUserName
5142
let type' ← forallTelescope argType fun args _ => mkForallFVars args (.sort rlvl)
5243
withLocalDeclD name type' fun arg' => do
53-
let snd ← mkForallFVars arg_args (mkAppN arg' arg_args)
54-
let e' ← mkPProd argType snd
44+
let e' ← mkForallFVars arg_args <| ← mkPProd arg_type (mkAppN arg' arg_args)
5545
mkLambdaFVars #[arg'] (← go (prods.push e') args)
5646
else
5747
mkLambdaFVars #[arg] (← go prods args)
@@ -86,8 +76,6 @@ private def mkBelowFromRec (recName : Name) (ibelow reflexive : Bool) (nParams :
8676
let refType :=
8777
if ibelow then
8878
recVal.type.instantiateLevelParams [lvlParam] [0]
89-
else if reflexive then
90-
recVal.type.instantiateLevelParams [lvlParam] [lvl.succ]
9179
else
9280
recVal.type
9381

@@ -116,12 +104,9 @@ private def mkBelowFromRec (recName : Name) (ibelow reflexive : Bool) (nParams :
116104
if ibelow then
117105
0
118106
else if reflexive then
119-
if let .max 1 ilvl' := ilvl then
120-
mkLevelMax' (.succ lvl) ilvl'
121-
else
122-
mkLevelMax' (.succ lvl) ilvl
107+
mkLevelMax ilvl lvl
123108
else
124-
mkLevelMax' 1 lvl
109+
mkLevelMax 1 lvl
125110

126111
let mut val := .const recName (rlvl.succ :: lvls)
127112
-- add parameters
@@ -168,8 +153,8 @@ private def mkBelowOrIBelow (indName : Name) (ibelow : Bool) : MetaM Unit := do
168153
let belowName := belowName.appendIndexAfter (i + 1)
169154
mkBelowFromRec recName ibelow indVal.isReflexive indVal.numParams belowName
170155

171-
def mkBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName true
172-
def mkIBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName false
156+
def mkBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName false
157+
def mkIBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName true
173158

174159
/--
175160
If `minorType` is the type of a minor premies of a recursor, such as
@@ -207,8 +192,7 @@ private def buildBRecOnMinorPremise (rlvl : Level) (motives : Array Expr)
207192
let type' ← mkForallFVars arg_args
208193
(← mkPProd arg_type (mkAppN belows[idx]! arg_type_args) )
209194
withLocalDeclD name type' fun arg' => do
210-
let r ← etaPProd arg_args arg'
211-
mkLambdaFVars #[arg'] (← go (prods.push r) args)
195+
mkLambdaFVars #[arg'] (← go (prods.push arg') args)
212196
else
213197
mkLambdaFVars #[arg] (← go prods args)
214198
go #[] minor_args.toList
@@ -251,8 +235,6 @@ private def mkBRecOnFromRec (recName : Name) (ind reflexive : Bool) (nParams : N
251235
let refType :=
252236
if ind then
253237
recVal.type.instantiateLevelParams [lvlParam] [0]
254-
else if reflexive then
255-
recVal.type.instantiateLevelParams [lvlParam] [lvl.succ]
256238
else
257239
recVal.type
258240

@@ -279,12 +261,9 @@ private def mkBRecOnFromRec (recName : Name) (ind reflexive : Bool) (nParams : N
279261
if ind then
280262
0
281263
else if reflexive then
282-
if let .max 1 ilvl' := ilvl then
283-
mkLevelMax' (.succ lvl) ilvl'
284-
else
285-
mkLevelMax' (.succ lvl) ilvl
264+
mkLevelMax ilvl lvl
286265
else
287-
mkLevelMax' 1 lvl
266+
mkLevelMax 1 lvl
288267

289268
-- One `below` for each motive, with the same motive parameters
290269
let blvls := if ind then lvls else lvl::lvls

src/library/suffixes.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ Author: Leonardo de Moura
88

99
namespace lean {
1010
constexpr char const * g_rec = "rec";
11-
constexpr char const * g_brec_on = "brecOn";
12-
constexpr char const * g_binduction_on = "binductionOn";
1311
constexpr char const * g_cases_on = "casesOn";
1412
constexpr char const * g_no_confusion = "noConfusion";
1513
constexpr char const * g_no_confusion_type = "noConfusionType";

tests/lean/run/7638.lean

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
inductive Foo : Type
2+
| mk : Foo → Foo
3+
4+
inductive Bar : Type
5+
| mk : (Unit → Bar) → Bar
6+
7+
def Foo.elim {α : Sort u} : Foo → α
8+
| ⟨foo⟩ => elim foo
9+
termination_by structural foo => foo
10+
11+
def Bar.elim {α : Sort u} : Bar → α
12+
| ⟨bar⟩ => elim (bar ())
13+
termination_by structural bar => bar
14+
15+
inductive StressTest : Type 5
16+
| f (x : Type 4 → StressTest)
17+
| g (x : Type 3 → StressTest)
18+
| h (x : Type 4 → StressTest) (y : Type 3 → StressTest)
19+
20+
def StressTest.elim {α : Sort u} : StressTest → α
21+
| f x => elim (x (Type 3))
22+
| g x => elim (x (Type 2))
23+
| h x _y => elim (x (Type 3))
24+
termination_by structural t => t

tests/lean/run/issue4650.lean

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ inductive Foo1 : Sort (max 1 u) where
44
| intro: (h : Nat → Foo1) → Foo1
55

66
/--
7-
info: Foo1.below.{u_1, u} {motive : Foo1.{u} → Type u_1} (t : Foo1.{u}) : Sort (max (u_1 + 1) u)
7+
info: Foo1.below.{u_1, u} {motive : Foo1.{u} → Sort u_1} (t : Foo1.{u}) : Sort (max (max 1 u) u_1)
88
-/
99
#guard_msgs in
1010
#check Foo1.below
@@ -13,15 +13,15 @@ inductive Foo2 : Sort (max u 1) where
1313
| intro: (h : Nat → Foo2) → Foo2
1414

1515
/--
16-
info: Foo2.below.{u_1, u} {motive : Foo2.{u} → Type u_1} (t : Foo2.{u}) : Sort (max (u_1 + 1) u 1)
16+
info: Foo2.below.{u_1, u} {motive : Foo2.{u} → Sort u_1} (t : Foo2.{u}) : Sort (max (max u 1) u_1)
1717
-/
1818
#guard_msgs in
1919
#check Foo2.below
2020

2121
inductive Foo3 : Sort (u+1) where
2222
| intro: (h : Nat → Foo3) → Foo3
2323

24-
/-- info: Foo3.below.{u_1, u} {motive : Foo3.{u} → Type u_1} (t : Foo3.{u}) : Type (max u_1 u) -/
24+
/-- info: Foo3.below.{u_1, u} {motive : Foo3.{u} → Sort u_1} (t : Foo3.{u}) : Sort (max (u + 1) u_1) -/
2525
#guard_msgs in
2626
#check Foo3.below
2727

0 commit comments

Comments
 (0)