Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
a7d59aa
Add ray tune, optuna
toby-coleman Apr 5, 2025
12e2a64
Schema for optimisation
toby-coleman Apr 6, 2025
0e52b32
Missing docstring
toby-coleman Apr 6, 2025
bb77a2d
Reorder attributes
toby-coleman Apr 6, 2025
865c914
Start implementation of Tuner
toby-coleman Apr 6, 2025
24adb5a
Partial implementation
toby-coleman Apr 8, 2025
7bcdc9a
Build parameters
toby-coleman Apr 8, 2025
f24caf5
Update build_parameter
toby-coleman Apr 8, 2025
a8feb30
Merge remote-tracking branch 'origin/main' into feat/optimisation
toby-coleman Apr 13, 2025
cd08f0e
Update pyproject/lock file
toby-coleman Apr 13, 2025
4bae1a5
Upgrade lockfile
toby-coleman Apr 13, 2025
9cca2be
Initial override work
toby-coleman Apr 15, 2025
7e2edff
Change to BaseFieldSpec
toby-coleman Apr 27, 2025
a7ea30d
FieldSpec no longer ABC
toby-coleman Apr 27, 2025
58647e4
Update objective schemas
toby-coleman Apr 29, 2025
5f5c07e
Revert changes to schemas
toby-coleman Apr 29, 2025
7d24e62
Tune implementation
toby-coleman May 5, 2025
f40bb02
Add some logging
toby-coleman May 5, 2025
509e603
Rename object -> object_type
toby-coleman May 5, 2025
59e8267
Change default on objective spec
toby-coleman May 5, 2025
9ca2d2a
Merge remote-tracking branch 'origin/main' into feat/optimisation
toby-coleman May 5, 2025
9b5d790
Add optuna dependency
toby-coleman May 5, 2025
35b5377
Basic test passing
toby-coleman May 5, 2025
e91e098
Make run sync
toby-coleman May 5, 2025
61fac58
Refactor async runner
toby-coleman May 5, 2025
b161ab6
Fail test on trial failures
toby-coleman May 5, 2025
2cec8b2
Test optimisation result, component registry issues
toby-coleman May 5, 2025
a5cf152
Add optuna to test requirements
toby-coleman May 5, 2025
3d149f9
Add ray tune to test requirements
toby-coleman May 5, 2025
f525aec
Test both directions
toby-coleman May 5, 2025
ced66ec
Improve ComponentRegistry handling
toby-coleman May 5, 2025
7ead8c1
Delete unused code
toby-coleman May 5, 2025
c168041
Test tune one ray - requires namespace on RayProcess
toby-coleman May 17, 2025
3cdbaab
Add no cover
toby-coleman May 17, 2025
a5ccb43
More no cover
toby-coleman May 17, 2025
14b0c60
Merge remote-tracking branch 'origin/main' into feat/optimisation
toby-coleman May 17, 2025
dd6a857
Upgrade ray version
toby-coleman May 17, 2025
fd5ce8b
Recreate lock file
toby-coleman May 17, 2025
8edec11
Try again
toby-coleman May 17, 2025
804d04c
Reinstate version from main
toby-coleman May 17, 2025
9fe2d99
Try again
toby-coleman May 17, 2025
92834f0
Update type ignore
toby-coleman May 17, 2025
6b6beb1
Test multi-objective case
toby-coleman May 18, 2025
5ea8b28
Update test
toby-coleman May 19, 2025
9dc7d57
Unit tests for Tuner
toby-coleman May 19, 2025
3995607
List of results for multi-objective
toby-coleman May 19, 2025
710a98b
Additional tests for schemas
toby-coleman May 26, 2025
045467b
Fix type
toby-coleman May 26, 2025
772e771
Improve coverage
toby-coleman May 26, 2025
0e2aa56
Add CLI command
toby-coleman May 26, 2025
04cd715
Fix typing issue
toby-coleman May 26, 2025
1a87923
Support load of config from file
toby-coleman May 26, 2025
de62c93
Nocover on field_type
toby-coleman May 26, 2025
500aae6
Update namespace
toby-coleman May 28, 2025
875fd8c
Fix error in type
toby-coleman Jun 3, 2025
76eaaaf
Simplify import
toby-coleman Jun 3, 2025
b7c7dd8
Comment on default concurrency
toby-coleman Jun 3, 2025
397592d
Fix typo
toby-coleman Jun 3, 2025
ab96436
Apply suggestions from code review
toby-coleman Jun 8, 2025
7bd476b
Revert change
toby-coleman Jun 8, 2025
f1424ae
Increase samples on test
toby-coleman Jun 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions plugboard/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
from .io import IODirection
from .process import ProcessArgsDict, ProcessArgsSpec, ProcessSpec
from .state import StateBackendArgsDict, StateBackendArgsSpec, StateBackendSpec
from .tune import (
Direction,
ObjectiveSpec,
OptunaSpec,
ParameterSpec,
TuneArgsDict,
TuneArgsSpec,
TuneSpec,
)


