-
Notifications
You must be signed in to change notification settings - Fork 0
feat: Optimisation #139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
feat: Optimisation #139
Changes from 27 commits
Commits
Show all changes
61 commits
Select commit
Hold shift + click to select a range
a7d59aa
Add ray tune, optuna
toby-coleman 12e2a64
Schema for optimisation
toby-coleman 0e52b32
Missing docstring
toby-coleman bb77a2d
Reorder attributes
toby-coleman 865c914
Start implementation of Tuner
toby-coleman 24adb5a
Partial implementation
toby-coleman 7bcdc9a
Build parameters
toby-coleman f24caf5
Update build_parameter
toby-coleman a8feb30
Merge remote-tracking branch 'origin/main' into feat/optimisation
toby-coleman cd08f0e
Update pyproject/lock file
toby-coleman 4bae1a5
Upgrade lockfile
toby-coleman 9cca2be
Initial override work
toby-coleman 7e2edff
Change to BaseFieldSpec
toby-coleman a7ea30d
FieldSpec no longer ABC
toby-coleman 58647e4
Update objective schemas
toby-coleman 5f5c07e
Revert changes to schemas
toby-coleman 7d24e62
Tune implementation
toby-coleman f40bb02
Add some logging
toby-coleman 509e603
Rename object -> object_type
toby-coleman 59e8267
Change default on objective spec
toby-coleman 9ca2d2a
Merge remote-tracking branch 'origin/main' into feat/optimisation
toby-coleman 9b5d790
Add optuna dependency
toby-coleman 35b5377
Basic test passing
toby-coleman e91e098
Make run sync
toby-coleman 61fac58
Refactor async runner
toby-coleman b161ab6
Fail test on trial failures
toby-coleman 2cec8b2
Test optimisation result, component registry issues
toby-coleman a5cf152
Add optuna to test requirements
toby-coleman 3d149f9
Add ray tune to test requirements
toby-coleman f525aec
Test both directions
toby-coleman ced66ec
Improve ComponentRegistry handling
toby-coleman 7ead8c1
Delete unused code
toby-coleman c168041
Test tune one ray - requires namespace on RayProcess
toby-coleman 3cdbaab
Add no cover
toby-coleman a5ccb43
More no cover
toby-coleman 14b0c60
Merge remote-tracking branch 'origin/main' into feat/optimisation
toby-coleman dd6a857
Upgrade ray version
toby-coleman fd5ce8b
Recreate lock file
toby-coleman 8edec11
Try again
toby-coleman 804d04c
Reinstate version from main
toby-coleman 9fe2d99
Try again
toby-coleman 92834f0
Update type ignore
toby-coleman 6b6beb1
Test multi-objective case
toby-coleman 5ea8b28
Update test
toby-coleman 9dc7d57
Unit tests for Tuner
toby-coleman 3995607
List of results for multi-objective
toby-coleman 710a98b
Additional tests for schemas
toby-coleman 045467b
Fix type
toby-coleman 772e771
Improve coverage
toby-coleman 0e2aa56
Add CLI command
toby-coleman 04cd715
Fix typing issue
toby-coleman 1a87923
Support load of config from file
toby-coleman de62c93
Nocover on field_type
toby-coleman 500aae6
Update namespace
toby-coleman 875fd8c
Fix error in type
toby-coleman 76eaaaf
Simplify import
toby-coleman b7c7dd8
Comment on default concurrency
toby-coleman 397592d
Fix typo
toby-coleman ab96436
Apply suggestions from code review
toby-coleman 7bd476b
Revert change
toby-coleman f1424ae
Increase samples on test
toby-coleman File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,48 @@ | ||
| """Common classes for Plugboard schemas.""" | ||
|
|
||
| from abc import ABC | ||
| import typing as _t | ||
|
|
||
| from pydantic import BaseModel, ConfigDict | ||
|
|
||
|
|
||
| class PlugboardBaseModel(BaseModel, ABC): | ||
| """Custom base model for Plugboard schemas.""" | ||
|
|
||
| model_config = ConfigDict(extra="forbid", populate_by_name=True, use_enum_values=True) | ||
| model_config = ConfigDict( | ||
| extra="forbid", populate_by_name=True, use_enum_values=True, validate_assignment=True | ||
| ) | ||
|
|
||
| def override(self, location: str, value: _t.Any) -> None: | ||
| """Set the value of the attribute at the namespaces `location`.""" | ||
| # if "." not in location: | ||
| # # Set the field on this object | ||
| # setattr(self, location, value) | ||
| # return | ||
| # field_name, sub_location = location.split(".", 1) | ||
| # override_field = getattr(self, field_name) | ||
| # if isinstance(override_field, dict) and "." not in sub_location: | ||
| # # Handle the case where we need to change item in a dict | ||
| # override_field[sub_location] = value | ||
| # return | ||
| # # Otherwise recursively override the sub-location | ||
| # if isinstance(override_field, PlugboardBaseModel): | ||
| # match = override_field | ||
| # elif isinstance(override_field, dict): | ||
| # match = override_field[sub_location] | ||
| # elif isinstance(override_field, list): | ||
| # # Match the list item by name | ||
| # try: | ||
| # match = next( | ||
| # obj | ||
| # for obj in override_field | ||
| # # Match either name or args.name | ||
| # if getattr(obj, "name") == sub_location | ||
| # or getattr(getattr(obj, "args"), "name") == sub_location | ||
| # ) | ||
| # except (StopIteration, AttributeError): | ||
| # raise ValueError(f"Cannot find item named {sub_location} in {override_field}") | ||
| # if isinstance(match, PlugboardBaseModel): | ||
| # # Recursively override the sub-location | ||
| # match.override(sub_location, value) | ||
| # raise ValueError(f"Cannot override {sub_location} on {override_field}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,170 @@ | ||
| """Provides the `TuneSpec` class for configuring optimisation jobs.""" | ||
|
|
||
| from abc import ABC | ||
| import typing as _t | ||
|
|
||
| from pydantic import Field, PositiveInt, model_validator | ||
|
|
||
| from plugboard.schemas._common import PlugboardBaseModel | ||
|
|
||
|
|
||
| class OptunaSpec(PlugboardBaseModel): | ||
| """Specification for the Optuna configuration. | ||
|
|
||
| See: https://docs.ray.io/en/latest/tune/api/doc/ray.tune.search.optuna.OptunaSearch.html | ||
| and https://optuna.readthedocs.io/en/stable/reference/index.html for more information on the | ||
| Optuna configuration. | ||
|
|
||
| Attributes: | ||
| type: The algorithm type to load. | ||
| study_name: Optional; The name of the study. | ||
| storage: Optional; The storage URI to save the optimisation results to. | ||
| """ | ||
|
|
||
| type: _t.Literal["ray.tune.search.optuna.OptunaSearch"] = "ray.tune.search.optuna.OptunaSearch" | ||
| study_name: str | None = None | ||
| storage: str | None = None | ||
|
|
||
|
|
||
| class BaseFieldSpec(PlugboardBaseModel, ABC): | ||
| """Base class for specifying fields within a Plugboard [`Process`][plugboard.process.Process]. | ||
|
|
||
| These fields may be used as adjustable parameter inputs or as an optimisation objective. | ||
|
|
||
| Attributes: | ||
| object_type: The type of object on which the field is defined. Defaults to "component". | ||
| object_name: The name of the object on which the field is defined. | ||
| field_type: The type of field. This can be "arg", "initial_value", or "field". | ||
| field_name: The name of the field. | ||
| """ | ||
|
|
||
| object_type: _t.Literal["component"] = Field("component", exclude=True) | ||
| object_name: str = Field(..., exclude=True) | ||
| field_type: _t.Literal["arg", "initial_value", "field"] = Field(..., exclude=True) | ||
| field_name: str = Field(..., exclude=True) | ||
|
|
||
| @property | ||
| def full_name(self) -> str: | ||
| """Returns the full name of the field, including the object name and field name.""" | ||
| return f"{self.object_name}.{self.field_name}" | ||
|
|
||
|
|
||
| class ObjectiveSpec(BaseFieldSpec): | ||
| """Specification for an objective field.""" | ||
|
|
||
| @model_validator(mode="before") | ||
| @classmethod | ||
| def _fill_defaults(cls, data: dict[str, _t.Any]) -> dict[str, _t.Any]: | ||
| if "field_type" not in data: | ||
| data["field_type"] = "field" | ||
| if data["field_type"] != "field": | ||
| raise ValueError("The field type must be 'field' for an objective specification.") | ||
| return data | ||
|
|
||
|
|
||
| class FloatParameterSpec(BaseFieldSpec): | ||
| """Specification for a uniform float parameter. | ||
|
|
||
| See: https://docs.ray.io/en/latest/tune/api/search_space.html. | ||
|
|
||
| Attributes: | ||
| type: The type of the parameter. | ||
| lower: The lower bound of the parameter. | ||
| upper: The upper bound of the parameter. | ||
| """ | ||
|
|
||
| type: _t.Literal["ray.tune.uniform"] = "ray.tune.uniform" | ||
| lower: float | ||
| upper: float | ||
|
|
||
|
|
||
| class IntParameterSpec(BaseFieldSpec): | ||
| """Specification for a uniform integer parameter. | ||
|
|
||
| See: https://docs.ray.io/en/latest/tune/api/search_space.html. | ||
|
|
||
| Attributes: | ||
| type: The type of the parameter. | ||
| lower: The lower bound of the parameter. | ||
| upper: The upper bound of the parameter. | ||
| """ | ||
|
|
||
| type: _t.Literal["ray.tune.randint"] = "ray.tune.randint" | ||
| lower: int | ||
| upper: int | ||
|
|
||
|
|
||
| class CategoricalParameterSpec(BaseFieldSpec): | ||
| """Specification for a categorical parameter. | ||
|
|
||
| See: https://docs.ray.io/en/latest/tune/api/search_space.html. | ||
|
|
||
| Attributes: | ||
| type: The type of the parameter. | ||
| categories: The categories of the parameter. | ||
| """ | ||
|
|
||
| type: _t.Literal["ray.tune.choice"] = "ray.tune.choice" | ||
| categories: list[_t.Any] | ||
|
|
||
|
|
||
| ParameterSpec = _t.Union[ | ||
| FloatParameterSpec, | ||
| IntParameterSpec, | ||
| CategoricalParameterSpec, | ||
| ] | ||
|
|
||
| Direction = _t.Literal["min", "max"] | ||
|
|
||
|
|
||
| class TuneArgsDict(_t.TypedDict): | ||
| """`TypedDict` of the [`Tuner`][plugboard.tune.Tuner] constructor arguments.""" | ||
|
|
||
| objective: str | list[str] | ||
| parameters: list[ParameterSpec] | ||
| num_samples: int | ||
| mode: _t.NotRequired[Direction | list[list[Direction]]] | ||
| max_concurrent: _t.NotRequired[int | None] | ||
| algorithm: OptunaSpec | ||
|
|
||
|
|
||
| class TuneArgsSpec(PlugboardBaseModel): | ||
| """Specification of the arguments for the `Tune` class. | ||
|
|
||
| Attributes: | ||
| objective: The location of the objective(s) to optimise for in the `Process`. | ||
| parameters: The parameters to optimise over. | ||
| num_samples: The number of samples to draw during the optimisation. | ||
| mode: The mode of optimisation. For multi-objective optimisation, this should be a list | ||
| containing a direction for each objective. | ||
| max_concurrent: The maximum number of concurrent trials. | ||
| algorithm: The algorithm to use for the optimisation. | ||
| """ | ||
|
|
||
| objective: ObjectiveSpec | list[ObjectiveSpec] | ||
| parameters: list[ParameterSpec] = Field(min_length=1) | ||
| num_samples: PositiveInt | ||
| mode: Direction | list[list[Direction]] = "max" | ||
| max_concurrent: PositiveInt | None = None | ||
| algorithm: _t.Union[OptunaSpec] = Field(OptunaSpec(), discriminator="type") | ||
|
|
||
| @model_validator(mode="after") | ||
| def _validate_model(self: _t.Self) -> _t.Self: | ||
| if isinstance(self.mode, list): | ||
| if not isinstance(self.objective, list): | ||
| raise ValueError( | ||
| "In multi-objective optimisation, both `mode` and `objective` must be lists." | ||
| ) | ||
| if len(self.mode) != len(self.objective): | ||
| raise ValueError("The length of `mode` must match the length of `objective`.") | ||
| return self | ||
|
|
||
|
|
||
| class TuneSpec(PlugboardBaseModel): | ||
| """Configuration for an optimisation job. | ||
|
|
||
| Attributes: | ||
| args: The arguments for the `Tune` job. | ||
| """ | ||
|
|
||
| args: TuneArgsSpec | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| """Tune submodule for configuring optimisation jobs.""" | ||
|
|
||
| from plugboard.tune.tune import Tuner | ||
|
|
||
|
|
||
| __all__ = ["Tuner"] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.