Skip to content

Commit 8d8fd07

Browse files
authored
fix: increase precision of new compiler's noncomputable check (#8675)
This PR increases the precision of the new compiler's non computable check, particularly around irrelevant uses of `noncomputable` defs in applications. There are no tests included because they don't pass with the old compiler. They are on the new compiler's branch and they will be merged when it is enabled.
1 parent 4abc443 commit 8d8fd07

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

src/Lean/Compiler/LCNF/ToMono.lean

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,38 @@ def checkFVarUse (fvarId : FVarId) : ToMonoM Unit := do
3131
if let some declName := (← get).noncomputableVars.get? fvarId then
3232
throwError f!"failed to compile definition, consider marking it as 'noncomputable' because it depends on '{declName}', which is 'noncomputable'"
3333

34-
def argToMono (arg : Arg) : ToMonoM Arg := do
34+
def checkFVarUseDeferred (resultFVar fvarId : FVarId) : ToMonoM Unit := do
35+
if let some declName := (← get).noncomputableVars.get? fvarId then
36+
modify fun s => { s with noncomputableVars := s.noncomputableVars.insert resultFVar declName }
37+
38+
@[inline]
39+
def argToMonoBase (check : FVarId → ToMonoM Unit) (arg : Arg) : ToMonoM Arg := do
3540
match arg with
3641
| .erased | .type .. => return .erased
3742
| .fvar fvarId =>
3843
if (← get).typeParams.contains fvarId then
3944
return .erased
4045
else
41-
checkFVarUse fvarId
46+
check fvarId
4247
return arg
4348

44-
def ctorAppToMono (ctorInfo : ConstructorVal) (args : Array Arg) : ToMonoM LetValue := do
45-
let argsNew : Array Arg ← args[:ctorInfo.numParams].toArray.mapM fun arg => do
49+
def argToMono (arg : Arg) : ToMonoM Arg := argToMonoBase checkFVarUse arg
50+
51+
def argToMonoDeferredCheck (resultFVar : FVarId) (arg : Arg) : ToMonoM Arg :=
52+
argToMonoBase (checkFVarUseDeferred resultFVar) arg
53+
54+
def ctorAppToMono (resultFVar : FVarId) (ctorInfo : ConstructorVal) (args : Array Arg)
55+
: ToMonoM LetValue := do
56+
let argsNewParams : Array Arg ← args[:ctorInfo.numParams].toArray.mapM fun arg => do
4657
-- We only preserve constructor parameters that are types
4758
match arg with
4859
| .type type => return .type (← toMonoType type)
4960
| .fvar .. | .erased => return .erased
50-
let argsNew := argsNew ++ (← args[ctorInfo.numParams:].toArray.mapM argToMono)
61+
let argsNewFields ← args[ctorInfo.numParams:].toArray.mapM (argToMonoDeferredCheck resultFVar)
62+
let argsNew := argsNewParams ++ argsNewFields
5163
return .const ctorInfo.name [] argsNew
5264

53-
partial def LetValue.toMono (e : LetValue) (fvarId : FVarId) : ToMonoM LetValue := do
65+
partial def LetValue.toMono (e : LetValue) (resultFVar : FVarId) : ToMonoM LetValue := do
5466
match e with
5567
| .erased | .lit .. => return e
5668
| .const declName _ args =>
@@ -63,26 +75,25 @@ partial def LetValue.toMono (e : LetValue) (fvarId : FVarId) : ToMonoM LetValue
6375
-- and Bool have the same runtime representation.
6476
return args[1]!.toLetValue
6577
else if let some e' ← isTrivialConstructorApp? declName args then
66-
e'.toMono fvarId
78+
e'.toMono resultFVar
6779
else if let some (.ctorInfo ctorInfo) := (← getEnv).find? declName then
68-
ctorAppToMono ctorInfo args
80+
ctorAppToMono resultFVar ctorInfo args
6981
else
7082
let env ← getEnv
7183
if isNoncomputable env declName && !(isExtern env declName) then
72-
modify fun s => { s with noncomputableVars := s.noncomputableVars.insert fvarId declName }
73-
return .const declName [] (← args.mapM argToMono)
84+
modify fun s => { s with noncomputableVars := s.noncomputableVars.insert resultFVar declName }
85+
return .const declName [] (← args.mapM (argToMonoDeferredCheck resultFVar))
7486
| .fvar fvarId args =>
7587
if (← get).typeParams.contains fvarId then
7688
return .erased
7789
else
78-
checkFVarUse fvarId
79-
return .fvar fvarId (← args.mapM argToMono)
90+
checkFVarUseDeferred resultFVar fvarId
91+
return .fvar fvarId (← args.mapM (argToMonoDeferredCheck resultFVar))
8092
| .proj structName fieldIdx baseFVar =>
8193
if (← get).typeParams.contains baseFVar then
8294
return .erased
8395
else
84-
if let some declName := (← get).noncomputableVars.get? baseFVar then
85-
modify fun s => { s with noncomputableVars := s.noncomputableVars.insert fvarId declName }
96+
checkFVarUseDeferred resultFVar baseFVar
8697
if let some info ← hasTrivialStructure? structName then
8798
if info.fieldIdx == fieldIdx then
8899
return .fvar baseFVar #[]

0 commit comments

Comments
 (0)