Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/Lean/Compiler/LCNF/ConfigOptions.lean
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ structure ConfigOptions where
Cache closed terms and evaluate them at initialization time.
-/
extractClosed : Bool := true
/--
Maximum number of times a definition tagged with `@[specialize]` can be recursively specialized
before generating an error during compilation.
-/
maxRecSpecialize : Nat := 64
deriving Inhabited

register_builtin_option compiler.small : Nat := {
Expand Down Expand Up @@ -66,12 +71,18 @@ register_builtin_option compiler.extract_closed : Bool := {
descr := "(compiler) enable/disable closed term caching"
}

register_builtin_option compiler.maxRecSpecialize : Nat := {
defValue := 64
descr := "(compiler) maximum number of times a definition tagged with `@[specialize]` can be recursively specialized before generating an error during compilation."
}

def toConfigOptions (opts : Options) : ConfigOptions := {
smallThreshold := compiler.small.get opts
maxRecInline := compiler.maxRecInline.get opts
maxRecInlineIfReduce := compiler.maxRecInlineIfReduce.get opts
checkTypes := compiler.checkTypes.get opts
extractClosed := compiler.extract_closed.get opts
maxRecSpecialize := compiler.maxRecSpecialize.get opts
}

end Lean.Compiler.LCNF
1 change: 1 addition & 0 deletions src/Lean/Compiler/LCNF/Passes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public import Lean.Compiler.LCNF.ElimDeadBranches
public import Lean.Compiler.LCNF.StructProjCases
public import Lean.Compiler.LCNF.ExtractClosed
public import Lean.Compiler.LCNF.Visibility
public import Lean.Compiler.LCNF.Simp

public section

Expand Down
102 changes: 74 additions & 28 deletions src/Lean/Compiler/LCNF/SpecInfo.lean
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ inductive SpecParamInfo where
| other
deriving Inhabited, Repr

namespace SpecParamInfo

@[inline]
def causesSpecialization : SpecParamInfo → Bool
| .fixedInst | .fixedHO | .user => true
| .fixedNeutral | .other => false

end SpecParamInfo

instance : ToMessageData SpecParamInfo where
toMessageData
| .fixedInst => "I"
Expand All @@ -53,20 +62,35 @@ instance : ToMessageData SpecParamInfo where
| .user => "U"
| .other => "O"

structure SpecState where
specInfo : PHashMap Name (Array SpecParamInfo) := {}
deriving Inhabited

structure SpecEntry where
/--
The name of the declaration.
-/
declName : Name
/--
Information about which parameters of the declaration qualify for specialization.
-/
paramsInfo : Array SpecParamInfo
/--
True if `declName` was already specialized before. This is relevant because we specialize
declarations that have already been specialized less aggressively than declarations that have not.
-/
alreadySpecialized : Bool
deriving Inhabited

instance : ToMessageData SpecEntry where
toMessageData := fun { declName, paramsInfo, alreadySpecialized } =>
m!"{declName}, alreadySpecialized? {alreadySpecialized}, info: {paramsInfo}"

structure SpecState where
specInfo : PHashMap Name SpecEntry := {}
deriving Inhabited

namespace SpecState

def addEntry (s : SpecState) (e : SpecEntry) : SpecState :=
match s with
| { specInfo } => { specInfo := specInfo.insert e.declName e.paramsInfo }
| { specInfo } => { specInfo := specInfo.insert e.declName e }

end SpecState

Expand All @@ -77,7 +101,7 @@ private abbrev sortEntries (entries : Array SpecEntry) : Array SpecEntry :=
entries.qsort declLt

private abbrev findAtSorted? (entries : Array SpecEntry) (declName : Name) : Option SpecEntry :=
entries.binSearch { declName, paramsInfo := #[] } declLt
entries.binSearch { declName, paramsInfo := #[], alreadySpecialized := false } declLt

/--
Extension for storing `SpecParamInfo` for declarations being compiled.
Expand Down Expand Up @@ -136,20 +160,23 @@ See comment at `.fixedNeutral`.
private def hasFwdDeps (decl : Decl) (paramsInfo : Array SpecParamInfo) (j : Nat) : Bool := Id.run do
let param := decl.params[j]!
for h : k in (j+1)...decl.params.size do
if paramsInfo[k]! matches .user | .fixedHO | .fixedInst then
if paramsInfo[k]!.causesSpecialization then
let param' := decl.params[k]
if param'.type.containsFVar param.fvarId then
return true
return false

/--
Save parameter information for `decls`.

Remark: this function, similarly to `mkFixedArgMap`,
assumes that if a function `f` was declared in a mutual block, then `decls`
contains all (computationally relevant) functions in the mutual block.
Compute specialization information for `decls`. We assume that `decls` contains a full SCC of
computationally relevant declarations. Furthermore this function takes:
- `autoSpecialize` which determines whether we apply "automated" specialization to a decl, that is
whether we automatically specialize for all fixedHO parameters. It receives both the name and
the array of arguments mentioned in `@[specialize]` if any.
- `alreadySpecialized` which is a mask that says whether a decl is already a specialized declaration
itself.
-/
def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do
def computeSpecEntries (decls : Array Decl) (autoSpecialize : Name → Option (Array Nat) → Bool)
(alreadySpecialized : Array Bool) : CompilerM (Array SpecEntry) := do
let mut declsInfo := #[]
for decl in decls do
if hasNospecializeAttribute (← getEnv) decl.name then
Expand Down Expand Up @@ -178,20 +205,20 @@ def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do
specify which arguments must be specialized besides instances. In this case, we try to specialize
any "fixed higher-order argument"
-/
else if specArgs? == some #[] && param.type matches .forallE .. then
else if autoSpecialize decl.name specArgs? && param.type matches .forallE .. then
pure .fixedHO
else
pure .other
paramsInfo := paramsInfo.push info
pure ()
declsInfo := declsInfo.push paramsInfo
if declsInfo.any fun paramsInfo => paramsInfo.any (· matches .user | .fixedInst | .fixedHO) then
if declsInfo.any fun paramsInfo => paramsInfo.any SpecParamInfo.causesSpecialization then
let m := mkFixedParamsMap decls
let mut entries := Array.emptyWithCapacity decls.size
for hi : i in *...decls.size do
let decl := decls[i]
let mut paramsInfo := declsInfo[i]!
let some mask := m.find? decl.name | unreachable!
trace[Compiler.specialize.info] "{decl.name} {mask}"
paramsInfo := Array.zipWith (as := paramsInfo) (bs := mask) fun info fixed =>
if fixed || info matches .user then
info
Expand All @@ -201,24 +228,43 @@ def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do
let mut info := paramsInfo[j]!
if info matches .fixedNeutral && !hasFwdDeps decl paramsInfo j then
paramsInfo := paramsInfo.set! j .other
if paramsInfo.any fun info => info matches .fixedInst | .fixedHO | .user then
trace[Compiler.specialize.info] "{decl.name} {paramsInfo}"
modifyEnv fun env => specExtension.addEntry env { declName := decl.name, paramsInfo }
entries := entries.push {
declName := decl.name,
paramsInfo,
alreadySpecialized := alreadySpecialized[i]!
}
return entries
else
return decls.mapIdx fun i decl => {
declName := decl.name,
paramsInfo := Array.replicate decl.params.size .other
alreadySpecialized := alreadySpecialized[i]!
}

def getSpecParamInfoCore? (env : Environment) (declName : Name) : Option (Array SpecParamInfo) :=
/--
Compute and save specialization information for `decls`. Assumes that `decls` is an SCC of user
defined declarations.
-/
def saveSpecEntries (decls : Array Decl) : CompilerM Unit := do
let entries ← computeSpecEntries
decls
(fun _ specArgs? => specArgs? == some #[])
(Array.replicate decls.size false)
for entry in entries do
if entry.paramsInfo.any SpecParamInfo.causesSpecialization then
trace[Compiler.specialize.info] "{entry.declName} {entry.paramsInfo}"
modifyEnv fun env => specExtension.addEntry env entry

def getSpecEntryCore? (env : Environment) (declName : Name) : Option SpecEntry :=
match env.getModuleIdxFor? declName with
| some modIdx =>
if let some entry := findAtSorted? (specExtension.getModuleEntries env modIdx) declName then
some entry.paramsInfo
else
none
| some modIdx => findAtSorted? (specExtension.getModuleEntries env modIdx) declName
| none => (specExtension.getState env).specInfo.find? declName

def getSpecParamInfo? [Monad m] [MonadEnv m] (declName : Name) : m (Option (Array SpecParamInfo)) :=
return getSpecParamInfoCore? (← getEnv) declName
def getSpecEntry? [Monad m] [MonadEnv m] (declName : Name) : m (Option SpecEntry) :=
return getSpecEntryCore? (← getEnv) declName

def isSpecCandidate [Monad m] [MonadEnv m] (declName : Name) : m Bool := do
return getSpecParamInfoCore? (← getEnv) declName |>.isSome
return getSpecEntryCore? (← getEnv) declName |>.isSome

builtin_initialize
registerTraceClass `Compiler.specialize.info
Expand Down
Loading
Loading