Skip to content

Commit 4e2085e

Browse files
committed
remove hub_from_spoke
1 parent c4b4efc commit 4e2085e

File tree

3 files changed

+74
-69
lines changed

3 files changed

+74
-69
lines changed

Diff for: mpisppy/cylinders/hub.py

+13-68
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,10 @@
77
# full copyright and license information.
88
###############################################################################
99

10-
import numpy as np
1110
import abc
1211
import logging
1312
import mpisppy.log
14-
from mpisppy.opt.aph import APH
1513

16-
from mpisppy import MPI
1714
from mpisppy.cylinders.spcommunicator import RecvArray, SendArray, SPCommunicator
1815
from math import inf
1916

@@ -83,7 +80,6 @@ def current_iteration(self):
8380
def main(self):
8481
pass
8582

86-
8783
def register_extension_recv_field(self, field: Field, strata_rank: int, buf_len: int = -1) -> RecvArray:
8884
"""
8985
Register an extensions interest in the given field from the given spoke. The hub
@@ -125,15 +121,14 @@ def sync_extension_fields(self):
125121
for key in self.extension_recv:
126122
ext_buf = self.receive_buffers[key]
127123
(field, srank) = self._split_key(key)
128-
ext_buf._is_new = self.hub_from_spoke(ext_buf, srank, field)
124+
ext_buf._is_new = self.get_receive_buffer(ext_buf, field, srank)
129125
## End for
130126
return
131127

132128
def clear_latest_chars(self):
133129
self.latest_ib_char = None
134130
self.latest_ob_char = None
135131

136-
137132
def compute_gaps(self):
138133
""" Compute the current absolute and relative gaps,
139134
using the current self.BestInnerBound and self.BestOuterBound
@@ -157,7 +152,6 @@ def compute_gaps(self):
157152
rel_gap = float("inf")
158153
return abs_gap, rel_gap
159154

160-
161155
def get_update_string(self):
162156
if self.latest_ib_char is None and \
163157
self.latest_ob_char is None:
@@ -236,7 +230,7 @@ def receive_innerbounds(self):
236230
"""
237231
logging.debug("Hub is trying to receive from InnerBounds")
238232
for idx, cls, recv_buf in self.receive_field_spcomms[Field.OBJECTIVE_INNER_BOUND]:
239-
is_new = self.hub_from_spoke(recv_buf, idx, Field.OBJECTIVE_INNER_BOUND)
233+
is_new = self.get_receive_buffer(recv_buf, Field.OBJECTIVE_INNER_BOUND, idx)
240234
if is_new:
241235
bound = recv_buf[0]
242236
logging.debug("!! new InnerBound to opt {}".format(bound))
@@ -250,7 +244,7 @@ def receive_outerbounds(self):
250244
"""
251245
logging.debug("Hub is trying to receive from OuterBounds")
252246
for idx, cls, recv_buf in self.receive_field_spcomms[Field.OBJECTIVE_OUTER_BOUND]:
253-
is_new = self.hub_from_spoke(recv_buf, idx, Field.OBJECTIVE_OUTER_BOUND)
247+
is_new = self.get_receive_buffer(recv_buf, Field.OBJECTIVE_OUTER_BOUND, idx)
254248
if is_new:
255249
bound = recv_buf[0]
256250
logging.debug("!! new OuterBound to opt {}".format(bound))
@@ -264,7 +258,7 @@ def OuterBoundUpdate(self, new_bound, cls=None, idx=None, char='*'):
264258
self.latest_ob_char = char
265259
self.last_ob_idx = 0
266260
else:
267-
self.latest_ib_char = cls.converger_spoke_char
261+
self.latest_ob_char = cls.converger_spoke_char
268262
self.last_ob_idx = idx
269263
return new_bound
270264
else:
@@ -326,7 +320,6 @@ def register_receive_fields(self):
326320

327321
return
328322

329-
330323
def register_send_fields(self):
331324
super().register_send_fields()
332325

@@ -336,63 +329,6 @@ def register_send_fields(self):
336329

337330
return
338331

