Skip to content

Commit cf2d51e

Browse files
committed
Support (N, M, L) per-dimension zones in octree containers
1 parent 1e94e33 commit cf2d51e

13 files changed

Lines changed: 129 additions & 78 deletions

File tree

yt/data_objects/index_subobjects/octree_subset.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ class OctreeSubset(YTSelectionContainer, abc.ABC):
4646

4747
def __init__(self, base_region, domain, ds, num_zones=2, num_ghost_zones=0):
4848
super().__init__(ds, None)
49-
self._num_zones = num_zones
49+
if hasattr(num_zones, "__len__"):
50+
self._num_zones = np.array(num_zones, dtype="int64")
51+
else:
52+
self._num_zones = np.array([num_zones, num_zones, num_zones], dtype="int64")
5053
self._num_ghost_zones = num_ghost_zones
5154
self.domain = domain
5255
self.domain_id = domain.domain_id
@@ -80,23 +83,28 @@ def __getitem__(self, key):
8083

8184
@property
8285
def nz(self):
83-
return self._num_zones + 2 * self._num_ghost_zones
86+
nz = self._num_zones + 2 * self._num_ghost_zones
87+
if hasattr(nz, "__len__"):
88+
return nz
89+
return np.array([nz, nz, nz], dtype="int64")
8490

8591
def get_bbox(self):
8692
return self.base_region.get_bbox()
8793

