Skip to content

Commit 7a33788

Browse files
committed
more generic receive logic
1 parent dd91f14 commit 7a33788

File tree

5 files changed

+88
-173
lines changed

5 files changed

+88
-173
lines changed

mpisppy/cylinders/hub.py

+20-114
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from mpisppy import MPI
1717
from mpisppy.cylinders.spcommunicator import RecvArray, SendArray, SPCommunicator
1818
from math import inf
19-
from mpisppy.cylinders.spoke import ConvergerSpokeType
2019

2120
from mpisppy import global_toc
2221

@@ -51,6 +50,8 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communic
5150

5251
self.extension_recv = set()
5352

53+
self.initialize_bound_values()
54+
5455
return
5556

5657
@abc.abstractmethod
@@ -233,14 +234,12 @@ def receive_innerbounds(self):
233234
(but should be harmless to call if there are none)
234235
"""
235236
logging.debug("Hub is trying to receive from InnerBounds")
236-
for idx in self.innerbound_spoke_indices:
237-
key = self._make_key(Field.OBJECTIVE_INNER_BOUND, idx)
238-
recv_buf = self.receive_buffers[key]
237+
for idx, cls, recv_buf in self.receive_field_spcomms[Field.OBJECTIVE_INNER_BOUND]:
239238
is_new = self.hub_from_spoke(recv_buf, idx, Field.OBJECTIVE_INNER_BOUND)
240239
if is_new:
241240
bound = recv_buf[0]
242241
logging.debug("!! new InnerBound to opt {}".format(bound))
243-
self.BestInnerBound = self.InnerBoundUpdate(bound, idx)
242+
self.BestInnerBound = self.InnerBoundUpdate(bound, cls, idx)
244243
logging.debug("ph back from InnerBounds")
245244

246245
def receive_outerbounds(self):
@@ -249,37 +248,35 @@ def receive_outerbounds(self):
249248
(but should be harmless to call if there are none)
250249
"""
251250
logging.debug("Hub is trying to receive from OuterBounds")
252-
for idx in self.outerbound_spoke_indices:
253-
key = self._make_key(Field.OBJECTIVE_OUTER_BOUND, idx)
254-
recv_buf = self.receive_buffers[key]
251+
for idx, cls, recv_buf in self.receive_field_spcomms[Field.OBJECTIVE_OUTER_BOUND]:
255252
is_new = self.hub_from_spoke(recv_buf, idx, Field.OBJECTIVE_OUTER_BOUND)
256253
if is_new:
257254
bound = recv_buf[0]
258255
logging.debug("!! new OuterBound to opt {}".format(bound))
259-
self.BestOuterBound = self.OuterBoundUpdate(bound, idx)
256+
self.BestOuterBound = self.OuterBoundUpdate(bound, cls, idx)
260257
logging.debug("ph back from OuterBounds")
261258

262-
def OuterBoundUpdate(self, new_bound, idx=None, char='*'):
259+
def OuterBoundUpdate(self, new_bound, cls=None, idx=None, char='*'):
263260
current_bound = self.BestOuterBound
264261
if self._outer_bound_update(new_bound, current_bound):
265-
if idx is None:
262+
if cls is None:
266263
self.latest_ob_char = char
267264
self.last_ob_idx = 0
268265
else:
269-
self.latest_ob_char = self.outerbound_spoke_chars[idx]
266+
self.latest_ib_char = cls.converger_spoke_char
270267
self.last_ob_idx = idx
271268
return new_bound
272269
else:
273270
return current_bound
274271

275-
def InnerBoundUpdate(self, new_bound, idx=None, char='*'):
272+
def InnerBoundUpdate(self, new_bound, cls=None, idx=None, char='*'):
276273
current_bound = self.BestInnerBound
277274
if self._inner_bound_update(new_bound, current_bound):
278-
if idx is None:
275+
if cls is None:
279276
self.latest_ib_char = char
280277
self.last_ib_idx = 0
281278
else:
282-
self.latest_ib_char = self.innerbound_spoke_chars[idx]
279+
self.latest_ib_char = cls.converger_spoke_char
283280
self.last_ib_idx = idx
284281
return new_bound
285282
else:
@@ -297,28 +294,6 @@ def initialize_bound_values(self):
297294
self._inner_bound_update = lambda new, old : (new > old)
298295
self._outer_bound_update = lambda new, old : (new < old)
299296

