Skip to content

Commit c90299a

Browse files
committed
Started separating plans from Agent
1 parent 8f034da commit c90299a

File tree

2 files changed

+120
-59
lines changed

2 files changed

+120
-59
lines changed

src/blop/ax/agent.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,11 +310,15 @@ def learn(self, iterations: int = 1, n: int = 1) -> Generator[dict[int, TOutcome
310310
The number of trials to run per iteration. Higher values can lead to more efficient data acquisition,
311311
but slower optimization progress.
312312
313+
.. deprecated:: v0.8.2
314+
Use blop.plans.optimize instead.
315+
313316
Returns
314317
-------
315318
Generator[dict[int, TOutcome], None, None]
316319
A generator that yields the outcomes of the trials.
317320
"""
321+
warnings.warn("'learn' is deprecated. Use blop.plans.optimize instead.", DeprecationWarning, stacklevel=2)
318322
for _ in range(iterations):
319323
trials = self.get_next_trials(n)
320324
data = yield from self.acquire(trials)
@@ -327,6 +331,9 @@ def acquire(
327331
Acquire data given a set of trials. Deploys the trials in a single Bluesky run and
328332
returns the outcomes of the trials computed by the digestion function.
329333
334+
.. deprecated:: v0.8.2
335+
Use blop.plans.acquire instead.
336+
330337
Parameters
331338
----------
332339
trials : dict[int, TParameterization]
@@ -341,6 +348,7 @@ def acquire(
341348
--------
342349
blop.plans.acquire : The Bluesky plan to acquire data.
343350
"""
351+
warnings.warn("'acquire' is deprecated. Use blop.plans.optimize_step instead.", DeprecationWarning, stacklevel=2)
344352
uid = yield from acquire(self.readables, self.dofs, trials, per_step=per_step)
345353
results = self.data_access.get_data(uid)
346354
return {trial_index: self.digestion(trial_index, results, **self.digestion_kwargs) for trial_index in trials.keys()}

src/blop/plans.py

Lines changed: 112 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,118 @@
1212
from ophyd import Signal # type: ignore[import-untyped]
1313

1414
from .dofs import DOF
15+
from .ax.agent import Agent
16+
17+
18+
def _unpack_parameters(dofs: dict[str, DOF], parameterizations: list[TParameterization]) -> list[Movable | TParameterValue]:
19+
"""Unpack the parameterizations into Bluesky plan arguments."""
20+
unpacked_dict = defaultdict(list)
21+
for parameterization in parameterizations:
22+
for dof_name in dofs.keys():
23+
if dof_name in parameterization:
24+
unpacked_dict[dof_name].append(parameterization[dof_name])
25+
else:
26+
raise ValueError(f"Parameter {dof_name} not found in parameterization. Parameterization: {parameterization}")
27+
28+
unpacked_list = []
29+
for dof_name, values in unpacked_dict.items():
30+
unpacked_list.append(dofs[dof_name].movable)
31+
unpacked_list.append(values)
32+
33+
return unpacked_list
34+
35+
36+
@plan
37+
def acquire(
38+
readables: Sequence[Readable],
39+
dofs: dict[str, DOF],
40+
trials: dict[int, TParameterization],
41+
per_step: bp.PerStep | None = None,
42+
**kwargs: Any,
43+
) -> MsgGenerator[str]:
44+
"""
45+
A plan to acquire data for optimization.
46+
47+
Parameters
48+
----------
49+
readables: Sequence[Readable]
50+
The readables to trigger and read.
51+
dofs: dict[str, DOF]
52+
A dictionary mapping DOF names to DOFs.
53+
trials: dict[int, TParameterization]
54+
A dictionary mapping trial indices to their suggested parameterizations. Typically only a single trial is provided.
55+
per_step: bp.PerStep | None = None
56+
The plan to execute for each step of the scan.
57+
**kwargs: Any
58+
Additional keyword arguments to pass to the list_scan plan.
59+
60+
Returns
61+
-------
62+
str
63+
The UID of the Bluesky run.
64+
65+
See Also
66+
--------
67+
bluesky.plans.list_scan : The Bluesky plan to acquire data.
68+
"""
69+
plan_args = _unpack_parameters(dofs, trials.values())
70+
return (
71+
yield from bp.list_scan(
72+
readables, *plan_args, md={"ax_trial_indices": list(trials.keys())}, per_step=per_step, **kwargs
73+
)
74+
)
75+
76+
77+
@plan
78+
def optimize_step(
79+
generator: Agent,
80+
n: int = 1,
81+
acquisition_plan: Callable[[], MsgGenerator[None]] | None = None,
82+
) -> MsgGenerator[None]:
83+
"""
84+
A single step of the optimization loop.
85+
86+
Parameters
87+
----------
88+
generator : Agent
89+
The generator to optimize with.
90+
n : int, optional
91+
The number of trials to suggest.
92+
acquisition_plan : Callable[[], MsgGenerator[None]] | None, optional
93+
The acquisition plan to use to acquire data. If not provided, the default acquisition plan will be used.
94+
"""
95+
if acquisition_plan is None:
96+
acquisition_plan = acquire
97+
trials = generator.suggest(n)
98+
data = yield from acquisition_plan(generator.readables, generator.dofs, trials)
99+
outcomes = generator.evaluate(trials, data)
100+
generator.ingest(outcomes)
101+
102+
103+
@plan
104+
def optimize(
105+
generator: Agent,
106+
iterations: int = 1,
107+
n: int = 1,
108+
acquisition_plan: Callable[[], MsgGenerator[None]] | None = None,
109+
) -> MsgGenerator[None]:
110+
"""
111+
A plan to optimize the generator.
112+
113+
Parameters
114+
----------
115+
generator : Agent
116+
The generator to optimize with.
117+
iterations : int, optional
118+
The number of optimization iterations to run.
119+
n : int, optional
120+
The number of trials to suggest per iteration.
121+
acquisition_plan : Callable[[], MsgGenerator[None]] | None, optional
122+
The acquisition plan to use to acquire data. If not provided, the default acquisition plan will be used.
123+
"""
124+
125+
for _ in range(iterations):
126+
yield from optimize_step(generator, n, acquisition_plan)
15127

16128

17129
@plan
@@ -131,65 +243,6 @@ def per_step_background_read(
131243
return functools.partial(bps.one_nd_step, take_reading=take_reading)
132244

133245

134-
def _unpack_parameters(dofs: dict[str, DOF], parameterizations: list[TParameterization]) -> list[Movable | TParameterValue]:
135-
"""Unpack the parameterizations into Bluesky plan arguments."""
136-
unpacked_dict = defaultdict(list)
137-
for parameterization in parameterizations:
138-
for dof_name in dofs.keys():
139-
if dof_name in parameterization:
140-
unpacked_dict[dof_name].append(parameterization[dof_name])
141-
else:
142-
raise ValueError(f"Parameter {dof_name} not found in parameterization. Parameterization: {parameterization}")
143-
144-
unpacked_list = []
145-
for dof_name, values in unpacked_dict.items():
146-
unpacked_list.append(dofs[dof_name].movable)
147-
unpacked_list.append(values)
148-
149-
return unpacked_list
150-
151-
152-
@plan
153-
def acquire(
154-
readables: Sequence[Readable],
155-
dofs: dict[str, DOF],
156-
trials: dict[int, TParameterization],
157-
per_step: bp.PerStep | None = None,
158-
**kwargs: Any,
159-
) -> MsgGenerator[str]:
160-
"""
161-
A plan to acquire data for optimization.
162-
163-
Parameters
164-
----------
165-
readables: Sequence[Readable]
166-
The readables to trigger and read.
167-
dofs: dict[str, DOF]
168-
A dictionary mapping DOF names to DOFs.
169-
trials: dict[int, TParameterization]
170-
A dictionary mapping trial indices to their suggested parameterizations. Typically only a single trial is provided.
171-
per_step: bp.PerStep | None = None
172-
The plan to execute for each step of the scan.
173-
**kwargs: Any
174-
Additional keyword arguments to pass to the list_scan plan.
175-
176-
Returns
177-
-------
178-
str
179-
The UID of the Bluesky run.
180-
181-
See Also
182-
--------
183-
bluesky.plans.list_scan : The Bluesky plan to acquire data.
184-
"""
185-
plan_args = _unpack_parameters(dofs, trials.values())
186-
return (
187-
yield from bp.list_scan(
188-
readables, *plan_args, md={"ax_trial_indices": list(trials.keys())}, per_step=per_step, **kwargs
189-
)
190-
)
191-
192-
193246
@plan
194247
def acquire_with_background(
195248
readables: Sequence[Readable],

0 commit comments

Comments
 (0)