1+ import warnings
12from collections .abc import Callable , Sequence
23from functools import wraps
34from typing import Any
67import numpy .typing as npt
78from gt4py import storage as gt_storage
89
10+ from ndsl import xumpy
911from ndsl .config .backend import Backend
1012from ndsl .constants import N_HALO_DEFAULT
11- from ndsl .dsl .typing import DTypes , Float
13+ from ndsl .dsl .typing import Float
1214from ndsl .logging import ndsl_log
1315from ndsl .optional_imports import cupy as cp
1416
@@ -49,19 +51,16 @@ def wrapper(*args, **kwargs) -> Any:
4951 return inner
5052
5153
52- def _mask_to_dimensions (
53- mask : tuple [bool , ...], shape : Sequence [int ]
54- ) -> list [str | int ]:
54+ def _mask_to_dimensions (mask : tuple [bool , ...], shape : Sequence [int ]) -> list [str ]:
5555 assert len (mask ) >= 3
56- dimensions : list [str | int ] = []
56+ dimensions : list [str ] = []
5757 for i , axis in enumerate (("I" , "J" , "K" )):
5858 if mask [i ]:
5959 dimensions .append (axis )
6060 if len (mask ) > 3 :
6161 for i in range (3 , len (mask )):
62- dimensions .append (str (shape [i ]))
63- offset = int (sum (mask ))
64- dimensions .extend (shape [offset :])
62+ if mask [i ]:
63+ dimensions .append (str (shape [i ]))
6564 return dimensions
6665
6766
@@ -86,7 +85,7 @@ def make_storage_data(
8685 origin : tuple [int , ...] = origin ,
8786 * ,
8887 backend : Backend ,
89- dtype : DTypes = Float ,
88+ dtype : npt . DTypeLike = Float ,
9089 mask : tuple [bool , ...] | None = None ,
9190 start : tuple [int , ...] = (0 , 0 , 0 ),
9291 dummy : tuple [int , ...] | None = None ,
@@ -205,12 +204,12 @@ def _make_storage_data_1d(
205204 axis : int = 2 ,
206205 read_only : bool = True ,
207206 * ,
208- dtype : DTypes = Float ,
207+ dtype : npt . DTypeLike = Float ,
209208 backend : Backend ,
210209) -> npt .NDArray :
211210 # axis refers to a repeated axis, dummy refers to a singleton axis
212211 axis = min (axis , len (shape ) - 1 )
213- buffer = zeros (shape [axis ], dtype = dtype , backend = backend )
212+ buffer = xumpy . zeros (shape [axis ], backend , dtype )
214213 if dummy :
215214 axis = list (set ((0 , 1 , 2 )).difference (dummy ))[0 ]
216215
@@ -242,7 +241,7 @@ def _make_storage_data_2d(
242241 axis : int = 2 ,
243242 read_only : bool = True ,
244243 * ,
245- dtype : DTypes = Float ,
244+ dtype : npt . DTypeLike = Float ,
246245 backend : Backend ,
247246) -> npt .NDArray :
248247 # axis refers to which axis should be repeated (when making a full 3d data),
@@ -256,7 +255,7 @@ def _make_storage_data_2d(
256255
257256 start1 , start2 = start [0 :2 ]
258257 size1 , size2 = data .shape
259- buffer = zeros (shape2d , dtype = dtype , backend = backend )
258+ buffer = xumpy . zeros (shape2d , backend , dtype )
260259 buffer [start1 : start1 + size1 , start2 : start2 + size2 ] = asarray (
261260 data , type (buffer )
262261 )
@@ -276,12 +275,12 @@ def _make_storage_data_3d(
276275 shape : tuple [int , ...],
277276 start : tuple [int , ...] = (0 , 0 , 0 ),
278277 * ,
279- dtype : DTypes = Float ,
278+ dtype : npt . DTypeLike = Float ,
280279 backend : Backend ,
281280) -> npt .NDArray :
282281 istart , jstart , kstart = start
283282 isize , jsize , ksize = data .shape
284- buffer = zeros (shape , dtype = dtype , backend = backend )
283+ buffer = xumpy . zeros (shape , backend , dtype )
285284 buffer [
286285 istart : istart + isize ,
287286 jstart : jstart + jsize ,
@@ -295,12 +294,12 @@ def _make_storage_data_Nd(
295294 shape : tuple [int , ...],
296295 start : tuple [int , ...] | None = None ,
297296 * ,
298- dtype : DTypes = Float ,
297+ dtype : npt . DTypeLike = Float ,
299298 backend : Backend ,
300299) -> npt .NDArray :
301300 if start is None :
302301 start = tuple ([0 ] * data .ndim )
303- buffer = zeros (shape , dtype = dtype , backend = backend )
302+ buffer = xumpy . zeros (shape , backend , dtype )
304303 idx = tuple ([slice (start [i ], start [i ] + data .shape [i ]) for i in range (len (start ))])
305304 buffer [idx ] = asarray (data , type (buffer ))
306305 return buffer
@@ -311,7 +310,7 @@ def make_storage_from_shape(
311310 origin : tuple [int , ...] = origin ,
312311 * ,
313312 backend : Backend ,
314- dtype : DTypes = Float ,
313+ dtype : npt . DTypeLike = Float ,
315314 mask : tuple [bool , ...] | None = None ,
316315) -> npt .NDArray :
317316 """Create a new gt4py storage of a given shape filled with zeros.
@@ -333,12 +332,16 @@ def make_storage_from_shape(
333332 )
334333 3) q_out = utils.make_storage_from_shape(q_in.shape, origin,)
335334 """
336- if not mask :
335+ if mask is None :
337336 n_dims = len (shape )
338337 if n_dims == 1 :
339338 mask = (False , False , True ) # Assume 1D is a k-field
339+ elif n_dims == 2 :
340+ mask = (True , True , False ) # Assume 2D is an ij-field
341+ elif n_dims < 3 :
342+ raise NotImplementedError (f"Unexpected number of dimensions { n_dims } ." )
340343 else :
341- mask = ( n_dims * (True ,)) + (( 3 - n_dims ) * ( False ,) )
344+ mask = n_dims * (True ,)
342345 storage = gt_storage .zeros (
343346 shape ,
344347 dtype ,
@@ -359,7 +362,7 @@ def make_storage_dict(
359362 axis : int = 2 ,
360363 * ,
361364 backend : Backend ,
362- dtype : DTypes = Float ,
365+ dtype : npt . DTypeLike = Float ,
363366) -> dict [str , npt .NDArray ]:
364367 assert names is not None , "for 4d variable storages, specify a list of names"
365368 if shape is None :
@@ -447,9 +450,12 @@ def asarray(array, to_type=np.ndarray, dtype=None, order=None):
447450
448451
449452def zeros (shape , dtype = Float , * , backend : Backend ):
450- storage_type = cp .ndarray if backend .is_gpu_backend () else np .ndarray
451- xp = cp if cp and storage_type is cp .ndarray else np
452- return xp .zeros (shape , dtype = dtype )
453+ warnings .warn (
454+ "gt4py_utils.zeros() is deprecated. Use `zeros()` from `ndsl.xumpy` instead." ,
455+ category = DeprecationWarning ,
456+ stacklevel = 2 ,
457+ )
458+ return xumpy .zeros (shape , backend , dtype )
453459
454460
455461def sum (array , axis = None , dtype = Float , out = None , keepdims = False ):
0 commit comments