|
22 | 22 | import numpy as np
|
23 | 23 | import abc
|
24 | 24 | import time
|
25 |
| -import itertools |
26 | 25 |
|
27 | 26 | from mpisppy.cylinders.spwindow import Field, FieldLengths, SPWindow
|
28 | 27 |
|
@@ -191,25 +190,40 @@ def _build_window_spec(self) -> dict[Field, int]:
|
191 | 190 |
|
192 | 191 | def _exchange_send_fields(self) -> None:
|
193 | 192 | """ Do an all-to-all so we know what the other communicators are sending """
|
194 |
| - self.send_fields_by_rank = self.strata_comm.allgather(tuple(self.send_buffers.keys())) |
| 193 | + send_buffers = tuple((k, buff._length) for k, buff in self.send_buffers.items()) |
| 194 | + self.send_fields_lengths_by_rank = self.strata_comm.allgather(send_buffers) |
| 195 | + |
| 196 | + self.send_fields_by_rank = {} |
195 | 197 |
|
196 | 198 | self.available_receive_fields = {}
|
197 |
| - for rank, fields in enumerate(self.send_fields_by_rank): |
| 199 | + for rank, fields_lengths in enumerate(self.send_fields_lengths_by_rank): |
198 | 200 | if rank == self.strata_rank:
|
199 | 201 | continue
|
200 |
| - for f in fields: |
| 202 | + self.send_fields_by_rank[rank] = [] |
| 203 | + for f, length in fields_lengths: |
201 | 204 | if f not in self.available_receive_fields:
|
202 | 205 | self.available_receive_fields[f] = []
|
203 | 206 | self.available_receive_fields[f].append(rank)
|
| 207 | + self.send_fields_by_rank[rank].append(f) |
| 208 | + |
| 209 | + # print(f"{self.__class__.__name__}: {self.available_receive_fields=}") |
204 | 210 |
|
205 | 211 | def register_recv_field(self, field: Field, origin: int, length: int = -1) -> RecvArray:
|
| 212 | + # print(f"{self.__class__.__name__}.register_recv_field, {field=}, {origin=}") |
206 | 213 | key = self._make_key(field, origin)
|
207 | 214 | if length == -1:
|
208 | 215 | length = self._field_lengths[field]
|
209 | 216 | if key in self.receive_buffers:
|
210 | 217 | my_fa = self.receive_buffers[key]
|
211 | 218 | assert(length + 1 == np.size(my_fa.array()))
|
212 | 219 | else:
|
| 220 | + available_fields_from_origin = self.send_fields_lengths_by_rank[origin] |
| 221 | + for _field, _length in available_fields_from_origin: |
| 222 | + if field == _field: |
| 223 | + assert length == _length |
| 224 | + break |
| 225 | + else: # couldn't find field! |
| 226 | + raise RuntimeError(f"Couldn't find {field=} from {origin=}") |
213 | 227 | my_fa = RecvArray(length)
|
214 | 228 | self.receive_buffers[key] = my_fa
|
215 | 229 | ## End if
|
@@ -284,6 +298,7 @@ def register_send_fields(self) -> None:
|
284 | 298 | self.register_send_field(field)
|
285 | 299 |
|
286 | 300 | def register_receive_fields(self) -> None:
|
| 301 | + # print(f"{self.__class__.__name__}: {self.receive_fields=}") |
287 | 302 | for field in self.receive_fields:
|
288 | 303 | self.receive_field_spcomms[field] = []
|
289 | 304 | for strata_rank, comm in enumerate(self.communicators):
|
|
0 commit comments