339-
def hub_from_spoke(self,
340-
buf: RecvArray,
341-
spoke_num: int,
342-
field: Field,
343-
):
344-
""" spoke_num is the rank in the strata_comm, so it is 1-based not 0-based
345-
346-
Returns:
347-
is_new (bool): Indicates whether the "gotten" values are new,
348-
based on the write_id.
349-
"""
350-
buf._is_new = self._hub_from_spoke(buf.array(), spoke_num, field, buf.id())
351-
if buf.is_new():
352-
buf._pull_id()
353-
return buf.is_new()
354-
355-
def _hub_from_spoke(self,
356-
values: np.typing.NDArray,
357-
spoke_num: int,
358-
field: Field,
359-
last_write_id: int,
360-
):
361-
""" spoke_num is the rank in the strata_comm, so it is 1-based not 0-based
362-
363-
Returns:
364-
is_new (bool): Indicates whether the "gotten" values are new,
365-
based on the write_id.
366-
"""
367-
# so the window in each rank gets read at approximately the same time,
368-
# and so has the same write_id
369-
if not isinstance(self.opt, APH):
370-
self.cylinder_comm.Barrier()
371-
## End if
372-
self.window.get(values, spoke_num, field)
373-
374-
if isinstance(self.opt, APH):
375-
# # reverting part of changes from Ben getting rid of spoke sleep DLW jan 2023
376-
if values[-1] > last_write_id:
377-
return True
378-
else:
379-
new_id = int(values[-1])
380-
local_val = np.array((new_id,), 'i')
381-
sum_ids = np.zeros(1, 'i')
382-
self.cylinder_comm.Allreduce((local_val, MPI.INT),
383-
(sum_ids, MPI.INT),
384-
op=MPI.SUM)
385-
if new_id != sum_ids[0] / self.cylinder_comm.size:
386-
return False
387-
## End if
388-
if new_id > last_write_id or new_id < 0:
389-
return True
390-
## End if
391-
## End if
392-
393-
return False
394-
395-
396332
def send_terminate(self):
397333
""" Send an array of zeros with a -1 appended to the
398334
end to indicate termination. This function puts to the local
@@ -614,6 +550,15 @@ def main(self):
614550
logger.critical("aph debug main in hub.py")
615551
self.opt.APH_main(spcomm=self, finalize=False)
616552

553+
# overwrite the default behavior of this method for APH
554+
def get_receive_buffer(self,
555+
buf: RecvArray,
556+
field: Field,
557+
origin: int = -1,
558+
synchronize: bool = False,
559+
):
560+
return super().get_receive_buffer(buf, field, origin, synchronize)
561+
617562
def finalize(self):
618563
""" does PH.post_loops, returns Eobj """
619564
# NOTE: APH_main does NOT pass in extensions

Diff for: mpisppy/cylinders/spcommunicator.py

+60
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import abc
2424
import time
2525

26+
from mpisppy import MPI
2627
from mpisppy.cylinders.spwindow import Field, FieldLengths, SPWindow
2728

2829
def communicator_array(size):
@@ -309,3 +310,62 @@ def put_send_buffer(self, buf: SendArray, field: Field):
309310
buf._next_write_id()
310311
self.window.put(buf.array(), field)
311312
return
313+
314+
def get_receive_buffer(self,
315+
buf: RecvArray,
316+
field: Field,
317+
origin: int = -1,
318+
synchronize: bool = True,
319+
):
320+
""" Gets the specified values from another cylinder and copies them into
321+
the specified locally-owned buffer. Updates the write_id in the locally-
322+
owned buffer, if appropriate.
323+
324+
Args:
325+
buf (RecvArray) : Buffer to put the data in
326+
field (Field) : The source field
327+
origin (:obj:`int`, optional) : The rank on strata_comm to get the data.
328+
If not provided (or -1), will attempt to infer a unique origin. If
329+
no unique origin is found, will raise an error. Default: -1.
330+
synchronize (:obj:`bool`, optional) : If True, will only report
331+
updated data if the write_ids are the same across the cylinder_comm
332+
are identical. Default: True.
333+
334+
Returns:
335+
is_new (bool): Indicates whether the "gotten" values are new,
336+
based on the write_id.
337+
"""
338+
if not synchronize:
339+
self.cylinder_comm.Barrier()
340+
341+
if origin == -1:
342+
origin = self.fields_to_ranks[field][0]
343+
if len(self.fields_to_ranks[field]) > 1:
344+
raise RuntimeError(f"Non-unique origin for {field=}. Possible "
345+
f"origins are {self.fields_to_ranks[field]=}.")
346+
347+
last_id = buf.id()
348+
349+
self.window.get(buf.array(), origin, field)
350+
351+
new_id = int(buf.array()[-1])
352+
if synchronize:
353+
local_val = np.array((new_id,), 'i')
354+
sum_ids = np.zeros(1, 'i')
355+
self.cylinder_comm.Allreduce((local_val, MPI.INT),
356+
(sum_ids, MPI.INT),
357+
op=MPI.SUM)
358+
if new_id != sum_ids[0] / self.cylinder_comm.size:
359+
buf._is_new = False
360+
return False
361+
362+
else:
363+
if new_id <= last_id:
364+
buf._is_new = False
365+
return False
366+
367+
# in either case, now we have new data
368+
buf._is_new = True
369+
buf._pull_id()
370+
371+
return True

Diff for: mpisppy/cylinders/spoke.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import math
1515

1616
from mpisppy import MPI
17-
from mpisppy.cylinders.spcommunicator import RecvArray, SendArray, SPCommunicator
17+
from mpisppy.cylinders.spcommunicator import RecvArray, SPCommunicator
1818
from mpisppy.cylinders.spwindow import Field
1919

2020

0 commit comments

Comments
 (0)