Skip to content

Commit 3b40682

Browse files
authored
perf: handle per-constructor noConfusion in toLCNF (#11566)
This PR lets the compiler treat per-constructor `noConfusion` like the general one, and moves some more logic closer to no confusion generation.
1 parent 06037ad commit 3b40682

File tree

5 files changed

+99
-71
lines changed

5 files changed

+99
-71
lines changed

src/Lean/AuxRecursor.lean

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,31 @@ public def isSparseCasesOn (env : Environment) (declName : Name) : Bool :=
6262
public def isCasesOnLike (env : Environment) (declName : Name) : Bool :=
6363
isCasesOnRecursor env declName || isSparseCasesOn env declName
6464

65-
builtin_initialize noConfusionExt : TagDeclarationExtension ← mkTagDeclarationExtension
65+
/--
66+
Shape information for no confusion lemmas.
67+
The `arity` does not include the final major argument (which is not there when the constructors differ)
68+
The regular no confusion lemma marks the lhs and rhs arguments for the compiler to look at and
69+
find the number of fields.
70+
The per-constructor no confusion lemmas know the number of (non-prop) fields statically.
71+
-/
72+
inductive NoConfusionInfo where
73+
| regular (arity : Nat) (lhs : Nat) (rhs : Nat)
74+
| perCtor (arity : Nat) (fields : Nat)
75+
deriving Inhabited
76+
77+
def NoConfusionInfo.arity : NoConfusionInfo → Nat
78+
| .regular arity _ _ => arity
79+
| .perCtor arity _ => arity
6680

67-
def markNoConfusion (env : Environment) (n : Name) : Environment :=
68-
noConfusionExt.tag env n
81+
builtin_initialize noConfusionExt : MapDeclarationExtension NoConfusionInfo ← mkMapDeclarationExtension (asyncMode := .mainOnly)
82+
83+
def markNoConfusion (env : Environment) (n : Name) (info : NoConfusionInfo) : Environment :=
84+
noConfusionExt.insert env n info
6985

7086
def isNoConfusion (env : Environment) (n : Name) : Bool :=
71-
noConfusionExt.isTagged env n
87+
noConfusionExt.contains env n
88+
89+
def getNoConfusionInfo (env : Environment) (n : Name) : NoConfusionInfo :=
90+
(noConfusionExt.find? env n).get!
7291

7392
end Lean

src/Lean/Compiler/LCNF/ToLCNF.lean

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -664,32 +664,38 @@ where
664664

665665
visitNoConfusion (e : Expr) : M Arg := do
666666
let .const declName _ := e.getAppFn | unreachable!
667+
let info := getNoConfusionInfo (← getEnv) declName
667668
let typeName := declName.getPrefix
668-
let .inductInfo inductVal ← getConstInfo typeName | unreachable!
669-
let arity := inductVal.numParams + 1 /- motive -/ + 3*(inductVal.numIndices + 1) /- lhs/rhs and equalities -/
670-
etaIfUnderApplied e arity do
669+
etaIfUnderApplied e info.arity do
671670
let args := e.getAppArgs
672-
let lhs ← liftMetaM do Meta.whnf args[inductVal.numParams + 1 + inductVal.numIndices]!
673-
let rhs ← liftMetaM do Meta.whnf args[inductVal.numParams + 1 + inductVal.numIndices + 1 + inductVal.numIndices]!
674-
let lhs ← liftMetaM lhs.toCtorIfLit
675-
let rhs ← liftMetaM rhs.toCtorIfLit
676-
match (← liftMetaM <| Meta.isConstructorApp? lhs), (← liftMetaM <| Meta.isConstructorApp? rhs) with
677-
| some lhsCtorVal, some rhsCtorVal =>
678-
if lhsCtorVal.name == rhsCtorVal.name then
679-
etaIfUnderApplied e (arity+1) do
680-
let major := args[arity]!
671+
let visitMajor (numNonPropFields : Nat) := do
672+
etaIfUnderApplied e (info.arity+1) do
673+
let major := args[info.arity]!
674+
let major ← expandNoConfusionMajor major numNonPropFields
675+
let major := mkAppN major args[(info.arity+1)...*]
676+
visit major
677+
678+
match info with
679+
| .regular _ lhsPos rhsPos =>
680+
let lhs ← liftMetaM do Meta.whnf args[lhsPos]!
681+
let rhs ← liftMetaM do Meta.whnf args[rhsPos]!
682+
let lhs ← liftMetaM lhs.toCtorIfLit
683+
let rhs ← liftMetaM rhs.toCtorIfLit
684+
match (← liftMetaM <| Meta.isConstructorApp? lhs), (← liftMetaM <| Meta.isConstructorApp? rhs) with
685+
| some lhsCtorVal, some rhsCtorVal =>
686+
if lhsCtorVal.name == rhsCtorVal.name then
681687
let numNonPropFields ← liftMetaM <| Meta.forallTelescope lhsCtorVal.type fun params _ =>
682688
params[lhsCtorVal.numParams...*].foldlM (init := 0) fun n param => do
683689
let type ← param.fvarId!.getType
684690
return if !(← Meta.isProp type) then n + 1 else n
685-
let major ← expandNoConfusionMajor major numNonPropFields
686-
let major := mkAppN major args[(arity+1)...*]
687-
visit major
688-
else
689-
let type ← toLCNFType (← liftMetaM <| Meta.inferType e)
690-
mkUnreachable type
691-
| _, _ =>
692-
throwError "code generator failed, unsupported occurrence of `{.ofConstName declName}`"
691+
visitMajor numNonPropFields
692+
else
693+
let type ← toLCNFType (← liftMetaM <| Meta.inferType e)
694+
mkUnreachable type
695+
| _, _ =>
696+
throwError "code generator failed, unsupported occurrence of `{.ofConstName declName}`"
697+
| .perCtor _ numNonPropFields =>
698+
visitMajor numNonPropFields
693699

694700
expandNoConfusionMajor (major : Expr) (numFields : Nat) : M Expr := do
695701
match numFields with

src/Lean/Meta/Constructions/NoConfusion.lean

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,10 @@ def mkNoConfusionCoreImp (indName : Name) : MetaM Unit := do
257257
(value := e)
258258
(hints := ReducibilityHints.abbrev)))
259259
setReducibleAttribute declName
260-
modifyEnv fun env => markNoConfusion env declName
260+
let arity := info.numParams + 1 + 3 * (info.numIndices + 1)
261+
let lhsPos := info.numParams + 1 + info.numIndices
262+
let rhsPos := info.numParams + 1 + info.numIndices + 1 + info.numIndices
263+
modifyEnv fun env => markNoConfusion env declName (.regular arity lhsPos rhsPos)
261264
modifyEnv fun env => addProtected env declName
262265

