|
| 1 | +/- |
| 2 | +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. |
| 3 | +Released under Apache 2.0 license as described in the file LICENSE. |
| 4 | +Authors: Joachim Breitner |
| 5 | +-/ |
| 6 | +prelude |
| 7 | +import Lean.AddDecl |
| 8 | +import Lean.Meta.AppBuilder |
| 9 | +import Lean.Meta.CompletionName |
| 10 | + |
| 11 | +/-! |
| 12 | +This module produces a construction for the `noConfusionType` that is linear in size in the number of |
| 13 | +constructors of the inductive type. This is in contrast to the previous construction (definde in |
| 14 | +`no_confusion.cpp`), that is quadratic in size due to nested `.brecOn` applications. |
| 15 | +
|
| 16 | +We still use the old construction when processing the prelude, for the few inductives that we need |
| 17 | +until below (`Nat`, `Bool`, `Decidable`). |
| 18 | +
|
| 19 | +The main trick is to use a `withCtor` helper that is like a match with one constructor pattern and |
| 20 | +one catch-all pattern, and thus linear in size. And the helper itself is a single function |
| 21 | +definition, rather than one for each constructor, using a `withCtorType` helper in the function. |
| 22 | +
|
| 23 | +See the `linearNoConfusion.lean` test for exemplary output of this translation (checked to be |
| 24 | +up-to-date). |
| 25 | +
|
| 26 | +The `withCtor` functions could be generally useful, but for that they should probably eliminate into |
| 27 | +`Sort _` rather than `Type _`, and then writing the `withCtorType` function runs into universe level |
| 28 | +confusion, which may be solvable if the kernel had more complete univere level normalization. |
| 29 | +Until then we put these helper in the `noConfusionType` namespace to indicate that they are not |
| 30 | +general purpose. |
| 31 | +
|
| 32 | +This module is written in a rather manual style, constructing the `Expr` directly. It's best |
| 33 | +read with the expected output to the side. |
| 34 | +-/ |
| 35 | + |
| 36 | +namespace Lean.NoConfusionLinear |
| 37 | + |
| 38 | +open Meta |
| 39 | + |
| 40 | +/-- |
| 41 | +List of constants that the linear `noConfusionType` construction depends on. |
| 42 | +-/ |
| 43 | +def deps : Array Lean.Name := |
| 44 | + #[ ``Nat.lt, ``cond, ``Nat, ``PUnit, ``Eq, ``Not, ``dite, ``Nat.decEq, ``Nat.blt ] |
| 45 | + |
| 46 | +def mkNatLookupTable (n : Expr) (es : Array Expr) (default : Expr) : MetaM Expr := do |
| 47 | + let type ← inferType default |
| 48 | + let u ← getLevel type |
| 49 | + let rec go (start stop : Nat) (hstart : start < stop := by omega) (hstop : stop ≤ es.size := by omega) : MetaM Expr := do |
| 50 | + if h : start + 1 = stop then |
| 51 | + return es[start] |
| 52 | + else |
| 53 | + let mid := (start + stop) / 2 |
| 54 | + let low ← go start mid |
| 55 | + let high ← go mid stop |
| 56 | + return mkApp4 (mkConst ``cond [u]) type (mkApp2 (mkConst ``Nat.blt) n (mkRawNatLit mid)) low high |
| 57 | + if h : es.size = 0 then |
| 58 | + pure default |
| 59 | + else |
| 60 | + go 0 es.size |
| 61 | + |
| 62 | +def mkWithCtorTypeName (indName : Name) : Name := |
| 63 | + Name.str indName "noConfusionType" |>.str "withCtorType" |
| 64 | + |
| 65 | +def mkWithCtorName (indName : Name) : Name := |
| 66 | + Name.str indName "noConfusionType" |>.str "withCtor" |
| 67 | + |
| 68 | +def mkNoConfusionTypeName (indName : Name) : Name := |
| 69 | + Name.str indName "noConfusionType" |
| 70 | + |
| 71 | +def mkWithCtorType (indName : Name) : MetaM Unit := do |
| 72 | + let ConstantInfo.inductInfo info ← getConstInfo indName | unreachable! |
| 73 | + let casesOnName := mkCasesOnName indName |
| 74 | + let casesOnInfo ← getConstVal casesOnName |
| 75 | + let v::us := casesOnInfo.levelParams.map mkLevelParam | panic! "unexpected universe levels on `casesOn`" |
| 76 | + let indTyCon := mkConst indName us |
| 77 | + let indTyKind ← inferType indTyCon |
| 78 | + let indLevel ← getLevel indTyKind |
| 79 | + let e ← forallBoundedTelescope indTyKind info.numParams fun xs _ => do |
| 80 | + withLocalDeclD `P (mkSort v.succ) fun P => do |
| 81 | + withLocalDeclD `ctorIdx (mkConst ``Nat) fun ctorIdx => do |
| 82 | + let default ← mkArrow (mkConst ``PUnit [indLevel]) P |
| 83 | + let es ← info.ctors.toArray.mapM fun ctorName => do |
| 84 | + let ctor := mkAppN (mkConst ctorName us) xs |
| 85 | + let ctorType ← inferType ctor |
| 86 | + let argType ← forallTelescope ctorType fun ys _ => |
| 87 | + mkForallFVars ys P |
| 88 | + mkArrow (mkConst ``PUnit [indLevel]) argType |
| 89 | + let e ← mkNatLookupTable ctorIdx es default |
| 90 | + mkLambdaFVars ((xs.push P).push ctorIdx) e |
| 91 | + |
| 92 | + let declName := mkWithCtorTypeName indName |
| 93 | + addAndCompile (.defnDecl (← mkDefinitionValInferrringUnsafe |
| 94 | + (name := declName) |
| 95 | + (levelParams := casesOnInfo.levelParams) |
| 96 | + (type := (← inferType e)) |
| 97 | + (value := e) |
| 98 | + (hints := ReducibilityHints.abbrev) |
| 99 | + )) |
| 100 | + modifyEnv fun env => addToCompletionBlackList env declName |
| 101 | + modifyEnv fun env => addProtected env declName |
| 102 | + setReducibleAttribute declName |
| 103 | + |
| 104 | +def mkWithCtor (indName : Name) : MetaM Unit := do |
| 105 | + let ConstantInfo.inductInfo info ← getConstInfo indName | unreachable! |
| 106 | + let withCtorTypeName := mkWithCtorTypeName indName |
| 107 | + let casesOnName := mkCasesOnName indName |
| 108 | + let casesOnInfo ← getConstVal casesOnName |
| 109 | + let v::us := casesOnInfo.levelParams.map mkLevelParam | panic! "unexpected universe levels on `casesOn`" |
| 110 | + let indTyCon := mkConst indName us |
| 111 | + let indTyKind ← inferType indTyCon |
| 112 | + let indLevel ← getLevel indTyKind |
| 113 | + let e ← forallBoundedTelescope indTyKind info.numParams fun xs t => do |
| 114 | + withLocalDeclD `P (mkSort v.succ) fun P => do |
| 115 | + withLocalDeclD `ctorIdx (mkConst ``Nat) fun ctorIdx => do |
| 116 | + let withCtorTypeNameApp := mkAppN (mkConst withCtorTypeName (v :: us)) (xs.push P) |
| 117 | + let kType := mkApp withCtorTypeNameApp ctorIdx |
| 118 | + withLocalDeclD `k kType fun k => |
| 119 | + withLocalDeclD `k' P fun k' => |
| 120 | + forallBoundedTelescope t info.numIndices fun ys t' => do |
| 121 | + let t' ← whnfD t' |
| 122 | + assert! t'.isSort |
| 123 | + withLocalDeclD `x (mkAppN indTyCon (xs ++ ys)) fun x => do |
| 124 | + let e := mkConst (mkCasesOnName indName) (v.succ :: us) |
| 125 | + let e := mkAppN e xs |
| 126 | + let motive ← mkLambdaFVars (ys.push x) P |
| 127 | + let e := mkApp e motive |
| 128 | + let e := mkAppN e ys |
| 129 | + let e := mkApp e x |
| 130 | + let alts ← info.ctors.toArray.mapIdxM fun i ctorName => do |
| 131 | + let ctor := mkAppN (mkConst ctorName us) xs |
| 132 | + let ctorType ← inferType ctor |
| 133 | + forallTelescope ctorType fun zs _ => do |
| 134 | + let heq := mkApp3 (mkConst ``Eq [1]) (mkConst ``Nat) ctorIdx (mkRawNatLit i) |
| 135 | + let «then» ← withLocalDeclD `h heq fun h => do |
| 136 | + let e ← mkEqNDRec (motive := withCtorTypeNameApp) k h |
| 137 | + let e := mkApp e (mkConst ``PUnit.unit [indLevel]) |
| 138 | + let e := mkAppN e zs |
| 139 | + -- ``Eq.ndrec |
| 140 | + mkLambdaFVars #[h] e |
| 141 | + let «else» ← withLocalDeclD `h (mkNot heq) fun h => |
| 142 | + mkLambdaFVars #[h] k' |
| 143 | + let alt := mkApp5 (mkConst ``dite [v.succ]) |
| 144 | + P heq (mkApp2 (mkConst ``Nat.decEq) ctorIdx (mkRawNatLit i)) |
| 145 | + «then» «else» |
| 146 | + mkLambdaFVars zs alt |
| 147 | + let e := mkAppN e alts |
| 148 | + mkLambdaFVars (xs ++ #[P, ctorIdx, k, k'] ++ ys ++ #[x]) e |
| 149 | + |
| 150 | + let declName := mkWithCtorName indName |
| 151 | + -- not compiled to avoid old code generator bug #1774 |
| 152 | + addDecl (.defnDecl (← mkDefinitionValInferrringUnsafe |
| 153 | + (name := declName) |
| 154 | + (levelParams := casesOnInfo.levelParams) |
| 155 | + (type := (← inferType e)) |
| 156 | + (value := e) |
| 157 | + (hints := ReducibilityHints.abbrev) |
| 158 | + )) |
| 159 | + modifyEnv fun env => addToCompletionBlackList env declName |
| 160 | + modifyEnv fun env => addProtected env declName |
| 161 | + setReducibleAttribute declName |
| 162 | + |
| 163 | +def mkNoConfusionTypeLinear (indName : Name) : MetaM Unit := do |
| 164 | + let declName := mkNoConfusionTypeName indName |
| 165 | + let ConstantInfo.inductInfo info ← getConstInfo indName | unreachable! |
| 166 | + let casesOnName := mkCasesOnName indName |
| 167 | + let casesOnInfo ← getConstVal casesOnName |
| 168 | + let v::us := casesOnInfo.levelParams.map mkLevelParam | panic! "unexpected universe levels on `casesOn`" |
| 169 | + let e := mkConst casesOnName (v.succ::us) |
| 170 | + let t ← inferType e |
| 171 | + let e ← forallBoundedTelescope t info.numParams fun xs t => do |
| 172 | + let e := mkAppN e xs |
| 173 | + let PType := mkSort v |
| 174 | + withLocalDeclD `P PType fun P => do |
| 175 | + let motive ← forallTelescope (← whnfD t).bindingDomain! fun ys _ => |
| 176 | + mkLambdaFVars ys PType |
| 177 | + let t ← instantiateForall t #[motive] |
| 178 | + let e := mkApp e motive |
| 179 | + forallBoundedTelescope t info.numIndices fun ys t => do |
| 180 | + let e := mkAppN e ys |
| 181 | + let xType := mkAppN (mkConst indName us) (xs ++ ys) |
| 182 | + withLocalDeclD `x1 xType fun x1 => do |
| 183 | + withLocalDeclD `x2 xType fun x2 => do |
| 184 | + let t ← instantiateForall t #[x1] |
| 185 | + let e := mkApp e x1 |
| 186 | + forallBoundedTelescope t info.numCtors fun alts _ => do |
| 187 | + let alts' ← alts.mapIdxM fun i alt => do |
| 188 | + let altType ← inferType alt |
| 189 | + forallTelescope altType fun zs1 _ => do |
| 190 | + let alt := mkConst (mkWithCtorName indName) (v :: us) |
| 191 | + let alt := mkAppN alt xs |
| 192 | + let alt := mkApp alt PType |
| 193 | + let alt := mkApp alt (mkRawNatLit i) |
| 194 | + let k ← forallTelescopeReducing (← inferType alt).bindingDomain! fun zs2 _ => do |
| 195 | + let eqs ← (Array.zip zs1 zs2[1:]).filterMapM fun (z1,z2) => do |
| 196 | + if (← isProof z1) then |
| 197 | + return none |
| 198 | + else |
| 199 | + return some (← mkEqHEq z1 z2) |
| 200 | + let k ← mkArrowN eqs P |
| 201 | + let k ← mkArrow k P |
| 202 | + mkLambdaFVars zs2 k |
| 203 | + let alt := mkApp alt k |
| 204 | + let alt := mkApp alt P |
| 205 | + let alt := mkAppN alt ys |
| 206 | + let alt := mkApp alt x2 |
| 207 | + mkLambdaFVars zs1 alt |
| 208 | + let e := mkAppN e alts' |
| 209 | + let e ← mkLambdaFVars #[x1, x2] e |
| 210 | + let e ← mkLambdaFVars #[P] e |
| 211 | + let e ← mkLambdaFVars ys e |
| 212 | + let e ← mkLambdaFVars xs e |
| 213 | + pure e |
| 214 | + |
| 215 | + addDecl (.defnDecl (← mkDefinitionValInferrringUnsafe |
| 216 | + (name := declName) |
| 217 | + (levelParams := casesOnInfo.levelParams) |
| 218 | + (type := (← inferType e)) |
| 219 | + (value := e) |
| 220 | + (hints := ReducibilityHints.abbrev) |
| 221 | + )) |
| 222 | + modifyEnv fun env => addToCompletionBlackList env declName |
| 223 | + modifyEnv fun env => addProtected env declName |
| 224 | + setReducibleAttribute declName |
| 225 | + |
| 226 | +end Lean.NoConfusionLinear |
0 commit comments