Skip to content

Commit 11e166f

Browse files
authored
Merge pull request #5391 from cphyc/feature/support-non-cubic-zones
Support (N, M, L) per-dimension zones in octree containers
2 parents fcfc44e + e569ea9 commit 11e166f

13 files changed

Lines changed: 180 additions & 112 deletions

File tree

yt/data_objects/index_subobjects/octree_subset.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ class OctreeSubset(YTSelectionContainer, abc.ABC):
4646

4747
def __init__(self, base_region, domain, ds, num_zones=2, num_ghost_zones=0):
4848
super().__init__(ds, None)
49-
self._num_zones = num_zones
49+
if hasattr(num_zones, "__len__"):
50+
self._num_zones = np.array(num_zones, dtype="int64")
51+
else:
52+
self._num_zones = np.array([num_zones, num_zones, num_zones], dtype="int64")
5053
self._num_ghost_zones = num_ghost_zones
5154
self.domain = domain
5255
self.domain_id = domain.domain_id
@@ -80,23 +83,28 @@ def __getitem__(self, key):
8083

8184
@property
8285
def nz(self):
83-
return self._num_zones + 2 * self._num_ghost_zones
86+
nz = self._num_zones + 2 * self._num_ghost_zones
87+
if hasattr(nz, "__len__"):
88+
return nz
89+
return np.array([nz, nz, nz], dtype="int64")
8490

8591
def get_bbox(self):
8692
return self.base_region.get_bbox()
8793

