@@ -403,8 +403,7 @@ mutual
403403 trace[Compiler.specialize.candidate] "{e.toExpr}, {specEntry}"
404404 let paramsInfo := specEntry.paramsInfo
405405 let (argMask, params, decls) ← Collector.collect paramsInfo args
406- let targetArgs := argMask.filterMap id
407- let keyBody := .const declName us targetArgs
406+ let keyBody := .const declName us (argMask.filterMap id)
408407 let (key, levelParamsNew) ← mkKey params decls keyBody
409408 trace[Compiler.specialize.candidate] "key: {key}"
410409 assert! !key.hasLooseBVars
@@ -421,6 +420,11 @@ mutual
421420 | .type .. | .erased => return false
422421 | .fvar fvar => do
423422 if let some param ← findParam? fvar then
423+ /-
424+ For now we only allow recursive specialization on non class parameters, reason:
425+ We can encounter situations where we repeatedly re-abstract over type classes
426+ recursively and would end up in a loop because of that.
427+ -/
424428 return (param.type matches .forallE ..) && !(← isArrowClass? param.type).isSome
425429 else
426430 return false
@@ -500,7 +504,7 @@ def updateLocalSpecParamInfo : SpecializeM Unit := do
500504 for entry in infos do
501505 if let some mask := (← get).parentMasks[entry.declName]? then
502506 let maskInfo info :=
503- mask.zipWith info (f := fun b i => if i.causesSpecialization && !b then .other else i)
507+ mask.zipWith info (f := fun b i => if !b && i.causesSpecialization then .other else i)
504508 let entry := { entry with paramsInfo := maskInfo entry.paramsInfo }
505509 modify fun s => {
506510 s with
@@ -509,24 +513,22 @@ def updateLocalSpecParamInfo : SpecializeM Unit := do
509513
510514 trace[Compiler.specialize.step] m!"Info for next round: {(← get).localSpecParamInfo.toList}"
511515
512- def endOfLoop : SpecializeM Unit := do
513- for (declName, paramsInfo) in (← get).localSpecParamInfo do
514- if paramsInfo.any SpecParamInfo.causesSpecialization then
515- trace[Compiler.specialize.info] "{declName} {paramsInfo}"
516- modifyEnv fun env => specExtension.addEntry env {
517- declName,
518- paramsInfo,
519- alreadySpecialized := true
520- }
521-
522516partial def loop (round : Nat := 0 ) : SpecializeM Unit := do
523517 let targets ← modifyGet (fun s => (s.workingDecls, { s with workingDecls := #[] }))
518+ let limit := (← getConfig).maxRecSpecialize
524519 if targets.isEmpty then
525520 trace[Compiler.specialize.step] m!"Termination after {round} rounds"
526- endOfLoop
521+ for (declName, paramsInfo) in (← get).localSpecParamInfo do
522+ if paramsInfo.any SpecParamInfo.causesSpecialization then
523+ trace[Compiler.specialize.info] "{declName} {paramsInfo}"
524+ modifyEnv fun env => specExtension.addEntry env {
525+ declName,
526+ paramsInfo,
527+ alreadySpecialized := true
528+ }
527529 return ()
528- else if round > (← getConfig).maxRecSpecialize then
529- throwError "Lost in specialization"
530+ else if round > limit then
531+ throwError m! "Exceeded recursive specialization limit ({limit}), consider increasing it with `set_option compiler.maxRecSpecialize {limit}` "
530532
531533 trace[Compiler.specialize.step] m!"Round: {round}"
532534 for decl in targets do
0 commit comments