Skip to content

Commit e4679cf

Browse files
committed
Implement fcoords etc. using subclass of OctreeSubset
1 parent 7023682 commit e4679cf

1 file changed

Lines changed: 184 additions & 38 deletions

File tree

yt/frontends/dyablo/data_structures.py

Lines changed: 184 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import h5py
99
import numpy as np
1010

11+
from yt.data_objects.index_subobjects.octree_subset import OctreeSubset
1112
from yt.data_objects.static_output import Dataset
1213
from yt.geometry.geometry_handler import YTDataChunk
1314
from yt.geometry.oct_container import OctreeContainer
@@ -23,6 +24,8 @@
2324
class DyabloOctreeIndex(OctreeIndex):
2425
"""Octree Index for Dyablo data with leaf-only blocks."""
2526

27+
domain_id = 1 # Dyablo uses a single domain
28+
2629
def __init__(self, ds, dataset_type):
2730
self.dataset_type = dataset_type
2831
self.dataset = ds
@@ -55,7 +58,8 @@ def _initialize_oct_handler(self):
5558
if n_cells % cells_per_block != 0:
5659
mylog.warning(
5760
f"Number of cells ({n_cells}) is not divisible by "
58-
f"cells per block ({cells_per_block}). Some cells will be ignored."
61+
f"cells per block ({cells_per_block}). "
62+
"Some cells will be ignored."
5963
)
6064

6165
# NOTE: we're storing an octree of blocks, so we insert
@@ -82,7 +86,8 @@ def _initialize_oct_handler(self):
8286
block_corners = coordinates[block_connectivity, :]
8387

8488
# Compute block bounds from all cell corners
85-
all_corners = block_corners.reshape(-1, 3) # (cells_per_block * 8, 3)
89+
# (cells_per_block * 8, 3)
90+
all_corners = block_corners.reshape(-1, 3)
8691
block_min = np.min(all_corners, axis=0)
8792
block_max = np.max(all_corners, axis=0)
8893
block_width = block_max - block_min
@@ -104,7 +109,8 @@ def _initialize_oct_handler(self):
104109
] # (cells_per_block, 8, 3)
105110

106111
# Compute block center: mean of all cell centers
107-
cell_centers = np.mean(block_corners, axis=1) # (cells_per_block, 3)
112+
# (cells_per_block, 3)
113+
cell_centers = np.mean(block_corners, axis=1)
108114
block_centers[block_id] = np.mean(cell_centers, axis=0)
109115

110116
# Compute block size and level
@@ -116,9 +122,8 @@ def _initialize_oct_handler(self):
116122
# Level is determined by block size relative to coarsest block
117123
max_dim_width = np.max(block_width)
118124
refinement_ratio = max_block_width / max_dim_width
119-
block_levels[block_id] = np.round(np.log2(refinement_ratio)).astype(
120-
np.uint64
121-
)
125+
level = np.round(np.log2(refinement_ratio)).astype(np.uint64)
126+
block_levels[block_id] = level
122127

123128
# Count number of blocks - note that a block at level l
124129
# requires its parent blocks at levels < l to be present
@@ -213,18 +218,175 @@ def _chunk_io(self, dobj, cache=True, local_only=False):
213218
yield YTDataChunk(dobj, "io", [subset], None, cache=cache)
214219

215220

