Skip to content

Commit ebbf8be

Browse files
committed
OptimizationProblem dataclass and updated plans
1 parent 2ca4e7a commit ebbf8be

File tree

4 files changed

+136
-70
lines changed

4 files changed

+136
-70
lines changed

src/blop/ax/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from ax.generation_strategy.generation_strategy import GenerationStrategy
1111
from bluesky.protocols import Readable
1212

13-
from ..evaluation import default_evaluation_function
1413
from ..dofs import DOF, DOFConstraint
14+
from ..evaluation import default_evaluation_function
1515
from ..objectives import Objective
1616
from .adapters import configure_metrics, configure_objectives, configure_parameters
1717

src/blop/evaluation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from ax.api.types import TOutcome
22
from tiled.client.container import Container
33

4-
from .objectives import Objective
54
from .data_access import TiledDataAccess
5+
from .objectives import Objective
66

77

88
def default_evaluation_function(

src/blop/plans.py

Lines changed: 95 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,46 @@
11
import functools
2-
from collections import defaultdict
2+
import warnings
33
from collections.abc import Callable, Generator, Mapping, Sequence
4-
from typing import Any
4+
from typing import Any, cast
55

66
import bluesky.plan_stubs as bps
77
import bluesky.plans as bp
8-
from ax.api.types import TParameterization, TParameterValue
9-
from bluesky.protocols import Movable, Readable, Reading
10-
from bluesky.run_engine import Msg
11-
from bluesky.utils import MsgGenerator, plan
8+
from ax.api.types import TParameterization
9+
from bluesky.protocols import Movable, NamedMovable, Readable, Reading
10+
from bluesky.utils import Msg, MsgGenerator, plan
1211
from ophyd import Signal # type: ignore[import-untyped]
1312

1413
from .dofs import DOF
15-
from .ax.agent import Agent
14+
from .protocols import OptimizationProblem
1615

1716

18-
def _unpack_parameters(dofs: dict[str, DOF], parameterizations: list[TParameterization]) -> list[Movable | TParameterValue]:
19-
"""Unpack the parameterizations into Bluesky plan arguments."""
20-
unpacked_dict = defaultdict(list)
21-
for parameterization in parameterizations:
22-
for dof_name in dofs.keys():
23-
if dof_name in parameterization:
24-
unpacked_dict[dof_name].append(parameterization[dof_name])
25-
else:
26-
raise ValueError(f"Parameter {dof_name} not found in parameterization. Parameterization: {parameterization}")
27-
17+
def _unpack_for_list_scan(movables: Mapping[NamedMovable, Sequence[Any]]) -> list[NamedMovable | Any]:
18+
"""Unpack the movables and inputs into Bluesky list_scan plan arguments."""
2819
unpacked_list = []
29-
for dof_name, values in unpacked_dict.items():
30-
unpacked_list.append(dofs[dof_name].movable)
20+
for movable, values in movables.items():
21+
unpacked_list.append(movable)
3122
unpacked_list.append(values)
3223

3324
return unpacked_list
3425

3526

3627
@plan
37-
def acquire(
28+
def default_acquire(
29+
movables: Mapping[NamedMovable, Sequence[Any]],
3830
readables: Sequence[Readable],
39-
dofs: dict[str, DOF],
40-
trials: dict[int, TParameterization],
31+
*,
4132
per_step: bp.PerStep | None = None,
4233
**kwargs: Any,
4334
) -> MsgGenerator[str]:
4435
"""
45-
A plan to acquire data for optimization.
36+
A default plan to acquire data for optimization. Simply a list scan.
4637
4738
Parameters
4839
----------
40+
movables: Mapping[NamedMovable, Sequence[Any]]
41+
The movables to move and the inputs to move them to.
4942
readables: Sequence[Readable]
5043
The readables to trigger and read.
51-
dofs: dict[str, DOF]
52-
A dictionary mapping DOF names to DOFs.
5344
trials: dict[int, TParameterization]
5445
A dictionary mapping trial indices to their suggested parameterizations. Typically only a single trial is provided.
5546
per_step: bp.PerStep | None = None
@@ -66,69 +57,81 @@ def acquire(
6657
--------
6758
bluesky.plans.list_scan : The Bluesky plan to acquire data.
6859
"""
69-
plan_args = _unpack_parameters(dofs, trials.values())
60+
plan_args = _unpack_for_list_scan(movables)
7061
return (
62+
# TODO: fix argument type in bluesky.plans.list_scan
7163
yield from bp.list_scan(
72-
readables, *plan_args, md={"ax_trial_indices": list(trials.keys())}, per_step=per_step, **kwargs
64+
readables,
65+
*plan_args,
66+
per_step=per_step,
67+
**kwargs, # type: ignore[arg-type]
7368
)
7469
)
7570

7671

7772
@plan
7873
def optimize_step(
79-
generator: Agent,
80-
n: int = 1,
81-
acquisition_plan: Callable[[], MsgGenerator[None]] | None = None,
74+
optimization_problem: OptimizationProblem,
75+
n_points: int = 1,
76+
*args: Any,
77+
**kwargs: Any,
8278
) -> MsgGenerator[None]:
8379
"""
8480
A single step of the optimization loop.
8581
8682
Parameters
8783
----------
88-
generator : Agent
89-
The generator to optimize with.
90-
n : int, optional
91-
The number of trials to suggest.
92-
acquisition_plan : Callable[[], MsgGenerator[None]] | None, optional
93-
The acquisition plan to use to acquire data. If not provided, the default acquisition plan will be used.
84+
optimization_problem : OptimizationProblem
85+
The optimization problem to solve.
86+
n_points : int, optional
87+
The number of points to suggest.
9488
"""
95-
if acquisition_plan is None:
96-
acquisition_plan = acquire
97-
trials = generator.suggest(n)
98-
data = yield from acquisition_plan(generator.readables, generator.dofs, trials)
99-
outcomes = generator.evaluate(trials, data)
89+
if optimization_problem.acquisition_plan is None:
90+
acquisition_plan = default_acquire
91+
else:
92+
acquisition_plan = optimization_problem.acquisition_plan
93+
generator = optimization_problem.generator
94+
movables = optimization_problem.movables
95+
suggestions = generator.suggest(n_points)
96+
movables_and_inputs = {movable: [suggestion[movable.name] for suggestion in suggestions] for movable in movables}
97+
uid = yield from acquisition_plan(movables_and_inputs, optimization_problem.readables, *args, **kwargs)
98+
outcomes = optimization_problem.evaluation_function(uid)
10099
generator.ingest(outcomes)
101100

102101

103102
@plan
104103
def optimize(
105-
generator: Agent,
104+
optimization_problem: OptimizationProblem,
106105
iterations: int = 1,
107-
n: int = 1,
108-
acquisition_plan: Callable[[], MsgGenerator[None]] | None = None,
106+
n_points: int = 1,
107+
*args: Any,
108+
**kwargs: Any,
109109
) -> MsgGenerator[None]:
110110
"""
111111
A plan to optimize the generator.
112112
113113
Parameters
114114
----------
115-
generator : Agent
116-
The generator to optimize with.
115+
optimization_problem : OptimizationProblem
116+
The optimization problem to solve.
117117
iterations : int, optional
118118
The number of optimization iterations to run.
119-
n : int, optional
120-
The number of trials to suggest per iteration.
121-
acquisition_plan : Callable[[], MsgGenerator[None]] | None, optional
122-
The acquisition plan to use to acquire data. If not provided, the default acquisition plan will be used.
119+
n_points : int, optional
120+
The number of points to suggest per iteration.
123121
"""
124122

125123
for _ in range(iterations):
126-
yield from optimize_step(generator, n, acquisition_plan)
124+
yield from optimize_step(optimization_problem, n_points, *args, **kwargs)
127125

128126

129127
@plan
130128
def list_scan_with_delay(*args: Any, delay: float = 0, **kwargs: Any) -> Generator[Msg, None, str]:
131-
"Accepts all the normal 'scan' parameters, plus an optional delay."
129+
"""
130+
Accepts all the normal 'scan' parameters, plus an optional delay.
131+
132+
.. deprecated:: v0.8.2
133+
This plan is deprecated and will be removed in Blop v1.0.0. See documentation how-to-guides for more information.
134+
"""
132135

133136
def one_nd_step_with_delay(
134137
detectors: Sequence[Signal], step: Mapping[Movable, Any], pos_cache: Mapping[Movable, Any]
@@ -149,6 +152,11 @@ def default_acquisition_plan(
149152
dofs: Sequence[DOF], inputs: Mapping[str, Sequence[Any]], dets: Sequence[Signal], **kwargs: Any
150153
) -> Generator[Msg, None, str]:
151154
"""
155+
Default acquisition plan.
156+
157+
.. deprecated:: v0.8.2
158+
This plan is deprecated and will be removed in Blop v1.0.0. See documentation how-to-guides for more information.
159+
152160
Parameters
153161
----------
154162
x : list of DOFs or DOFList
@@ -158,6 +166,11 @@ def default_acquisition_plan(
158166
dets: list
159167
A list of detectors to trigger
160168
"""
169+
warnings.warn(
170+
"This plan is deprecated and will be removed in Blop v1.0.0. See documentation how-to-guides for more information.",
171+
DeprecationWarning,
172+
stacklevel=2,
173+
)
161174
delay = kwargs.get("delay", 0)
162175
args = []
163176
for dof in dofs:
@@ -177,6 +190,11 @@ def read(readables: Sequence[Readable], **kwargs: Any) -> MsgGenerator[dict[str,
177190
----------
178191
readables : Sequence[Readable]
179192
The readables to read.
193+
194+
Returns
195+
-------
196+
dict[str, Any]
197+
A dictionary of the readable names and their current values.
180198
"""
181199
results = {}
182200
for readable in readables:
@@ -281,18 +299,16 @@ def acquire_with_background(
281299
per_step_background_read : The per-step plan to take background readings.
282300
"""
283301
per_step = per_step_background_read(block_beam, unblock_beam)
284-
return (yield from acquire(readables, dofs, trials, per_step=per_step, **kwargs))
302+
return (yield from default_acquire(readables, dofs, trials, per_step=per_step, **kwargs))
285303

286304

287305
def acquire_baseline(
288-
generator: Agent,
289-
parameterization: TParameterization | None = None,
290-
arm_name: str | None = None,
291-
acquisition_plan: Callable[[Sequence[Readable], dict[str, DOF], dict[int, TParameterization], Any], MsgGenerator[str]] | None = None,
306+
optimization_problem: OptimizationProblem,
307+
parameterization: dict[str, Any] | None = None,
292308
**kwargs: Any,
293309
) -> MsgGenerator[None]:
294310
"""
295-
Acquire a baseline reading.
311+
Acquire a baseline reading. Useful for relative outcome constraints.
296312
297313
Parameters
298314
----------
@@ -306,10 +322,26 @@ def acquire_baseline(
306322
The per-step plan to execute for each step of the scan.
307323
**kwargs: Any
308324
Additional keyword arguments to pass to the acquire plan.
325+
326+
See Also
327+
--------
328+
default_acquire : The default plan to acquire data.
309329
"""
330+
movables = optimization_problem.movables
310331
if parameterization is None:
311-
parameterization = yield from read([dof.movable for dof in generator.dofs.values()])
312-
trials = generator.attach_baseline(parameters=parameterization, arm_name=arm_name)
313-
uid = yield from acquisition_plan(generator.readables, generator.dofs, trials, **kwargs)
314-
outcomes = generator.evaluation_function(trials, uid, **generator.evaluation_kwargs)
315-
generator.ingest(trials, outcomes)
332+
if all(isinstance(movable, Readable) for movable in movables):
333+
parameterization = yield from read(cast(Sequence[Readable], movables))
334+
else:
335+
raise ValueError(
336+
"All movables must also implement the Readable protocol to acquire a baseline from current positions."
337+
)
338+
generator = optimization_problem.generator
339+
if optimization_problem.acquisition_plan is None:
340+
acquisition_plan = default_acquire
341+
else:
342+
acquisition_plan = optimization_problem.acquisition_plan
343+
movables_and_inputs = {movable: parameterization[movable.name] for movable in movables}
344+
uid = yield from acquisition_plan(movables_and_inputs, optimization_problem.readables, **kwargs)
345+
outcome = optimization_problem.evaluation_function(uid)[0]
346+
outcome.update({"_id": 0})
347+
generator.ingest([outcome])

src/blop/protocols.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1-
from collections.abc import Sequence
2-
from typing import Any, Mapping, Protocol
1+
from collections.abc import Mapping, Sequence
2+
from dataclasses import dataclass
3+
from typing import Any, Protocol, runtime_checkable
34

45
from bluesky.protocols import NamedMovable, Readable
56
from bluesky.utils import MsgGenerator, plan
67

78

8-
class Agent(Protocol):
9+
@runtime_checkable
10+
class Generator(Protocol):
11+
"""
12+
A minimal generator interface for optimization.
13+
"""
914

1015
def suggest(self, num_points: int | None = None) -> list[dict]:
1116
"""
@@ -38,6 +43,7 @@ def ingest(self, points: list[dict]) -> None:
3843
...
3944

4045

46+
@runtime_checkable
4147
class EvaluationFunction(Protocol):
4248
def __call__(self, uid: str, *args: Any, **kwargs: Any) -> list[dict]:
4349
"""
@@ -56,6 +62,7 @@ def __call__(self, uid: str, *args: Any, **kwargs: Any) -> list[dict]:
5662
...
5763

5864

65+
@runtime_checkable
5966
class AcquisitionPlan(Protocol):
6067
@plan
6168
def __call__(
@@ -67,7 +74,7 @@ def __call__(
6774
) -> MsgGenerator[str]:
6875
"""
6976
Acquire data for optimization.
70-
77+
7178
This should be a Bluesky plan that moves the movables to each of their suggested positions
7279
and acquires data from the readables.
7380
@@ -83,4 +90,31 @@ def __call__(
8390
str
8491
The unique identifier of the Bluesky run.
8592
"""
86-
...
93+
...
94+
95+
96+
@dataclass(frozen=True)
97+
class OptimizationProblem:
98+
"""
99+
An optimization problem to solve. Immutable once initialized.
100+
101+
Attributes
102+
----------
103+
generator: Generator
104+
Suggests points to evaluate and ingests outcomes to inform the optimization.
105+
movables: Sequence[NamedMovable]
106+
Objects that can be moved to control the beamline using the Bluesky RunEngine.
107+
A subset of the movables' names must match the names of suggested parameterizations.
108+
readables: Sequence[Readable]
109+
Objects that can be read to acquire data from the beamline using the Bluesky RunEngine.
110+
evaluation_function: EvaluationFunction
111+
A callable to evaluate data from a Bluesky run and produce outcomes.
112+
acquisition_plan: AcquisitionPlan, optional
113+
A Bluesky plan to acquire data from the beamline. If not provided, a default plan will be used.
114+
"""
115+
116+
generator: Generator
117+
movables: Sequence[NamedMovable]
118+
readables: Sequence[Readable]
119+
evaluation_function: EvaluationFunction
120+
acquisition_plan: AcquisitionPlan | None = None

0 commit comments

Comments
 (0)