Skip to content

Commit ad1f6b9

Browse files
committed
automatically register send fields based on class attributes
1 parent da2e00b commit ad1f6b9

7 files changed

+114
-101
lines changed

mpisppy/cylinders/cross_scen_spoke.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,16 @@
1111
from mpisppy import MPI
1212
from mpisppy.utils.lshaped_cuts import LShapedCutGenerator
1313
from mpisppy.cylinders.spwindow import Field
14+
from mpisppy.cylinders.spoke import Spoke
1415

1516
import numpy as np
1617
import pyomo.environ as pyo
17-
import mpisppy.cylinders.spoke as spoke
1818

19-
class CrossScenarioCutSpoke(spoke.Spoke):
19+
class CrossScenarioCutSpoke(Spoke):
20+
21+
send_fields = (*Spoke.send_fields, Field.CROSS_SCENARIO_CUT)
22+
receive_fields = (*Spoke.receive_fields, Field.NONANT, Field.CROSS_SCENARIO_COST)
23+
optional_receive_fields = (*Spoke.optional_receive_fields, )
2024

2125
def register_send_fields(self) -> None:
2226

@@ -35,15 +39,13 @@ def register_send_fields(self) -> None:
3539
(self.nonant_per_scen, remainder) = divmod(vbuflen, local_scen_count)
3640
assert(remainder == 0)
3741

38-
## the _locals will also have the kill signal
3942
self.all_nonant_len = vbuflen
4043
self.all_eta_len = nscen*local_scen_count
4144

4245
self.all_nonants = self.register_recv_field(Field.NONANT, 0, vbuflen)
4346
self.all_etas = self.register_recv_field(Field.CROSS_SCENARIO_COST, 0, nscen * nscen)
4447

45-
self.all_coefs = self.register_send_field(Field.CROSS_SCENARIO_CUT,
46-
nscen*(self.nonant_per_scen + 1 + 1))
48+
self.all_coefs = self.send_buffers[Field.CROSS_SCENARIO_CUT]
4749

4850
return
4951

@@ -301,7 +303,6 @@ def main(self):
301303

302304
# main loop
303305
while not (self.got_kill_signal()):
304-
# if self._new_locals:
305306
if self.all_nonants.is_new() and self.all_etas.is_new():
306307
self.make_cut()
307308
## End if

mpisppy/cylinders/hub.py

+32-56
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030

3131
class Hub(SPCommunicator):
3232

33+
send_fields = (*SPCommunicator.send_fields, Field.SHUTDOWN, Field.BEST_OBJECTIVE_BOUNDS,)
34+
receive_fields = (*SPCommunicator.receive_fields, )
35+
optional_receive_fields = (*SPCommunicator.optional_receive_fields, Field.OBJECTIVE_INNER_BOUND, Field.OBJECTIVE_OUTER_BOUND, )
36+
3337
_hub_algo_best_bound_provider = False
3438

