Skip to content

Commit 0e680a9

Browse files
committed
promote Spoke.update_receive_buffers to SPCommuicator
1 parent 0c3a202 commit 0e680a9

File tree

3 files changed

+29
-47
lines changed

3 files changed

+29
-47
lines changed

mpisppy/cylinders/hub.py

+20-29
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,6 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communic
5353
def setup_hub(self):
5454
pass
5555

56-
@abc.abstractmethod
57-
def sync(self):
58-
""" To be called within the whichever optimization algorithm
59-
is being run on the hub (e.g. PH)
60-
"""
61-
pass
62-
6356
@abc.abstractmethod
6457
def is_converged(self):
6558
""" The hub has the ability to halt the optimization algorithm on the
@@ -167,37 +160,32 @@ def determine_termination(self):
167160
return abs_gap_satisfied or rel_gap_satisfied or max_stalled_satisfied
168161

169162
def hub_finalize(self):
170-
self.receive_outerbounds()
171-
self.receive_innerbounds()
163+
self.update_receive_buffers()
164+
self.update_outerbounds()
165+
self.update_innerbounds()
172166

173167
if self.global_rank == 0:
174168
self.print_init = True
175169
global_toc("Statistics at termination", True)
176170
self.screen_trace()
177171

178-
def receive_innerbounds(self):
179-
""" Get inner bounds from inner bound spokes
180-
NOTE: Does not check if there _are_ innerbound spokes
181-
(but should be harmless to call if there are none)
172+
def update_innerbounds(self):
173+
""" Update the inner bounds after receiving them from the spokes
182174
"""
183-
logging.debug("Hub is trying to receive from InnerBounds")
175+
logging.debug("Hub is trying to update from InnerBounds")
184176
for idx, cls, recv_buf in self.receive_field_spcomms[Field.OBJECTIVE_INNER_BOUND]:
185-
is_new = self.get_receive_buffer(recv_buf, Field.OBJECTIVE_INNER_BOUND, idx)
186-
if is_new:
177+
if recv_buf.is_new():
187178
bound = recv_buf[0]
188179
logging.debug("!! new InnerBound to opt {}".format(bound))
189180
self.BestInnerBound = self.InnerBoundUpdate(bound, cls, idx)
190181
logging.debug("ph back from InnerBounds")
191182

192-
def receive_outerbounds(self):
193-
""" Get outer bounds from outer bound spokes
194-
NOTE: Does not check if there _are_ outerbound spokes
195-
(but should be harmless to call if there are none)
183+
def update_outerbounds(self):
184+
""" Update the outer bounds after receiving them from the spokes
196185
"""
197-
logging.debug("Hub is trying to receive from OuterBounds")
186+
logging.debug("Hub is trying to update from OuterBounds")
198187
for idx, cls, recv_buf in self.receive_field_spcomms[Field.OBJECTIVE_OUTER_BOUND]:
199-
is_new = self.get_receive_buffer(recv_buf, Field.OBJECTIVE_OUTER_BOUND, idx)
200-
if is_new:
188+
if recv_buf.is_new():
201189
bound = recv_buf[0]
202190
logging.debug("!! new OuterBound to opt {}".format(bound))
203191
self.BestOuterBound = self.OuterBoundUpdate(bound, cls, idx)
@@ -320,17 +308,19 @@ def sync(self):
320308
self.send_ws()
321309
self.send_nonants()
322310
self.send_boundsout()
323-
self.receive_outerbounds()
324-
self.receive_innerbounds()
311+
self.update_receive_buffers()
312+
self.update_outerbounds()
313+
self.update_innerbounds()
325314
if self.opt.extensions is not None:
326315
self.opt.extobject.sync_with_spokes()
327316

328317
def sync_with_spokes(self):
329318
self.sync()
330319

331320
def sync_bounds(self):
332-
self.receive_outerbounds()
333-
self.receive_innerbounds()
321+
self.update_receive_buffers()
322+
self.update_outerbounds()
323+
self.update_innerbounds()
334324
self.send_boundsout()
335325

336326
def sync_extensions(self):
@@ -431,8 +421,9 @@ def sync(self, send_nonants=True):
431421
"""
432422
if send_nonants:
433423
self.send_nonants()
434-
self.receive_outerbounds()
435-
self.receive_innerbounds()
424+
self.update_receive_buffers()
425+
self.update_innerbounds()
426+
self.update_outerbounds()
436427
# in case LShaped ever gets extensions
437428
if getattr(self.opt, "extensions", None) is not None:
438429
self.opt.extobject.sync_with_spokes()

mpisppy/cylinders/spcommunicator.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -240,16 +240,6 @@ def main(self):
240240
"""
241241
pass
242242

243-
def sync(self):
244-
""" Every hub/spoke may have a sync function
245-
"""
246-
pass
247-
248-
def is_converged(self):
249-
""" Every hub/spoke may have a is_converged function
250-
"""
251-
return False
252-
253243
def finalize(self):
254244
""" Every hub/spoke may have a finalize function,
255245
which does some final calculations/flushing to
@@ -381,3 +371,10 @@ def get_receive_buffer(self,
381371
else:
382372
buf._is_new = False
383373
return False
374+
375+
def update_receive_buffers(self):
376+
for (key, recv_buf) in self.receive_buffers.items():
377+
field, rank = self._split_key(key)
378+
self.get_receive_buffer(recv_buf, field, rank)
379+
## End for
380+
return

mpisppy/cylinders/spoke.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,11 @@ def main(self):
3838
The main call for the Spoke. Derived classe
3939
should call the got_kill_signal method
4040
regularly to ensure all ranks terminate
41-
with the Hub.
41+
with the Hub, and to receive new data
42+
from other cylinders.
4243
"""
4344
pass
4445

45-
def update_receive_buffers(self):
46-
for (key, recv_buf) in self.receive_buffers.items():
47-
field, rank = self._split_key(key)
48-
self.get_receive_buffer(recv_buf, field, rank)
49-
## End for
50-
return
51-
5246

5347
class _BoundSpoke(Spoke):
5448
""" A base class for bound spokes

0 commit comments

Comments
 (0)