|
12 | 12 | from ophyd import Signal # type: ignore[import-untyped] |
13 | 13 |
|
14 | 14 | from .dofs import DOF |
| 15 | +from .ax.agent import Agent |
| 16 | + |
| 17 | + |
| 18 | +def _unpack_parameters(dofs: dict[str, DOF], parameterizations: list[TParameterization]) -> list[Movable | TParameterValue]: |
| 19 | + """Unpack the parameterizations into Bluesky plan arguments.""" |
| 20 | + unpacked_dict = defaultdict(list) |
| 21 | + for parameterization in parameterizations: |
| 22 | + for dof_name in dofs.keys(): |
| 23 | + if dof_name in parameterization: |
| 24 | + unpacked_dict[dof_name].append(parameterization[dof_name]) |
| 25 | + else: |
| 26 | + raise ValueError(f"Parameter {dof_name} not found in parameterization. Parameterization: {parameterization}") |
| 27 | + |
| 28 | + unpacked_list = [] |
| 29 | + for dof_name, values in unpacked_dict.items(): |
| 30 | + unpacked_list.append(dofs[dof_name].movable) |
| 31 | + unpacked_list.append(values) |
| 32 | + |
| 33 | + return unpacked_list |
| 34 | + |
| 35 | + |
| 36 | +@plan |
| 37 | +def acquire( |
| 38 | + readables: Sequence[Readable], |
| 39 | + dofs: dict[str, DOF], |
| 40 | + trials: dict[int, TParameterization], |
| 41 | + per_step: bp.PerStep | None = None, |
| 42 | + **kwargs: Any, |
| 43 | +) -> MsgGenerator[str]: |
| 44 | + """ |
| 45 | + A plan to acquire data for optimization. |
| 46 | +
|
| 47 | + Parameters |
| 48 | + ---------- |
| 49 | + readables: Sequence[Readable] |
| 50 | + The readables to trigger and read. |
| 51 | + dofs: dict[str, DOF] |
| 52 | + A dictionary mapping DOF names to DOFs. |
| 53 | + trials: dict[int, TParameterization] |
| 54 | + A dictionary mapping trial indices to their suggested parameterizations. Typically only a single trial is provided. |
| 55 | + per_step: bp.PerStep | None = None |
| 56 | + The plan to execute for each step of the scan. |
| 57 | + **kwargs: Any |
| 58 | + Additional keyword arguments to pass to the list_scan plan. |
| 59 | +
|
| 60 | + Returns |
| 61 | + ------- |
| 62 | + str |
| 63 | + The UID of the Bluesky run. |
| 64 | +
|
| 65 | + See Also |
| 66 | + -------- |
| 67 | + bluesky.plans.list_scan : The Bluesky plan to acquire data. |
| 68 | + """ |
| 69 | + plan_args = _unpack_parameters(dofs, trials.values()) |
| 70 | + return ( |
| 71 | + yield from bp.list_scan( |
| 72 | + readables, *plan_args, md={"ax_trial_indices": list(trials.keys())}, per_step=per_step, **kwargs |
| 73 | + ) |
| 74 | + ) |
| 75 | + |
| 76 | + |
| 77 | +@plan |
| 78 | +def optimize_step( |
| 79 | + generator: Agent, |
| 80 | + n: int = 1, |
| 81 | + acquisition_plan: Callable[[], MsgGenerator[None]] | None = None, |
| 82 | +) -> MsgGenerator[None]: |
| 83 | + """ |
| 84 | + A single step of the optimization loop. |
| 85 | +
|
| 86 | + Parameters |
| 87 | + ---------- |
| 88 | + generator : Agent |
| 89 | + The generator to optimize with. |
| 90 | + n : int, optional |
| 91 | + The number of trials to suggest. |
| 92 | + acquisition_plan : Callable[[], MsgGenerator[None]] | None, optional |
| 93 | + The acquisition plan to use to acquire data. If not provided, the default acquisition plan will be used. |
| 94 | + """ |
| 95 | + if acquisition_plan is None: |
| 96 | + acquisition_plan = acquire |
| 97 | + trials = generator.suggest(n) |
| 98 | + data = yield from acquisition_plan(generator.readables, generator.dofs, trials) |
| 99 | + outcomes = generator.evaluate(trials, data) |
| 100 | + generator.ingest(outcomes) |
| 101 | + |
| 102 | + |
| 103 | +@plan |
| 104 | +def optimize( |
| 105 | + generator: Agent, |
| 106 | + iterations: int = 1, |
| 107 | + n: int = 1, |
| 108 | + acquisition_plan: Callable[[], MsgGenerator[None]] | None = None, |
| 109 | +) -> MsgGenerator[None]: |
| 110 | + """ |
| 111 | + A plan to optimize the generator. |
| 112 | +
|
| 113 | + Parameters |
| 114 | + ---------- |
| 115 | + generator : Agent |
| 116 | + The generator to optimize with. |
| 117 | + iterations : int, optional |
| 118 | + The number of optimization iterations to run. |
| 119 | + n : int, optional |
| 120 | + The number of trials to suggest per iteration. |
| 121 | + acquisition_plan : Callable[[], MsgGenerator[None]] | None, optional |
| 122 | + The acquisition plan to use to acquire data. If not provided, the default acquisition plan will be used. |
| 123 | + """ |
| 124 | + |
| 125 | + for _ in range(iterations): |
| 126 | + yield from optimize_step(generator, n, acquisition_plan) |
15 | 127 |
|
16 | 128 |
|
17 | 129 | @plan |
@@ -131,65 +243,6 @@ def per_step_background_read( |
131 | 243 | return functools.partial(bps.one_nd_step, take_reading=take_reading) |
132 | 244 |
|
133 | 245 |
|
134 | | -def _unpack_parameters(dofs: dict[str, DOF], parameterizations: list[TParameterization]) -> list[Movable | TParameterValue]: |
135 | | - """Unpack the parameterizations into Bluesky plan arguments.""" |
136 | | - unpacked_dict = defaultdict(list) |
137 | | - for parameterization in parameterizations: |
138 | | - for dof_name in dofs.keys(): |
139 | | - if dof_name in parameterization: |
140 | | - unpacked_dict[dof_name].append(parameterization[dof_name]) |
141 | | - else: |
142 | | - raise ValueError(f"Parameter {dof_name} not found in parameterization. Parameterization: {parameterization}") |
143 | | - |
144 | | - unpacked_list = [] |
145 | | - for dof_name, values in unpacked_dict.items(): |
146 | | - unpacked_list.append(dofs[dof_name].movable) |
147 | | - unpacked_list.append(values) |
148 | | - |
149 | | - return unpacked_list |
150 | | - |
151 | | - |
152 | | -@plan |
153 | | -def acquire( |
154 | | - readables: Sequence[Readable], |
155 | | - dofs: dict[str, DOF], |
156 | | - trials: dict[int, TParameterization], |
157 | | - per_step: bp.PerStep | None = None, |
158 | | - **kwargs: Any, |
159 | | -) -> MsgGenerator[str]: |
160 | | - """ |
161 | | - A plan to acquire data for optimization. |
162 | | -
|
163 | | - Parameters |
164 | | - ---------- |
165 | | - readables: Sequence[Readable] |
166 | | - The readables to trigger and read. |
167 | | - dofs: dict[str, DOF] |
168 | | - A dictionary mapping DOF names to DOFs. |
169 | | - trials: dict[int, TParameterization] |
170 | | - A dictionary mapping trial indices to their suggested parameterizations. Typically only a single trial is provided. |
171 | | - per_step: bp.PerStep | None = None |
172 | | - The plan to execute for each step of the scan. |
173 | | - **kwargs: Any |
174 | | - Additional keyword arguments to pass to the list_scan plan. |
175 | | -
|
176 | | - Returns |
177 | | - ------- |
178 | | - str |
179 | | - The UID of the Bluesky run. |
180 | | -
|
181 | | - See Also |
182 | | - -------- |
183 | | - bluesky.plans.list_scan : The Bluesky plan to acquire data. |
184 | | - """ |
185 | | - plan_args = _unpack_parameters(dofs, trials.values()) |
186 | | - return ( |
187 | | - yield from bp.list_scan( |
188 | | - readables, *plan_args, md={"ax_trial_indices": list(trials.keys())}, per_step=per_step, **kwargs |
189 | | - ) |
190 | | - ) |
191 | | - |
192 | | - |
193 | 246 | @plan |
194 | 247 | def acquire_with_background( |
195 | 248 | readables: Sequence[Readable], |
|
0 commit comments