diff --git a/yt/frontends/amrvac/data_structures.py b/yt/frontends/amrvac/data_structures.py index 7fb9787e84d..dfed3efe9c2 100644 --- a/yt/frontends/amrvac/data_structures.py +++ b/yt/frontends/amrvac/data_structures.py @@ -15,7 +15,7 @@ from more_itertools import always_iterable from yt.config import ytcfg -from yt.data_objects.index_subobjects.grid_patch import AMRGridPatch +from yt.data_objects.index_subobjects.stretched_grid import StretchedGrid from yt.data_objects.static_output import Dataset from yt.funcs import mylog, setdefaultattr from yt.geometry.api import Geometry @@ -52,17 +52,18 @@ def _parse_geometry(geometry_tag: str) -> Geometry: return Geometry(geometry_str.lower()) -class AMRVACGrid(AMRGridPatch): +class AMRVACGrid(StretchedGrid): """A class to populate AMRVACHierarchy.grids, setting parent/children relations.""" _id_offset = 0 - def __init__(self, id, index, level): + def __init__(self, id, cell_widths, filename, index, level, dims): # should use yt's convention (start from 0) - super().__init__(id, filename=index.index_filename, index=index) + super().__init__(id=id, filename=filename, index=index, cell_widths=cell_widths) self.Parent = None self.Children = [] self.Level = level + self.ActiveDimensions = dims def get_global_startindex(self): """Refresh and retrieve the starting index for each dimension at current level. @@ -142,6 +143,40 @@ def _parse_index(self): dim = self.dataset.dimensionality self.grids = np.empty(self.num_grids, dtype="object") + meshlist = self.ds.namelist["meshlist"] + if (stretch_dim := meshlist.get("stretch_dim")) is not None: + assert isinstance(stretch_dim, list) + assert len(stretch_dim) >= self.ds.dimensionality + stretch_baselevel = meshlist.get("qstretch_baselevel") + if "qstretch_baselevel" not in meshlist: + # compute default values dynamically, just as done in AMRVAC + stretched_dims = [bool(k) for k in stretch_dim] + assert sum(stretched_dims) == 1 # exactly one stretched direction + stretched_dim = stretched_dims.index(True) + _sbl = [ + 1.0, + ] * self.ds.dimensionality + _sbl[stretched_dim] = ( + meshlist[f"xprobmax{stretched_dim + 1}"] + / meshlist[f"xprobmin{stretched_dim + 1}"] + ) ** (1.0 / meshlist[f"domain_nx{stretched_dim + 1}"]) + stretch_baselevel = tuple(_sbl) + elif isinstance(stretch_baselevel := meshlist["qstretch_baselevel"], list): + assert len(stretch_baselevel) >= self.ds.dimensionality + stretch_baselevel = ( + float(b) for b in stretch_baselevel[: self.ds.dimensionality] + ) + else: + assert isinstance(stretch_baselevel, float | int) + stretched_dims = [bool(k) for k in stretch_dim] + assert sum(stretched_dims) == 1 # exactly one stretched direction + stretched_dim = stretched_dims.index(True) + _sbl = [ + 1.0, + ] * self.ds.dimensionality + _sbl[stretched_dim] = stretch_baselevel + stretch_baselevel = tuple(_sbl) + for igrid, (ytlevel, morton_index) in enumerate( zip(ytlevels, morton_indices, strict=True) ): @@ -152,7 +187,14 @@ def _parse_index(self): self.grid_left_edge[igrid, :dim] = left_edge self.grid_right_edge[igrid, :dim] = left_edge + block_nx * dx self.grid_dimensions[igrid, :dim] = block_nx - self.grids[igrid] = self.grid(igrid, self, ytlevels[igrid]) + self.grids[igrid] = self.grid( + id=igrid, + index=self, + level=ytlevels[igrid], + filename=self.index_filename, + cell_widths=_cell_widths, + dims=self.grid_dimensions[igrid], + ) def _populate_grid_objects(self): # required method