@@ -46,7 +46,10 @@ class OctreeSubset(YTSelectionContainer, abc.ABC):
4646
4747 def __init__ (self , base_region , domain , ds , num_zones = 2 , num_ghost_zones = 0 ):
4848 super ().__init__ (ds , None )
49- self ._num_zones = num_zones
49+ if hasattr (num_zones , "__len__" ):
50+ self ._num_zones = np .array (num_zones , dtype = "int64" )
51+ else :
52+ self ._num_zones = np .array ([num_zones , num_zones , num_zones ], dtype = "int64" )
5053 self ._num_ghost_zones = num_ghost_zones
5154 self .domain = domain
5255 self .domain_id = domain .domain_id
@@ -80,23 +83,28 @@ def __getitem__(self, key):
8083
8184 @property
8285 def nz (self ):
83- return self ._num_zones + 2 * self ._num_ghost_zones
86+ nz = self ._num_zones + 2 * self ._num_ghost_zones
87+ if hasattr (nz , "__len__" ):
88+ return nz
89+ return np .array ([nz , nz , nz ], dtype = "int64" )
8490
8591 def get_bbox (self ):
8692 return self .base_region .get_bbox ()
8793
8894 def _reshape_vals (self , arr ):
8995 nz = self .nz
96+ nzx , nzy , nzz = nz [0 ], nz [1 ], nz [2 ]
97+ nzones = nzx * nzy * nzz
9098 if len (arr .shape ) <= 2 :
91- n_oct = arr .shape [0 ] // ( nz ** 3 )
99+ n_oct = arr .shape [0 ] // nzones
92100 elif arr .shape [- 1 ] == 3 :
93101 n_oct = arr .shape [- 2 ]
94102 else :
95103 n_oct = arr .shape [- 1 ]
96- if arr .size == nz * nz * nz * n_oct :
97- new_shape = (nz , nz , nz , n_oct )
98- elif arr .size == nz * nz * nz * n_oct * 3 :
99- new_shape = (nz , nz , nz , n_oct , 3 )
104+ if arr .size == nzones * n_oct :
105+ new_shape = (nzx , nzy , nzz , n_oct )
106+ elif arr .size == nzones * n_oct * 3 :
107+ new_shape = (nzx , nzy , nzz , n_oct , 3 )
100108 else :
101109 raise RuntimeError
102110 # Note that if arr is already F-contiguous, this *shouldn't* copy the
@@ -172,7 +180,7 @@ def deposit(self, positions, fields=None, method=None, kernel_name="cubic"):
172180 if cls is None :
173181 raise YTParticleDepositionNotImplemented (method )
174182 nz = self .nz
175- nvals = (nz , nz , nz , (self .domain_ind >= 0 ).sum ())
183+ nvals = (int ( nz [ 0 ]), int ( nz [ 1 ]), int ( nz [ 2 ]) , (self .domain_ind >= 0 ).sum ())
176184 if np .max (self .domain_ind ) >= nvals [- 1 ]:
177185 print (
178186 f"nocts, domain_ind >= 0, max { self .oct_handler .nocts } { nvals [- 1 ]} { np .max (self .domain_ind )} "
@@ -335,7 +343,7 @@ def smooth(
335343 [1 , 1 , 1 ],
336344 self .ds .domain_left_edge ,
337345 self .ds .domain_right_edge ,
338- num_zones = self ._nz ,
346+ num_zones = self ._num_zones ,
339347 )
340348 # This should ensure we get everything within one neighbor of home.
341349 particle_octree .n_ref = nneighbors * 2
@@ -354,7 +362,7 @@ def smooth(
354362 raise YTParticleDepositionNotImplemented (method )
355363 nz = self .nz
356364 mdom_ind = self .domain_ind
357- nvals = (nz , nz , nz , (mdom_ind >= 0 ).sum ())
365+ nvals = (int ( nz [ 0 ]), int ( nz [ 1 ]), int ( nz [ 2 ]) , (mdom_ind >= 0 ).sum ())
358366 op = cls (nvals , len (fields ), nneighbors , kernel_name )
359367 op .initialize ()
360368 mylog .debug (
@@ -455,7 +463,7 @@ def particle_operation(
455463 raise YTParticleDepositionNotImplemented (method )
456464 nz = self .nz
457465 mdom_ind = self .domain_ind
458- nvals = (nz , nz , nz , (mdom_ind >= 0 ).sum ())
466+ nvals = (int ( nz [ 0 ]), int ( nz [ 1 ]), int ( nz [ 2 ]) , (mdom_ind >= 0 ).sum ())
459467 op = cls (nvals , len (fields ), nneighbors , kernel_name )
460468 op .initialize ()
461469 mylog .debug (
@@ -548,7 +556,7 @@ def __init__(self, ind, block_slice):
548556 self .ind = ind
549557 self .block_slice = block_slice
550558 nz = self .block_slice .octree_subset .nz
551- self .ActiveDimensions = np .array ([nz , nz , nz ], dtype = "int64" )
559+ self .ActiveDimensions = np .array ([nz [ 0 ] , nz [ 1 ] , nz [ 2 ] ], dtype = "int64" )
552560 self .ds = block_slice .ds
553561
554562 def __getitem__ (self , key ):
0 commit comments