Skip to content

Commit 4d3b668

Browse files
committed
unifying register_receive_fields
1 parent c547777 commit 4d3b668

File tree

4 files changed

+33
-27
lines changed

4 files changed

+33
-27
lines changed

mpisppy/cylinders/cross_scen_spoke.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,16 @@ def register_send_fields(self) -> None:
4141
self.all_nonant_len = vbuflen
4242
self.all_eta_len = nscen*local_scen_count
4343

44-
self.all_nonants = self.register_recv_field(Field.NONANT, 0, vbuflen)
45-
self.all_etas = self.register_recv_field(Field.CROSS_SCENARIO_COST, 0, nscen * nscen)
4644

4745
self.all_coefs = self.send_buffers[Field.CROSS_SCENARIO_CUT]
4846

4947
return
5048

49+
def register_receive_fields(self):
50+
super().register_receive_fields()
51+
self.all_nonants = self.register_recv_field(Field.NONANT, 0)
52+
self.all_etas = self.register_recv_field(Field.CROSS_SCENARIO_COST, 0)
53+
5154
def prep_cs_cuts(self):
5255
# create a map scenario -> index, this index is used for various lists containing scenario dependent info.
5356
self.scenario_to_index = { scen : indx for indx, scen in enumerate(self.opt.all_scenario_names) }

mpisppy/cylinders/hub.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ class Hub(SPCommunicator):
3535
_hub_algo_best_bound_provider = False
3636

3737
def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communicators, options=None):
38+
# The extensions will be registered in SPCommunicator.__init__
39+
self.extension_recv = set()
40+
3841
super().__init__(spbase_object, fullcomm, strata_comm, cylinder_comm, communicators, options=options)
42+
3943
logger.debug(f"Built the hub object on global rank {fullcomm.Get_rank()}")
4044
# for logging
4145
self.print_init = True
@@ -47,8 +51,6 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communic
4751
self.stalled_iter_cnt = 0
4852
self.last_gap = float('inf') # abs_gap tracker
4953

50-
self.extension_recv = set()
51-
5254
self.initialize_bound_values()
5355

5456
return

mpisppy/cylinders/spcommunicator.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import numpy as np
2323
import abc
2424
import time
25-
import itertools
2625

2726
from mpisppy.cylinders.spwindow import Field, FieldLengths, SPWindow
2827

@@ -191,25 +190,40 @@ def _build_window_spec(self) -> dict[Field, int]:
191190

192191
def _exchange_send_fields(self) -> None:
193192
""" Do an all-to-all so we know what the other communicators are sending """
194-
self.send_fields_by_rank = self.strata_comm.allgather(tuple(self.send_buffers.keys()))
193+
send_buffers = tuple((k, buff._length) for k, buff in self.send_buffers.items())
194+
self.send_fields_lengths_by_rank = self.strata_comm.allgather(send_buffers)
195+
196+
self.send_fields_by_rank = {}
195197

196198
self.available_receive_fields = {}
197-
for rank, fields in enumerate(self.send_fields_by_rank):
199+
for rank, fields_lengths in enumerate(self.send_fields_lengths_by_rank):
198200
if rank == self.strata_rank:
199201
continue
200-
for f in fields:
202+
self.send_fields_by_rank[rank] = []
203+
for f, length in fields_lengths:
201204
if f not in self.available_receive_fields:
202205
self.available_receive_fields[f] = []
203206
self.available_receive_fields[f].append(rank)
207+
self.send_fields_by_rank[rank].append(f)
208+
209+
# print(f"{self.__class__.__name__}: {self.available_receive_fields=}")
204210

205211
def register_recv_field(self, field: Field, origin: int, length: int = -1) -> RecvArray:
212+
# print(f"{self.__class__.__name__}.register_recv_field, {field=}, {origin=}")
206213
key = self._make_key(field, origin)
207214
if length == -1:
208215
length = self._field_lengths[field]
209216
if key in self.receive_buffers:
210217
my_fa = self.receive_buffers[key]
211218
assert(length + 1 == np.size(my_fa.array()))
212219
else:
220+
available_fields_from_origin = self.send_fields_lengths_by_rank[origin]
221+
for _field, _length in available_fields_from_origin:
222+
if field == _field:
223+
assert length == _length
224+
break
225+
else: # couldn't find field!
226+
raise RuntimeError(f"Couldn't find {field=} from {origin=}")
213227
my_fa = RecvArray(length)
214228
self.receive_buffers[key] = my_fa
215229
## End if
@@ -284,6 +298,7 @@ def register_send_fields(self) -> None:
284298
self.register_send_field(field)
285299

286300
def register_receive_fields(self) -> None:
301+
# print(f"{self.__class__.__name__}: {self.receive_fields=}")
287302
for field in self.receive_fields:
288303
self.receive_field_spcomms[field] = []
289304
for strata_rank, comm in enumerate(self.communicators):

mpisppy/cylinders/spoke.py

+5-19
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communic
2929

3030
self.last_call_to_got_kill_signal = time.time()
3131

32-
# All spokes need the SHUTDOWN field to know when to terminate. Just
33-
# register that here.
34-
self.shutdown = self.register_recv_field(Field.SHUTDOWN, 0, 1)
35-
3632
return
3733

3834
def spoke_to_hub(self, buf: SendArray, field: Field):
@@ -95,7 +91,7 @@ def _spoke_from_hub(self,
9591
def _got_kill_signal(self):
9692
shutdown_buf = self.receive_buffers[self._make_key(Field.SHUTDOWN, 0)]
9793
if shutdown_buf.is_new():
98-
shutdown = (self.shutdown[0] == 1.0)
94+
shutdown = (shutdown_buf[0] == 1.0)
9995
else:
10096
shutdown = False
10197
## End if
@@ -158,9 +154,12 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, options=
158154
def register_send_fields(self) -> None:
159155
super().register_send_fields()
160156
self._bound = self.send_buffers[self.bound_type()]
161-
self._hub_bounds = self.register_recv_field(Field.BEST_OBJECTIVE_BOUNDS, 0, 2)
162157
return
163158

159+
def register_receive_fields(self) -> None:
160+
super().register_receive_fields()
161+
self._hub_bounds = self.register_recv_field(Field.BEST_OBJECTIVE_BOUNDS, 0, 2)
162+
164163
@abc.abstractmethod
165164
def bound_type(self) -> Field:
166165
pass
@@ -203,19 +202,6 @@ def nonant_len_type(self) -> Field:
203202
# TODO: Make this a static method?
204203
pass
205204

206-
def register_send_fields(self) -> None:
207-
208-
super().register_send_fields()
209-
210-
vbuflen = 0
211-
for s in self.opt.local_scenarios.values():
212-
vbuflen += len(s._mpisppy_data.nonant_indices)
213-
## End for
214-
215-
self.register_recv_field(self.nonant_len_type(), 0, vbuflen)
216-
217-
return
218-
219205

220206
class InnerBoundSpoke(_BoundSpoke):
221207
""" For Spokes that provide an inner bound through self.bound to the

0 commit comments

Comments
 (0)