@@ -14,6 +14,7 @@ public import Lean.Compiler.LCNF.MonadScope
1414public import Lean.Compiler.LCNF.Closure
1515public import Lean.Compiler.LCNF.FVarUtil
1616import all Lean.Compiler.LCNF.ToExpr
17+ import Std.Data.Iterators
1718
1819public section
1920
@@ -203,19 +204,46 @@ end Collector
203204/--
204205Return `true` if it is worth using arguments `args` for specialization given the parameter specialization information.
205206-/
206- def shouldSpecialize (paramsInfo : Array SpecParamInfo) (args : Array Arg) : SpecializeM Bool := do
207+ def shouldSpecialize (declName : Name) (paramsInfo : Array SpecParamInfo) (args : Array Arg) :
208+ SpecializeM Bool := do
209+ let hoCheck :=
210+ if (← get).localSpecParamInfo.contains declName then
211+ fun arg => do
212+ /-
213+ If we have `f p` where `p` is a param it makes no sense to specialize as we will just
214+ close over `p` again and will have made no progress.
215+
216+ The reason for doing this only for declarations which have `localSpecParamInfo` (i.e. have
217+ already been specialised themselves) is, that we *must* always specialize declarations that
218+ are marked with `@[specialize]`. This is because the specializer will not specialize their
219+ bodies because it waits for the bodies to be specialized at the call site. This is for example
220+ important in the following situation:
221+ ```
222+ @[specialize]
223+ def test (f : ... -> ...) :=
224+ ...
225+ HashMap.get? inst1 inst2 xs ys
226+ ```
227+ Here the call to `HashMap.get?` will not be specialized unless `test` is specialized. Thus,
228+ even when `f` is just going to be re-abstracted, it makes sense to specialize a call to `test`
229+ that closes over parameters, in order to optimize the `HashMap` invocation.
230+
231+ We thought about lifting this restriction and instead always specializing `@[specialize]`
232+ decls twice, once at their definition site and once at their call site. However, almost all
233+ `@[specialize]` function declarations will indeed get specialized properly. Hence keeping
234+ the first version around is likely a waste of space.
235+ -/
236+ match arg with
237+ | .erased | .type .. => return false
238+ | .fvar fvar => return (← findParam? fvar).isNone
239+ else
240+ fun _ => pure true
207241 for paramInfo in paramsInfo, arg in args do
208242 match paramInfo with
209243 | .other => pure ()
210244 | .fixedNeutral => pure () -- If we want to monomorphize types such as `Array`, we need to change here
211245 | .fixedInst | .user => if (← isGround arg) then return true
212- | .fixedHO =>
213- match arg with
214- | .erased | .type .. => pure ()
215- | .fvar fvar =>
216- -- If we have `f p` where `p` is a param it makes no sense to specialize as we will just
217- -- close over `p` again and will have made no progress.
218- if (← findParam? fvar).isNone then return true
246+ | .fixedHO => if ← hoCheck arg then return true
219247
220248 return false
221249
@@ -341,12 +369,13 @@ mutual
341369 if args.isEmpty then return none
342370 if (← Meta.isInstance declName) then return none
343371 let some paramsInfo ← getSpecParamInfo? declName | return none
344- unless (← shouldSpecialize paramsInfo args) do return none
372+ unless (← shouldSpecialize declName paramsInfo args) do return none
345373 let some decl ← getDecl? declName | return none
346374 let .code _ := decl.value | return none
347375 trace[Compiler.specialize.candidate] " {e.toExpr}, {paramsInfo}"
348376 let (argMask, params, decls) ← Collector.collect paramsInfo args
349- let keyBody := .const declName us (argMask.filterMap id)
377+ let targetArgs := argMask.filterMap id
378+ let keyBody := .const declName us targetArgs
350379 let (key, levelParamsNew) ← mkKey params decls keyBody
351380 trace[Compiler.specialize.candidate] " key: {key}"
352381 assert! !key.hasLooseBVars
@@ -358,18 +387,33 @@ mutual
358387 return some (.const declName usNew argsNew)
359388 else
360389 let specDecl ← mkSpecDecl decl us argMask params decls levelParamsNew
390+ let targetParams : Std.HashSet Arg ←
391+ args.iterM SpecializeM
392+ |>.zip (paramsInfo.iterM _)
393+ |>.foldM (init := {}) fun acc (arg, info) => do
394+ match info with
395+ | .fixedInst | .fixedNeutral | .other => return acc
396+ | .fixedHO | .user =>
397+ match arg with
398+ | .type .. | .erased => return acc
399+ | .fvar fvar =>
400+ if (← findParam? fvar).isSome then
401+ return acc.insert arg
402+ else
403+ return acc
404+ let parentMask := argsNew.map targetParams.contains
361405 trace[Compiler.specialize.step] " new: {specDecl.name}: {← ppDecl specDecl}"
362406 cacheSpec key specDecl.name
363407 specDecl.saveBase
364408 let specDecl ← specDecl.etaExpand
365409 specDecl.saveBase
366410 let specDecl ← specDecl.simp {}
367411 let specDecl ← specDecl.simp { etaPoly := true, inlinePartial := true, implementedBy := true }
412+
368413 modify fun s => {
369414 s with
370415 workingDecls := s.workingDecls.push specDecl,
371- -- TODO: correct mask
372- parentMask := s.parentMask.insert specDecl.name (Array.replicate specDecl.params.size true)
416+ parentMask := s.parentMask.insert specDecl.name parentMask
373417 }
374418 return some (.const specDecl.name usNew argsNew)
375419
406450
407451def specializeDecl (decl : Decl) : SpecializeM (Decl × Bool) := do
408452 trace[Compiler.specialize.step] m!" Working {decl.name}"
409- -- TODO: conside a different heuristic here
410453 if (← decl.isTemplateLike) then
411454 return (decl, false)
412455 else
0 commit comments