88import h5py
99import numpy as np
1010
11+ from yt .data_objects .index_subobjects .octree_subset import OctreeSubset
1112from yt .data_objects .static_output import Dataset
1213from yt .geometry .geometry_handler import YTDataChunk
1314from yt .geometry .oct_container import OctreeContainer
2324class DyabloOctreeIndex (OctreeIndex ):
2425 """Octree Index for Dyablo data with leaf-only blocks."""
2526
27+ domain_id = 1 # Dyablo uses a single domain
28+
2629 def __init__ (self , ds , dataset_type ):
2730 self .dataset_type = dataset_type
2831 self .dataset = ds
@@ -55,7 +58,8 @@ def _initialize_oct_handler(self):
5558 if n_cells % cells_per_block != 0 :
5659 mylog .warning (
5760 f"Number of cells ({ n_cells } ) is not divisible by "
58- f"cells per block ({ cells_per_block } ). Some cells will be ignored."
61+ f"cells per block ({ cells_per_block } ). "
62+ "Some cells will be ignored."
5963 )
6064
6165 # NOTE: we're storing an octree of blocks, so we insert
@@ -82,7 +86,8 @@ def _initialize_oct_handler(self):
8286 block_corners = coordinates [block_connectivity , :]
8387
8488 # Compute block bounds from all cell corners
85- all_corners = block_corners .reshape (- 1 , 3 ) # (cells_per_block * 8, 3)
89+ # (cells_per_block * 8, 3)
90+ all_corners = block_corners .reshape (- 1 , 3 )
8691 block_min = np .min (all_corners , axis = 0 )
8792 block_max = np .max (all_corners , axis = 0 )
8893 block_width = block_max - block_min
@@ -104,7 +109,8 @@ def _initialize_oct_handler(self):
104109 ] # (cells_per_block, 8, 3)
105110
106111 # Compute block center: mean of all cell centers
107- cell_centers = np .mean (block_corners , axis = 1 ) # (cells_per_block, 3)
112+ # (cells_per_block, 3)
113+ cell_centers = np .mean (block_corners , axis = 1 )
108114 block_centers [block_id ] = np .mean (cell_centers , axis = 0 )
109115
110116 # Compute block size and level
@@ -116,9 +122,8 @@ def _initialize_oct_handler(self):
116122 # Level is determined by block size relative to coarsest block
117123 max_dim_width = np .max (block_width )
118124 refinement_ratio = max_block_width / max_dim_width
119- block_levels [block_id ] = np .round (np .log2 (refinement_ratio )).astype (
120- np .uint64
121- )
125+ level = np .round (np .log2 (refinement_ratio )).astype (np .uint64 )
126+ block_levels [block_id ] = level
122127
123128 # Count number of blocks - note that a block at level l
124129 # requires its parent blocks at levels < l to be present
@@ -213,18 +218,175 @@ def _chunk_io(self, dobj, cache=True, local_only=False):
213218 yield YTDataChunk (dobj , "io" , [subset ], None , cache = cache )
214219
215220
216- class DyabloOctreeSubset :
217- """Octree subset for Dyablo data."""
221+ class DyabloOctreeSubset (OctreeSubset ):
222+ """Octree subset for Dyablo data with NxMxL blocks.
223+
224+ Unlike standard octrees where each oct contains 2x2x2 cells,
225+ Dyablo uses blocks as octs, where each block contains NxMxL cells.
226+ This class overrides coordinate methods to properly handle this
227+ non-standard structure.
228+ """
229+
230+ _domain_offset = 1
231+ _block_order = "C"
232+
233+ def __init__ (self , base_region , domain , ds ):
234+ # Get block size for num_zones parameter
235+ N , M , L = ds .block_size
236+ # Use the maximum dimension for num_zones
237+ num_zones = max (N , M , L )
238+
239+ super ().__init__ (base_region , domain , ds , num_zones = num_zones )
240+
241+ self ._current_particle_type = "io"
242+ self ._current_fluid_type = self .ds .default_fluid_type
243+
244+ @property
245+ def oct_handler (self ):
246+ return self .domain .oct_handler
247+
248+ def _get_selected_cell_info (self , selector ):
249+ """
250+ Get information about selected cells using two-stage selection.
251+
252+ This method:
253+ 1. Uses octree to select blocks via domain_ind
254+ 2. Computes cell centers from connectivity data
255+ 3. Applies selector at cell level to filter cells
256+
257+ Parameters
258+ ----------
259+ selector : SelectorObject
260+ Selector object for filtering cells
261+
262+ Returns
263+ -------
264+ block_inds : ndarray
265+ Array of selected block indices
266+ (oct indices where domain_ind != -1)
267+ cell_centers : ndarray
268+ Array of cell center positions (n_cells_in_blocks, 3)
269+ mask : ndarray
270+ Boolean mask for cells passing selector (n_cells_in_blocks,)
271+ """
272+ # Stage 1: Use octree to get selected block indices
273+ domain_ind_array = self .oct_handler .domain_ind (selector )
274+ block_inds = np .where (domain_ind_array >= 0 )[0 ]
275+
276+ if len (block_inds ) == 0 :
277+ empty_centers = np .array ([]).reshape (0 , 3 )
278+ empty_mask = np .array ([], dtype = bool )
279+ return block_inds , empty_centers , empty_mask
280+
281+ # Stage 2: Compute cell centers from connectivity
282+ block_size = self .ds .block_size
283+ coordinates = self .domain .coordinates
284+ connectivity = self .domain .connectivity .reshape (- 1 , * block_size , 8 )
218285
219- def __init__ (self , base_region , index , ds ):
220- self .base_region = base_region
221- self .index = index
222- self .ds = ds
223- self .oct_handler = index .oct_handler
224- self .domain_id = 1 # Dyablo uses domain 1
286+ # Get connectivity for selected blocks (n_blocks, N, M, L, 8)
287+ selected_block_connectivity = connectivity [block_inds ]
288+ # Reshape to (n_blocks * N * M * L, 8)
289+ selected_cell_connectivity = selected_block_connectivity .reshape (- 1 , 8 )
225290
226- def __getitem__ (self , key ):
227- return self .base_region [key ]
291+ # Compute cell centers: average of 8 corner coordinates
292+ cell_centers = np .mean (
293+ coordinates [selected_cell_connectivity ], axis = 1
294+ ) # (n_cells, 3)
295+
296+ # Stage 3: Apply selector at cell level
297+ radius = 0
298+ mask = selector .select_points (* cell_centers .T , radius )
299+
300+ mylog .debug (
301+ "Selected %d cells from %d blocks (%d total cells checked)" ,
302+ mask .sum (),
303+ len (block_inds ),
304+ len (mask ),
305+ )
306+
307+ return block_inds , cell_centers , mask
308+
309+ def select_icoords (self , dobj ):
310+ """
311+ Return integer coordinates of selected cells.
312+
313+ Integer coordinates represent the cell index on a uniform grid
314+ at the appropriate refinement level.
315+ """
316+ block_inds , cell_centers , mask = self ._get_selected_cell_info (dobj .selector )
317+
318+ if len (block_inds ) == 0 :
319+ return np .array ([], dtype = "int64" ).reshape (0 , 3 )
320+
321+ # Calculate integer coordinates from cell positions
322+ # For uniform grid (level 0), cell width is domain_width / total_cells
323+ block_size = self .ds .block_size
324+ N , M , L = block_size
325+
326+ # Domain dimensions represent number of blocks in each direction
327+ domain_dims = self .ds .domain_dimensions
328+
329+ # Total number of cells in each direction
330+ total_cells = domain_dims * np .array ([N , M , L ])
331+
332+ # Convert positions to integer grid indices
333+ domain_width = self .ds .domain_right_edge - self .ds .domain_left_edge
334+ cell_width = domain_width / total_cells
335+ rel_pos = cell_centers - self .ds .domain_left_edge
336+ icoords = (rel_pos / cell_width ).astype (np .int64 )
337+
338+ # Clip to valid range [0, total_cells-1] to handle cells at boundaries
339+ for i in range (3 ):
340+ icoords [:, i ] = np .clip (icoords [:, i ], 0 , total_cells [i ] - 1 )
341+
342+ block_inds , cell_centers , mask = self ._get_selected_cell_info (dobj .selector )
343+
344+ if len (block_inds ) == 0 :
345+ return self .ds .arr (np .array ([]).reshape (0 , 3 ), "code_length" )
346+
347+ return self .ds .arr (cell_centers [mask ], "code_length" )
348+
349+ def select_fwidth (self , dobj ):
350+ """
351+ Return cell widths of selected cells.
352+ """
353+ block_inds , cell_centers , mask = self ._get_selected_cell_info (dobj .selector )
354+
355+ if len (block_inds ) == 0 :
356+ return self .ds .arr (np .array ([]).reshape (0 , 3 ), "code_length" )
357+
358+ # Compute cell widths from connectivity corner positions
359+ block_size = self .ds .block_size
360+ coordinates = self .domain .coordinates
361+ connectivity = self .domain .connectivity .reshape (- 1 , * block_size , 8 )
362+
363+ # Get connectivity for selected blocks
364+ selected_block_connectivity = connectivity [block_inds ]
365+ selected_cell_connectivity = selected_block_connectivity .reshape (- 1 , 8 )
366+
367+ # Width in each dimension: max - min of corner coordinates
368+ # (n_cells, 8, 3)
369+ cell_coords = coordinates [selected_cell_connectivity ]
370+ # (n_cells, 3)
371+ cell_widths = cell_coords .max (axis = 1 ) - cell_coords .min (axis = 1 )
372+
373+ return self .ds .arr (cell_widths [mask ], "code_length" )
374+
375+ def select_ires (self , dobj ):
376+ """
377+ Return refinement level (resolution) of selected cells.
378+
379+ For Dyablo uniform grid data, all cells are at level 0.
380+ """
381+ block_inds , cell_centers , mask = self ._get_selected_cell_info (dobj .selector )
382+
383+ if len (block_inds ) == 0 :
384+ return np .array ([], dtype = "int64" )
385+
386+ # For uniform grid, all blocks are at level 0
387+ # Could use self.domain.block_levels if AMR support is added later
388+ n_selected = mask .sum ()
389+ return np .zeros (n_selected , dtype = "int64" )
228390
229391 def fill (self , fields , selector ):
230392 """
@@ -246,9 +408,8 @@ def fill(self, fields, selector):
246408 dict
247409 Dictionary mapping (ftype, fname) to selected data arrays
248410 """
249- # Get the indices of the blocks to read from
250- block_inds = self .oct_handler .domain_ind (selector )
251- # Get the file indices (block indices) that contain selected cells
411+ # Use helper to get selected blocks and cell mask
412+ block_inds , cell_centers , mask = self ._get_selected_cell_info (selector )
252413
253414 if len (block_inds ) == 0 :
254415 return {field : np .array ([]) for field in fields }
@@ -265,29 +426,14 @@ def fill(self, fields, selector):
265426 # Read only the cells we need
266427 full_data = f [field_path ][:].reshape (- 1 , * block_size )
267428
268- # Only keep those octs we selected
429+ # Only keep those blocks we selected
269430 data_subset = full_data [block_inds ].flatten ()
270- result [ftype , fname ] = data_subset
431+
432+ # Apply cell-level mask
433+ result [ftype , fname ] = data_subset [mask ]
271434 else :
272435 result [ftype , fname ] = np .array ([])
273436
274- # Also read coordinates for selector filtering
275- coordinates = self .ds ._coordinates
276- connectivity = self .ds ._connectivity .reshape (- 1 , * block_size , 8 )
277- selected_block_connectivity = connectivity [block_inds ].reshape (- 1 , 8 )
278-
279- # Compute cell centers
280- cell_centers = np .mean (
281- coordinates [selected_block_connectivity ], axis = 1
282- ) # (n_selected_cells, 3)
283-
284- # Apply selector to filter cells
285- radius = 0
286- mask = selector .select_points (* cell_centers .T , radius )
287- mylog .debug ("Selected %d cells within %d blocks" , mask .sum (), len (block_inds ))
288- for key in result :
289- result [key ] = result [key ][mask ]
290-
291437 return result
292438
293439
0 commit comments