300-
def initialize_outer_bound_buffers(self):
301-
""" Initialize outer bound receive buffers
302-
"""
303-
self.outerbound_receive_buffers = dict()
304-
for idx in self.outerbound_spoke_indices:
305-
self.outerbound_receive_buffers[idx] = self.register_recv_field(
306-
Field.OBJECTIVE_OUTER_BOUND, idx, 1,
307-
)
308-
## End for
309-
return
310-
311-
def initialize_inner_bound_buffers(self):
312-
""" Initialize inner bound receive buffers
313-
"""
314-
self.innerbound_receive_buffers = dict()
315-
for idx in self.innerbound_spoke_indices:
316-
self.innerbound_receive_buffers[idx] = self.register_recv_field(
317-
Field.OBJECTIVE_INNER_BOUND, idx, 1
318-
)
319-
## End for
320-
return
321-
322297
def _populate_boundsout_cache(self, buf):
323298
""" Populate a given buffer with the current bounds
324299
"""
@@ -327,62 +302,26 @@ def _populate_boundsout_cache(self, buf):
327302

328303
def send_boundsout(self):
329304
""" Send bounds to the appropriate spokes
330-
This is called only for spokes which are bounds only.
331-
w and nonant spokes are passed bounds through the w and nonant buffers
332305
"""
333306
my_bounds = self.send_buffers[Field.BEST_OBJECTIVE_BOUNDS]
334307
self._populate_boundsout_cache(my_bounds.array())
335308
logging.debug("hub is sending bounds={}".format(my_bounds))
336309
self.hub_to_spoke(my_bounds, Field.BEST_OBJECTIVE_BOUNDS)
337310
return
338311

339-
def initialize_spoke_indices(self):
312+
def register_receive_fields(self):
340313
""" Figure out what types of spokes we have,
341314
and sort them into the appropriate classes.
342315
343316
Note:
344317
Some spokes may be multiple types (e.g. outerbound and nonant),
345318
though not all combinations are supported.
346319
"""
347-
self.outerbound_spoke_indices = set()
348-
self.innerbound_spoke_indices = set()
349-
self.nonant_spoke_indices = set()
350-
self.w_spoke_indices = set()
351-
352-
self.outerbound_spoke_chars = dict()
353-
self.innerbound_spoke_chars = dict()
354-
355-
for (i, spoke) in enumerate(self.communicators):
356-
if i == self.strata_rank:
357-
continue
358-
spoke_class = spoke["spcomm_class"]
359-
if hasattr(spoke_class, "converger_spoke_types"):
360-
for cst in spoke_class.converger_spoke_types:
361-
if cst == ConvergerSpokeType.OUTER_BOUND:
362-
self.outerbound_spoke_indices.add(i)
363-
self.outerbound_spoke_chars[i] = spoke_class.converger_spoke_char
364-
elif cst == ConvergerSpokeType.INNER_BOUND:
365-
self.innerbound_spoke_indices.add(i)
366-
self.innerbound_spoke_chars[i] = spoke_class.converger_spoke_char
367-
elif cst == ConvergerSpokeType.W_GETTER:
368-
self.w_spoke_indices.add(i)
369-
elif cst == ConvergerSpokeType.NONANT_GETTER:
370-
self.nonant_spoke_indices.add(i)
371-
else:
372-
raise RuntimeError(f"Unrecognized converger_spoke_type {cst}")
373-
374-
else: ##this isn't necessarily wrong, i.e., cut generators
375-
logger.debug(f"Spoke class {spoke_class} not recognized by hub")
376-
377-
# all _BoundSpoke spokes get hub bounds so we determine which spokes
378-
# are "bounds only"
379-
self.bounds_only_indices = \
380-
(self.outerbound_spoke_indices | self.innerbound_spoke_indices) - \
381-
(self.w_spoke_indices | self.nonant_spoke_indices)
320+
super().register_receive_fields()
382321

