Skip to content

Commit bded378

Browse files
authored
feat(Pipeline): Optimize pipelines directly with optimize() (#230)
1 parent 4198de7 commit bded378

14 files changed

Lines changed: 1108 additions & 14 deletions

File tree

src/amltk/_richutil/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from amltk._richutil.renderable import RichRenderable
22
from amltk._richutil.renderers import Function, rich_make_column_selector
3-
from amltk._richutil.util import df_to_table, richify
3+
from amltk._richutil.util import df_to_table, is_jupyter, richify
44

55
__all__ = [
66
"df_to_table",
77
"richify",
88
"RichRenderable",
99
"Function",
1010
"rich_make_column_selector",
11+
"is_jupyter",
1112
]

src/amltk/_richutil/util.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# where rich not being installed.
44
from __future__ import annotations
55

6+
import os
67
from concurrent.futures import ProcessPoolExecutor
78
from typing import TYPE_CHECKING, Any
89

@@ -70,3 +71,25 @@ def df_to_table(
7071
table.add_row(str(index), *[str(cell) for cell in row])
7172

7273
return table
74+
75+
76+
def is_jupyter() -> bool:
77+
"""Return True if running in a Jupyter environment."""
78+
# https://github.com/Textualize/rich/blob/fd981823644ccf50d685ac9c0cfe8e1e56c9dd35/rich/console.py#L518-L535
79+
try:
80+
get_ipython # type: ignore[name-defined] # noqa: B018
81+
except NameError:
82+
return False
83+
ipython = get_ipython() # type: ignore[name-defined] # noqa: F821
84+
shell = ipython.__class__.__name__
85+
if (
86+
"google.colab" in str(ipython.__class__)
87+
or os.getenv("DATABRICKS_RUNTIME_VERSION")
88+
or shell == "ZMQInteractiveShell"
89+
):
90+
return True # Jupyter notebook or qtconsole
91+
92+
if shell == "TerminalInteractiveShell":
93+
return False # Terminal running IPython
94+
95+
return False # Other type (?)

src/amltk/_util.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
6+
def threadpoolctl_heuristic(item_contained_in_node: Any | None) -> bool:
7+
"""Heuristic to determine if we should automatically set threadpoolctl.
8+
9+
This is done by detecting if it's a scikit-learn `BaseEstimator` but this may
10+
be extended in the future.
11+
12+
!!! tip
13+
14+
The reason to have this heuristic is that when running scikit-learn, or any
15+
multithreaded model, in parallel, they will over subscribe to threads. This
16+
causes a significant performance hit as most of the time is spent switching
17+
thread contexts instead of work. This can be particularly bad for HPO where
18+
we are evaluating multiple models in parallel on the same system.
19+
20+
The recommened thread count is 1 per core with no additional information to
21+
act upon.
22+
23+
!!! todo
24+
25+
This is potentially not an issue if running on multiple nodes of some cluster,
26+
as they do not share logical cores and hence do not clash.
27+
28+
Args:
29+
item_contained_in_node: The item with which to base the heuristic on.
30+
31+
Returns:
32+
Whether we should automatically set threadpoolctl.
33+
"""
34+
if item_contained_in_node is None or not isinstance(item_contained_in_node, type):
35+
return False
36+
37+
try:
38+
# NOTE: sklearn depends on threadpoolctl so it will be installed.
39+
from sklearn.base import BaseEstimator
40+
41+
return issubclass(item_contained_in_node, BaseEstimator)
42+
except ImportError:
43+
return False

src/amltk/evalutors/__init__.py

Whitespace-only changes.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Evaluation protocols for how a trial and a pipeline should be evaluated.
2+
3+
TODO: Sorry
4+
"""
5+
from __future__ import annotations
6+
7+
from collections.abc import Callable, Iterable
8+
from typing import TYPE_CHECKING
9+
10+
from amltk.scheduling import Plugin
11+
12+
if TYPE_CHECKING:
13+
from amltk.optimization import Trial
14+
from amltk.pipeline import Node
15+
from amltk.scheduling import Scheduler, Task
16+
17+
18+
class EvaluationProtocol:
19+
"""A protocol for how a trial should be evaluated on a pipeline."""
20+
21+
fn: Callable[[Trial, Node], Trial.Report]
22+
23+
def task(
24+
self,
25+
scheduler: Scheduler,
26+
plugins: Plugin | Iterable[Plugin] | None = None,
27+
) -> Task[[Trial, Node], Trial.Report]:
28+
"""Create a task for this protocol.
29+
30+
Args:
31+
scheduler: The scheduler to use for the task.
32+
plugins: The plugins to use for the task.
33+
34+
Returns:
35+
The created task.
36+
"""
37+
_plugins: tuple[Plugin, ...]
38+
match plugins:
39+
case None:
40+
_plugins = ()
41+
case Plugin():
42+
_plugins = (plugins,)
43+
case Iterable():
44+
_plugins = tuple(plugins)
45+
46+
return scheduler.task(self.fn, plugins=_plugins)
47+
48+
49+
class CustomProtocol(EvaluationProtocol):
50+
"""A custom evaluation protocol based on a user function."""
51+
52+
def __init__(self, fn: Callable[[Trial, Node], Trial.Report]) -> None:
53+
"""Initialize the protocol.
54+
55+
Args:
56+
fn: The function to use for the evaluation.
57+
"""
58+
super().__init__()
59+
self.fn = fn

src/amltk/optimization/history.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def target_function(trial: Trial) -> Trial.Report:
6565
from collections import defaultdict
6666
from collections.abc import Callable, Hashable, Iterable, Iterator
6767
from dataclasses import dataclass, field
68-
from typing import TYPE_CHECKING, Literal, TypeVar
68+
from typing import TYPE_CHECKING, Literal, TypeVar, overload
6969
from typing_extensions import override
7070

7171
import pandas as pd
@@ -527,7 +527,14 @@ def sortby(
527527

528528
return sorted(history.reports, key=sort_key, reverse=reverse)
529529

530-
@override
530+
@overload
531+
def __getitem__(self, key: int | str) -> Trial.Report:
532+
...
533+
534+
@overload
535+
def __getitem__(self, key: slice) -> Trial.Report:
536+
...
537+
531538
def __getitem__( # type: ignore
532539
self,
533540
key: int | str | slice,

src/amltk/optimization/optimizer.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,22 @@
1515
"""
1616
from __future__ import annotations
1717

18+
import logging
1819
from abc import abstractmethod
19-
from collections.abc import Callable, Iterable, Sequence
20+
from collections.abc import Callable, Iterable, Iterator, Sequence
2021
from datetime import datetime
22+
from pathlib import Path
2123
from typing import (
2224
TYPE_CHECKING,
2325
Any,
2426
Concatenate,
2527
Generic,
2628
ParamSpec,
29+
Protocol,
2730
TypeVar,
2831
overload,
2932
)
33+
from typing_extensions import Self
3034

3135
from more_itertools import all_unique
3236

@@ -36,11 +40,14 @@
3640
from amltk.optimization.metric import Metric
3741
from amltk.optimization.trial import Trial
3842
from amltk.pipeline import Node
43+
from amltk.types import Seed
3944

4045
I = TypeVar("I") # noqa: E741
4146
P = ParamSpec("P")
4247
ParserOutput = TypeVar("ParserOutput")
4348

49+
logger = logging.getLogger(__name__)
50+
4451

4552
class Optimizer(Generic[I]):
4653
"""An optimizer protocol.
@@ -123,3 +130,91 @@ def preferred_parser(
123130
124131
"""
125132
return None
133+
134+
@classmethod
135+
@abstractmethod
136+
def create(
137+
cls,
138+
*,
139+
space: Node,
140+
metrics: Metric | Sequence[Metric],
141+
bucket: str | Path | PathBucket | None = None,
142+
seed: Seed | None = None,
143+
) -> Self:
144+
"""Create this optimizer.
145+
146+
!!! note
147+
148+
Subclasses should override this with more specific configuration
149+
but these arguments should be all that's necessary to create the optimizer.
150+
151+
Args:
152+
space: The space to optimize over.
153+
bucket: The bucket for where to store things related to the trial.
154+
metrics: The metrics to optimize.
155+
seed: The seed to use for the optimizer.
156+
157+
Returns:
158+
The optimizer.
159+
"""
160+
161+
class CreateSignature(Protocol):
162+
"""A Protocol which defines the keywords required to create an
163+
optimizer with deterministic behavior at a desired location.
164+
165+
This protocol matches the `Optimizer.create` classmethod, however we also
166+
allow any function which accepts the keyword arguments to create an
167+
Optimizer.
168+
"""
169+
170+
def __call__(
171+
self,
172+
*,
173+
space: Node,
174+
metrics: Metric | Sequence[Metric],
175+
bucket: PathBucket | None = None,
176+
seed: Seed | None = None,
177+
) -> Optimizer:
178+
"""A function which creates an optimizer for node.optimize should
179+
accept the following keyword arguments.
180+
181+
Args:
182+
space: The node to optimize
183+
metrics: The metrics to optimize
184+
bucket: The bucket to store the results in
185+
seed: The seed to use for the optimization
186+
"""
187+
...
188+
189+
@classmethod
190+
def _get_known_importable_optimizer_classes(cls) -> Iterator[type[Optimizer]]:
191+
"""Get all developer known optimizer classes. This is used for defaults.
192+
193+
Do not rely on this functionality and prefer to give concrete optimizers to
194+
functionality requiring one. This is intended for convenience of particular
195+
quickstart methods.
196+
"""
197+
# NOTE: We can't use the `Optimizer.__subclasses__` method as the optimizers
198+
# are not imported by any other module initially and so they do no exist
199+
# until imported. Hence this manual iteration. For now, we be explicit and
200+
# only if the optimizer list grows should we consider dynamic importing.
201+
try:
202+
from amltk.optimization.optimizers.smac import SMACOptimizer
203+
204+
yield SMACOptimizer
205+
except ImportError as e:
206+
logger.debug("Failed to import SMACOptimizer", exc_info=e)
207+
208+
try:
209+
from amltk.optimization.optimizers.optuna import OptunaOptimizer
210+
211+
yield OptunaOptimizer
212+
except ImportError as e:
213+
logger.debug("Failed to import OptunaOptimizer", exc_info=e)
214+
215+
try:
216+
from amltk.optimization.optimizers.neps import NEPSOptimizer
217+
218+
yield NEPSOptimizer
219+
except ImportError as e:
220+
logger.debug("Failed to import NEPSOptimizer", exc_info=e)

src/amltk/optimization/optimizers/neps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def __init__(
249249
self,
250250
*,
251251
space: SearchSpace,
252-
loss_metric: Metric,
252+
loss_metric: Metric | Sequence[Metric],
253253
cost_metric: Metric | None = None,
254254
optimizer: BaseOptimizer,
255255
working_dir: Path,
@@ -307,7 +307,7 @@ def create( # noqa: PLR0913
307307
| Mapping[str, ConfigurationSpace | Parameter]
308308
| Node
309309
),
310-
metrics: Metric,
310+
metrics: Metric | Sequence[Metric],
311311
cost_metric: Metric | None = None,
312312
bucket: PathBucket | str | Path | None = None,
313313
searcher: str | BaseOptimizer = "default",

0 commit comments

Comments
 (0)