@@ -31,26 +31,38 @@ def checkFVarUse (fvarId : FVarId) : ToMonoM Unit := do
3131 if let some declName := (← get).noncomputableVars.get? fvarId then
3232 throwError f!"failed to compile definition, consider marking it as 'noncomputable' because it depends on '{declName}', which is 'noncomputable'"
3333
34- def argToMono (arg : Arg) : ToMonoM Arg := do
34+ def checkFVarUseDeferred (resultFVar fvarId : FVarId) : ToMonoM Unit := do
35+ if let some declName := (← get).noncomputableVars.get? fvarId then
36+ modify fun s => { s with noncomputableVars := s.noncomputableVars.insert resultFVar declName }
37+
38+ @[inline]
39+ def argToMonoBase (check : FVarId → ToMonoM Unit) (arg : Arg) : ToMonoM Arg := do
3540 match arg with
3641 | .erased | .type .. => return .erased
3742 | .fvar fvarId =>
3843 if (← get).typeParams.contains fvarId then
3944 return .erased
4045 else
41- checkFVarUse fvarId
46+ check fvarId
4247 return arg
4348
44- def ctorAppToMono (ctorInfo : ConstructorVal) (args : Array Arg) : ToMonoM LetValue := do
45- let argsNew : Array Arg ← args[:ctorInfo.numParams].toArray.mapM fun arg => do
49+ def argToMono (arg : Arg) : ToMonoM Arg := argToMonoBase checkFVarUse arg
50+
51+ def argToMonoDeferredCheck (resultFVar : FVarId) (arg : Arg) : ToMonoM Arg :=
52+ argToMonoBase (checkFVarUseDeferred resultFVar) arg
53+
54+ def ctorAppToMono (resultFVar : FVarId) (ctorInfo : ConstructorVal) (args : Array Arg)
55+ : ToMonoM LetValue := do
56+ let argsNewParams : Array Arg ← args[:ctorInfo.numParams].toArray.mapM fun arg => do
4657 -- We only preserve constructor parameters that are types
4758 match arg with
4859 | .type type => return .type (← toMonoType type)
4960 | .fvar .. | .erased => return .erased
50- let argsNew := argsNew ++ (← args[ctorInfo.numParams:].toArray.mapM argToMono)
61+ let argsNewFields ← args[ctorInfo.numParams:].toArray.mapM (argToMonoDeferredCheck resultFVar)
62+ let argsNew := argsNewParams ++ argsNewFields
5163 return .const ctorInfo.name [] argsNew
5264
53- partial def LetValue.toMono (e : LetValue) (fvarId : FVarId) : ToMonoM LetValue := do
65+ partial def LetValue.toMono (e : LetValue) (resultFVar : FVarId) : ToMonoM LetValue := do
5466 match e with
5567 | .erased | .lit .. => return e
5668 | .const declName _ args =>
@@ -63,26 +75,25 @@ partial def LetValue.toMono (e : LetValue) (fvarId : FVarId) : ToMonoM LetValue
6375 -- and Bool have the same runtime representation.
6476 return args[1 ]!.toLetValue
6577 else if let some e' ← isTrivialConstructorApp? declName args then
66- e'.toMono fvarId
78+ e'.toMono resultFVar
6779 else if let some (.ctorInfo ctorInfo) := (← getEnv).find? declName then
68- ctorAppToMono ctorInfo args
80+ ctorAppToMono resultFVar ctorInfo args
6981 else
7082 let env ← getEnv
7183 if isNoncomputable env declName && !(isExtern env declName) then
72- modify fun s => { s with noncomputableVars := s.noncomputableVars.insert fvarId declName }
73- return .const declName [] (← args.mapM argToMono )
84+ modify fun s => { s with noncomputableVars := s.noncomputableVars.insert resultFVar declName }
85+ return .const declName [] (← args.mapM (argToMonoDeferredCheck resultFVar) )
7486 | .fvar fvarId args =>
7587 if (← get).typeParams.contains fvarId then
7688 return .erased
7789 else
78- checkFVarUse fvarId
79- return .fvar fvarId (← args.mapM argToMono )
90+ checkFVarUseDeferred resultFVar fvarId
91+ return .fvar fvarId (← args.mapM (argToMonoDeferredCheck resultFVar) )
8092 | .proj structName fieldIdx baseFVar =>
8193 if (← get).typeParams.contains baseFVar then
8294 return .erased
8395 else
84- if let some declName := (← get).noncomputableVars.get? baseFVar then
85- modify fun s => { s with noncomputableVars := s.noncomputableVars.insert fvarId declName }
96+ checkFVarUseDeferred resultFVar baseFVar
8697 if let some info ← hasTrivialStructure? structName then
8798 if info.fieldIdx == fieldIdx then
8899 return .fvar baseFVar #[]
0 commit comments