Skip to content

Commit d512137

Browse files
committed
automatically register send fields based on class attributes
1 parent c70d11f commit d512137

7 files changed

+116
-101
lines changed

mpisppy/cylinders/cross_scen_spoke.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,18 @@
1111
from mpisppy import MPI
1212
from mpisppy.utils.lshaped_cuts import LShapedCutGenerator
1313
from mpisppy.cylinders.spwindow import Field
14+
from mpisppy.cylinders.spoke import Spoke
1415

1516
import numpy as np
1617
import pyomo.environ as pyo
17-
import mpisppy.cylinders.spoke as spoke
1818

19-
class CrossScenarioCutSpoke(spoke.Spoke):
19+
class CrossScenarioCutSpoke(Spoke):
20+
21+
22+
send_fields = (*Spoke.send_fields, Field.CROSS_SCENARIO_CUT)
23+
receive_fields = (*Spoke.receive_fields, Field.NONANT, Field.CROSS_SCENARIO_COST)
24+
optional_receive_fields = (*Spoke.optional_receive_fields, )
25+
2026
def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, options=None):
2127
super().__init__(spbase_object, fullcomm, strata_comm, cylinder_comm, options=options)
2228

@@ -37,15 +43,13 @@ def register_send_fields(self) -> None:
3743
(self.nonant_per_scen, remainder) = divmod(vbuflen, local_scen_count)
3844
assert(remainder == 0)
3945

40-
## the _locals will also have the kill signal
4146
self.all_nonant_len = vbuflen
4247
self.all_eta_len = nscen*local_scen_count
4348

4449
self.all_nonants = self.register_recv_field(Field.NONANT, 0, vbuflen)
4550
self.all_etas = self.register_recv_field(Field.CROSS_SCENARIO_COST, 0, nscen * nscen)
4651

47-
self.all_coefs = self.register_send_field(Field.CROSS_SCENARIO_CUT,
48-
nscen*(self.nonant_per_scen + 1 + 1))
52+
self.all_coefs = self.send_buffers[Field.CROSS_SCENARIO_CUT]
4953

5054
return
5155

@@ -303,7 +307,6 @@ def main(self):
303307

304308
# main loop
305309
while not (self.got_kill_signal()):
306-
# if self._new_locals:
307310
if self.all_nonants.is_new() and self.all_etas.is_new():
308311
self.make_cut()
309312
## End if

mpisppy/cylinders/hub.py

+32-56
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030

3131
class Hub(SPCommunicator):
3232

33+
send_fields = (*SPCommunicator.send_fields, Field.SHUTDOWN, Field.BEST_OBJECTIVE_BOUNDS,)
34+
receive_fields = (*SPCommunicator.receive_fields, )
35+
optional_receive_fields = (*SPCommunicator.optional_receive_fields, Field.OBJECTIVE_INNER_BOUND, Field.OBJECTIVE_OUTER_BOUND, )
36+
3337
_hub_algo_best_bound_provider = False
3438