216-
class DyabloOctreeSubset:
217-
"""Octree subset for Dyablo data."""
221+
class DyabloOctreeSubset(OctreeSubset):
222+
"""Octree subset for Dyablo data with NxMxL blocks.
223+
224+
Unlike standard octrees where each oct contains 2x2x2 cells,
225+
Dyablo uses blocks as octs, where each block contains NxMxL cells.
226+
This class overrides coordinate methods to properly handle this
227+
non-standard structure.
228+
"""
229+
230+
_domain_offset = 1
231+
_block_order = "C"
232+
233+
def __init__(self, base_region, domain, ds):
234+
# Get block size for num_zones parameter
235+
N, M, L = ds.block_size
236+
# Use the maximum dimension for num_zones
237+
num_zones = max(N, M, L)
238+
239+
super().__init__(base_region, domain, ds, num_zones=num_zones)
240+
241+
self._current_particle_type = "io"
242+
self._current_fluid_type = self.ds.default_fluid_type
243+
244+
@property
245+
def oct_handler(self):
246+
return self.domain.oct_handler
247+
248+
def _get_selected_cell_info(self, selector):
249+
"""
250+
Get information about selected cells using two-stage selection.
251+
252+
This method:
253+
1. Uses octree to select blocks via domain_ind
254+
2. Computes cell centers from connectivity data
255+
3. Applies selector at cell level to filter cells
256+
257+
Parameters
258+
----------
259+
selector : SelectorObject
260+
Selector object for filtering cells
261+
262+
Returns
263+
-------
264+
block_inds : ndarray
265+
Array of selected block indices
266+
(oct indices where domain_ind != -1)
267+
cell_centers : ndarray
268+
Array of cell center positions (n_cells_in_blocks, 3)
269+
mask : ndarray
270+
Boolean mask for cells passing selector (n_cells_in_blocks,)
271+
"""
272+
# Stage 1: Use octree to get selected block indices
273+
domain_ind_array = self.oct_handler.domain_ind(selector)
274+
block_inds = np.where(domain_ind_array >= 0)[0]
275+
276+
if len(block_inds) == 0:
277+
empty_centers = np.array([]).reshape(0, 3)
278+
empty_mask = np.array([], dtype=bool)
279+
return block_inds, empty_centers, empty_mask
280+
281+
# Stage 2: Compute cell centers from connectivity
282+
block_size = self.ds.block_size
283+
coordinates = self.domain.coordinates
284+
connectivity = self.domain.connectivity.reshape(-1, *block_size, 8)
218285

219-
def __init__(self, base_region, index, ds):
220-
self.base_region = base_region
221-
self.index = index
222-
self.ds = ds
223-
self.oct_handler = index.oct_handler
224-
self.domain_id = 1 # Dyablo uses domain 1
286+
# Get connectivity for selected blocks (n_blocks, N, M, L, 8)
287+
selected_block_connectivity = connectivity[block_inds]
288+
# Reshape to (n_blocks * N * M * L, 8)
289+
selected_cell_connectivity = selected_block_connectivity.reshape(-1, 8)
225290

