Skip to content

Commit a94c2dd

Browse files
committed
extensions have to do nearly everything anyways; might as well be explicit about it
1 parent 2769a54 commit a94c2dd

8 files changed

+38
-68
lines changed

mpisppy/cylinders/hub.py

+1-52
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import logging
1212
import mpisppy.log
1313

14-
from mpisppy.cylinders.spcommunicator import RecvArray, SendArray, SPCommunicator
14+
from mpisppy.cylinders.spcommunicator import RecvArray, SPCommunicator
1515
from math import inf
1616

1717
from mpisppy import global_toc
@@ -32,9 +32,6 @@ class Hub(SPCommunicator):
3232
_hub_algo_best_bound_provider = False
3333

3434
def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communicators, options=None):
35-
# The extensions will be registered in SPCommunicator.__init__
36-
self.extension_recv = set()
37-
3835
super().__init__(spbase_object, fullcomm, strata_comm, cylinder_comm, communicators, options=options)
3936

4037
logger.debug(f"Built the hub object on global rank {fullcomm.Get_rank()}")
@@ -80,51 +77,6 @@ def current_iteration(self):
8077
def main(self):
8178
pass
8279

83-
def register_extension_recv_field(self, field: Field, strata_rank: int, buf_len: int = -1) -> RecvArray:
84-
"""
85-
Register an extensions interest in the given field from the given spoke. The hub
86-
is then responsible for updating this field into a local buffer prior to the call
87-
to the extension sync_with_spokes function.
88-
"""
89-
key = self._make_key(field, strata_rank)
90-
if key not in self.receive_buffers:
91-
# if it is not already registered, we need to update the local buffer
92-
self.extension_recv.add(key)
93-
## End if
94-
ra = self.register_recv_field(field, strata_rank, buf_len)
95-
return ra
96-
97-
def register_extension_send_field(self, field: Field, buf_len: int) -> SendArray:
98-
"""
99-
Register a field with the hub that an extension will be making available to spokes. Returns a
100-
buffer that is usable for sending the desired values. The extension is responsible for calling
101-
the hub publish_extension_field when ready to send the values. Returns a SendArray to use
102-
to publish values to spokes. Meant to be called within the extension function
103-
`register_send_fields`.
104-
"""
105-
return self.register_send_field(field, buf_len)
106-
107-
def is_send_field_registered(self, field: Field) -> bool:
108-
return field in self.send_buffers
109-
110-
def extension_send_field(self, field: Field, buf: SendArray):
111-
"""
112-
Send the data in the SendArray `buf` which stores the Field `field`. This will make
113-
the data available to the spokes in this strata.
114-
"""
115-
return self.put_send_buffer(buf, field)
116-
117-
def sync_extension_fields(self):
118-
"""
119-
Update all registered extension fields. Safe to call even when there are no extension fields.
120-
"""
121-
for key in self.extension_recv:
122-
ext_buf = self.receive_buffers[key]
123-
(field, srank) = self._split_key(key)
124-
ext_buf._is_new = self.get_receive_buffer(ext_buf, field, srank)
125-
## End for
126-
return
127-
12880
def clear_latest_chars(self):
12981
self.latest_ib_char = None
13082
self.latest_ob_char = None
@@ -371,7 +323,6 @@ def sync(self):
371323
self.receive_outerbounds()
372324
self.receive_innerbounds()
373325
if self.opt.extensions is not None:
374-
self.sync_extension_fields()
375326
self.opt.extobject.sync_with_spokes()
376327

377328
def sync_with_spokes(self):
@@ -384,7 +335,6 @@ def sync_bounds(self):
384335

385336
def sync_extensions(self):
386337
if self.opt.extensions is not None:
387-
self.sync_extension_fields()
388338
self.opt.extobject.sync_with_spokes()
389339

390340
def sync_nonants(self):
@@ -485,7 +435,6 @@ def sync(self, send_nonants=True):
485435
self.receive_innerbounds()
486436
# in case LShaped ever gets extensions
487437
if getattr(self.opt, "extensions", None) is not None:
488-
self.sync_extension_fields()
489438
self.opt.extobject.sync_with_spokes()
490439

491440
def is_converged(self):

mpisppy/cylinders/spcommunicator.py

+3
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,9 @@ def _make_windows(self) -> None:
287287

288288
return
289289

290+
def is_send_field_registered(self, field: Field) -> bool:
291+
return field in self.send_buffers
292+
290293
def register_send_fields(self) -> None:
291294
for field in self.send_fields:
292295
self.register_send_field(field)

mpisppy/extensions/coeff_rho.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# full copyright and license information.
88
###############################################################################
99

10+
from mpisppy import global_toc
1011
import mpisppy.extensions.extension
1112

1213
from mpisppy.utils.sputils import nonant_cost_coeffs
@@ -36,5 +37,4 @@ def post_iter0(self):
3637
# nv = s._mpisppy_data.nonant_indices[ndn_i] # var_data object
3738
# print(ndn_i,nv.getname(),cc[ndn_i],rho._value)
3839

39-
if self.ph.cylinder_rank == 0:
40-
print("Rho values updated by CoeffRho Extension")
40+
global_toc("Rho values updated by CoeffRho Extension", self.ph.cylinder_rank == 0)

mpisppy/extensions/cross_scen_extension.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ def _check_bound(self):
126126
cached_ph_obj[k].activate()
127127

128128
def get_from_cross_cuts(self):
129+
self.opt.spcomm.get_receive_buffer(
130+
self.cuts,
131+
Field.CROSS_SCENARIO_CUT,
132+
self.cross_scenario_index,
133+
)
129134
if self.cuts.is_new():
130135
self.make_cuts(self.cuts.array())
131136

