@@ -42,7 +42,6 @@ class FieldArray:
4242 """
4343
4444 def __init__ (self , length : int ):
45- self ._length = length
4645 self ._array = communicator_array (length )
4746 self ._id = 0
4847 return
@@ -117,8 +116,6 @@ class SPCommunicator:
117116 receive_fields = ()
118117
119118 def __init__ (self , spbase_object , fullcomm , strata_comm , cylinder_comm , communicators , options = None ):
120- # flag for if the windows have been constructed
121- self ._windows_constructed = False
122119 self .fullcomm = fullcomm
123120 self .strata_comm = strata_comm
124121 self .cylinder_comm = cylinder_comm
@@ -152,9 +149,9 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communic
152149
153150 self .register_send_fields ()
154151
155- self ._exchange_send_fields ()
156- # TODO: here we can have a dynamic exchange of the send fields
157- # so we can do error checking (all-to-all in send fields)
152+ self ._make_windows ()
153+ self . _create_field_rank_mappings ()
154+
158155 self .register_receive_fields ()
159156
160157 # TODO: check that we have something in receive_field_spcomms??
@@ -188,25 +185,37 @@ def _build_window_spec(self) -> dict[Field, int]:
188185 ## End for
189186 return window_spec
190187
191- def _exchange_send_fields (self ) -> None :
192- """ Do an all-to-all so we know what the other communicators are sending """
193- send_buffers = tuple ((k , buff ._length ) for k , buff in self .send_buffers .items ())
194- self .send_fields_lengths_by_rank = self .strata_comm .allgather (send_buffers )
195-
196- self .send_fields_by_rank = {}
188+ def _create_field_rank_mappings (self ) -> None :
189+ self .fields_to_ranks = {}
190+ self .ranks_to_fields = {}
197191
198- self .available_receive_fields = {}
199- for rank , fields_lengths in enumerate (self .send_fields_lengths_by_rank ):
192+ for rank , buffer_layout in enumerate (self .window .strata_buffer_layouts ):
200193 if rank == self .strata_rank :
201194 continue
202- self .send_fields_by_rank [rank ] = []
203- for f , length in fields_lengths :
204- if f not in self .available_receive_fields :
205- self .available_receive_fields [f ] = []
206- self .available_receive_fields [f ].append (rank )
207- self .send_fields_by_rank [rank ].append (f )
208-
209- # print(f"{self.__class__.__name__}: {self.available_receive_fields=}")
195+ self .ranks_to_fields [rank ] = []
196+ for field in buffer_layout :
197+ if field not in self .fields_to_ranks :
198+ self .fields_to_ranks [field ] = []
199+ self .fields_to_ranks [field ].append (rank )
200+ self .ranks_to_fields [rank ].append (field )
201+
202+ # print(f"{self.__class__.__name__}: {self.fields_to_ranks=}, {self.ranks_to_fields=}")
203+
204+ def _validate_recv_field (self , field : Field , origin : int , length : int ):
205+ remote_buffer_layout = self .window .strata_buffer_layouts [origin ]
206+ if field not in remote_buffer_layout :
207+ raise RuntimeError (f"{ self .__class__ .__name__ } on local { self .strata_rank = } "
208+ f"could not find { field = } on remote rank { origin } with "
209+ f"class { self .communicators [origin ]['spcomm_class' ]} ."
210+ )
211+ _ , remote_length = remote_buffer_layout [field ]
212+ if (length + 1 ) != remote_length :
213+ raise RuntimeError (f"{ self .__class__ .__name__ } on local { self .strata_rank = } "
214+ f"{ field = } has length { length } on local "
215+ f"{ self .strata_rank = } and length { remote_length } "
216+ f"on remote rank { origin } with class "
217+ f"{ self .communicators [origin ]['spcomm_class' ]} ."
218+ )
210219
211220 def register_recv_field (self , field : Field , origin : int , length : int = - 1 ) -> RecvArray :
212221 # print(f"{self.__class__.__name__}.register_recv_field, {field=}, {origin=}")
@@ -217,13 +226,7 @@ def register_recv_field(self, field: Field, origin: int, length: int = -1) -> Re
217226 my_fa = self .receive_buffers [key ]
218227 assert (length + 1 == np .size (my_fa .array ()))
219228 else :
220- available_fields_from_origin = self .send_fields_lengths_by_rank [origin ]
221- for _field , _length in available_fields_from_origin :
222- if field == _field :
223- assert length == _length
224- break
225- else : # couldn't find field!
226- raise RuntimeError (f"Couldn't find { field = } from { origin = } " )
229+ self ._validate_recv_field (field , origin , length )
227230 my_fa = RecvArray (length )
228231 self .receive_buffers [key ] = my_fa
229232 ## End if
@@ -276,20 +279,10 @@ def hub_finalize(self):
276279 def allreduce_or (self , val ):
277280 return self .opt .allreduce_or (val )
278281
279- def free_windows (self ):
280- """
281- """
282- if self ._windows_constructed :
283- self .window .free ()
284- self ._windows_constructed = False
285-
286- def make_windows (self ) -> None :
287- if self ._windows_constructed :
288- return
282+ def _make_windows (self ) -> None :
289283
290284 window_spec = self ._build_window_spec ()
291285 self .window = SPWindow (window_spec , self .strata_comm )
292- self ._windows_constructed = True
293286
294287 return
295288
@@ -305,6 +298,6 @@ def register_receive_fields(self) -> None:
305298 if strata_rank == self .strata_rank :
306299 continue
307300 cls = comm ["spcomm_class" ]
308- if field in self .send_fields_by_rank [strata_rank ]:
301+ if field in self .ranks_to_fields [strata_rank ]:
309302 buff = self .register_recv_field (field , strata_rank )
310303 self .receive_field_spcomms [field ].append ((strata_rank , cls , buff ))
0 commit comments