|
10 | 10 | import numpy as np
|
11 | 11 | from mpisppy import global_toc
|
12 | 12 | from mpisppy.extensions.sensi_rho import _SensiRhoBase
|
13 |
| -from mpisppy.cylinders.reduced_costs_spoke import ReducedCostsSpoke |
| 13 | + |
| 14 | +from mpisppy.cylinders.spwindow import Field |
14 | 15 |
|
15 | 16 | class ReducedCostsRho(_SensiRhoBase):
|
16 | 17 | """
|
@@ -41,24 +42,20 @@ def __init__(self, ph, comm=None):
|
41 | 42 | self._last_serial_number = -1
|
42 | 43 | self.reduced_costs_spoke_index = None
|
43 | 44 |
|
44 |
| - def initialize_spoke_indices(self): |
45 |
| - for (i, spoke) in enumerate(self.opt.spcomm.spokes): |
46 |
| - if spoke["spoke_class"] == ReducedCostsSpoke: |
47 |
| - self.reduced_costs_spoke_index = i + 1 |
48 |
| - if self.reduced_costs_spoke_index is None: |
49 |
| - raise RuntimeError("ReducedCostsRho requires a ReducedCostsSpoke for calculations") |
50 |
| - |
51 |
| - def _get_serial_number(self): |
52 |
| - return int(round(self.opt.spcomm.outerbound_receive_buffers[self.reduced_costs_spoke_index][-1])) |
| 45 | + def register_receive_fields(self): |
| 46 | + spcomm = self.opt.spcomm |
| 47 | + reduced_cost_ranks = spcomm.fields_to_ranks[Field.SCENARIO_REDUCED_COST] |
| 48 | + assert len(reduced_cost_ranks) == 1 |
| 49 | + self.reduced_costs_spoke_index = reduced_cost_ranks[0] |
53 | 50 |
|
54 |
| - def _get_reduced_costs_from_spoke(self): |
55 |
| - return self.opt.spcomm.outerbound_receive_buffers[self.reduced_costs_spoke_index][1+self.nonant_length:1+self.nonant_length+len(self._scenario_rc_buffer)] |
| 51 | + self.scenario_reduced_cost_buf = spcomm.register_extension_recv_field( |
| 52 | + Field.SCENARIO_REDUCED_COST, |
| 53 | + self.reduced_costs_spoke_index, |
| 54 | + ) |
56 | 55 |
|
57 | 56 | def sync_with_spokes(self):
|
58 |
| - serial_number = self._get_serial_number() |
59 |
| - if serial_number > self._last_serial_number: |
60 |
| - self._last_serial_number = serial_number |
61 |
| - self._scenario_rc_buffer[:] = self._get_reduced_costs_from_spoke() |
| 57 | + if self.scenario_reduced_cost_buf.is_new(): |
| 58 | + self._scenario_rc_buffer[:] = self.scenario_reduced_cost_buf.value_array() |
62 | 59 | # print(f"In ReducedCostsRho; {self._scenario_rc_buffer=}")
|
63 | 60 | else:
|
64 | 61 | if self.opt.cylinder_rank == 0 and self.verbose:
|
@@ -89,8 +86,8 @@ def post_iter0_after_sync(self):
|
89 | 86 | global_toc("Using reduced cost rho setter")
|
90 | 87 | self.update_caches()
|
91 | 88 | # wait until the spoke has data
|
92 |
| - if self._get_serial_number() == 0: |
93 |
| - while not self.ph.spcomm.hub_from_spoke(self.opt.spcomm.outerbound_receive_buffers[self.reduced_costs_spoke_index], self.reduced_costs_spoke_index): |
| 89 | + if self.scenario_reduced_cost_buf.id() == 0: |
| 90 | + while not self.ph.spcomm.get_receive_buffer(self.scenario_reduced_cost_buf, Field.SCENARIO_REDUCED_COST, self.reduced_costs_spoke_index): |
94 | 91 | continue
|
95 | 92 | self.sync_with_spokes()
|
96 | 93 | self.compute_and_update_rho()
|
0 commit comments