Skip to content

Commit c4b4efc

Browse files
committed
update RC rho for PR #476
1 parent 8675a82 commit c4b4efc

File tree

2 files changed

+16
-19
lines changed

2 files changed

+16
-19
lines changed

Diff for: mpisppy/extensions/reduced_costs_fixer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def iter0_post_solver_creation(self):
8585
if self.opt.cylinder_rank == 0 and self.verbose:
8686
print("Fixing based on reduced costs prior to iteration 0!")
8787
if self.reduced_cost_buf.id() == 0:
88-
while not self.opt.spcomm.hub_from_spoke(self.outer_bound_buf, self.reduced_costs_spoke_index, Field.EXPECTED_REDUCED_COST):
88+
while not self.opt.spcomm.get_receive_buffer(self.outer_bound_buf, Field.EXPECTED_REDUCED_COST, self.reduced_costs_spoke_index):
8989
continue
9090
self.sync_with_spokes(pre_iter0 = True)
9191
self.fix_fraction_target = self._fix_fraction_target_iter0

Diff for: mpisppy/extensions/reduced_costs_rho.py

+15-18
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import numpy as np
1111
from mpisppy import global_toc
1212
from mpisppy.extensions.sensi_rho import _SensiRhoBase
13-
from mpisppy.cylinders.reduced_costs_spoke import ReducedCostsSpoke
13+
14+
from mpisppy.cylinders.spwindow import Field
1415

1516
class ReducedCostsRho(_SensiRhoBase):
1617
"""
@@ -41,24 +42,20 @@ def __init__(self, ph, comm=None):
4142
self._last_serial_number = -1
4243
self.reduced_costs_spoke_index = None
4344

44-
def initialize_spoke_indices(self):
45-
for (i, spoke) in enumerate(self.opt.spcomm.spokes):
46-
if spoke["spoke_class"] == ReducedCostsSpoke:
47-
self.reduced_costs_spoke_index = i + 1
48-
if self.reduced_costs_spoke_index is None:
49-
raise RuntimeError("ReducedCostsRho requires a ReducedCostsSpoke for calculations")
50-
51-
def _get_serial_number(self):
52-
return int(round(self.opt.spcomm.outerbound_receive_buffers[self.reduced_costs_spoke_index][-1]))
45+
def register_receive_fields(self):
46+
spcomm = self.opt.spcomm
47+
reduced_cost_ranks = spcomm.fields_to_ranks[Field.SCENARIO_REDUCED_COST]
48+
assert len(reduced_cost_ranks) == 1
49+
self.reduced_costs_spoke_index = reduced_cost_ranks[0]
5350

54-
def _get_reduced_costs_from_spoke(self):
55-
return self.opt.spcomm.outerbound_receive_buffers[self.reduced_costs_spoke_index][1+self.nonant_length:1+self.nonant_length+len(self._scenario_rc_buffer)]
51+
self.scenario_reduced_cost_buf = spcomm.register_extension_recv_field(
52+
Field.SCENARIO_REDUCED_COST,
53+
self.reduced_costs_spoke_index,
54+
)
5655

5756
def sync_with_spokes(self):
58-
serial_number = self._get_serial_number()
59-
if serial_number > self._last_serial_number:
60-
self._last_serial_number = serial_number
61-
self._scenario_rc_buffer[:] = self._get_reduced_costs_from_spoke()
57+
if self.scenario_reduced_cost_buf.is_new():
58+
self._scenario_rc_buffer[:] = self.scenario_reduced_cost_buf.value_array()
6259
# print(f"In ReducedCostsRho; {self._scenario_rc_buffer=}")
6360
else:
6461
if self.opt.cylinder_rank == 0 and self.verbose:
@@ -89,8 +86,8 @@ def post_iter0_after_sync(self):
8986
global_toc("Using reduced cost rho setter")
9087
self.update_caches()
9188
# wait until the spoke has data
92-
if self._get_serial_number() == 0:
93-
while not self.ph.spcomm.hub_from_spoke(self.opt.spcomm.outerbound_receive_buffers[self.reduced_costs_spoke_index], self.reduced_costs_spoke_index):
89+
if self.scenario_reduced_cost_buf.id() == 0:
90+
while not self.ph.spcomm.get_receive_buffer(self.scenario_reduced_cost_buf, Field.SCENARIO_REDUCED_COST, self.reduced_costs_spoke_index):
9491
continue
9592
self.sync_with_spokes()
9693
self.compute_and_update_rho()

0 commit comments

Comments
 (0)