Skip to content

Commit 912195c

Browse files
committed
Revert "promote Spoke.update_receive_buffers to SPCommuicator"
This reverts commit 0e680a9.
1 parent 00ab113 commit 912195c

File tree

3 files changed

+43
-30
lines changed

3 files changed

+43
-30
lines changed

mpisppy/cylinders/hub.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communic
4646
def setup_hub(self):
4747
pass
4848

49+
@abc.abstractmethod
50+
def sync(self):
51+
""" To be called within the whichever optimization algorithm
52+
is being run on the hub (e.g. PH)
53+
"""
54+
pass
55+
4956
@abc.abstractmethod
5057
def is_converged(self):
5158
""" The hub has the ability to halt the optimization algorithm on the
@@ -153,9 +160,8 @@ def determine_termination(self):
153160
return abs_gap_satisfied or rel_gap_satisfied or max_stalled_satisfied
154161

155162
def hub_finalize(self):
156-
self.update_receive_buffers()
157-
self.update_outerbounds()
158-
self.update_innerbounds()
163+
self.receive_outerbounds()
164+
self.receive_innerbounds()
159165

160166
if self.global_rank == 0:
161167
self.print_init = True
@@ -241,9 +247,8 @@ def sync(self):
241247
self.send_ws()
242248
self.send_nonants()
243249
self.send_boundsout()
244-
self.update_receive_buffers()
245-
self.update_outerbounds()
246-
self.update_innerbounds()
250+
self.receive_outerbounds()
251+
self.receive_innerbounds()
247252
self.update_nonant_bounds()
248253
if self.opt.extensions is not None:
249254
self.opt.extobject.sync_with_spokes()
@@ -252,9 +257,8 @@ def sync_with_spokes(self):
252257
self.sync()
253258

254259
def sync_bounds(self):
255-
self.update_receive_buffers()
256-
self.update_outerbounds()
257-
self.update_innerbounds()
260+
self.receive_outerbounds()
261+
self.receive_innerbounds()
258262
self.update_nonant_bounds()
259263
self.send_boundsout()
260264

@@ -356,9 +360,8 @@ def sync(self, send_nonants=True):
356360
"""
357361
if send_nonants:
358362
self.send_nonants()
359-
self.update_receive_buffers()
360-
self.update_innerbounds()
361-
self.update_outerbounds()
363+
self.receive_outerbounds()
364+
self.receive_innerbounds()
362365
self.update_nonant_bounds()
363366
# in case LShaped ever gets extensions
364367
if getattr(self.opt, "extensions", None) is not None:

mpisppy/cylinders/spcommunicator.py

+20-16
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,16 @@ def main(self):
252252
"""
253253
pass
254254

255+
def sync(self):
256+
""" Every hub/spoke may have a sync function
257+
"""
258+
pass
259+
260+
def is_converged(self):
261+
""" Every hub/spoke may have a is_converged function
262+
"""
263+
return False
264+
255265
def finalize(self):
256266
""" Every hub/spoke may have a finalize function,
257267
which does some final calculations/flushing to
@@ -384,13 +394,6 @@ def get_receive_buffer(self,
384394
buf._is_new = False
385395
return False
386396

387-
def update_receive_buffers(self):
388-
for (key, recv_buf) in self.receive_buffers.items():
389-
field, rank = self._split_key(key)
390-
self.get_receive_buffer(recv_buf, field, rank)
391-
## End for
392-
return
393-
394397
def update_nonant_bounds(self):
395398
""" update the bounds on the nonanticipative variables based on
396399
Field.NONANT_LOWER_BOUNDS and Field.NONANT_UPPER_BOUNDS. The lower and
@@ -426,23 +429,25 @@ def update_nonant_bounds(self):
426429
if bounds_modified > 0:
427430
global_toc(f"{self.__class__.__name__}: tightened {int(bounds_modified)} variable bounds", self.cylinder_rank == 0)
428431

429-
def update_innerbounds(self):
430-
""" Update the inner bounds after receiving them from the spokes
432+
def receive_innerbounds(self):
433+
""" Get inner bounds from inner bound providers
431434
"""
432-
logger.debug(f"{self.__class__.__name__} is trying to update from InnerBounds")
435+
logger.debug(f"{self.__class__.__name__} is trying to receive from InnerBounds")
433436
for idx, cls, recv_buf in self.receive_field_spcomms[Field.OBJECTIVE_INNER_BOUND]:
434-
if recv_buf.is_new():
437+
is_new = self.get_receive_buffer(recv_buf, Field.OBJECTIVE_INNER_BOUND, idx)
438+
if is_new:
435439
bound = recv_buf[0]
436440
logger.debug("!! new InnerBound to opt {}".format(bound))
437441
self.BestInnerBound = self.InnerBoundUpdate(bound, cls, idx)
438442
logger.debug(f"{self.__class__.__name__} back from InnerBounds")
439443

440-
def update_outerbounds(self):
441-
""" Update the outer bounds after receiving them from the spokes
444+
def receive_outerbounds(self):
445+
""" Get outer bounds from outer bound providers
442446
"""
443-
logger.debug(f"{self.__class__.__name__} is trying to update from OuterBounds")
447+
logger.debug(f"{self.__class__.__name__} is trying to receive from OuterBounds")
444448
for idx, cls, recv_buf in self.receive_field_spcomms[Field.OBJECTIVE_OUTER_BOUND]:
445-
if recv_buf.is_new():
449+
is_new = self.get_receive_buffer(recv_buf, Field.OBJECTIVE_OUTER_BOUND, idx)
450+
if is_new:
446451
bound = recv_buf[0]
447452
logger.debug("!! new OuterBound to opt {}".format(bound))
448453
self.BestOuterBound = self.OuterBoundUpdate(bound, cls, idx)
@@ -485,4 +490,3 @@ def initialize_bound_values(self):
485490
self.BestOuterBound = inf
486491
self._inner_bound_update = lambda new, old : (new > old)
487492
self._outer_bound_update = lambda new, old : (new < old)
488-

mpisppy/cylinders/spoke.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,17 @@ def main(self):
3535
The main call for the Spoke. Derived classe
3636
should call the got_kill_signal method
3737
regularly to ensure all ranks terminate
38-
with the Hub, and to receive new data
39-
from other cylinders.
38+
with the Hub.
4039
"""
4140
pass
4241

42+
def update_receive_buffers(self):
43+
for (key, recv_buf) in self.receive_buffers.items():
44+
field, rank = self._split_key(key)
45+
self.get_receive_buffer(recv_buf, field, rank)
46+
## End for
47+
return
48+
4349

4450
class _BoundSpoke(Spoke):
4551
""" A base class for bound spokes

0 commit comments

Comments
 (0)