3539
def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communicators, options=None):
@@ -85,7 +89,7 @@ def register_extension_recv_field(self, field: Field, strata_rank: int, buf_len:
8589
to the extension sync_with_spokes function.
8690
"""
8791
key = self._make_key(field, strata_rank)
88-
if key not in self._locals:
92+
if key not in self.receive_buffers:
8993
# if it is not already registered, we need to update the local buffer
9094
self.extension_recv.add(key)
9195
## End if
@@ -103,7 +107,7 @@ def register_extension_send_field(self, field: Field, buf_len: int) -> SendArray
103107
return self.register_send_field(field, buf_len)
104108

105109
def is_send_field_registered(self, field: Field) -> bool:
106-
return field in self._sends
110+
return field in self.send_buffers
107111

108112
def extension_send_field(self, field: Field, buf: SendArray):
109113
"""
@@ -117,7 +121,7 @@ def sync_extension_fields(self):
117121
Update all registered extension fields. Safe to call even when there are no extension fields.
118122
"""
119123
for key in self.extension_recv:
120-
ext_buf = self._locals[key]
124+
ext_buf = self.receive_buffers[key]
121125
(field, srank) = self._split_key(key)
122126
ext_buf._is_new = self.hub_from_spoke(ext_buf, srank, field)
123127
## End for
@@ -233,7 +237,7 @@ def receive_innerbounds(self):
233237
logging.debug("Hub is trying to receive from InnerBounds")
234238
for idx in self.innerbound_spoke_indices:
235239
key = self._make_key(Field.OBJECTIVE_INNER_BOUND, idx)
236-
recv_buf = self._locals[key]
240+
recv_buf = self.receive_buffers[key]
237241
is_new = self.hub_from_spoke(recv_buf, idx, Field.OBJECTIVE_INNER_BOUND)
238242
if is_new:
239243
bound = recv_buf[0]
@@ -249,7 +253,7 @@ def receive_outerbounds(self):
249253
logging.debug("Hub is trying to receive from OuterBounds")
250254
for idx in self.outerbound_spoke_indices:
251255
key = self._make_key(Field.OBJECTIVE_OUTER_BOUND, idx)
252-
recv_buf = self._locals[key]
256+
recv_buf = self.receive_buffers[key]
253257
is_new = self.hub_from_spoke(recv_buf, idx, Field.OBJECTIVE_OUTER_BOUND)
254258
if is_new:
255259
bound = recv_buf[0]
@@ -320,18 +324,18 @@ def initialize_inner_bound_buffers(self):
320324
def _populate_boundsout_cache(self, buf):
321325
""" Populate a given buffer with the current bounds
322326
"""
323-
buf[-3] = self.BestOuterBound
324-
buf[-2] = self.BestInnerBound
327+
buf[0] = self.BestOuterBound
328+
buf[1] = self.BestInnerBound
325329

326330
def send_boundsout(self):
327331
""" Send bounds to the appropriate spokes
328332
This is called only for spokes which are bounds only.
329333
w and nonant spokes are passed bounds through the w and nonant buffers
330334
"""
331-
my_bounds = self.boundsout_send_buffer
335+
my_bounds = self.send_buffers[Field.BEST_OBJECTIVE_BOUNDS]
332336
self._populate_boundsout_cache(my_bounds.array())
333337
logging.debug("hub is sending bounds={}".format(my_bounds))
334-
self.hub_to_spoke(my_bounds, Field.OBJECTIVE_BOUNDS)
338+
self.hub_to_spoke(my_bounds, Field.BEST_OBJECTIVE_BOUNDS)
335339
return
336340

337341
def initialize_spoke_indices(self):
@@ -392,45 +396,7 @@ def initialize_spoke_indices(self):
392396

393397

394398
def register_send_fields(self):
395-
396-
self.shutdown = self.register_send_field(Field.SHUTDOWN, 1)
397-
398-
required_fields = set()
399-
for i, spoke in enumerate(self.communicators):
400-
if i == self.strata_rank:
401-
continue
402-
spoke_class = spoke["spcomm_class"]
403-
if hasattr(spoke_class, "converger_spoke_types"):
404-
for cst in spoke_class.converger_spoke_types:
405-
if cst == ConvergerSpokeType.W_GETTER:
406-
required_fields.add(Field.DUALS)
407-
elif cst == ConvergerSpokeType.NONANT_GETTER:
408-
required_fields.add(Field.NONANT)
409-
elif cst == ConvergerSpokeType.INNER_BOUND or cst == ConvergerSpokeType.OUTER_BOUND:
410-
required_fields.add(Field.OBJECTIVE_BOUNDS)
411-
else:
412-
pass # Intentional no-op
413-
## End if
414-
## End for
415-
else:
416-
# Intentional no-op. Non-converger spokes need to register any needed
417-
# fields separately. See the functions `register_extension_recv_field`
418-
# and `register_extension_send_field`.
419-
pass
420-
## End if
421-
## End for
422-
423-
n_nonants = 0
424-
for s in self.opt.local_scenarios.values():
425-
n_nonants += len(s._mpisppy_data.nonant_indices)
426-
## End for
427-
428-
if Field.DUALS in required_fields:
429-
self.w_send_buffer = self.register_send_field(Field.DUALS, n_nonants)
430-
if Field.NONANT in required_fields:
431-
self.nonant_send_buffer = self.register_send_field(Field.NONANT, n_nonants)
432-
if Field.OBJECTIVE_BOUNDS in required_fields:
433-
self.boundsout_send_buffer = self.register_send_field(Field.OBJECTIVE_BOUNDS, 2)
399+
super().register_send_fields()
434400

435401
# Not all opt classes may have extensions
436402
if getattr(self.opt, "extensions", None) is not None:
@@ -439,7 +405,6 @@ def register_send_fields(self):
439405
return
440406

441407

442-
443408
def hub_to_spoke(self, buf: SendArray, field: Field):
444409
""" Put the specified values into the specified locally-owned buffer
445410
for the spoke to pick up.
@@ -534,13 +499,17 @@ def send_terminate(self):
534499
buffer, so every spoke will see it simultaneously.
535500
processes (don't need to call them one at a time).
536501
"""
537-
shutdown = self.shutdown
538-
shutdown[0] = 1.0
539-
self.hub_to_spoke(shutdown, Field.SHUTDOWN)
502+
self.send_buffers[Field.SHUTDOWN][0] = 1.0
503+
self.hub_to_spoke(self.send_buffers[Field.SHUTDOWN], Field.SHUTDOWN)
540504
return
541505

542506

543507
class PHHub(Hub):
508+
509+
send_fields = (*Hub.send_fields, Field.NONANT, Field.DUALS)
510+
receive_fields = (*Hub.receive_fields,)
511+
optional_receive_fields = (*Hub.optional_receive_fields,)
512+
544513
def setup_hub(self):
545514
""" Must be called after make_windows(), so that
546515
the hub knows the sizes of all the spokes windows
@@ -673,8 +642,7 @@ def send_nonants(self):
673642
"""
674643
self.opt._save_nonants()
675644
ci = 0 ## index to self.nonant_send_buffer
676-
# my_nonants = self._sends[Field.NONANT]
677-
nonant_send_buffer = self.nonant_send_buffer
645+
nonant_send_buffer = self.send_buffers[Field.NONANT]
678646
for k, s in self.opt.local_scenarios.items():
679647
for xvar in s._mpisppy_data.nonant_indices.values():
680648
nonant_send_buffer[ci] = xvar._value
@@ -690,7 +658,7 @@ def send_ws(self):
690658
""" Send dual weights to the appropriate spokes
691659
"""
692660
# NOTE: my_ws.array() and self.w_send_buffer should be the same array.
693-
my_ws = self._sends[Field.DUALS]
661+
my_ws = self.send_buffers[Field.DUALS]
694662
self.opt._populate_W_cache(my_ws.array(), padding=1)
695663
logging.debug("hub is sending Ws={}".format(my_ws.array()))
696664

@@ -701,6 +669,10 @@ def send_ws(self):
701669

702670
class LShapedHub(Hub):
703671

672+
send_fields = (*Hub.send_fields, Field.NONANT,)
673+
receive_fields = (*Hub.receive_fields,)
674+
optional_receive_fields = (*Hub.optional_receive_fields,)
675+
704676
def setup_hub(self):
705677
""" Must be called after make_windows(), so that
706678
the hub knows the sizes of all the spokes windows
@@ -781,7 +753,7 @@ def send_nonants(self):
781753
TODO: Will likely fail with bundling
782754
"""
783755
ci = 0 ## index to self.nonant_send_buffer
784-
nonant_send_buffer = self.nonant_send_buffer
756+
nonant_send_buffer = self.send_buffers[Field.NONANT]
785757
for k, s in self.opt.local_scenarios.items():
786758
nonant_to_root_var_map = s._mpisppy_model.subproblem_to_root_vars_map
787759
for xvar in s._mpisppy_data.nonant_indices.values():
@@ -797,6 +769,8 @@ def send_nonants(self):
797769

798770
class SubgradientHub(PHHub):
799771

772+
# send / receive fields are same as PHHub
773+
800774
_hub_algo_best_bound_provider = True
801775

802776
def main(self):
@@ -806,6 +780,8 @@ def main(self):
806780

807781
class APHHub(PHHub):
808782

783+
# send / receive fields are same as PHHub
784+
809785
def main(self):
810786
""" SPComm gets attached by self.__init___; holding APH harmless """
811787
logger.critical("aph debug main in hub.py")

mpisppy/cylinders/reduced_costs_spoke.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515

1616
class ReducedCostsSpoke(LagrangianOuterBound):
1717

18+
send_fields = (*LagrangianOuterBound.send_fields, Field.EXPECTED_REDUCED_COST, Field.SCENARIO_REDUCED_COST ,)
19+
receive_fields = (*LagrangianOuterBound.receive_fields,)
20+
optional_receive_fields = (*LagrangianOuterBound.optional_receive_fields,)
21+
1822
converger_spoke_char = 'R'
1923

2024
def __init__(self, *args, **kwargs):
@@ -54,28 +58,25 @@ def register_send_fields(self) -> None:
5458
scenario_buffer_len += len(s._mpisppy_data.nonant_indices)
5559
self._scenario_rc_buffer = np.zeros(scenario_buffer_len)
5660

57-
self.register_send_field(Field.EXPECTED_REDUCED_COST, self.nonant_length)
58-
self.register_send_field(Field.SCENARIO_REDUCED_COST, scenario_buffer_len)
59-
6061
return
6162

6263
@property
6364
def rc_global(self):
64-
return self._sends[Field.EXPECTED_REDUCED_COST].value_array()
65+
return self.send_buffers[Field.EXPECTED_REDUCED_COST].value_array()
6566

6667
@rc_global.setter
6768
def rc_global(self, vals):
68-
arr = self._sends[Field.EXPECTED_REDUCED_COST].value_array()
69+
arr = self.send_buffers[Field.EXPECTED_REDUCED_COST].value_array()
6970
arr[:] = vals
7071
return
7172

7273
@property
7374
def rc_scenario(self):
74-
return self._sends[Field.SCENARIO_REDUCED_COST].value_array()
75+
return self.send_buffers[Field.SCENARIO_REDUCED_COST].value_array()
7576

7677
@rc_scenario.setter
7778
def rc_scenario(self, vals):
78-
arr = self._sends[Field.SCENARIO_REDUCED_COST].value_array()
79+
arr = self.send_buffers[Field.SCENARIO_REDUCED_COST].value_array()
7980
arr[:] = vals
8081
return
8182

mpisppy/cylinders/spcommunicator.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import abc
2424
import time
2525

26-
from mpisppy.cylinders.spwindow import Field, SPWindow
26+
from mpisppy.cylinders.spwindow import Field, FieldLengths, SPWindow
2727

2828
def communicator_array(size):
2929
arr = np.empty(size+1, dtype='d')
@@ -138,26 +138,26 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communic
138138
self.options = options
139139

140140
# Common fields for spokes and hubs
141-
self._locals = dict()
142-
self._sends = dict()
141+
self.receive_buffers = dict()
142+
self.send_buffers = dict()
143143

144144
# setup FieldLengths which calculates
145145
# the length of each buffer type based
146146
# on the problem data
147-
# self._field_lengths = FieldLengths(self.opt)
147+
self._field_lengths = FieldLengths(self.opt)
148148

149149
# attach the SPCommunicator to
150150
# the SPBase object
151151
self.opt.spcomm = self
152152

153-
# self.register_send_fields()
153+
self.register_send_fields()
154154

155155
return
156156

157157
def _make_key(self, field: Field, origin: int):
158158
"""
159159
Given a field and an origin (i.e. a strata_rank), generate a key for indexing
160-
into the self._locals dictionary and getting the corresponding RecvArray.
160+
into the self.receive_buffers dictionary and getting the corresponding RecvArray.
161161
162162
Undone by `_split_key`. Currently, the key is simply a Tuple[field, origin].
163163
"""
@@ -175,9 +175,8 @@ def _split_key(self, key) -> tuple[Field, int]:
175175
def _build_window_spec(self) -> dict[Field, int]:
176176
""" Build dict with fields and lengths needed for local MPI window
177177
"""
178-
self.register_send_fields()
179178
window_spec = dict()
180-
for (field,buf) in self._sends.items():
179+
for (field,buf) in self.send_buffers.items():
181180
window_spec[field] = np.size(buf.array())
182181
## End for
183182
return window_spec
@@ -186,28 +185,28 @@ def register_recv_field(self, field: Field, origin: int, length: int = -1) -> Re
186185
key = self._make_key(field, origin)
187186
if length == -1:
188187
length = self._field_lengths[field]
189-
if key in self._locals:
190-
my_fa = self._locals[key]
188+
if key in self.receive_buffers:
189+
my_fa = self.receive_buffers[key]
191190
assert(length + 1 == np.size(my_fa.array()))
192191
else:
193192
my_fa = RecvArray(length)
194-
self._locals[key] = my_fa
193+
self.receive_buffers[key] = my_fa
195194
## End if
196195
return my_fa
197196

198197
def register_send_field(self, field: Field, length: int = -1) -> SendArray:
199-
assert field not in self._sends, "Field {} is already registered".format(field)
198+
assert field not in self.send_buffers, "Field {} is already registered".format(field)
200199
if length == -1:
201200
length = self._field_lengths[field]
202-
# if field in self._sends:
203-
# my_fa = self._sends[field]
201+
# if field in self.send_buffers:
202+
# my_fa = self.send_buffers[field]
204203
# assert(length + 1 == np.size(my_fa.array()))
205204
# else:
206205
# my_fa = SendArray(length)
207-
# self._sends[field] = my_fa
206+
# self.send_buffers[field] = my_fa
208207
# ## End if else
209208
my_fa = SendArray(length)
210-
self._sends[field] = my_fa
209+
self.send_buffers[field] = my_fa
211210
return my_fa
212211

213212
@abc.abstractmethod
@@ -259,6 +258,7 @@ def make_windows(self) -> None:
259258

260259
return
261260

262-
@abc.abstractmethod
263261
def register_send_fields(self) -> None:
264-
pass
262+
self.send_buffers = {}
263+
for field in self.send_fields:
264+
self.send_buffers[field] = self.register_send_field(field)

0 commit comments

Comments
 (0)