Skip to content

Commit fd5b87d

Browse files
committed
feat: first safe version
1 parent 6f353de commit fd5b87d

File tree

1 file changed

+56
-13
lines changed

1 file changed

+56
-13
lines changed

src/Lean/Compiler/LCNF/Specialize.lean

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ public import Lean.Compiler.LCNF.MonadScope
1414
public import Lean.Compiler.LCNF.Closure
1515
public import Lean.Compiler.LCNF.FVarUtil
1616
import all Lean.Compiler.LCNF.ToExpr
17+
import Std.Data.Iterators
1718

1819
public section
1920

@@ -203,19 +204,46 @@ end Collector
203204
/--
204205
Return `true` if it is worth using arguments `args` for specialization given the parameter specialization information.
205206
-/
206-
def shouldSpecialize (paramsInfo : Array SpecParamInfo) (args : Array Arg) : SpecializeM Bool := do
207+
def shouldSpecialize (declName : Name) (paramsInfo : Array SpecParamInfo) (args : Array Arg) :
208+
SpecializeM Bool := do
209+
let hoCheck :=
210+
if (← get).localSpecParamInfo.contains declName then
211+
fun arg => do
212+
/-
213+
If we have `f p` where `p` is a param it makes no sense to specialize as we will just
214+
close over `p` again and will have made no progress.
215+
216+
The reason for doing this only for declarations which have `localSpecParamInfo` (i.e. have
217+
already been specialised themselves) is, that we *must* always specialize declarations that
218+
are marked with `@[specialize]`. This is because the specializer will not specialize their
219+
bodies because it waits for the bodies to be specialized at the call site. This is for example
220+
important in the following situation:
221+
```
222+
@[specialize]
223+
def test (f : ... -> ...) :=
224+
...
225+
HashMap.get? inst1 inst2 xs ys
226+
```
227+
Here the call to `HashMap.get?` will not be specialized unless `test` is specialized. Thus,
228+
even when `f` is just going to be re-abstracted, it makes sense to specialize a call to `test`
229+
that closes over parameters, in order to optimize the `HashMap` invocation.
230+
231+
We thought about lifting this restriction and instead always specializing `@[specialize]`
232+
decls twice, once at their definition site and once at their call site. However, almost all
233+
`@[specialize]` function declarations will indeed get specialized properly. Hence keeping
234+
the first version around is likely a waste of space.
235+
-/
236+
match arg with
237+
| .erased | .type .. => return false
238+
| .fvar fvar => return (← findParam? fvar).isNone
239+
else
240+
fun _ => pure true
207241
for paramInfo in paramsInfo, arg in args do
208242
match paramInfo with
209243
| .other => pure ()
210244
| .fixedNeutral => pure () -- If we want to monomorphize types such as `Array`, we need to change here
211245
| .fixedInst | .user => if (← isGround arg) then return true
212-
| .fixedHO =>
213-
match arg with
214-
| .erased | .type .. => pure ()
215-
| .fvar fvar =>
216-
-- If we have `f p` where `p` is a param it makes no sense to specialize as we will just
217-
-- close over `p` again and will have made no progress.
218-
if (← findParam? fvar).isNone then return true
246+
| .fixedHO => if ← hoCheck arg then return true
219247

220248
return false
221249

@@ -341,12 +369,13 @@ mutual
341369
if args.isEmpty then return none
342370
if (← Meta.isInstance declName) then return none
343371
let some paramsInfo ← getSpecParamInfo? declName | return none
344-
unless (← shouldSpecialize paramsInfo args) do return none
372+
unless (← shouldSpecialize declName paramsInfo args) do return none
345373
let some decl ← getDecl? declName | return none
346374
let .code _ := decl.value | return none
347375
trace[Compiler.specialize.candidate] "{e.toExpr}, {paramsInfo}"
348376
let (argMask, params, decls) ← Collector.collect paramsInfo args
349-
let keyBody := .const declName us (argMask.filterMap id)
377+
let targetArgs := argMask.filterMap id
378+
let keyBody := .const declName us targetArgs
350379
let (key, levelParamsNew) ← mkKey params decls keyBody
351380
trace[Compiler.specialize.candidate] "key: {key}"
352381
assert! !key.hasLooseBVars
@@ -358,18 +387,33 @@ mutual
358387
return some (.const declName usNew argsNew)
359388
else
360389
let specDecl ← mkSpecDecl decl us argMask params decls levelParamsNew
390+
let targetParams : Std.HashSet Arg ←
391+
args.iterM SpecializeM
392+
|>.zip (paramsInfo.iterM _)
393+
|>.foldM (init := {}) fun acc (arg, info) => do
394+
match info with
395+
| .fixedInst | .fixedNeutral | .other => return acc
396+
| .fixedHO | .user =>
397+
match arg with
398+
| .type .. | .erased => return acc
399+
| .fvar fvar =>
400+
if (← findParam? fvar).isSome then
401+
return acc.insert arg
402+
else
403+
return acc
404+
let parentMask := argsNew.map targetParams.contains
361405
trace[Compiler.specialize.step] "new: {specDecl.name}: {← ppDecl specDecl}"
362406
cacheSpec key specDecl.name
363407
specDecl.saveBase
364408
let specDecl ← specDecl.etaExpand
365409
specDecl.saveBase
366410
let specDecl ← specDecl.simp {}
367411
let specDecl ← specDecl.simp { etaPoly := true, inlinePartial := true, implementedBy := true }
412+
368413
modify fun s => {
369414
s with
370415
workingDecls := s.workingDecls.push specDecl,
371-
-- TODO: correct mask
372-
parentMask := s.parentMask.insert specDecl.name (Array.replicate specDecl.params.size true)
416+
parentMask := s.parentMask.insert specDecl.name parentMask
373417
}
374418
return some (.const specDecl.name usNew argsNew)
375419

@@ -406,7 +450,6 @@ end
406450

407451
def specializeDecl (decl : Decl) : SpecializeM (Decl × Bool) := do
408452
trace[Compiler.specialize.step] m!"Working {decl.name}"
409-
-- TODO: conside a different heuristic here
410453
if (← decl.isTemplateLike) then
411454
return (decl, false)
412455
else

0 commit comments

Comments
 (0)