__all__ = [
Expand All @@ -33,13 +42,20 @@
"ConnectorMode",
"ConnectorSocket",
"ConnectorSpec",
"Direction",
"Entity",
"IODirection",
"ObjectiveSpec",
"OptunaSpec",
"ParameterSpec",
"ProcessConfigSpec",
"ProcessSpec",
"ProcessArgsDict",
"ProcessArgsSpec",
"StateBackendSpec",
"StateBackendArgsDict",
"StateBackendArgsSpec",
"TuneArgsDict",
"TuneArgsSpec",
"TuneSpec",
]
39 changes: 38 additions & 1 deletion plugboard/schemas/_common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,48 @@
"""Common classes for Plugboard schemas."""

from abc import ABC
import typing as _t

from pydantic import BaseModel, ConfigDict


class PlugboardBaseModel(BaseModel, ABC):
"""Custom base model for Plugboard schemas."""

model_config = ConfigDict(extra="forbid", populate_by_name=True, use_enum_values=True)
model_config = ConfigDict(
extra="forbid", populate_by_name=True, use_enum_values=True, validate_assignment=True
)

def override(self, location: str, value: _t.Any) -> None:
"""Set the value of the attribute at the namespaces `location`."""
# if "." not in location:
# # Set the field on this object
# setattr(self, location, value)
# return
# field_name, sub_location = location.split(".", 1)
# override_field = getattr(self, field_name)
# if isinstance(override_field, dict) and "." not in sub_location:
# # Handle the case where we need to change item in a dict
# override_field[sub_location] = value
# return
# # Otherwise recursively override the sub-location
# if isinstance(override_field, PlugboardBaseModel):
# match = override_field
# elif isinstance(override_field, dict):
# match = override_field[sub_location]
# elif isinstance(override_field, list):
# # Match the list item by name
# try:
# match = next(
# obj
# for obj in override_field
# # Match either name or args.name
# if getattr(obj, "name") == sub_location
# or getattr(getattr(obj, "args"), "name") == sub_location
# )
# except (StopIteration, AttributeError):
# raise ValueError(f"Cannot find item named {sub_location} in {override_field}")
# if isinstance(match, PlugboardBaseModel):
# # Recursively override the sub-location
# match.override(sub_location, value)
# raise ValueError(f"Cannot override {sub_location} on {override_field}")
3 changes: 3 additions & 0 deletions plugboard/schemas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@

from plugboard.schemas._common import PlugboardBaseModel
from .process import ProcessSpec
from .tune import TuneSpec


class ProcessConfigSpec(PlugboardBaseModel):
"""A `ProcessSpec` within a Plugboard configuration.

Attributes:
process: A `ProcessSpec` that specifies the process.
tune: Optional; A `TuneSpec` that specifies an optimisation configuration.
"""

process: ProcessSpec
tune: TuneSpec | None = None


class ConfigSpec(PlugboardBaseModel):
Expand Down
170 changes: 170 additions & 0 deletions plugboard/schemas/tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Provides the `TuneSpec` class for configuring optimisation jobs."""

from abc import ABC
import typing as _t

from pydantic import Field, PositiveInt, model_validator

from plugboard.schemas._common import PlugboardBaseModel


class OptunaSpec(PlugboardBaseModel):
"""Specification for the Optuna configuration.

See: https://docs.ray.io/en/latest/tune/api/doc/ray.tune.search.optuna.OptunaSearch.html
and https://optuna.readthedocs.io/en/stable/reference/index.html for more information on the
Optuna configuration.

Attributes:
type: The algorithm type to load.
study_name: Optional; The name of the study.
storage: Optional; The storage URI to save the optimisation results to.
"""

type: _t.Literal["ray.tune.search.optuna.OptunaSearch"] = "ray.tune.search.optuna.OptunaSearch"
study_name: str | None = None
storage: str | None = None


class BaseFieldSpec(PlugboardBaseModel, ABC):
"""Base class for specifying fields within a Plugboard [`Process`][plugboard.process.Process].

