Skip to content

Commit 5e40f4a

Browse files
authored
feat: linear-size noConfusionType construction (#8037)
This PR introduces a `noConfusionType` construction that’s sub-quadratic in size, and reduces faster. The previous `noConfusion` construction with two nested `match` statements is quadratic in size and reduction behavior. Using some helper definitions, a linear size construction is possible. With this, processing the RISC-V-AST definition from https://github.com/opencompl/sail-riscv-lean takes 6s instead of 60s. The previous construction is still used when processing the early prelude, and can be enabled elsewhere using `set_option backwards.linearNoConfusionType false`.
1 parent 2594a8e commit 5e40f4a

File tree

6 files changed

+356
-20
lines changed

6 files changed

+356
-20
lines changed

src/Lean/Meta/Constructions/NoConfusion.lean

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ prelude
77
import Lean.AddDecl
88
import Lean.Meta.AppBuilder
99
import Lean.Meta.CompletionName
10+
import Lean.Meta.Constructions.NoConfusionLinear
11+
12+
13+
register_builtin_option backwards.linearNoConfusionType : Bool := {
14+
defValue := true
15+
descr := "use the linear-size construction for the `noConfusionType` declaration of an inductive type. Set to false to use the previous, simpler but quadratic-size construction. "
16+
}
1017

1118
namespace Lean
1219

@@ -21,12 +28,23 @@ def mkNoConfusionCore (declName : Name) : MetaM Unit := do
2128
let recInfo ← getConstInfo (mkRecName declName)
2229
unless recInfo.levelParams.length > indVal.levelParams.length do return
2330

24-
let name := Name.mkStr declName "noConfusionType"
25-
let decl ← ofExceptKernelException (mkNoConfusionTypeCoreImp (← getEnv) declName)
26-
addDecl decl
27-
setReducibleAttribute name
28-
modifyEnv fun env => addToCompletionBlackList env name
29-
modifyEnv fun env => addProtected env name
31+
let useLinear ←
32+
if backwards.linearNoConfusionType.get (← getOptions) then
33+
NoConfusionLinear.deps.allM (hasConst · (skipRealize := true))
34+
else
35+
pure false
36+
37+
if useLinear then
38+
NoConfusionLinear.mkWithCtorType declName
39+
NoConfusionLinear.mkWithCtor declName
40+
NoConfusionLinear.mkNoConfusionTypeLinear declName
41+
else
42+
let name := Name.mkStr declName "noConfusionType"
43+
let decl ← ofExceptKernelException (mkNoConfusionTypeCoreImp (← getEnv) declName)
44+
addDecl decl
45+
setReducibleAttribute name
46+
modifyEnv fun env => addToCompletionBlackList env name
47+
modifyEnv fun env => addProtected env name
3048

3149
let name := Name.mkStr declName "noConfusion"
3250
let decl ← ofExceptKernelException (mkNoConfusionCoreImp (← getEnv) declName)
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
/-
2+
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Joachim Breitner
5+
-/
6+
prelude
7+
import Lean.AddDecl
8+
import Lean.Meta.AppBuilder
9+
import Lean.Meta.CompletionName
10+
11+
/-!
12+
This module produces a construction for the `noConfusionType` that is linear in size in the number of
13+
constructors of the inductive type. This is in contrast to the previous construction (definde in
14+
`no_confusion.cpp`), that is quadratic in size due to nested `.brecOn` applications.
15+
16+
We still use the old construction when processing the prelude, for the few inductives that we need
17+
until below (`Nat`, `Bool`, `Decidable`).
18+
19+
The main trick is to use a `withCtor` helper that is like a match with one constructor pattern and
20+
one catch-all pattern, and thus linear in size. And the helper itself is a single function
21+
definition, rather than one for each constructor, using a `withCtorType` helper in the function.
22+
23+
See the `linearNoConfusion.lean` test for exemplary output of this translation (checked to be
24+
up-to-date).
25+
26+
The `withCtor` functions could be generally useful, but for that they should probably eliminate into
27+
`Sort _` rather than `Type _`, and then writing the `withCtorType` function runs into universe level
28+
confusion, which may be solvable if the kernel had more complete univere level normalization.
29+
Until then we put these helper in the `noConfusionType` namespace to indicate that they are not
30+
general purpose.
31+
32+
This module is written in a rather manual style, constructing the `Expr` directly. It's best
33+
read with the expected output to the side.
34+
-/
35+
36+
namespace Lean.NoConfusionLinear
37+
38+
open Meta
39+
40+
/--
41+
List of constants that the linear `noConfusionType` construction depends on.
42+
-/
43+
def deps : Array Lean.Name :=
44+
#[ ``Nat.lt, ``cond, ``Nat, ``PUnit, ``Eq, ``Not, ``dite, ``Nat.decEq, ``Nat.blt ]
45+
46+
def mkNatLookupTable (n : Expr) (es : Array Expr) (default : Expr) : MetaM Expr := do
47+
let type ← inferType default
48+
let u ← getLevel type
49+
let rec go (start stop : Nat) (hstart : start < stop := by omega) (hstop : stop ≤ es.size := by omega) : MetaM Expr := do
50+
if h : start + 1 = stop then
51+
return es[start]
52+
else
53+
let mid := (start + stop) / 2
54+
let low ← go start mid
55+
let high ← go mid stop
56+
return mkApp4 (mkConst ``cond [u]) type (mkApp2 (mkConst ``Nat.blt) n (mkRawNatLit mid)) low high
57+
if h : es.size = 0 then
58+
pure default
59+
else
60+
go 0 es.size
61+
62+
def mkWithCtorTypeName (indName : Name) : Name :=
63+
Name.str indName "noConfusionType" |>.str "withCtorType"
64+
65+
def mkWithCtorName (indName : Name) : Name :=
66+
Name.str indName "noConfusionType" |>.str "withCtor"
67+
68+
def mkNoConfusionTypeName (indName : Name) : Name :=
69+
Name.str indName "noConfusionType"
70+
71+
def mkWithCtorType (indName : Name) : MetaM Unit := do
72+
let ConstantInfo.inductInfo info ← getConstInfo indName | unreachable!
73+
let casesOnName := mkCasesOnName indName
74+
let casesOnInfo ← getConstVal casesOnName
75+
let v::us := casesOnInfo.levelParams.map mkLevelParam | panic! "unexpected universe levels on `casesOn`"
76+
let indTyCon := mkConst indName us
77+
let indTyKind ← inferType indTyCon
78+
let indLevel ← getLevel indTyKind
79+
let e ← forallBoundedTelescope indTyKind info.numParams fun xs _ => do
80+
withLocalDeclD `P (mkSort v.succ) fun P => do
81+
withLocalDeclD `ctorIdx (mkConst ``Nat) fun ctorIdx => do
82+
let default ← mkArrow (mkConst ``PUnit [indLevel]) P
83+
let es ← info.ctors.toArray.mapM fun ctorName => do
84+
let ctor := mkAppN (mkConst ctorName us) xs
85+
let ctorType ← inferType ctor
86+
let argType ← forallTelescope ctorType fun ys _ =>
87+
mkForallFVars ys P
88+
mkArrow (mkConst ``PUnit [indLevel]) argType
89+
let e ← mkNatLookupTable ctorIdx es default
90+
mkLambdaFVars ((xs.push P).push ctorIdx) e
91+
92+
let declName := mkWithCtorTypeName indName
93+
addAndCompile (.defnDecl (← mkDefinitionValInferrringUnsafe
94+
(name := declName)
95+
(levelParams := casesOnInfo.levelParams)
96+
(type := (← inferType e))
97+
(value := e)
98+
(hints := ReducibilityHints.abbrev)
99+
))
100+
modifyEnv fun env => addToCompletionBlackList env declName
101+
modifyEnv fun env => addProtected env declName
102+
setReducibleAttribute declName
103+
104+
def mkWithCtor (indName : Name) : MetaM Unit := do
105+
let ConstantInfo.inductInfo info ← getConstInfo indName | unreachable!
106+
let withCtorTypeName := mkWithCtorTypeName indName
107+
let casesOnName := mkCasesOnName indName
108+
let casesOnInfo ← getConstVal casesOnName
109+
let v::us := casesOnInfo.levelParams.map mkLevelParam | panic! "unexpected universe levels on `casesOn`"
110+
let indTyCon := mkConst indName us
111+
let indTyKind ← inferType indTyCon
112+
let indLevel ← getLevel indTyKind
113+
let e ← forallBoundedTelescope indTyKind info.numParams fun xs t => do
114+
withLocalDeclD `P (mkSort v.succ) fun P => do
115+
withLocalDeclD `ctorIdx (mkConst ``Nat) fun ctorIdx => do
116+
let withCtorTypeNameApp := mkAppN (mkConst withCtorTypeName (v :: us)) (xs.push P)
117+
let kType := mkApp withCtorTypeNameApp ctorIdx
118+
withLocalDeclD `k kType fun k =>
119+
withLocalDeclD `k' P fun k' =>
120+
forallBoundedTelescope t info.numIndices fun ys t' => do
121+
let t' ← whnfD t'
122+
assert! t'.isSort
123+
withLocalDeclD `x (mkAppN indTyCon (xs ++ ys)) fun x => do
124+
let e := mkConst (mkCasesOnName indName) (v.succ :: us)
125+
let e := mkAppN e xs
126+
let motive ← mkLambdaFVars (ys.push x) P
127+
let e := mkApp e motive
128+
let e := mkAppN e ys
129+
let e := mkApp e x
130+
let alts ← info.ctors.toArray.mapIdxM fun i ctorName => do
131+
let ctor := mkAppN (mkConst ctorName us) xs
132+
let ctorType ← inferType ctor
133+
forallTelescope ctorType fun zs _ => do
134+
let heq := mkApp3 (mkConst ``Eq [1]) (mkConst ``Nat) ctorIdx (mkRawNatLit i)
135+
let «then» ← withLocalDeclD `h heq fun h => do
136+
let e ← mkEqNDRec (motive := withCtorTypeNameApp) k h
137+
let e := mkApp e (mkConst ``PUnit.unit [indLevel])
138+
let e := mkAppN e zs
139+
-- ``Eq.ndrec
140+
mkLambdaFVars #[h] e
141+
let «else» ← withLocalDeclD `h (mkNot heq) fun h =>
142+
mkLambdaFVars #[h] k'
143+
let alt := mkApp5 (mkConst ``dite [v.succ])
144+
P heq (mkApp2 (mkConst ``Nat.decEq) ctorIdx (mkRawNatLit i))
145+
«then» «else»
146+
mkLambdaFVars zs alt
147+
let e := mkAppN e alts
148+
mkLambdaFVars (xs ++ #[P, ctorIdx, k, k'] ++ ys ++ #[x]) e
149+
150+
let declName := mkWithCtorName indName
151+
-- not compiled to avoid old code generator bug #1774
152+
addDecl (.defnDecl (← mkDefinitionValInferrringUnsafe
153+
(name := declName)
154+
(levelParams := casesOnInfo.levelParams)
155+
(type := (← inferType e))
156+
(value := e)
157+
(hints := ReducibilityHints.abbrev)
158+
))
159+
modifyEnv fun env => addToCompletionBlackList env declName
160+
modifyEnv fun env => addProtected env declName
161+
setReducibleAttribute declName
162+
163+
def mkNoConfusionTypeLinear (indName : Name) : MetaM Unit := do
164+
let declName := mkNoConfusionTypeName indName
165+
let ConstantInfo.inductInfo info ← getConstInfo indName | unreachable!
166+
let casesOnName := mkCasesOnName indName
167+
let casesOnInfo ← getConstVal casesOnName
168+
let v::us := casesOnInfo.levelParams.map mkLevelParam | panic! "unexpected universe levels on `casesOn`"
169+
let e := mkConst casesOnName (v.succ::us)
170+
let t ← inferType e
171+
let e ← forallBoundedTelescope t info.numParams fun xs t => do
172+
let e := mkAppN e xs
173+
let PType := mkSort v
174+
withLocalDeclD `P PType fun P => do
175+
let motive ← forallTelescope (← whnfD t).bindingDomain! fun ys _ =>
176+
mkLambdaFVars ys PType
177+
let t ← instantiateForall t #[motive]
178+
let e := mkApp e motive
179+
forallBoundedTelescope t info.numIndices fun ys t => do
180+
let e := mkAppN e ys
181+
let xType := mkAppN (mkConst indName us) (xs ++ ys)
182+
withLocalDeclD `x1 xType fun x1 => do
183+
withLocalDeclD `x2 xType fun x2 => do
184+
let t ← instantiateForall t #[x1]
185+
let e := mkApp e x1
186+
forallBoundedTelescope t info.numCtors fun alts _ => do
187+
let alts' ← alts.mapIdxM fun i alt => do
188+
let altType ← inferType alt
189+
forallTelescope altType fun zs1 _ => do
190+
let alt := mkConst (mkWithCtorName indName) (v :: us)
191+
let alt := mkAppN alt xs
192+
let alt := mkApp alt PType
193+
let alt := mkApp alt (mkRawNatLit i)
194+
let k ← forallTelescopeReducing (← inferType alt).bindingDomain! fun zs2 _ => do
195+
let eqs ← (Array.zip zs1 zs2[1:]).filterMapM fun (z1,z2) => do
196+
if (← isProof z1) then
197+
return none
198+
else
199+
return some (← mkEqHEq z1 z2)
200+
let k ← mkArrowN eqs P
201+
let k ← mkArrow k P
202+
mkLambdaFVars zs2 k
203+
let alt := mkApp alt k
204+
let alt := mkApp alt P
205+
let alt := mkAppN alt ys
206+
let alt := mkApp alt x2
207+
mkLambdaFVars zs1 alt
208+
let e := mkAppN e alts'
209+
let e ← mkLambdaFVars #[x1, x2] e
210+
let e ← mkLambdaFVars #[P] e
211+
let e ← mkLambdaFVars ys e
212+
let e ← mkLambdaFVars xs e
213+
pure e
214+
215+
addDecl (.defnDecl (← mkDefinitionValInferrringUnsafe
216+
(name := declName)
217+
(levelParams := casesOnInfo.levelParams)
218+
(type := (← inferType e))
219+
(value := e)
220+
(hints := ReducibilityHints.abbrev)
221+
))
222+
modifyEnv fun env => addToCompletionBlackList env declName
223+
modifyEnv fun env => addProtected env declName
224+
setReducibleAttribute declName
225+
226+
end Lean.NoConfusionLinear

tests/lean/run/815.lean

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
def is_smooth {α β} (f : α → β) : Prop := sorry
1+
axiom testSorry : α
2+
3+
def is_smooth {α β} (f : α → β) : Prop := testSorry
24

35
class IsSmooth {α β} (f : α → β) : Prop where
46
(proof : is_smooth f)
57

6-
instance identity : IsSmooth fun a : α => a := sorry
7-
instance const (b : β) : IsSmooth fun a : α => b := sorry
8-
instance swap (f : α → β → γ) [∀ a, IsSmooth (f a)] : IsSmooth (λ b a => f a b) := sorry
9-
instance parm (f : α → β → γ) [IsSmooth f] (b : β) : IsSmooth (λ a => f a b) := sorry
10-
instance comp (f : β → γ) (g : α → β) [IsSmooth f] [IsSmooth g] : IsSmooth (fun a => f (g a)) := sorry
11-
instance diag (f : β → δ → γ) (g : α → β) (h : α → δ) [IsSmooth f] [∀ b, IsSmooth (f b)] [IsSmooth g] [IsSmooth h] : IsSmooth (λ a => f (g a) (h a)) := sorry
8+
instance identity : IsSmooth fun a : α => a := testSorry
9+
instance const (b : β) : IsSmooth fun a : α => b := testSorry
10+
instance swap (f : α → β → γ) [∀ a, IsSmooth (f a)] : IsSmooth (λ b a => f a b) := testSorry
11+
instance parm (f : α → β → γ) [IsSmooth f] (b : β) : IsSmooth (λ a => f a b) := testSorry
12+
instance comp (f : β → γ) (g : α → β) [IsSmooth f] [IsSmooth g] : IsSmooth (fun a => f (g a)) := testSorry
13+
instance diag (f : β → δ → γ) (g : α → β) (h : α → δ) [IsSmooth f] [∀ b, IsSmooth (f b)] [IsSmooth g] [IsSmooth h] : IsSmooth (λ a => f (g a) (h a)) := testSorry
1214

1315
example (f : β → δ → γ) [IsSmooth f] (g : α → β) [IsSmooth g] (d : δ) : IsSmooth (λ a => f (g a) d) := by infer_instance
1416
example (f : β → δ → γ) [IsSmooth f] (g : α → β) [IsSmooth g] : IsSmooth (λ a d => f (g a) d) := by infer_instance

tests/lean/run/grind_pre.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ bs : List Point
108108
b₂ : Nat
109109
b₃ : Int
110110
head_eq : a₁ = b₁
111-
x_eq : a₂ = b₂
112-
y_eq : a₃ = b₃
111+
h_1 : a₂ = b₂
112+
h_2 : a₃ = b₃
113113
tail_eq_1 : as = bs
114114
⊢ False
115115
[grind] Goal diagnostics
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/-!
2+
This test tests and also explains the noConfusionType construction.
3+
4+
It contains copies of the definitions of the constructions, for manual experimentation with
5+
the code, and uses `#guard_msgs` and `rfl` to compare them to the generated ones.
6+
7+
This also serves as documentation to the `NoConfusionLinear` module, so do not delete or remove
8+
from this file without updating that module's docstring.
9+
-/
10+
11+
inductive Vec.{u} (α : Type) : Nat → Type u where
12+
| nil : Vec α 0
13+
| cons {n} : α → Vec α n → Vec α (n + 1)
14+
15+
@[reducible] protected def Vec.noConfusionType.withCtorType'.{u_1, u} :
16+
TypeType u_1 → Nat → Type (max (u + 1) u_1) := fun α P ctorIdx =>
17+
bif Nat.blt ctorIdx 1
18+
then PUnit.{u + 2} → P
19+
else PUnit.{u + 2} → {n : Nat} → α → Vec.{u} α n → P
20+
21+
/--
22+
info: @[reducible] protected def Vec.noConfusionType.withCtorType.{u_1, u} : Type → Type u_1 → Nat → Type (max (u + 1) u_1) :=
23+
fun α P ctorIdx => bif ctorIdx.blt 1 then PUnit → P else PUnit → {n : Nat} → α → Vec α n → P
24+
-/
25+
#guard_msgs in
26+
#print Vec.noConfusionType.withCtorType
27+
28+
example : @Vec.noConfusionType.withCtorType.{u_1,u} = @Vec.noConfusionType.withCtorType'.{u_1,u} := rfl
29+
30+
@[reducible] protected noncomputable def Vec.noConfusionType.withCtor'.{u_1, u} : (α : Type) →
31+
(P : Type u_1) → (ctorIdx : Nat) → Vec.noConfusionType.withCtorType' α P ctorIdx → P → (a : Nat) → Vec.{u} α a → P :=
32+
fun _α _P ctorIdx k k' _a x =>
33+
Vec.casesOn x
34+
(if h : ctorIdx = 0 then Eq.ndrec k h PUnit.unit else k')
35+
(fun a a_1 => if h : ctorIdx = 1 then Eq.ndrec k h PUnit.unit a a_1 else k')
36+
37+
/--
38+
info: @[reducible] protected def Vec.noConfusionType.withCtor.{u_1, u} : (α : Type) →
39+
(P : Type u_1) → (ctorIdx : Nat) → Vec.noConfusionType.withCtorType α P ctorIdx → P → (a : Nat) → Vec α a → P :=
40+
fun α P ctorIdx k k' a x =>
41+
Vec.casesOn x (if h : ctorIdx = 0 then (h ▸ k) PUnit.unit else k') fun {n} a a_1 =>
42+
if h : ctorIdx = 1 then (h ▸ k) PUnit.unit a a_1 else k'
43+
-/
44+
#guard_msgs in
45+
#print Vec.noConfusionType.withCtor
46+
47+
example : @Vec.noConfusionType.withCtor.{u_1,u} = @Vec.noConfusionType.withCtor'.{u_1,u} := rfl
48+
49+
@[reducible] protected def Vec.noConfusionType'.{u_1, u} : {α : Type} →
50+
{a : Nat} → Sort u_1 → Vec.{u} α a → Vec α a → Sort u_1 :=
51+
fun {α} {a} P x1 x2 =>
52+
Vec.casesOn x1
53+
(Vec.noConfusionType.withCtor' α (Sort u_1) 0 (fun _x => P → P) P a x2)
54+
(fun {n} a_1 a_2 => Vec.noConfusionType.withCtor' α (Sort u_1) 1 (fun _x {n_1} a a_3 => (n = n_1 → a_1 = a → HEq a_2 a_3 → P) → P) P a x2)
55+
56+
/--
57+
info: @[reducible] protected def Vec.noConfusionType.{u_1, u} : {α : Type} →
58+
{a : Nat} → Sort u_1 → Vec α a → Vec α a → Sort u_1 :=
59+
fun {α} {a} P x1 x2 =>
60+
Vec.casesOn x1 (Vec.noConfusionType.withCtor α (Sort u_1) 0 (fun x => P → P) P a x2) fun {n} a_1 a_2 =>
61+
Vec.noConfusionType.withCtor α (Sort u_1) 1 (fun x {n_1} a a_3 => (n = n_1 → a_1 = a → HEq a_2 a_3 → P) → P) P a x2
62+
-/
63+
#guard_msgs in
64+
#print Vec.noConfusionType
65+
66+
example : @Vec.noConfusionType.{u_1,u} = @Vec.noConfusionType'.{u_1,u} := rfl
67+
68+
/-
69+
run_meta do
70+
let mut i := 0
71+
for (n, _c) in (← getEnv).constants do
72+
if let .str indName "noConfusion" := n then
73+
let ConstantInfo.inductInfo _ ← getConstInfo indName | continue
74+
logInfo m!"Looking at {.ofConstName indName}"
75+
mkToCtorIdx' indName
76+
mkWithCtorType indName
77+
mkWithCtor indName
78+
mkNoConfusionType' indName
79+
i := i + 1
80+
if i > 10 then
81+
return
82+
-/
83+
84+
-- inductive Enum.{u} : Type u where | a | b
85+
-- set_option pp.universes true in
86+
-- #print noConfusionTypeEnum

0 commit comments

Comments
 (0)