226-
def __getitem__(self, key):
227-
return self.base_region[key]
291+
# Compute cell centers: average of 8 corner coordinates
292+
cell_centers = np.mean(
293+
coordinates[selected_cell_connectivity], axis=1
294+
) # (n_cells, 3)
295+
296+
# Stage 3: Apply selector at cell level
297+
radius = 0
298+
mask = selector.select_points(*cell_centers.T, radius)
299+
300+
mylog.debug(
301+
"Selected %d cells from %d blocks (%d total cells checked)",
302+
mask.sum(),
303+
len(block_inds),
304+
len(mask),
305+
)
306+
307+
return block_inds, cell_centers, mask
308+
309+
def select_icoords(self, dobj):
310+
"""
311+
Return integer coordinates of selected cells.
312+
313+
Integer coordinates represent the cell index on a uniform grid
314+
at the appropriate refinement level.
315+
"""
316+
block_inds, cell_centers, mask = self._get_selected_cell_info(dobj.selector)
317+
318+
if len(block_inds) == 0:
319+
return np.array([], dtype="int64").reshape(0, 3)
320+
321+
# Calculate integer coordinates from cell positions
322+
# For uniform grid (level 0), cell width is domain_width / total_cells
323+
block_size = self.ds.block_size
324+
N, M, L = block_size
325+
326+
# Domain dimensions represent number of blocks in each direction
327+
domain_dims = self.ds.domain_dimensions
328+
329+
# Total number of cells in each direction
330+
total_cells = domain_dims * np.array([N, M, L])
331+
332+
# Convert positions to integer grid indices
333+
domain_width = self.ds.domain_right_edge - self.ds.domain_left_edge
334+
cell_width = domain_width / total_cells
335+
rel_pos = cell_centers - self.ds.domain_left_edge
336+
icoords = (rel_pos / cell_width).astype(np.int64)
337+
338+
# Clip to valid range [0, total_cells-1] to handle cells at boundaries
339+
for i in range(3):
340+
icoords[:, i] = np.clip(icoords[:, i], 0, total_cells[i] - 1)
341+
342+
block_inds, cell_centers, mask = self._get_selected_cell_info(dobj.selector)
343+
344+
if len(block_inds) == 0:
345+
return self.ds.arr(np.array([]).reshape(0, 3), "code_length")
346+
347+
return self.ds.arr(cell_centers[mask], "code_length")
348+
349+
def select_fwidth(self, dobj):
350+
"""
351+
Return cell widths of selected cells.
352+
"""
353+
block_inds, cell_centers, mask = self._get_selected_cell_info(dobj.selector)
354+
355+
if len(block_inds) == 0:
356+
return self.ds.arr(np.array([]).reshape(0, 3), "code_length")
357+
358+
# Compute cell widths from connectivity corner positions
359+
block_size = self.ds.block_size
360+
coordinates = self.domain.coordinates
361+
connectivity = self.domain.connectivity.reshape(-1, *block_size, 8)
362+
363+
# Get connectivity for selected blocks
364+
selected_block_connectivity = connectivity[block_inds]
365+
selected_cell_connectivity = selected_block_connectivity.reshape(-1, 8)
366+
367+
# Width in each dimension: max - min of corner coordinates
368+
# (n_cells, 8, 3)
369+
cell_coords = coordinates[selected_cell_connectivity]
370+
# (n_cells, 3)
371+
cell_widths = cell_coords.max(axis=1) - cell_coords.min(axis=1)
372+
373+
return self.ds.arr(cell_widths[mask], "code_length")
374+
375+
def select_ires(self, dobj):
376+
"""
377+
Return refinement level (resolution) of selected cells.
378+
379+
For Dyablo uniform grid data, all cells are at level 0.
380+
"""
381+
block_inds, cell_centers, mask = self._get_selected_cell_info(dobj.selector)
382+
383+
if len(block_inds) == 0:
384+
return np.array([], dtype="int64")
385+
386+
# For uniform grid, all blocks are at level 0
387+
# Could use self.domain.block_levels if AMR support is added later
388+
n_selected = mask.sum()
389+
return np.zeros(n_selected, dtype="int64")
228390

229391
def fill(self, fields, selector):
230392
"""
@@ -246,9 +408,8 @@ def fill(self, fields, selector):
246408
dict
247409
Dictionary mapping (ftype, fname) to selected data arrays
248410
"""
249-
# Get the indices of the blocks to read from
250-
block_inds = self.oct_handler.domain_ind(selector)
251-
# Get the file indices (block indices) that contain selected cells
411+
# Use helper to get selected blocks and cell mask
412+
block_inds, cell_centers, mask = self._get_selected_cell_info(selector)
252413

253414
if len(block_inds) == 0:
254415
return {field: np.array([]) for field in fields}
@@ -265,29 +426,14 @@ def fill(self, fields, selector):
265426
# Read only the cells we need
266427
full_data = f[field_path][:].reshape(-1, *block_size)
267428

268-
# Only keep those octs we selected
429+
# Only keep those blocks we selected
269430
data_subset = full_data[block_inds].flatten()
270-
result[ftype, fname] = data_subset
431+
432+
# Apply cell-level mask
433+
result[ftype, fname] = data_subset[mask]
271434
else:
272435
result[ftype, fname] = np.array([])
273436

274-
# Also read coordinates for selector filtering
275-
coordinates = self.ds._coordinates
276-
connectivity = self.ds._connectivity.reshape(-1, *block_size, 8)
277-
selected_block_connectivity = connectivity[block_inds].reshape(-1, 8)
278-
279-
# Compute cell centers
280-
cell_centers = np.mean(
281-
coordinates[selected_block_connectivity], axis=1
282-
) # (n_selected_cells, 3)
283-
284-
# Apply selector to filter cells
285-
radius = 0
286-
mask = selector.select_points(*cell_centers.T, radius)
287-
mylog.debug("Selected %d cells within %d blocks", mask.sum(), len(block_inds))
288-
for key in result:
289-
result[key] = result[key][mask]
290-
291437
return result
292438

293439

0 commit comments

Comments
 (0)