These fields may be used as adjustable parameter inputs or as an optimisation objective.

Attributes:
object_type: The type of object on which the field is defined. Defaults to "component".
object_name: The name of the object on which the field is defined.
field_type: The type of field. This can be "arg", "initial_value", or "field".
field_name: The name of the field.
"""

object_type: _t.Literal["component"] = Field("component", exclude=True)
object_name: str = Field(..., exclude=True)
field_type: _t.Literal["arg", "initial_value", "field"] = Field(..., exclude=True)
field_name: str = Field(..., exclude=True)

@property
def full_name(self) -> str:
"""Returns the full name of the field, including the object name and field name."""
return f"{self.object_name}.{self.field_name}"


class ObjectiveSpec(BaseFieldSpec):
"""Specification for an objective field."""

@model_validator(mode="before")
@classmethod
def _fill_defaults(cls, data: dict[str, _t.Any]) -> dict[str, _t.Any]:
if "field_type" not in data:
data["field_type"] = "field"
if data["field_type"] != "field":
raise ValueError("The field type must be 'field' for an objective specification.")
return data


class FloatParameterSpec(BaseFieldSpec):
"""Specification for a uniform float parameter.

See: https://docs.ray.io/en/latest/tune/api/search_space.html.

Attributes:
type: The type of the parameter.
lower: The lower bound of the parameter.
upper: The upper bound of the parameter.
"""

type: _t.Literal["ray.tune.uniform"] = "ray.tune.uniform"
lower: float
upper: float


class IntParameterSpec(BaseFieldSpec):
"""Specification for a uniform integer parameter.

See: https://docs.ray.io/en/latest/tune/api/search_space.html.

Attributes:
type: The type of the parameter.
lower: The lower bound of the parameter.
upper: The upper bound of the parameter.
"""

type: _t.Literal["ray.tune.randint"] = "ray.tune.randint"
lower: int
upper: int


class CategoricalParameterSpec(BaseFieldSpec):
"""Specification for a categorical parameter.

See: https://docs.ray.io/en/latest/tune/api/search_space.html.

Attributes:
type: The type of the parameter.
categories: The categories of the parameter.
"""

type: _t.Literal["ray.tune.choice"] = "ray.tune.choice"
categories: list[_t.Any]


ParameterSpec = _t.Union[
FloatParameterSpec,
IntParameterSpec,
CategoricalParameterSpec,
]

Direction = _t.Literal["min", "max"]


class TuneArgsDict(_t.TypedDict):
"""`TypedDict` of the [`Tuner`][plugboard.tune.Tuner] constructor arguments."""

objective: str | list[str]
parameters: list[ParameterSpec]
num_samples: int
mode: _t.NotRequired[Direction | list[list[Direction]]]
Comment thread
toby-coleman marked this conversation as resolved.
Outdated
max_concurrent: _t.NotRequired[int | None]
algorithm: OptunaSpec


class TuneArgsSpec(PlugboardBaseModel):
"""Specification of the arguments for the `Tune` class.

Attributes:
objective: The location of the objective(s) to optimise for in the `Process`.
parameters: The parameters to optimise over.
num_samples: The number of samples to draw during the optimisation.
mode: The mode of optimisation. For multi-objective optimisation, this should be a list
containing a direction for each objective.
max_concurrent: The maximum number of concurrent trials.
algorithm: The algorithm to use for the optimisation.
"""

objective: ObjectiveSpec | list[ObjectiveSpec]
parameters: list[ParameterSpec] = Field(min_length=1)
num_samples: PositiveInt
mode: Direction | list[list[Direction]] = "max"
max_concurrent: PositiveInt | None = None
algorithm: _t.Union[OptunaSpec] = Field(OptunaSpec(), discriminator="type")

@model_validator(mode="after")
def _validate_model(self: _t.Self) -> _t.Self:
if isinstance(self.mode, list):
if not isinstance(self.objective, list):
raise ValueError(
"In multi-objective optimisation, both `mode` and `objective` must be lists."
)
if len(self.mode) != len(self.objective):
raise ValueError("The length of `mode` must match the length of `objective`.")
return self


class TuneSpec(PlugboardBaseModel):
"""Configuration for an optimisation job.

Attributes:
args: The arguments for the `Tune` job.
"""

args: TuneArgsSpec
6 changes: 6 additions & 0 deletions plugboard/tune/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Tune submodule for configuring optimisation jobs."""

from plugboard.tune.tune import Tuner


__all__ = ["Tuner"]
Loading