@@ -320,6 +320,12 @@ def local_chunks(self, domain, scales, rank=None, broadcast=False):
320
320
local_chunks .append (np .arange (start , end ))
321
321
return tuple (local_chunks )
322
322
323
+ def global_elements (self , domain , scales ):
324
+ """Global element indices by axis."""
325
+ global_shape = self .global_shape (domain , scales )
326
+ indices = [np .arange (n ) for n in global_shape ]
327
+ return tuple (indices )
328
+
323
329
def local_elements (self , domain , scales , rank = None , broadcast = False ):
324
330
"""Local element indices by axis."""
325
331
chunk_shape = self .chunk_shape (domain )
@@ -345,12 +351,7 @@ def valid_elements(self, tensorsig, domain, scales, rank=None, broadcast=False):
345
351
valid &= basis .valid_elements (tensorsig , grid_space [basis_axes ], elements [basis_axes ])
346
352
return valid
347
353
348
- @CachedMethod
349
- def local_group_arrays (self , domain , scales , rank = None , broadcast = False ):
350
- """Dense array of local groups (first axis)."""
351
- # Make dense array of local elements
352
- elements = self .local_elements (domain , scales , rank = rank , broadcast = broadcast )
353
- elements = np .array (np .meshgrid (* elements , indexing = 'ij' ))
354
+ def _group_arrays (self , elements , domain ):
354
355
# Convert to groups basis-by-basis
355
356
grid_space = self .grid_space
356
357
groups = np .zeros_like (elements )
@@ -360,6 +361,22 @@ def local_group_arrays(self, domain, scales, rank=None, broadcast=False):
360
361
groups [basis_axes ] = basis .elements_to_groups (grid_space [basis_axes ], elements [basis_axes ])
361
362
return groups
362
363
364
+ @CachedMethod
365
+ def local_group_arrays (self , domain , scales , rank = None , broadcast = False ):
366
+ """Dense array of local groups (first axis)."""
367
+ # Make dense array of local elements
368
+ elements = self .local_elements (domain , scales , rank = rank , broadcast = broadcast )
369
+ elements = np .array (np .meshgrid (* elements , indexing = 'ij' ))
370
+ return self ._group_arrays (elements , domain )
371
+
372
+ @CachedMethod
373
+ def global_group_arrays (self , domain , scales ):
374
+ """Dense array of local groups (first axis)."""
375
+ # Make dense array of local elements
376
+ elements = self .global_elements (domain , scales )
377
+ elements = np .array (np .meshgrid (* elements , indexing = 'ij' ))
378
+ return self ._group_arrays (elements , domain )
379
+
363
380
@CachedMethod
364
381
def local_groupsets (self , group_coupling , domain , scales , rank = None , broadcast = False ):
365
382
local_groupsets = self .local_group_arrays (domain , scales , rank = rank , broadcast = broadcast ).astype (object )
0 commit comments