Skip to content

Commit da721f1

Browse files
authored
refactor: use xumpy for allocation in gt4py_utils (#388)
1 parent 4f5315e commit da721f1

File tree

1 file changed

+30
-24
lines changed

1 file changed

+30
-24
lines changed

ndsl/dsl/gt4py_utils.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from collections.abc import Callable, Sequence
23
from functools import wraps
34
from typing import Any
@@ -6,9 +7,10 @@
67
import numpy.typing as npt
78
from gt4py import storage as gt_storage
89

10+
from ndsl import xumpy
911
from ndsl.config.backend import Backend
1012
from ndsl.constants import N_HALO_DEFAULT
11-
from ndsl.dsl.typing import DTypes, Float
13+
from ndsl.dsl.typing import Float
1214
from ndsl.logging import ndsl_log
1315
from ndsl.optional_imports import cupy as cp
1416

@@ -49,19 +51,16 @@ def wrapper(*args, **kwargs) -> Any:
4951
return inner
5052

5153

52-
def _mask_to_dimensions(
53-
mask: tuple[bool, ...], shape: Sequence[int]
54-
) -> list[str | int]:
54+
def _mask_to_dimensions(mask: tuple[bool, ...], shape: Sequence[int]) -> list[str]:
5555
assert len(mask) >= 3
56-
dimensions: list[str | int] = []
56+
dimensions: list[str] = []
5757
for i, axis in enumerate(("I", "J", "K")):
5858
if mask[i]:
5959
dimensions.append(axis)
6060
if len(mask) > 3:
6161
for i in range(3, len(mask)):
62-
dimensions.append(str(shape[i]))
63-
offset = int(sum(mask))
64-
dimensions.extend(shape[offset:])
62+
if mask[i]:
63+
dimensions.append(str(shape[i]))
6564
return dimensions
6665

6766

@@ -86,7 +85,7 @@ def make_storage_data(
8685
origin: tuple[int, ...] = origin,
8786
*,
8887
backend: Backend,
89-
dtype: DTypes = Float,
88+
dtype: npt.DTypeLike = Float,
9089
mask: tuple[bool, ...] | None = None,
9190
start: tuple[int, ...] = (0, 0, 0),
9291
dummy: tuple[int, ...] | None = None,
@@ -205,12 +204,12 @@ def _make_storage_data_1d(
205204
axis: int = 2,
206205
read_only: bool = True,
207206
*,
208-
dtype: DTypes = Float,
207+
dtype: npt.DTypeLike = Float,
209208
backend: Backend,
210209
) -> npt.NDArray:
211210
# axis refers to a repeated axis, dummy refers to a singleton axis
212211
axis = min(axis, len(shape) - 1)
213-
buffer = zeros(shape[axis], dtype=dtype, backend=backend)
212+
buffer = xumpy.zeros(shape[axis], backend, dtype)
214213
if dummy:
215214
axis = list(set((0, 1, 2)).difference(dummy))[0]
216215

@@ -242,7 +241,7 @@ def _make_storage_data_2d(
242241
axis: int = 2,
243242
read_only: bool = True,
244243
*,
245-
dtype: DTypes = Float,
244+
dtype: npt.DTypeLike = Float,
246245
backend: Backend,
247246
) -> npt.NDArray:
248247
# axis refers to which axis should be repeated (when making a full 3d data),
@@ -256,7 +255,7 @@ def _make_storage_data_2d(
256255

257256
start1, start2 = start[0:2]
258257
size1, size2 = data.shape
259-
buffer = zeros(shape2d, dtype=dtype, backend=backend)
258+
buffer = xumpy.zeros(shape2d, backend, dtype)
260259
buffer[start1 : start1 + size1, start2 : start2 + size2] = asarray(
261260
data, type(buffer)
262261
)
@@ -276,12 +275,12 @@ def _make_storage_data_3d(
276275
shape: tuple[int, ...],
277276
start: tuple[int, ...] = (0, 0, 0),
278277
*,
279-
dtype: DTypes = Float,
278+
dtype: npt.DTypeLike = Float,
280279
backend: Backend,
281280
) -> npt.NDArray:
282281
istart, jstart, kstart = start
283282
isize, jsize, ksize = data.shape
284-
buffer = zeros(shape, dtype=dtype, backend=backend)
283+
buffer = xumpy.zeros(shape, backend, dtype)
285284
buffer[
286285
istart : istart + isize,
287286
jstart : jstart + jsize,
@@ -295,12 +294,12 @@ def _make_storage_data_Nd(
295294
shape: tuple[int, ...],
296295
start: tuple[int, ...] | None = None,
297296
*,
298-
dtype: DTypes = Float,
297+
dtype: npt.DTypeLike = Float,
299298
backend: Backend,
300299
) -> npt.NDArray:
301300
if start is None:
302301
start = tuple([0] * data.ndim)
303-
buffer = zeros(shape, dtype=dtype, backend=backend)
302+
buffer = xumpy.zeros(shape, backend, dtype)
304303
idx = tuple([slice(start[i], start[i] + data.shape[i]) for i in range(len(start))])
305304
buffer[idx] = asarray(data, type(buffer))
306305
return buffer
@@ -311,7 +310,7 @@ def make_storage_from_shape(
311310
origin: tuple[int, ...] = origin,
312311
*,
313312
backend: Backend,
314-
dtype: DTypes = Float,
313+
dtype: npt.DTypeLike = Float,
315314
mask: tuple[bool, ...] | None = None,
316315
) -> npt.NDArray:
317316
"""Create a new gt4py storage of a given shape filled with zeros.
@@ -333,12 +332,16 @@ def make_storage_from_shape(
333332
)
334333
3) q_out = utils.make_storage_from_shape(q_in.shape, origin,)
335334
"""
336-
if not mask:
335+
if mask is None:
337336
n_dims = len(shape)
338337
if n_dims == 1:
339338
mask = (False, False, True) # Assume 1D is a k-field
339+
elif n_dims == 2:
340+
mask = (True, True, False) # Assume 2D is an ij-field
341+
elif n_dims < 3:
342+
raise NotImplementedError(f"Unexpected number of dimensions {n_dims}.")
340343
else:
341-
mask = (n_dims * (True,)) + ((3 - n_dims) * (False,))
344+
mask = n_dims * (True,)
342345
storage = gt_storage.zeros(
343346
shape,
344347
dtype,
@@ -359,7 +362,7 @@ def make_storage_dict(
359362
axis: int = 2,
360363
*,
361364
backend: Backend,
362-
dtype: DTypes = Float,
365+
dtype: npt.DTypeLike = Float,
363366
) -> dict[str, npt.NDArray]:
364367
assert names is not None, "for 4d variable storages, specify a list of names"
365368
if shape is None:
@@ -447,9 +450,12 @@ def asarray(array, to_type=np.ndarray, dtype=None, order=None):
447450

448451

449452
def zeros(shape, dtype=Float, *, backend: Backend):
450-
storage_type = cp.ndarray if backend.is_gpu_backend() else np.ndarray
451-
xp = cp if cp and storage_type is cp.ndarray else np
452-
return xp.zeros(shape, dtype=dtype)
453+
warnings.warn(
454+
"gt4py_utils.zeros() is deprecated. Use `zeros()` from `ndsl.xumpy` instead.",
455+
category=DeprecationWarning,
456+
stacklevel=2,
457+
)
458+
return xumpy.zeros(shape, backend, dtype)
453459

454460

455461
def sum(array, axis=None, dtype=Float, out=None, keepdims=False):

0 commit comments

Comments
 (0)