8894
def _reshape_vals(self, arr):
8995
nz = self.nz
96+
nzx, nzy, nzz = nz[0], nz[1], nz[2]
97+
nzones = nzx * nzy * nzz
9098
if len(arr.shape) <= 2:
91-
n_oct = arr.shape[0] // (nz**3)
99+
n_oct = arr.shape[0] // nzones
92100
elif arr.shape[-1] == 3:
93101
n_oct = arr.shape[-2]
94102
else:
95103
n_oct = arr.shape[-1]
96-
if arr.size == nz * nz * nz * n_oct:
97-
new_shape = (nz, nz, nz, n_oct)
98-
elif arr.size == nz * nz * nz * n_oct * 3:
99-
new_shape = (nz, nz, nz, n_oct, 3)
104+
if arr.size == nzones * n_oct:
105+
new_shape = (nzx, nzy, nzz, n_oct)
106+
elif arr.size == nzones * n_oct * 3:
107+
new_shape = (nzx, nzy, nzz, n_oct, 3)
100108
else:
101109
raise RuntimeError
102110
# Note that if arr is already F-contiguous, this *shouldn't* copy the
@@ -172,7 +180,7 @@ def deposit(self, positions, fields=None, method=None, kernel_name="cubic"):
172180
if cls is None:
173181
raise YTParticleDepositionNotImplemented(method)
174182
nz = self.nz
175-
nvals = (nz, nz, nz, (self.domain_ind >= 0).sum())
183+
nvals = (int(nz[0]), int(nz[1]), int(nz[2]), (self.domain_ind >= 0).sum())
176184
if np.max(self.domain_ind) >= nvals[-1]:
177185
print(
178186
f"nocts, domain_ind >= 0, max {self.oct_handler.nocts} {nvals[-1]} {np.max(self.domain_ind)}"
@@ -335,7 +343,7 @@ def smooth(
335343
[1, 1, 1],
336344
self.ds.domain_left_edge,
337345
self.ds.domain_right_edge,
338-
num_zones=self._nz,
346+
num_zones=self._num_zones,
339347
)
340348
# This should ensure we get everything within one neighbor of home.
341349
particle_octree.n_ref = nneighbors * 2
@@ -354,7 +362,7 @@ def smooth(
354362
raise YTParticleDepositionNotImplemented(method)
355363
nz = self.nz
356364
mdom_ind = self.domain_ind
357-
nvals = (nz, nz, nz, (mdom_ind >= 0).sum())
365+
nvals = (int(nz[0]), int(nz[1]), int(nz[2]), (mdom_ind >= 0).sum())
358366
op = cls(nvals, len(fields), nneighbors, kernel_name)
359367
op.initialize()
360368
mylog.debug(
@@ -455,7 +463,7 @@ def particle_operation(
455463
raise YTParticleDepositionNotImplemented(method)
456464
nz = self.nz
457465
mdom_ind = self.domain_ind
458-
nvals = (nz, nz, nz, (mdom_ind >= 0).sum())
466+
nvals = (int(nz[0]), int(nz[1]), int(nz[2]), (mdom_ind >= 0).sum())
459467
op = cls(nvals, len(fields), nneighbors, kernel_name)
460468
op.initialize()
461469
mylog.debug(
@@ -548,7 +556,7 @@ def __init__(self, ind, block_slice):
548556
self.ind = ind
549557
self.block_slice = block_slice
550558
nz = self.block_slice.octree_subset.nz
551-
self.ActiveDimensions = np.array([nz, nz, nz], dtype="int64")
559+
self.ActiveDimensions = np.array([nz[0], nz[1], nz[2]], dtype="int64")
552560
self.ds = block_slice.ds
553561

554562
def __getitem__(self, key):

yt/data_objects/tests/test_octree.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
from numpy.testing import assert_almost_equal, assert_equal
33

4+
from yt.geometry.oct_container import OctreeContainer
45
from yt.testing import fake_sph_grid_ds
56

67
n_ref = 4
@@ -118,3 +119,26 @@ def test_octree_properties():
118119
refined = octree["index", "refined"]
119120
refined_ans = np.array([True] + [False] * 7 + [True] + [False] * 8, dtype=np.bool_)
120121
assert_equal(refined, refined_ans)
122+
123+
124+
def test_num_zones_tuple():
125+
"""
126+
Test that OctreeContainer accepts num_zones as a scalar or a tuple (N, M, L).
127+
Both should correctly set per-dimension zone counts.
128+
"""
129+
# Scalar: all dimensions equal
130+
oct_scalar = OctreeContainer(
131+
[1, 1, 1], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], num_zones=2
132+
)
133+
# Tuple: potentially different per-dimension
134+
oct_tuple = OctreeContainer(
135+
[1, 1, 1], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], num_zones=(2, 2, 2)
136+
)
137+
# Non-uniform tuple
138+
oct_nonuniform = OctreeContainer(
139+
[1, 1, 1], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], num_zones=(2, 3, 4)
140+
)
141+
# Verify that creating these containers doesn't raise exceptions
142+
assert oct_scalar is not None
143+
assert oct_tuple is not None
144+
assert oct_nonuniform is not None

yt/frontends/artio/data_structures.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def deposit(self, positions, fields=None, method=None, kernel_name="cubic"):
128128
if cls is None:
129129
raise YTParticleDepositionNotImplemented(method)
130130
nz = self.nz
131-
nvals = (nz, nz, nz, self.ires.size)
131+
nvals = (int(nz[0]), int(nz[1]), int(nz[2]), self.ires.size)
132132
# We allocate number of zones, not number of octs
133133
op = cls(nvals, kernel_name)
134134
op.initialize()
@@ -241,6 +241,7 @@ def _identify_base_chunk(self, dobj):
241241
sfc_start = getattr(dobj, "sfc_start", None)
242242
sfc_end = getattr(dobj, "sfc_end", None)
243243
nz = getattr(dobj, "_num_zones", 0)
244+
nz_scalar = int(np.asarray(nz).flat[0]) if hasattr(nz, "__len__") else nz
244245
if all_data:
245246
mylog.debug("Selecting entire artio domain")
246247
list_sfc_ranges = self.ds._handle.root_sfc_ranges_all(
@@ -271,7 +272,7 @@ def _identify_base_chunk(self, dobj):
271272
)
272273
range_handler.construct_mesh()
273274
self.range_handlers[start, end] = range_handler
274-
if nz != 2:
275+
if nz_scalar != 2:
275276
ci.append(
276277
ARTIORootMeshSubset(
277278
base_region,
@@ -281,7 +282,7 @@ def _identify_base_chunk(self, dobj):
281282
self.ds,
282283
)
283284
)
284-
if nz != 1 and range_handler.total_octs > 0:
285+
if nz_scalar != 1 and range_handler.total_octs > 0:
285286
ci.append(
286287
ARTIOOctreeSubset(
287288
base_region,

yt/frontends/ramses/data_structures.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -467,8 +467,8 @@ def _fill_with_ghostzones(
467467
fields = [f for ft, f in fields]
468468
tr = {}
469469

470-
cell_count = (
471-
selector.count_octs(self.oct_handler, self.domain_id) * self.nz**ndim
470+
cell_count = selector.count_octs(self.oct_handler, self.domain_id) * int(
471+
np.prod(self.nz[:ndim])
472472
)
473473

474474
# Initializing data container
@@ -520,7 +520,7 @@ def fwidth(self):
520520
# new_fwidth contains the fwidth of the oct+ghost zones
521521
# this is a constant array in each oct, so we simply copy
522522
# the oct value using numpy fancy-indexing
523-
new_fwidth = np.zeros((n_oct, self.nz**3, 3), dtype=fwidth.dtype)
523+
new_fwidth = np.zeros((n_oct, int(np.prod(self.nz)), 3), dtype=fwidth.dtype)
524524
new_fwidth[:, :, :] = fwidth[:, 0:1, :]
525525
fwidth = new_fwidth.reshape(-1, 3)
526526
return fwidth
@@ -538,7 +538,7 @@ def fcoords(self):
538538
self.selector, self._num_ghost_zones
539539
)
540540

541-
N_per_oct = self.nz**3
541+
N_per_oct = int(np.prod(self.nz))
542542
oct_inds = oct_inds.reshape(-1, N_per_oct)
543543
cell_inds = cell_inds.reshape(-1, N_per_oct)
544544

yt/frontends/stream/data_structures.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -832,8 +832,8 @@ def _fill_no_ghostzones(self, content, dest, selector, offset):
832832
def _fill_with_ghostzones(self, content, dest, selector, offset):
833833
oct_handler = self.oct_handler
834834
ndim = self.ds.dimensionality
835-
cell_count = (
836-
selector.count_octs(self.oct_handler, self.domain_id) * self.nz**ndim
835+
cell_count = selector.count_octs(self.oct_handler, self.domain_id) * int(
836+
np.prod(self.nz[:ndim])
837837
)
838838

839839
gz_cache = getattr(self, "_ghost_zone_cache", None)

yt/geometry/_selection_routines/selector_object.pxi

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ cdef class SelectorObject:
159159
visitor.pos[1] = (visitor.pos[1] >> 1)
160160
visitor.pos[2] = (visitor.pos[2] >> 1)
161161
visitor.level -= 1
162-
elif this_level == 1 and visitor.nz > 1:
162+
elif this_level == 1 and (visitor.nz[0] > 1 or visitor.nz[1] > 1 or visitor.nz[2] > 1):
163163
visitor.global_index += increment
164164
increment = 0
165165
self.visit_oct_cells(root, ch, spos, sdds,
@@ -178,10 +178,10 @@ cdef class SelectorObject:
178178
cdef void visit_oct_cells(self, Oct *root, Oct *ch,
179179
np.float64_t spos[3], np.float64_t sdds[3],
180180
OctVisitor visitor, int i, int j, int k):
181-
# We can short-circuit the whole process if data.nz == 2.
181+
# We can short-circuit the whole process if data.nz == 2 in all dims.
182182
# This saves us some funny-business.
183183
cdef int selected
184-
if visitor.nz == 2:
184+
if visitor.nz[0] == 2 and visitor.nz[1] == 2 and visitor.nz[2] == 2:
185185
selected = self.select_cell(spos, sdds)
186186
if ch != NULL:
187187
selected *= self.overlap_cells
@@ -199,22 +199,25 @@ cdef class SelectorObject:
199199
cdef np.float64_t dds[3]
200200
cdef np.float64_t pos[3]
201201
cdef int ci, cj, ck
202-
cdef int nr = (visitor.nz >> 1)
202+
cdef int nr[3]
203+
nr[0] = visitor.nz[0] >> 1
204+
nr[1] = visitor.nz[1] >> 1
205+
nr[2] = visitor.nz[2] >> 1
203206
for ci in range(3):
204-
dds[ci] = sdds[ci] / nr
207+
dds[ci] = sdds[ci] / nr[ci]
205208
# Boot strap at the first index.
206209
pos[0] = (spos[0] - sdds[0]/2.0) + dds[0] * 0.5
207-
for ci in range(nr):
210+
for ci in range(nr[0]):
208211
pos[1] = (spos[1] - sdds[1]/2.0) + dds[1] * 0.5
209-
for cj in range(nr):
212+
for cj in range(nr[1]):
210213
pos[2] = (spos[2] - sdds[2]/2.0) + dds[2] * 0.5
211-
for ck in range(nr):
214+
for ck in range(nr[2]):
212215
selected = self.select_cell(pos, dds)
213216
if ch != NULL:
214217
selected *= self.overlap_cells
215-
visitor.ind[0] = ci + i * nr
216-
visitor.ind[1] = cj + j * nr
217-
visitor.ind[2] = ck + k * nr
218+
visitor.ind[0] = ci + i * nr[0]
219+
visitor.ind[1] = cj + j * nr[1]
220+
visitor.ind[2] = ck + k * nr[2]
218221
visitor.visit(root, selected)
219222
pos[2] += dds[2]
220223
pos[1] += dds[1]

yt/geometry/oct_container.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ cdef class OctreeContainer:
5757
cdef int partial_coverage
5858
cdef int level_offset
5959
cdef int nn[3]
60-
cdef np.uint8_t nz
60+
cdef np.uint8_t nz[3]
6161
cdef np.float64_t DLE[3]
6262
cdef np.float64_t DRE[3]
6363
cdef public np.int64_t nocts

0 commit comments

Comments
 (0)