@@ -168,8 +168,8 @@ The keys never contain free variables or loose bound variables.
168168
169169/--
170170Given the specialization mask `paramsInfo` and the arguments `args`,
171- collect their dependencies, and return an array `mask` of size `paramsInfo .size` s.t.
172- - `mask[i] = some args[i]` if `paramsInfo[i] != .other`
171+ collect their dependencies, and return an array `mask` of size `args .size` s.t.
172+ - `mask[i] = some args[i]` if `paramsInfo[i]? != some .other`
173173- `mask[i] = none`, otherwise.
174174 That is, `mask` contains only the arguments that are contributing to the code specialization.
175175We use this information to compute a "key" to uniquely identify the code specialization, and
@@ -185,7 +185,9 @@ def collect (paramsInfo : Array SpecParamInfo) (args : Array Arg) : SpecializeM
185185 !ctx.ground.contains fvarId
186186 Closure.run (inScope := ctx.scope.contains) (abstract := abstract) do
187187 let mut argMask := #[]
188- for paramInfo in paramsInfo, arg in args do
188+ for i in *...args.size do
189+ let paramInfo := paramsInfo[i]?.getD .fixedHO
190+ let arg := args[i]!
189191 match paramInfo with
190192 | .other =>
191193 argMask := argMask.push none
@@ -200,7 +202,9 @@ end Collector
200202Return `true` if it is worth using arguments `args` for specialization given the parameter specialization information.
201203-/
202204def shouldSpecialize (paramsInfo : Array SpecParamInfo) (args : Array Arg) : SpecializeM Bool := do
203- for paramInfo in paramsInfo, arg in args do
205+ for i in *...args.size do
206+ let arg := args[i]!
207+ let paramInfo := paramsInfo[i]?.getD .fixedHO -- .fixedHO might be too aggressive
204208 match paramInfo with
205209 | .other => pure ()
206210 | .fixedNeutral => pure () -- If we want to monomorphize types such as `Array`, we need to change here
@@ -267,21 +271,54 @@ where
267271 let .code code := decl.value | panic! " can only specialize decls with code"
268272 let mut params ← params.mapM internalizeParam
269273 let decls ← decls.mapM internalizeCodeDecl
270- for param in decl.params, arg in argMask do
274+ let mut bodyType := decl.type.instantiateLevelParamsNoCache decl.levelParams us
275+ for arg in argMask, param in decl.params do
276+ let .forallE _ d b _ := bodyType.headBeta
277+ | panic! " has param of type {param.type}, but bodyType {bodyType} was not a forall"
271278 if let some arg := arg then
272279 let arg ← normArg arg
273280 modify fun s => s.insert param.fvarId arg
281+ bodyType := b.instantiate1 arg.toExpr
274282 else
275283 -- Keep the parameter
276- let param := { param with type := param.type.instantiateLevelParamsNoCache decl.levelParams us }
277- params := params.push (← internalizeParam param)
278- for param in decl.params[argMask.size...*] do
279- let param := { param with type := param.type.instantiateLevelParamsNoCache decl.levelParams us }
280- params := params.push (← internalizeParam param)
284+ let param ← internalizeParam { param with type := d }
285+ params := params.push param
286+ bodyType := b.instantiate1 (.fvar param.fvarId)
287+ let extraParams := decl.params[argMask.size...*] -- non-empty if undersaturated app
288+ let extraMask := argMask[decl.params.size...*] -- non-empty if oversaturated app
289+ -- Add extraneous parameters to decl
290+ for param in extraParams do
291+ let .forallE _ d b _ := bodyType.headBeta
292+ | panic! " has param of type {param.type}, but bodyType {bodyType} was not a forall"
293+ -- Keep the parameter
294+ let param ← internalizeParam { param with type := d }
295+ params := params.push param
296+ bodyType := b.instantiate1 (.fvar param.fvarId)
281297 let code := code.instantiateValueLevelParams decl.levelParams us
282298 let code ← internalizeCode code
283299 let code := attachCodeDecls decls code
284- let type ← code.inferType
300+ -- Eta-expand to accomodate extraneous args (cf. `etaExpandCore`)
301+ let code ←
302+ if extraMask.size = 0 then
303+ pure code
304+ else
305+ let mut extraArgs := #[]
306+ for arg in extraMask do
307+ let .forallE _ d b _ := bodyType.headBeta
308+ | panic! " oversaturated arg mask but decl.type was not a forall"
309+ if let some arg := arg then
310+ let arg ← normArg arg
311+ extraArgs := extraArgs.push arg
312+ bodyType := b.instantiate1 arg.toExpr
313+ else
314+ let p ← mkAuxParam d
315+ params := params.push p
316+ extraArgs := extraArgs.push (.fvar p.fvarId)
317+ bodyType := b.instantiate1 (.fvar p.fvarId)
318+ code.bind fun fvarId => do
319+ let auxDecl ← mkAuxLetDecl (.fvar fvarId extraArgs)
320+ return .let auxDecl (.return auxDecl.fvarId)
321+ let type := bodyType
285322 let type ← mkForallParams params type
286323 let value := .code code
287324 let safe := decl.safe
@@ -298,7 +335,7 @@ def getRemainingArgs (paramsInfo : Array SpecParamInfo) (args : Array Arg) : Arr
298335 for info in paramsInfo, arg in args do
299336 if info matches .other then
300337 result := result.push arg
301- return result ++ args[paramsInfo.size...*]
338+ return result -- ++ args[ paramsInfo.size...* ]
302339
303340def paramsToGroundVars (params : Array Param) : CompilerM FVarIdSet :=
304341 params.foldlM (init := {}) fun r p => do
0 commit comments