383322
# Not all opt classes may have extensions
384323
if getattr(self.opt, "extensions", None) is not None:
385-
self.opt.extobject.initialize_spoke_indices()
324+
self.opt.extobject.register_receive_fields()
386325

387326
return
388327

@@ -511,31 +450,14 @@ def setup_hub(self):
511450
"Cannot call setup_hub before memory windows are constructed"
512451
)
513452

514-
self.initialize_spoke_indices()
515-
self.initialize_bound_values()
516-
517-
self.initialize_outer_bound_buffers()
518-
self.initialize_inner_bound_buffers()
519-
520-
## Do some checking for things we currently don't support
521-
if len(self.outerbound_spoke_indices & self.innerbound_spoke_indices) > 0:
522-
raise RuntimeError(
523-
"A Spoke providing both inner and outer "
524-
"bounds is currently unsupported"
525-
)
526-
if len(self.w_spoke_indices & self.nonant_spoke_indices) > 0:
527-
raise RuntimeError(
528-
"A Spoke needing both Ws and nonants is currently unsupported"
529-
)
530-
531453
## Generate some warnings if nothing is giving bounds
532-
if not self.outerbound_spoke_indices:
454+
if not self.receive_field_spcomms[Field.OBJECTIVE_OUTER_BOUND]:
533455
logger.warn(
534456
"No OuterBound Spokes defined, this converger "
535457
"will not cause the hub to terminate"
536458
)
537459

