Skip to content

Commit 76955c8

Browse files
committed
Fix coeff-space outputs and add tests
1 parent 97343d0 commit 76955c8

File tree

3 files changed

+33
-14
lines changed

3 files changed

+33
-14
lines changed

dedalus/core/distributor.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,12 @@ def local_chunks(self, domain, scales, rank=None, broadcast=False):
320320
local_chunks.append(np.arange(start, end))
321321
return tuple(local_chunks)
322322

323+
def global_elements(self, domain, scales):
324+
"""Global element indices by axis."""
325+
global_shape = self.global_shape(domain, scales)
326+
indices = [np.arange(n) for n in global_shape]
327+
return tuple(indices)
328+
323329
def local_elements(self, domain, scales, rank=None, broadcast=False):
324330
"""Local element indices by axis."""
325331
chunk_shape = self.chunk_shape(domain)
@@ -345,12 +351,7 @@ def valid_elements(self, tensorsig, domain, scales, rank=None, broadcast=False):
345351
valid &= basis.valid_elements(tensorsig, grid_space[basis_axes], elements[basis_axes])
346352
return valid
347353

348-
@CachedMethod
349-
def local_group_arrays(self, domain, scales, rank=None, broadcast=False):
350-
"""Dense array of local groups (first axis)."""
351-
# Make dense array of local elements
352-
elements = self.local_elements(domain, scales, rank=rank, broadcast=broadcast)
353-
elements = np.array(np.meshgrid(*elements, indexing='ij'))
354+
def _group_arrays(self, elements, domain):
354355
# Convert to groups basis-by-basis
355356
grid_space = self.grid_space
356357
groups = np.zeros_like(elements)
@@ -360,6 +361,22 @@ def local_group_arrays(self, domain, scales, rank=None, broadcast=False):
360361
groups[basis_axes] = basis.elements_to_groups(grid_space[basis_axes], elements[basis_axes])
361362
return groups
362363

364+
@CachedMethod
365+
def local_group_arrays(self, domain, scales, rank=None, broadcast=False):
366+
"""Dense array of local groups (first axis)."""
367+
# Make dense array of local elements
368+
elements = self.local_elements(domain, scales, rank=rank, broadcast=broadcast)
369+
elements = np.array(np.meshgrid(*elements, indexing='ij'))
370+
return self._group_arrays(elements, domain)
371+
372+
@CachedMethod
373+
def global_group_arrays(self, domain, scales):
374+
"""Dense array of local groups (first axis)."""
375+
# Make dense array of local elements
376+
elements = self.global_elements(domain, scales)
377+
elements = np.array(np.meshgrid(*elements, indexing='ij'))
378+
return self._group_arrays(elements, domain)
379+
363380
@CachedMethod
364381
def local_groupsets(self, group_coupling, domain, scales, rank=None, broadcast=False):
365382
local_groupsets = self.local_group_arrays(domain, scales, rank=rank, broadcast=broadcast).astype(object)

dedalus/core/evaluator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -647,10 +647,9 @@ def dset_metadata(self, task, task_num, dset, scale_group, gnc_shape, gnc_start,
647647
else:
648648
sn = 'k' + basis.coordsystem.coords[subaxis].name
649649
if virtual_file:
650-
data = basis.global_elements()[subaxis].ravel()
650+
data = layout.global_group_arrays(op.domain, scales)[subaxis]
651651
else:
652-
data = basis.local_elements()[subaxis].ravel()
653-
652+
data = layout.local_group_arrays(op.domain, scales)[subaxis]
654653

655654
if self.dist.comm_cart.rank == 0:
656655
scale_hash = hashlib.sha1(data).hexdigest()

dedalus/tests/test_output.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
@pytest.mark.parametrize('dealias', [1, 3/2])
1414
@pytest.mark.parametrize('output_scales', [1, 3/2, 2,
1515
pytest.param(1/2, marks=pytest.mark.xfail(reason="evaluator not copying correctly for scales < 1"))])
16-
def test_cartesian_output(dtype, dealias, output_scales):
16+
@pytest.mark.parametrize('output_layout', ['g', 'c'])
17+
def test_cartesian_output(dtype, dealias, output_scales, output_layout):
1718
Nx = Ny = Nz = 16
1819
Lx = Ly = Lz = 2 * np.pi
1920
# Bases
@@ -41,17 +42,18 @@ def test_cartesian_output(dtype, dealias, output_scales):
4142
tempdir = pathlib.Path(tempdir).stem
4243
output = solver.evaluator.add_file_handler(tempdir, iter=1)
4344
for task in tasks:
44-
output.add_task(task, layout='g', name=str(task), scales=output_scales)
45+
output.add_task(task, layout=output_layout, name=str(task), scales=output_scales)
4546
solver.evaluator.evaluate_handlers([output])
47+
output.process_virtual_file()
4648
# Check solution
4749
#post.merge_process_files('test_output')
4850
errors = []
49-
with h5py.File(f'{tempdir}/{tempdir}_s1/{tempdir}_s1_p0.h5', mode='r') as file:
51+
with h5py.File(f'{tempdir}/{tempdir}_s1.h5', mode='r') as file:
5052
for task in tasks:
5153
task_saved = file['tasks'][str(task)][-1]
5254
task = task.evaluate()
5355
task.change_scales(output_scales)
54-
errors.append(np.max(np.abs(task['g'] - task_saved)))
56+
errors.append(np.max(np.abs(task[output_layout] - task_saved)))
5557
assert np.allclose(errors, 0)
5658

5759

@@ -107,10 +109,11 @@ def test_spherical_output(Nphi, Ntheta, Nr, k, dealias, dtype, basis, output_sca
107109
for task in tasks:
108110
output.add_task(task, layout='g', name=str(task), scales=output_scales)
109111
solver.evaluator.evaluate_handlers([output])
112+
output.process_virtual_file()
110113
# Check solution
111114
#post.merge_process_files('test_output')
112115
errors = []
113-
with h5py.File(f'{tempdir}/{tempdir}_s1/{tempdir}_s1_p0.h5', mode='r') as file:
116+
with h5py.File(f'{tempdir}/{tempdir}_s1.h5', mode='r') as file:
114117
for task in tasks:
115118
task_saved = file['tasks'][str(task)][-1]
116119
task = task.evaluate()

0 commit comments

Comments
 (0)