|
15 | 15 | """ |
16 | 16 | from __future__ import annotations |
17 | 17 |
|
| 18 | +import logging |
18 | 19 | from abc import abstractmethod |
19 | | -from collections.abc import Callable, Iterable, Sequence |
| 20 | +from collections.abc import Callable, Iterable, Iterator, Sequence |
20 | 21 | from datetime import datetime |
| 22 | +from pathlib import Path |
21 | 23 | from typing import ( |
22 | 24 | TYPE_CHECKING, |
23 | 25 | Any, |
24 | 26 | Concatenate, |
25 | 27 | Generic, |
26 | 28 | ParamSpec, |
| 29 | + Protocol, |
27 | 30 | TypeVar, |
28 | 31 | overload, |
29 | 32 | ) |
| 33 | +from typing_extensions import Self |
30 | 34 |
|
31 | 35 | from more_itertools import all_unique |
32 | 36 |
|
|
36 | 40 | from amltk.optimization.metric import Metric |
37 | 41 | from amltk.optimization.trial import Trial |
38 | 42 | from amltk.pipeline import Node |
| 43 | + from amltk.types import Seed |
39 | 44 |
|
40 | 45 | I = TypeVar("I") # noqa: E741 |
41 | 46 | P = ParamSpec("P") |
42 | 47 | ParserOutput = TypeVar("ParserOutput") |
43 | 48 |
|
| 49 | +logger = logging.getLogger(__name__) |
| 50 | + |
44 | 51 |
|
45 | 52 | class Optimizer(Generic[I]): |
46 | 53 | """An optimizer protocol. |
@@ -123,3 +130,91 @@ def preferred_parser( |
123 | 130 |
|
124 | 131 | """ |
125 | 132 | return None |
| 133 | + |
| 134 | + @classmethod |
| 135 | + @abstractmethod |
| 136 | + def create( |
| 137 | + cls, |
| 138 | + *, |
| 139 | + space: Node, |
| 140 | + metrics: Metric | Sequence[Metric], |
| 141 | + bucket: str | Path | PathBucket | None = None, |
| 142 | + seed: Seed | None = None, |
| 143 | + ) -> Self: |
| 144 | + """Create this optimizer. |
| 145 | +
|
| 146 | + !!! note |
| 147 | +
|
| 148 | + Subclasses should override this with more specific configuration |
| 149 | + but these arguments should be all that's necessary to create the optimizer. |
| 150 | +
|
| 151 | + Args: |
| 152 | + space: The space to optimize over. |
| 153 | + bucket: The bucket for where to store things related to the trial. |
| 154 | + metrics: The metrics to optimize. |
| 155 | + seed: The seed to use for the optimizer. |
| 156 | +
|
| 157 | + Returns: |
| 158 | + The optimizer. |
| 159 | + """ |
| 160 | + |
| 161 | + class CreateSignature(Protocol): |
| 162 | + """A Protocol which defines the keywords required to create an |
| 163 | + optimizer with deterministic behavior at a desired location. |
| 164 | +
|
| 165 | + This protocol matches the `Optimizer.create` classmethod, however we also |
| 166 | + allow any function which accepts the keyword arguments to create an |
| 167 | + Optimizer. |
| 168 | + """ |
| 169 | + |
| 170 | + def __call__( |
| 171 | + self, |
| 172 | + *, |
| 173 | + space: Node, |
| 174 | + metrics: Metric | Sequence[Metric], |
| 175 | + bucket: PathBucket | None = None, |
| 176 | + seed: Seed | None = None, |
| 177 | + ) -> Optimizer: |
| 178 | + """A function which creates an optimizer for node.optimize should |
| 179 | + accept the following keyword arguments. |
| 180 | +
|
| 181 | + Args: |
| 182 | + space: The node to optimize |
| 183 | + metrics: The metrics to optimize |
| 184 | + bucket: The bucket to store the results in |
| 185 | + seed: The seed to use for the optimization |
| 186 | + """ |
| 187 | + ... |
| 188 | + |
| 189 | + @classmethod |
| 190 | + def _get_known_importable_optimizer_classes(cls) -> Iterator[type[Optimizer]]: |
| 191 | + """Get all developer known optimizer classes. This is used for defaults. |
| 192 | +
|
| 193 | + Do not rely on this functionality and prefer to give concrete optimizers to |
| 194 | + functionality requiring one. This is intended for convenience of particular |
| 195 | + quickstart methods. |
| 196 | + """ |
| 197 | + # NOTE: We can't use the `Optimizer.__subclasses__` method as the optimizers |
| 198 | + # are not imported by any other module initially and so they do no exist |
| 199 | + # until imported. Hence this manual iteration. For now, we be explicit and |
| 200 | + # only if the optimizer list grows should we consider dynamic importing. |
| 201 | + try: |
| 202 | + from amltk.optimization.optimizers.smac import SMACOptimizer |
| 203 | + |
| 204 | + yield SMACOptimizer |
| 205 | + except ImportError as e: |
| 206 | + logger.debug("Failed to import SMACOptimizer", exc_info=e) |
| 207 | + |
| 208 | + try: |
| 209 | + from amltk.optimization.optimizers.optuna import OptunaOptimizer |
| 210 | + |
| 211 | + yield OptunaOptimizer |
| 212 | + except ImportError as e: |
| 213 | + logger.debug("Failed to import OptunaOptimizer", exc_info=e) |
| 214 | + |
| 215 | + try: |
| 216 | + from amltk.optimization.optimizers.neps import NEPSOptimizer |
| 217 | + |
| 218 | + yield NEPSOptimizer |
| 219 | + except ImportError as e: |
| 220 | + logger.debug("Failed to import NEPSOptimizer", exc_info=e) |
0 commit comments