8894
def _reshape_vals(self, arr):
8995
nz = self.nz
96+
nzx, nzy, nzz = nz[0], nz[1], nz[2]
97+
nzones = nzx * nzy * nzz
9098
if len(arr.shape) <= 2:
91-
n_oct = arr.shape[0] // (nz**3)
99+
n_oct = arr.shape[0] // nzones
92100
elif arr.shape[-1] == 3:
93101
n_oct = arr.shape[-2]
94102
else:
95103
n_oct = arr.shape[-1]
96-
if arr.size == nz * nz * nz * n_oct:
97-
new_shape = (nz, nz, nz, n_oct)
98-
elif arr.size == nz * nz * nz * n_oct * 3:
99-
new_shape = (nz, nz, nz, n_oct, 3)
104+
if arr.size == nzones * n_oct:
105+
new_shape = (nzx, nzy, nzz, n_oct)
106+
elif arr.size == nzones * n_oct * 3:
107+
new_shape = (nzx, nzy, nzz, n_oct, 3)
100108
else:
101109
raise RuntimeError
102110
# Note that if arr is already F-contiguous, this *shouldn't* copy the
@@ -172,7 +180,7 @@ def deposit(self, positions, fields=None, method=None, kernel_name="cubic"):
172180
if cls is None:
173181
raise YTParticleDepositionNotImplemented(method)
174182
nz = self.nz
175-
nvals = (nz, nz, nz, (self.domain_ind >= 0).sum())
183+
nvals = (int(nz[0]), int(nz[1]), int(nz[2]), (self.domain_ind >= 0).sum())
176184
if np.max(self.domain_ind) >= nvals[-1]:
177185
print(
178186
f"nocts, domain_ind >= 0, max {self.oct_handler.nocts} {nvals[-1]} {np.max(self.domain_ind)}"
@@ -335,7 +343,7 @@ def smooth(
335343
[1, 1, 1],
336344
self.ds.domain_left_edge,
337345
self.ds.domain_right_edge,
338-
num_zones=self._nz,
346+
num_zones=self._num_zones,
339347
)
340348
# This should ensure we get everything within one neighbor of home.
341349
particle_octree.n_ref = nneighbors * 2
@@ -354,7 +362,7 @@ def smooth(
354362
raise YTParticleDepositionNotImplemented(method)
355363
nz = self.nz
356364
mdom_ind = self.domain_ind
357-
nvals = (nz, nz, nz, (mdom_ind >= 0).sum())
365+
nvals = (int(nz[0]), int(nz[1]), int(nz[2]), (mdom_ind >= 0).sum())
358366
op = cls(nvals, len(fields), nneighbors, kernel_name)
359367
op.initialize()
360368
mylog.debug(
@@ -455,7 +463,7 @@ def particle_operation(
455463
raise YTParticleDepositionNotImplemented(method)
456464
nz = self.nz
457465
mdom_ind = self.domain_ind
458-
nvals = (nz, nz, nz, (mdom_ind >= 0).sum())
466+
nvals = (int(nz[0]), int(nz[1]), int(nz[2]), (mdom_ind >= 0).sum())
459467
op = cls(nvals, len(fields), nneighbors, kernel_name)
460468
op.initialize()
461469
mylog.debug(
@@ -548,7 +556,7 @@ def __init__(self, ind, block_slice):
548556
self.ind = ind
549557
self.block_slice = block_slice
550558
nz = self.block_slice.octree_subset.nz
551-
self.ActiveDimensions = np.array([nz, nz, nz], dtype="int64")
559+
self.ActiveDimensions = np.array([nz[0], nz[1], nz[2]], dtype="int64")
552560
self.ds = block_slice.ds
553561

554562
def __getitem__(self, key):

yt/data_objects/tests/test_octree.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
from numpy.testing import assert_almost_equal, assert_equal
33

4+
from yt.geometry.oct_container import OctreeContainer
45
from yt.testing import fake_sph_grid_ds
56

67
n_ref = 4
@@ -118,3 +119,26 @@ def test_octree_properties():
118119
refined = octree["index", "refined"]
119120
refined_ans = np.array([True] + [False] * 7 + [True] + [False] * 8, dtype=np.bool_)
120121
assert_equal(refined, refined_ans)
122+
123+
124+
def test_num_zones_tuple():
125+
"""
126+
Test that OctreeContainer accepts num_zones as a scalar or a tuple (N, M, L).
127+
Both should correctly set per-dimension zone counts.
128+
"""
129+
# Scalar: all dimensions equal
130+
oct_scalar = OctreeContainer(
131+
[1, 1, 1], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], num_zones=2
132+
)
133+
# Tuple: potentially different per-dimension
134+
oct_tuple = OctreeContainer(
135+
[1, 1, 1], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], num_zones=(2, 2, 2)
136+
)
137+
# Non-uniform tuple
138+
oct_nonuniform = OctreeContainer(
139+
[1, 1, 1], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], num_zones=(2, 3, 4)
140+
)
141+
# Verify that creating these containers doesn't raise exceptions
142+
assert oct_scalar is not None
143+
assert oct_tuple is not None
144+
assert oct_nonuniform is not None

yt/frontends/artio/data_structures.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def deposit(self, positions, fields=None, method=None, kernel_name="cubic"):
128128
if cls is None:
129129
raise YTParticleDepositionNotImplemented(method)
130130
nz = self.nz
131-
nvals = (nz, nz, nz, self.ires.size)
131+
nvals = (int(nz[0]), int(nz[1]), int(nz[2]), self.ires.size)
132132
# We allocate number of zones, not number of octs
133133
op = cls(nvals, kernel_name)
134134
op.initialize()
@@ -241,6 +241,7 @@ def _identify_base_chunk(self, dobj):
241241
sfc_start = getattr(dobj, "sfc_start", None)
242242
sfc_end = getattr(dobj, "sfc_end", None)
243243
nz = getattr(dobj, "_num_zones", 0)
244+
nz_scalar = int(np.asarray(nz).flat[0]) if hasattr(nz, "__len__") else nz
244245
if all_data:
245246
mylog.debug("Selecting entire artio domain")
246247
list_sfc_ranges = self.ds._handle.root_sfc_ranges_all(
@@ -271,7 +272,7 @@ def _identify_base_chunk(self, dobj):
271272
)
272273
range_handler.construct_mesh()
273274
self.range_handlers[start, end] = range_handler
274-
if nz != 2:
275+
if nz_scalar != 2:
275276
ci.append(
276277
ARTIORootMeshSubset(
277278
base_region,
@@ -281,7 +282,7 @@ def _identify_base_chunk(self, dobj):
281282
self.ds,
282283
)
283284
)
284-
if nz != 1 and range_handler.total_octs > 0:
285+
if nz_scalar != 1 and range_handler.total_octs > 0:
285286
ci.append(
286287
ARTIOOctreeSubset(
287288
base_region,

yt/frontends/ramses/data_structures.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -468,8 +468,8 @@ def _fill_with_ghostzones(
468468
fields = [f for ft, f in fields]
469469
tr = {}
470470

471-
cell_count = (
472-
selector.count_octs(self.oct_handler, self.domain_id) * self.nz**ndim
471+
cell_count = selector.count_octs(self.oct_handler, self.domain_id) * int(
472+
np.prod(self.nz[:ndim])
473473
)
474474

475475
# Initializing data container
@@ -522,7 +522,7 @@ def fwidth(self):
522522
# new_fwidth contains the fwidth of the oct+ghost zones
523523
# this is a constant array in each oct, so we simply copy
524524
# the oct value using numpy fancy-indexing
525-
new_fwidth = np.zeros((n_oct, self.nz**3, 3), dtype=fwidth.dtype)
525+
new_fwidth = np.zeros((n_oct, int(np.prod(self.nz)), 3), dtype=fwidth.dtype)
526526
new_fwidth[:, :, :] = fwidth[:, 0:1, :]
527527
fwidth = new_fwidth.reshape(-1, 3)
528528
return fwidth
@@ -540,7 +540,7 @@ def fcoords(self):
540540
self.selector, self._num_ghost_zones
541541
)
542542

543-
N_per_oct = self.nz**3
543+
N_per_oct = int(np.prod(self.nz))
544544
oct_inds = oct_inds.reshape(-1, N_per_oct)
545545
cell_inds = cell_inds.reshape(-1, N_per_oct)
546546

yt/frontends/stream/data_structures.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -832,8 +832,8 @@ def _fill_no_ghostzones(self, content, dest, selector, offset):
832832
def _fill_with_ghostzones(self, content, dest, selector, offset):
833833
oct_handler = self.oct_handler
834834
ndim = self.ds.dimensionality
835-
cell_count = (
836-
selector.count_octs(self.oct_handler, self.domain_id) * self.nz**ndim
835+
cell_count = selector.count_octs(self.oct_handler, self.domain_id) * int(
836+
np.prod(self.nz[:ndim])
837837
)
838838

839839
gz_cache = getattr(self, "_ghost_zone_cache", None)

yt/geometry/_selection_routines/selector_object.pxi

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ cdef class SelectorObject:
159159
visitor.pos[1] = (visitor.pos[1] >> 1)
160160
visitor.pos[2] = (visitor.pos[2] >> 1)
161161
visitor.level -= 1
162-
elif this_level == 1 and visitor.nz > 1:
162+
elif this_level == 1 and (visitor.nz[0] > 1 or visitor.nz[1] > 1 or visitor.nz[2] > 1):
163163
visitor.global_index += increment
164164
increment = 0
165165
self.visit_oct_cells(root, ch, spos, sdds,
@@ -178,10 +178,22 @@ cdef class SelectorObject:
178178
cdef void visit_oct_cells(self, Oct *root, Oct *ch,
179179
np.float64_t spos[3], np.float64_t sdds[3],
180180
OctVisitor visitor, int i, int j, int k):
181-
# We can short-circuit the whole process if data.nz == 2.
182-
# This saves us some funny-business.
181+
"""Visit the cells in this oct.
182+
183+
Parameters
184+
----------
185+
root: The oct whose cells we are visiting.
186+
ch: The child oct, if it exists.
187+
spos: The position of a potential cell center, assuming that the
188+
oct contains 8 cells.
189+
sdds: The cell size, assuming that the oct contains 8 cells.
190+
visitor: The visitor object that is visiting the cells.
191+
i, j, k: The indices of the cell within the oct.
192+
"""
183193
cdef int selected
184-
if visitor.nz == 2:
194+
# If visitor.nz is 2 in all dimensions, then the passed spos and sdds
195+
# are correct and we just need to call `select_cell` on them.
196+
if visitor.nz[0] == 2 and visitor.nz[1] == 2 and visitor.nz[2] == 2:
185197
selected = self.select_cell(spos, sdds)
186198
if ch != NULL:
187199
selected *= self.overlap_cells
@@ -191,34 +203,42 @@ cdef class SelectorObject:
191203
visitor.ind[2] = k
192204
visitor.visit(root, selected)
193205
return
194-
# Okay, now that we've got that out of the way, we have to do some
195-
# other checks here. In this case, spos[] is the position of the
196-
# center of a *possible* oct child, which means it is the center of a
197-
# cluster of cells. That cluster might have 1, 8, 64, ... cells in it.
198-
# But, we can figure it out by calculating the cell dds.
206+
# Otherwise, we have to do some work to figure out where the cell centers are.
207+
# We assign integer index ranges to each octant using half-open bounds.
199208
cdef np.float64_t dds[3]
200209
cdef np.float64_t pos[3]
210+
cdef np.float64_t full_left[3]
201211
cdef int ci, cj, ck
202-
cdef int nr = (visitor.nz >> 1)
212+
cdef int start[3]
213+
cdef int end[3]
214+
cdef int split
215+
cdef int oct_ind[3]
216+
oct_ind[0] = i
217+
oct_ind[1] = j
218+
oct_ind[2] = k
203219
for ci in range(3):
204-
dds[ci] = sdds[ci] / nr
205-
# Boot strap at the first index.
206-
pos[0] = (spos[0] - sdds[0]/2.0) + dds[0] * 0.5
207-
for ci in range(nr):
208-
pos[1] = (spos[1] - sdds[1]/2.0) + dds[1] * 0.5
209-
for cj in range(nr):
210-
pos[2] = (spos[2] - sdds[2]/2.0) + dds[2] * 0.5
211-
for ck in range(nr):
220+
dds[ci] = (2.0 * sdds[ci]) / visitor.nz[ci]
221+
full_left[ci] = (spos[ci] - sdds[ci] / 2.0) - oct_ind[ci] * sdds[ci]
222+
split = visitor.nz[ci] // 2
223+
if oct_ind[ci] == 0:
224+
start[ci] = 0
225+
end[ci] = split
226+
else:
227+
start[ci] = split
228+
end[ci] = visitor.nz[ci]
229+
for ci in range(start[0], end[0]):
230+
pos[0] = full_left[0] + (ci + 0.5) * dds[0]
231+
for cj in range(start[1], end[1]):
232+
pos[1] = full_left[1] + (cj + 0.5) * dds[1]
233+
for ck in range(start[2], end[2]):
234+
pos[2] = full_left[2] + (ck + 0.5) * dds[2]
212235
selected = self.select_cell(pos, dds)
213236
if ch != NULL:
214237
selected *= self.overlap_cells
215-
visitor.ind[0] = ci + i * nr
216-
visitor.ind[1] = cj + j * nr
217-
visitor.ind[2] = ck + k * nr
238+
visitor.ind[0] = ci
239+
visitor.ind[1] = cj
240+
visitor.ind[2] = ck
218241
visitor.visit(root, selected)
219-
pos[2] += dds[2]
220-
pos[1] += dds[1]
221-
pos[0] += dds[0]
222242

223243
@cython.boundscheck(False)
224244
@cython.wraparound(False)

yt/geometry/oct_container.pxd

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ cdef class OctreeContainer:
5757
cdef int partial_coverage
5858
cdef int level_offset
5959
cdef int nn[3]
60-
cdef np.uint8_t nz
60+
cdef np.uint8_t nz[3]
6161
cdef np.float64_t DLE[3]
6262
cdef np.float64_t DRE[3]
6363
cdef public np.int64_t nocts
@@ -86,7 +86,7 @@ cdef class OctreeContainer:
8686
self,
8787
const int level,
8888
const np.uint8_t[::1] level_inds,
89-
const np.uint8_t[::1] cell_inds,
89+
const np.uint32_t[::1] cell_inds,
9090
const np.int64_t[::1] file_inds,
9191
dict dest_fields,
9292
dict source_fields,
@@ -96,7 +96,7 @@ cdef class OctreeContainer:
9696
self,
9797
const int level,
9898
const np.uint8_t[::1] level_inds,
99-
const np.uint8_t[::1] cell_inds,
99+
const np.uint32_t[::1] cell_inds,
100100
const np.int64_t[::1] file_inds,
101101
const np.int32_t[::1] domain_inds,
102102
dict dest_fields,

0 commit comments

Comments
 (0)