3539
def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communicators, options=None):
@@ -85,7 +89,7 @@ def register_extension_recv_field(self, field: Field, strata_rank: int, buf_len:
8589
to the extension sync_with_spokes function.
8690
"""
8791
key = self._make_key(field, strata_rank)
88-
if key not in self._locals:
92+
if key not in self.receive_buffers:
8993
# if it is not already registered, we need to update the local buffer
9094
self.extension_recv.add(key)
9195
## End if
@@ -103,7 +107,7 @@ def register_extension_send_field(self, field: Field, buf_len: int) -> SendArray
103107
return self.register_send_field(field, buf_len)
104108

105109
def is_send_field_registered(self, field: Field) -> bool:
106-
return field in self._sends
110+
return field in self.send_buffers
107111

108112
def extension_send_field(self, field: Field, buf: SendArray):
109113
"""
@@ -117,7 +121,7 @@ def sync_extension_fields(self):
117121
Update all registered extension fields. Safe to call even when there are no extension fields.
118122
"""
119123
for key in self.extension_recv:
120-
ext_buf = self._locals[key]
124+
ext_buf = self.receive_buffers[key]
121125
(field, srank) = self._split_key(key)
122126
ext_buf._is_new = self.hub_from_spoke(ext_buf, srank, field)
123127
## End for
@@ -233,7 +237,7 @@ def receive_innerbounds(self):
233237
logging.debug("Hub is trying to receive from InnerBounds")
234238
for idx in self.innerbound_spoke_indices:
235239
key = self._make_key(Field.OBJECTIVE_INNER_BOUND, idx)
236-
recv_buf = self._locals[key]
240+
recv_buf = self.receive_buffers[key]
237241
is_new = self.hub_from_spoke(recv_buf, idx, Field.OBJECTIVE_INNER_BOUND)
238242
if is_new:
239243
bound = recv_buf[0]
@@ -249,7 +253,7 @@ def receive_outerbounds(self):
249253
logging.debug("Hub is trying to receive from OuterBounds")
250254
for idx in self.outerbound_spoke_indices:
251255
key = self._make_key(Field.OBJECTIVE_OUTER_BOUND, idx)
252-
recv_buf = self._locals[key]
256+
recv_buf = self.receive_buffers[key]
253257
is_new = self.hub_from_spoke(recv_buf, idx, Field.OBJECTIVE_OUTER_BOUND)
254258
if is_new:
255259
bound = recv_buf[0]
@@ -320,18 +324,18 @@ def initialize_inner_bound_buffers(self):
320324
def _populate_boundsout_cache(self, buf):
321325
""" Populate a given buffer with the current bounds
322326
"""
323-
buf[-3] = self.BestOuterBound
324-
buf[-2] = self.BestInnerBound
327+
buf[0] = self.BestOuterBound
328+
buf[1] = self.BestInnerBound
325329

326330
def send_boundsout(self):
327331
""" Send bounds to the appropriate spokes
328332
This is called only for spokes which are bounds only.
329333
w and nonant spokes are passed bounds through the w and nonant buffers
330334
"""
331-
my_bounds = self.boundsout_send_buffer
335+
my_bounds = self.send_buffers[Field.BEST_OBJECTIVE_BOUNDS]
332336
self._populate_boundsout_cache(my_bounds.array())
333337
logging.debug("hub is sending bounds={}".format(my_bounds))
334-
self.hub_to_spoke(my_bounds, Field.OBJECTIVE_BOUNDS)
338+
self.hub_to_spoke(my_bounds, Field.BEST_OBJECTIVE_BOUNDS)
335339
return
336340

337341
def initialize_spoke_indices(self):
@@ -392,45 +396,7 @@ def initialize_spoke_indices(self):
392396

393397

394398
def register_send_fields(self):
395-
396-
self.shutdown = self.register_send_field(Field.SHUTDOWN, 1)
397-
398-
required_fields = set()
399-
for i, spoke in enumerate(self.communicators):
400-
if i == self.strata_rank:
401-
continue
402-
spoke_class = spoke["spcomm_class"]
403-
if hasattr(spoke_class, "converger_spoke_types"):
404-
for cst in spoke_class.converger_spoke_types:
405-
if cst == ConvergerSpokeType.W_GETTER:
406-
required_fields.add(Field.DUALS)
407-
elif cst == ConvergerSpokeType.NONANT_GETTER:
408-
required_fields.add(Field.NONANT)
409-
elif cst == ConvergerSpokeType.INNER_BOUND or cst == ConvergerSpokeType.OUTER_BOUND:
410-
required_fields.add(Field.OBJECTIVE_BOUNDS)
411-
else:
412-
pass # Intentional no-op
413-
## End if
414-
## End for
415-
else:
416-
# Intentional no-op. Non-converger spokes need to register any needed
417-
# fields separately. See the functions `register_extension_recv_field`
418-
# and `register_extension_send_field`.
419-
pass
420-
## End if
421-
## End for
422-
423-
n_nonants = 0
424-
for s in self.opt.local_scenarios.values():
425-
n_nonants += len(s._mpisppy_data.nonant_indices)
426-
## End for
427-
428-
if Field.DUALS in required_fields:
429-
self.w_send_buffer = self.register_send_field(Field.DUALS, n_nonants)
430-
if Field.NONANT in required_fields:
431-
self.nonant_send_buffer = self.register_send_field(Field.NONANT, n_nonants)
432-
if Field.OBJECTIVE_BOUNDS in required_fields:
433-
self.boundsout_send_buffer = self.register_send_field(Field.OBJECTIVE_BOUNDS, 2)
399+
super().register_send_fields()
434400

435401
# Not all opt classes may have extensions
436402
if getattr(self.opt, "extensions", None) is not None:
@@ -439,7 +405,6 @@ def register_send_fields(self):
439405
return
440406

441407

442-
443408
def hub_to_spoke(self, buf: SendArray, field: Field):
444409
""" Put the specified values into the specified locally-owned buffer
445410
for the spoke to pick up.
@@ -534,13 +499,17 @@ def send_terminate(self):
534499
buffer, so every spoke will see it simultaneously.
535500
processes (don't need to call them one at a time).
536501
"""
537-
shutdown = self.shutdown
538-
shutdown[0] = 1.0
539-
self.hub_to_spoke(shutdown, Field.SHUTDOWN)
502+
self.send_buffers[Field.SHUTDOWN][0] = 1.0
503+
self.hub_to_spoke(self.send_buffers[Field.SHUTDOWN], Field.SHUTDOWN)
540504
return
541505

542506

543507
class PHHub(Hub):
508+
509+
send_fields = (*Hub.send_fields, Field.NONANT, Field.DUALS)
510+
receive_fields = (*Hub.receive_fields,)
511+
optional_receive_fields = (*Hub.optional_receive_fields,)
512+
544513
def setup_hub(self):
545514
""" Must be called after make_windows(), so that
546515
the hub knows the sizes of all the spokes windows
@@ -673,8 +642,7 @@ def send_nonants(self):
673642
"""
674643
self.opt._save_nonants()
675644
ci = 0 ## index to self.nonant_send_buffer
676-
# my_nonants = self._sends[Field.NONANT]
677-
nonant_send_buffer = self.nonant_send_buffer
645+
nonant_send_buffer = self.send_buffers[Field.NONANT]
678646
for k, s in self.opt.local_scenarios.items():
679647
for xvar in s._mpisppy_data.nonant_indices.values():
680648
nonant_send_buffer[ci] = xvar._value
@@ -690,7 +658,7 @@ def send_ws(self):
690658
""" Send dual weights to the appropriate spokes
691659
"""
692660
# NOTE: my_ws.array() and self.w_send_buffer should be the same array.
693-
my_ws = self._sends[Field.DUALS]
661+
my_ws = self.send_buffers[Field.DUALS]
694662
self.opt._populate_W_cache(my_ws.array(), padding=1)
695663
logging.debug("hub is sending Ws={}".format(my_ws.array()))
696664

@@ -701,6 +669,10 @@ def send_ws(self):
701669

702670
class LShapedHub(Hub):
703671

672+
send_fields = (*Hub.send_fields, Field.NONANT,)
673+
receive_fields = (*Hub.receive_fields,)
674+
optional_receive_fields = (*Hub.optional_receive_fields,)
675+
704676
def setup_hub(self):
705677
""" Must be called after make_windows(), so that
706678
the hub knows the sizes of all the spokes windows
@@ -781,7 +753,7 @@ def send_nonants(self):
781753
TODO: Will likely fail with bundling
782754
"""
783755
ci = 0 ## index to self.nonant_send_buffer
784-
nonant_send_buffer = self.nonant_send_buffer
756+
nonant_send_buffer = self.send_buffers[Field.NONANT]
785757
for k, s in self.opt.local_scenarios.items():
786758
nonant_to_root_var_map = s._mpisppy_model.subproblem_to_root_vars_map
787759
for xvar in s._mpisppy_data.nonant_indices.values():
@@ -797,6 +769,8 @@ def send_nonants(self):
797769

798770
class SubgradientHub(PHHub):
799771

772+
# send / receive fields are same as PHHub
773+
800774
_hub_algo_best_bound_provider = True
801775

802776
def main(self):
@@ -806,6 +780,8 @@ def main(self):
806780

807781
class APHHub(PHHub):
808782

783+
# send / receive fields are same as PHHub
784+
809785
def main(self):
810786
""" SPComm gets attached by self.__init___; holding APH harmless """
811787
logger.critical("aph debug main in hub.py")

mpisppy/cylinders/reduced_costs_spoke.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515

1616
class ReducedCostsSpoke(LagrangianOuterBound):
1717

18+
send_fields = (*LagrangianOuterBound.send_fields, Field.EXPECTED_REDUCED_COST, Field.SCENARIO_REDUCED_COST ,)
19+
receive_fields = (*LagrangianOuterBound.receive_fields,)
20+
optional_receive_fields = (*LagrangianOuterBound.optional_receive_fields,)
21+
1822
converger_spoke_char = 'R'
1923

2024
def __init__(self, *args, **kwargs):
@@ -54,28 +58,25 @@ def register_send_fields(self) -> None:
5458
scenario_buffer_len += len(s._mpisppy_data.nonant_indices)
5559
self._scenario_rc_buffer = np.zeros(scenario_buffer_len)
5660

57-
self.register_send_field(Field.EXPECTED_REDUCED_COST, self.nonant_length)
58-
self.register_send_field(Field.SCENARIO_REDUCED_COST, scenario_buffer_len)
59-
6061
return
6162

6263
@property
6364
def rc_global(self):
64-
return self._sends[Field.EXPECTED_REDUCED_COST].value_array()
65+
return self.send_buffers[Field.EXPECTED_REDUCED_COST].value_array()
6566

6667
@rc_global.setter
6768
def rc_global(self, vals):
68-
arr = self._sends[Field.EXPECTED_REDUCED_COST].value_array()
69+
arr = self.send_buffers[Field.EXPECTED_REDUCED_COST].value_array()
6970
arr[:] = vals
7071
return
7172

7273
@property
7374
def rc_scenario(self):
74-
return self._sends[Field.SCENARIO_REDUCED_COST].value_array()
75+
return self.send_buffers[Field.SCENARIO_REDUCED_COST].value_array()
7576

7677
@rc_scenario.setter
7778
def rc_scenario(self, vals):
78-
arr = self._sends[Field.SCENARIO_REDUCED_COST].value_array()
79+
arr = self.send_buffers[Field.SCENARIO_REDUCED_COST].value_array()
7980
arr[:] = vals
8081
return
8182

0 commit comments

Comments
 (0)