@@ -144,7 +149,7 @@ def send_to_cross_cuts(self):
144149
## End for
145150
## End for
146151

147-
self.opt.spcomm.extension_send_field(Field.NONANT, all_nonants)
152+
self.opt.spcomm.put_send_buffer(all_nonants, Field.NONANT)
148153

149154
## End if
150155

@@ -156,7 +161,7 @@ def send_to_cross_cuts(self):
156161
all_etas[ci] = s._mpisppy_model.eta[sn]._value
157162
ci += 1
158163

159-
self.opt.spcomm.extension_send_field(Field.CROSS_SCENARIO_COST, all_etas)
164+
self.opt.spcomm.put_send_buffer(all_etas, Field.CROSS_SCENARIO_COST)
160165

161166
return
162167

@@ -255,13 +260,13 @@ def register_send_fields(self):
255260
if spcomm.is_send_field_registered(Field.NONANT):
256261
self.send_nonants = False
257262
else:
258-
self.all_nonants = spcomm.register_extension_send_field(
263+
self.all_nonants = spcomm.register_send_field(
259264
Field.NONANT,
260265
local_scen_count * self.opt.nonant_length
261266
)
262267
self.send_nonants = True
263268
## End if-else
264-
self.all_etas = spcomm.register_extension_send_field(
269+
self.all_etas = spcomm.register_send_field(
265270
Field.CROSS_SCENARIO_COST,
266271
nscen * nscen,
267272
)
@@ -282,11 +287,11 @@ def register_receive_fields(self):
282287
spcomm = self.opt.spcomm
283288
cross_scenario_cut_ranks = spcomm.fields_to_ranks[Field.CROSS_SCENARIO_CUT]
284289
assert len(cross_scenario_cut_ranks) == 1
285-
index = cross_scenario_cut_ranks[0]
290+
self.cross_scenario_index = cross_scenario_cut_ranks[0]
286291

287-
self.cuts = spcomm.register_extension_recv_field(
292+
self.cuts = spcomm.register_recv_field(
288293
Field.CROSS_SCENARIO_CUT,
289-
index,
294+
self.cross_scenario_index,
290295
)
291296

292297
def sync_with_spokes(self):

mpisppy/extensions/extension.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def setup_hub(self):
3434
def register_send_fields(self):
3535
'''
3636
Method called by the Hub SPCommunicator to get any fields that the extension
37-
will make available to spokes. Use hub function `register_extension_send_field`
38-
to register a field.
37+
will make available to spokes.
3938
'''
4039
return
4140

mpisppy/extensions/reduced_costs_fixer.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,11 @@ def register_receive_fields(self):
101101

102102
self.reduced_costs_spoke_index = index
103103

104-
self.reduced_cost_buf = spcomm.register_extension_recv_field(
104+
self.reduced_cost_buf = spcomm.register_recv_field(
105105
Field.EXPECTED_REDUCED_COST,
106106
self.reduced_costs_spoke_index,
107107
)
108-
self.outer_bound_buf = spcomm.register_extension_recv_field(
108+
self.outer_bound_buf = spcomm.register_recv_field(
109109
Field.OBJECTIVE_OUTER_BOUND,
110110
self.reduced_costs_spoke_index,
111111
)
@@ -116,6 +116,16 @@ def register_receive_fields(self):
116116
def sync_with_spokes(self, pre_iter0 = False):
117117
# TODO: if we calculate the new bounds in the spoke we don't need to check if the buffers
118118
# have the same ID
119+
self.opt.spcomm.get_receive_buffer(
120+
self.reduced_cost_buf,
121+
Field.EXPECTED_REDUCED_COST,
122+
self.reduced_costs_spoke_index,
123+
)
124+
self.opt.spcomm.get_receive_buffer(
125+
self.outer_bound_buf,
126+
Field.OBJECTIVE_OUTER_BOUND,
127+
self.reduced_costs_spoke_index,
128+
)
119129
if self.reduced_cost_buf.is_new() and self.reduced_cost_buf.id() == self.outer_bound_buf.id():
120130
reduced_costs = self.reduced_cost_buf.value_array()
121131
this_outer_bound = self.outer_bound_buf.value_array()[0]

mpisppy/extensions/reduced_costs_rho.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,17 @@ def register_receive_fields(self):
4848
assert len(reduced_cost_ranks) == 1
4949
self.reduced_costs_spoke_index = reduced_cost_ranks[0]
5050

51-
self.scenario_reduced_cost_buf = spcomm.register_extension_recv_field(
51+
self.scenario_reduced_cost_buf = spcomm.register_recv_field(
5252
Field.SCENARIO_REDUCED_COST,
5353
self.reduced_costs_spoke_index,
5454
)
5555

5656
def sync_with_spokes(self):
57+
self.opt.spcomm.get_receive_buffer(
58+
self.scenario_reduced_cost_buf,
59+
Field.SCENARIO_REDUCED_COST,
60+
self.reduced_costs_spoke_index,
61+
)
5762
if self.scenario_reduced_cost_buf.is_new():
5863
self._scenario_rc_buffer[:] = self.scenario_reduced_cost_buf.value_array()
5964
# print(f"In ReducedCostsRho; {self._scenario_rc_buffer=}")

mpisppy/extensions/sensi_rho.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ def compute_and_update_rho(self):
5454
# nv = s._mpisppy_data.nonant_indices[ndn_i] # var_data object
5555
# print(f"{s.name=}, {nv.name=}, {rho.value=}")
5656

57-
if ph.cylinder_rank == 0:
58-
print(f"Rho values updated by {self.__class__.__name__} Extension")
57+
global_toc(f"Rho values updated by {self.__class__.__name__} Extension", ph.cylinder_rank == 0)
5958

6059
def miditer(self):
6160
self.update_caches()

0 commit comments

Comments
 (0)