263266
/--
@@ -295,48 +298,47 @@ def mkNoConfusionCtors (declName : Name) : MetaM Unit := do
295298
for ctor in indVal.ctors do
296299
let ctorInfo ← getConstInfoCtor ctor
297300
if ctorInfo.numFields > 0 then
298-
let e ←
299-
forallBoundedTelescope ctorInfo.type ctorInfo.numParams fun xs t => do
300-
withLocalDeclD `P (.sort v) fun P =>
301-
forallBoundedTelescope t ctorInfo.numFields fun fields1 _ => do
302-
forallBoundedTelescope t ctorInfo.numFields fun fields2 _ => do
303-
withPrimedNames fields2 do
304-
withImplicitBinderInfos (xs ++ #[P] ++ fields1 ++ fields2) do
305-
let ctor1 := mkAppN (mkConst ctor us) (xs ++ fields1)
306-
let ctor2 := mkAppN (mkConst ctor us) (xs ++ fields2)
307-
let is1 := (← whnf (← inferType ctor1)).getAppArgsN indVal.numIndices
308-
let is2 := (← whnf (← inferType ctor2)).getAppArgsN indVal.numIndices
309-
withNeededEqTelescope (is1.push ctor1) (is2.push ctor2) fun eqvs eqs => do
310-
-- When the kernel checks this definition, it will perform the potentially expensive
311-
-- computation that `noConfusionType h` is equal to `$kType → P`
312-
let kType ← mkNoConfusionCtorArg ctor P
313-
let kType := kType.beta (xs ++ fields1 ++ fields2)
314-
withLocalDeclD `k kType fun k => do
315-
let mut e := mkConst noConfusionName (v :: us)
316-
e := mkAppN e (xs ++ #[P] ++ is1 ++ #[ctor1] ++ is2 ++ #[ctor2])
317-
-- eqs may have more Eq rather than HEq than expected by `noConfusion`
318-
for eq in eqs do
319-
let needsHEq := (← whnfForall (← inferType e)).bindingDomain!.isHEq
320-
if needsHEq && (← inferType eq).isEq then
321-
e := mkApp e (← mkHEqOfEq eq)
322-
else
323-
e := mkApp e eq
324-
e := mkApp e k
325-
e ← mkExpectedTypeHint e P
326-
mkLambdaFVars (xs ++ #[P] ++ fields1 ++ fields2 ++ eqvs ++ #[k]) e
327-
let name := ctor.str "noConfusion"
328-
addDecl (.defnDecl (← mkDefinitionValInferringUnsafe
329-
(name := name)
330-
(levelParams := recInfo.levelParams)
331-
(type := (← inferType e))
332-
(value := e)
333-
(hints := ReducibilityHints.abbrev)
334-
))
335-
setReducibleAttribute name
336-
-- The compiler has special support for `noConfusion`. So lets mark this as
337-
-- macroInline to not generate code for all these extra definitions, and instead
338-
-- let the compiler unfold this to then put the custom code there
339-
setInlineAttribute name (kind := .macroInline)
301+
forallBoundedTelescope ctorInfo.type ctorInfo.numParams fun xs t => do
302+
withLocalDeclD `P (.sort v) fun P =>
303+
forallBoundedTelescope t ctorInfo.numFields fun fields1 _ => do
304+
forallBoundedTelescope t ctorInfo.numFields fun fields2 _ => do
305+
withPrimedNames fields2 do
306+
withImplicitBinderInfos (xs ++ #[P] ++ fields1 ++ fields2) do
307+
let ctor1 := mkAppN (mkConst ctor us) (xs ++ fields1)
308+
let ctor2 := mkAppN (mkConst ctor us) (xs ++ fields2)
309+
let is1 := (← whnf (← inferType ctor1)).getAppArgsN indVal.numIndices
310+
let is2 := (← whnf (← inferType ctor2)).getAppArgsN indVal.numIndices
311+
withNeededEqTelescope (is1.push ctor1) (is2.push ctor2) fun eqvs eqs => do
312+
-- When the kernel checks this definition, it will perform the potentially expensive
313+
-- computation that `noConfusionType h` is equal to `$kType → P`
314+
let kType ← mkNoConfusionCtorArg ctor P
315+
let kType := kType.beta (xs ++ fields1 ++ fields2)
316+
withLocalDeclD `k kType fun k => do
317+
let mut e := mkConst noConfusionName (v :: us)
318+
e := mkAppN e (xs ++ #[P] ++ is1 ++ #[ctor1] ++ is2 ++ #[ctor2])
319+
-- eqs may have more Eq rather than HEq than expected by `noConfusion`
320+
for eq in eqs do
321+
let needsHEq := (← whnfForall (← inferType e)).bindingDomain!.isHEq
322+
if needsHEq && (← inferType eq).isEq then
323+
e := mkApp e (← mkHEqOfEq eq)
324+
else
325+
e := mkApp e eq
326+
e := mkApp e k
327+
e ← mkExpectedTypeHint e P
328+
e ← mkLambdaFVars (xs ++ #[P] ++ fields1 ++ fields2 ++ eqvs ++ #[k]) e
329+
330+
let name := ctor.str "noConfusion"
331+
addDecl (.defnDecl (← mkDefinitionValInferringUnsafe
332+
(name := name)
333+
(levelParams := recInfo.levelParams)
334+
(type := (← inferType e))
335+
(value := e)
336+
(hints := ReducibilityHints.abbrev)
337+
))
338+
setReducibleAttribute name
339+
let arity := ctorInfo.numParams + 1 + 2 * ctorInfo.numFields + indVal.numIndices + 1
340+
let fields := kType.getNumHeadForalls
341+
modifyEnv fun env => markNoConfusion env name (.perCtor arity fields)
340342

341343

342344
def mkNoConfusionCore (declName : Name) : MetaM Unit := do
@@ -375,7 +377,7 @@ where
375377
let ctorIdx := mkConst (mkCtorIdxName enumName) us
376378
mkLambdaFVars #[P, x, y] (← mkAppM ``noConfusionTypeEnum #[ctorIdx, P, x, y])
377379
let declName := Name.mkStr enumName "noConfusionType"
378-
addAndCompile <| Declaration.defnDecl {
380+
addDecl <| Declaration.defnDecl {
379381
name := declName
380382
levelParams := v :: info.levelParams
381383
type := declType
@@ -404,7 +406,7 @@ where
404406
else
405407
mkAppOptM ``noConfusionEnum #[none, none, none, ctorIdx, P, x, y, h]
406408
let declName := Name.mkStr enumName "noConfusion"
407-
addAndCompile <| Declaration.defnDecl {
409+
addDecl <| Declaration.defnDecl {
408410
name := declName
409411
levelParams := v :: info.levelParams
410412
type := declType
@@ -413,7 +415,7 @@ where
413415
hints := ReducibilityHints.abbrev
414416
}
415417
setReducibleAttribute declName
416-
modifyEnv fun env => markNoConfusion env declName
418+
modifyEnv fun env => markNoConfusion env declName (.regular 4 1 2)
417419

418420
public def mkNoConfusion (declName : Name) : MetaM Unit := do
419421
withTraceNode `Meta.mkNoConfusion (fun _ => return m!"{declName}") do

stage0/src/stdlib_flags.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "util/options.h"
22

3+
// please update me
4+
35
namespace lean {
46
options get_default_options() {
57
options opts;
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
librarySearch.lean:270:0-270:7: warning: declaration uses 'sorry'
21
librarySearch.lean:375:0-375:7: warning: declaration uses 'sorry'
32
librarySearch.lean:385:0-385:7: warning: declaration uses 'sorry'

0 commit comments

Comments
 (0)