538-
if not self.innerbound_spoke_indices:
460+
if not self.receive_field_spcomms[Field.OBJECTIVE_INNER_BOUND]:
539461
logger.warn(
540462
"No InnerBound Spokes defined, this converger "
541463
"will not cause the hub to terminate"
@@ -578,7 +500,7 @@ def is_converged(self):
578500
if self.opt.best_bound_obj_val is not None:
579501
self.BestOuterBound = self.OuterBoundUpdate(self.opt.best_bound_obj_val)
580502

581-
if not self.innerbound_spoke_indices:
503+
if not self.receive_field_spcomms[Field.OBJECTIVE_INNER_BOUND]:
582504
if self.opt._PHIter == 1:
583505
logger.warning(
584506
"PHHub cannot compute convergence without "
@@ -591,7 +513,7 @@ def is_converged(self):
591513

592514
return False
593515

594-
if not self.outerbound_spoke_indices:
516+
if not self.receive_field_spcomms[Field.OBJECTIVE_OUTER_BOUND]:
595517
if self.opt._PHIter == 1 and not self._hub_algo_best_bound_provider:
596518
global_toc(
597519
"Without outer bound spokes, no progress "
@@ -660,24 +582,8 @@ def setup_hub(self):
660582
"Cannot call setup_hub before memory windows are constructed"
661583
)
662584

663-
self.initialize_spoke_indices()
664-
self.initialize_bound_values()
665-
666-
self.initialize_outer_bound_buffers()
667-
self.initialize_inner_bound_buffers()
668-
669-
## Do some checking for things we currently
670-
## do not support
671-
if self.w_spoke_indices:
672-
raise RuntimeError("LShaped hub does not compute dual weights (Ws)")
673-
if len(self.outerbound_spoke_indices & self.innerbound_spoke_indices) > 0:
674-
raise RuntimeError(
675-
"A Spoke providing both inner and outer "
676-
"bounds is currently unsupported"
677-
)
678-
679585
## Generate some warnings if nothing is giving bounds
680-
if not self.innerbound_spoke_indices:
586+
if not self.receive_field_spcomms[Field.OBJECTIVE_INNER_BOUND]:
681587
logger.warn(
682588
"No InnerBound Spokes defined, this converger "
683589
"will not cause the hub to terminate"

mpisppy/cylinders/spcommunicator.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as np
2323
import abc
2424
import time
25+
import itertools
2526

2627
from mpisppy.cylinders.spwindow import Field, FieldLengths, SPWindow
2728

@@ -138,8 +139,10 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communic
138139
self.options = options
139140

140141
# Common fields for spokes and hubs
141-
self.receive_buffers = dict()
142-
self.send_buffers = dict()
142+
self.receive_buffers = {}
143+
self.send_buffers = {}
144+
# key: Field, value: list of (strata_rank, SPComm) with that Field
145+
self.receive_field_spcomms = {}
143146

144147
# setup FieldLengths which calculates
145148
# the length of each buffer type based
@@ -151,6 +154,11 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communic
151154
self.opt.spcomm = self
152155

153156
self.register_send_fields()
157+
# TODO: here we can have a dynamic exchange of the send fields
158+
# so we can do error checking (all-to-all in send fields)
159+
self.register_receive_fields()
160+
161+
# TODO: check that we have something in receive_field_spcomms??
154162

155163
return
156164

@@ -259,6 +267,17 @@ def make_windows(self) -> None:
259267
return
260268

261269
def register_send_fields(self) -> None:
262-
self.send_buffers = {}
263270
for field in self.send_fields:
264-
self.send_buffers[field] = self.register_send_field(field)
271+
self.register_send_field(field)
272+
273+
def register_receive_fields(self) -> None:
274+
for field in itertools.chain(self.receive_fields, self.optional_receive_fields):
275+
for strata_rank, comm in enumerate(self.communicators):
276+
if strata_rank == self.strata_rank:
277+
continue
278+
cls = comm["spcomm_class"]
279+
if field in cls.send_fields:
280+
buff = self.register_recv_field(field, strata_rank)
281+
if field not in self.receive_field_spcomms:
282+
self.receive_field_spcomms[field] = []
283+
self.receive_field_spcomms[field].append((strata_rank, cls, buff))

mpisppy/extensions/cross_scen_extension.py

+11-24
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,6 @@ def _check_bound(self):
127127
cached_ph_obj[k].activate()
128128

129129
def get_from_cross_cuts(self):
130-
# spcomm = self.opt.spcomm
131-
# idx = self.cut_gen_spoke_index
132-
# receive_buffer = np.empty(spcomm.remote_lengths[idx - 1] + 1, dtype="d") # Must be doubles
133-
# is_new = spcomm.hub_from_spoke(receive_buffer, idx)
134-
# if is_new:
135-
# self.make_cuts(receive_buffer)
136130
if self.cuts.is_new():
137131
self.make_cuts(self.cuts.array())
138132

@@ -275,8 +269,6 @@ def register_send_fields(self):
275269
return
276270

277271
def setup_hub(self):
278-
# idx = self.cut_gen_spoke_index
279-
# self.all_nonants_and_etas = np.zeros(self.opt.spcomm.local_lengths[idx - 1] + 1)
280272

281273
self.nonant_len = self.opt.nonant_length
282274

@@ -287,22 +279,17 @@ def setup_hub(self):
287279
# helping the extension track cuts
288280
self.new_cuts = False
289281

290-
def initialize_spoke_indices(self):
291-
for (i, spoke) in enumerate(self.opt.spcomm.communicators):
292-
if spoke["spcomm_class"] == CrossScenarioCutSpoke:
293-
self.cut_gen_spoke_index = i
294-
## End if
295-
## End for
296-
297-
if hasattr(self, "cut_gen_spoke_index"):
298-
spcomm = self.opt.spcomm
299-
nscen = len(self.opt.all_scenario_names)
300-
self.cuts = spcomm.register_extension_recv_field(
301-
Field.CROSS_SCENARIO_CUT,
302-
self.cut_gen_spoke_index,
303-
nscen*(self.opt.nonant_length + 1 + 1)
304-
)
305-
## End if
282+
def register_receive_fields(self):
283+
spcomm = self.opt.spcomm
284+
spcomms_cross_scenario_cut = spcomm.receive_field_spcomms[Field.CROSS_SCENARIO_CUT]
285+
assert len(spcomms_cross_scenario_cut) == 1
286+
index, cls = spcomms_cross_scenario_cut[0]
287+
assert cls is CrossScenarioCutSpoke
288+
289+
self.cuts = spcomm.register_extension_recv_field(
290+
Field.CROSS_SCENARIO_CUT,
291+
index,
292+
)
306293

307294
def sync_with_spokes(self):
308295
self.send_to_cross_cuts()

0 commit comments

Comments
 (0)