Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4ec7cc1
Update to `numpy` 2.x series
FlorianDeconinck Mar 2, 2026
eacd405
Remove `index_tricsk` import (now deeper into impl)
FlorianDeconinck Mar 2, 2026
565489e
Clean up lint around `type` in Buffer.py
FlorianDeconinck Mar 2, 2026
4f6cdc8
Update to `safe_mpi_allocate` type hint
FlorianDeconinck Mar 2, 2026
84a8ba6
Update `dtype` in halo update transformer
FlorianDeconinck Mar 2, 2026
154bf41
Use `npt.DTypeLike` for better support
FlorianDeconinck Mar 2, 2026
ec4e441
Remove `NumpyModule` type wrapper and replace `types.Module`
FlorianDeconinck Mar 2, 2026
3c040f2
Restore `Allocator` in types
FlorianDeconinck Mar 2, 2026
578a85d
Move `npt.DTypeLike` to cover more ground
FlorianDeconinck Mar 2, 2026
22daec4
Missing commit
FlorianDeconinck Mar 2, 2026
8de9610
type ignore a mypy mistake + restore MPI GPU test
FlorianDeconinck Mar 2, 2026
c46609f
Go back to list of str for `op_flags`
FlorianDeconinck Mar 2, 2026
c6d5d41
Narrow type ignore
FlorianDeconinck Mar 2, 2026
431ba78
[TMP] CI using branches on the dowstream repositories
FlorianDeconinck Mar 3, 2026
f900310
Remove `type: ignore`
FlorianDeconinck Mar 3, 2026
0362aa2
`no_type_check` the entire thing
FlorianDeconinck Mar 3, 2026
7d52d45
Merge branch 'develop' into update/numpy_2x
FlorianDeconinck Mar 4, 2026
ef22394
Add `to_xarray` API to State
FlorianDeconinck Mar 5, 2026
4a1b4c7
Merge branch 'minor/State_API_to_xarray' into update/numpy_2x
FlorianDeconinck Mar 5, 2026
6a87b9f
Merge branch 'develop' into update/numpy_2x
FlorianDeconinck Mar 6, 2026
d2a320a
Merge branch 'develop' into update/numpy_2x
FlorianDeconinck Mar 12, 2026
13ce3cb
Merge branch 'develop' into update/numpy_2x
FlorianDeconinck Mar 18, 2026
e6dd3ba
Merge branch 'develop' into update/numpy_2x
twicki Apr 1, 2026
46c378c
update pyfv3
twicki Apr 1, 2026
827369f
Clean up left out bad `gt4py_backend` fixture (HOW DID EVER RUN?!)
FlorianDeconinck Apr 1, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/fv3_translate_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:

jobs:
fv3_translate_tests:
uses: NOAA-GFDL/pyFV3/.github/workflows/translate.yaml@develop
uses: twicki/pyFV3/.github/workflows/translate.yaml@update/numpy_2x
with:
component_trigger: true
component_name: NDSL
2 changes: 1 addition & 1 deletion .github/workflows/pace_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:

jobs:
pace_main_tests:
uses: NOAA-GFDL/pace/.github/workflows/main_unit_tests.yaml@develop
uses: floriandeconinck/pace/.github/workflows/main_unit_tests.yaml@update/numpy_2x
with:
component_trigger: true
component_name: NDSL
2 changes: 1 addition & 1 deletion .github/workflows/shield_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:

jobs:
shield_translate_tests:
uses: NOAA-GFDL/pySHiELD/.github/workflows/translate.yaml@develop
uses: floriandeconinck/pySHiELD/.github/workflows/translate.yaml@update/numpy_2x
with:
component_trigger: true
component_name: NDSL
19 changes: 10 additions & 9 deletions ndsl/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import contextlib
from collections.abc import Callable, Generator, Iterable
from typing import Any

import numpy as np
from numpy.lib.index_tricks import IndexExpression
import numpy.typing as npt

from ndsl.performance.timer import NullTimer, Timer
from ndsl.types import Allocator
Expand All @@ -16,7 +17,7 @@
)


BufferKey = tuple[Callable, Iterable[int], type]
BufferKey = tuple[Callable, Iterable[int], npt.DTypeLike]
BUFFER_CACHE: dict[BufferKey, list["Buffer"]] = {}


Expand All @@ -41,7 +42,7 @@ def __init__(self, key: BufferKey, array: np.ndarray):

@classmethod
def pop_from_cache(
cls, allocator: Allocator, shape: Iterable[int], dtype: type
cls, allocator: Allocator, shape: Iterable[int], dtype: npt.DTypeLike
) -> Buffer:
"""Retrieve or insert then retrieve of buffer from cache.

Expand Down Expand Up @@ -78,8 +79,8 @@ def finalize_memory_transfer(self) -> None:
def assign_to(
self,
destination_array: np.ndarray,
buffer_slice: IndexExpression = np.index_exp[:],
buffer_reshape: IndexExpression = None,
buffer_slice: Any = np.index_exp[:],
buffer_reshape: Any | None = None,
) -> None:
"""Assign internal array to destination_array.

Expand All @@ -95,7 +96,7 @@ def assign_to(
)

def assign_from(
self, source_array: np.ndarray, buffer_slice: IndexExpression = np.index_exp[:]
self, source_array: np.ndarray, buffer_slice: Any = np.index_exp[:]
) -> None:
"""Assign source_array to internal array.

Expand All @@ -107,7 +108,7 @@ def assign_from(

@contextlib.contextmanager
def array_buffer(
allocator: Allocator, shape: Iterable[int], dtype: type
allocator: Allocator, shape: Iterable[int], dtype: npt.DTypeLike
) -> Generator[Buffer, Buffer, None]:
"""
A context manager providing a contiguous array, which may be re-used between calls.
Expand All @@ -132,7 +133,7 @@ def send_buffer(
allocator: Callable,
array: np.ndarray,
timer: Timer | None = None,
) -> np.ndarray:
) -> Generator[np.ndarray]:
"""A context manager ensuring that `array` is contiguous in a context where it is
being sent as data, copying into a recycled buffer array if necessary.

Expand Down Expand Up @@ -166,7 +167,7 @@ def recv_buffer(
allocator: Callable,
array: np.ndarray,
timer: Timer | None = None,
) -> np.ndarray:
) -> Generator[np.ndarray]:
"""A context manager ensuring that array is contiguous in a context where it is
being used to receive data, using a recycled buffer array and then copying the
result into array if necessary.
Expand Down
12 changes: 6 additions & 6 deletions ndsl/comm/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
from collections.abc import Mapping, Sequence
from types import ModuleType
from typing import Any, Self, cast

import numpy as np
Expand All @@ -16,7 +17,6 @@
from ndsl.optional_imports import cupy
from ndsl.performance.timer import NullTimer, Timer
from ndsl.quantity import Quantity, QuantityHaloSpec, QuantityMetadata
from ndsl.types import NumpyModule


def to_numpy(array, dtype=None) -> np.ndarray: # type: ignore[no-untyped-def]
Expand Down Expand Up @@ -83,7 +83,7 @@ def size(self) -> int:
"""Total number of ranks in this communicator"""
return self.comm.Get_size()

def _maybe_force_cpu(self, module: NumpyModule) -> NumpyModule:
def _maybe_force_cpu(self, module: ModuleType) -> ModuleType:
"""
Get a numpy-like module depending on configuration and
Quantity original allocator.
Expand Down Expand Up @@ -223,7 +223,7 @@ def _get_gather_recv_quantity(
) -> Quantity:
"""Initialize a Quantity for use when receiving global data during gather"""
recv_quantity = Quantity(
send_metadata.np.zeros(global_extent, dtype=send_metadata.dtype), # type: ignore
send_metadata.np.zeros(global_extent, dtype=send_metadata.dtype),
dims=send_metadata.dims,
units=send_metadata.units,
origin=tuple([0 for dim in send_metadata.dims]),
Expand All @@ -238,7 +238,7 @@ def _get_scatter_recv_quantity(
) -> Quantity:
"""Initialize a Quantity for use when receiving subtile data during scatter"""
recv_quantity = Quantity(
send_metadata.np.zeros(shape, dtype=send_metadata.dtype), # type: ignore
send_metadata.np.zeros(shape, dtype=send_metadata.dtype),
dims=send_metadata.dims,
units=send_metadata.units,
backend=send_metadata.backend,
Expand Down Expand Up @@ -837,7 +837,7 @@ def _get_gather_recv_quantity(
# needs to change the quantity dimensions since we add a "tile" dimension,
# unlike for tile scatter/gather which retains the same dimensions
recv_quantity = Quantity(
metadata.np.zeros(global_extent, dtype=metadata.dtype), # type: ignore
metadata.np.zeros(global_extent, dtype=metadata.dtype),
dims=(constants.TILE_DIM,) + metadata.dims,
units=metadata.units,
origin=(0,) + tuple([0 for dim in metadata.dims]),
Expand All @@ -859,7 +859,7 @@ def _get_scatter_recv_quantity(
# needs to change the quantity dimensions since we remove a "tile" dimension,
# unlike for tile scatter/gather which retains the same dimensions
recv_quantity = Quantity(
metadata.np.zeros(shape, dtype=metadata.dtype), # type: ignore
metadata.np.zeros(shape, dtype=metadata.dtype),
dims=metadata.dims[1:],
units=metadata.units,
backend=metadata.backend,
Expand Down
3 changes: 2 additions & 1 deletion ndsl/dsl/typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TypeAlias

import numpy as np
import numpy.typing as npt
from gt4py.cartesian import gtscript

from ndsl.dsl import NDSL_GLOBAL_PRECISION
Expand Down Expand Up @@ -110,7 +111,7 @@ def cast_to_index3d(val: tuple[int, ...]) -> Index3D:
return val


def is_float(dtype: type) -> bool:
def is_float(dtype: npt.DTypeLike) -> bool:
"""Expected floating point type"""
return dtype in [
Float,
Expand Down
21 changes: 13 additions & 8 deletions ndsl/halo/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from types import ModuleType
from typing import no_type_check
from uuid import UUID, uuid1

import numpy as np
Expand All @@ -22,7 +24,6 @@
from ndsl.halo.rotate import rotate_scalar_data, rotate_vector_data
from ndsl.optional_imports import cupy as cp
from ndsl.quantity import Quantity, QuantityHaloSpec
from ndsl.types import NumpyModule
from ndsl.utils import device_synchronize


Expand Down Expand Up @@ -53,7 +54,11 @@ def _push_stream(stream: "cp.cuda.Stream") -> None:
INDICES_CACHE: dict[str, "cp.ndarray"] = {}


def _build_flatten_indices( # type: ignore[no-untyped-def]
# `array_value[...] = xxx` is failing mypy because of bad inference
# of the type. We can't type ignore, because mypy also thinks that it
# no needed (but if removed, it will fail...)
@no_type_check
def _build_flatten_indices(
key,
shape,
slices: tuple[slice, ...],
Expand Down Expand Up @@ -186,7 +191,7 @@ class HaloDataTransformer(abc.ABC):

def __init__(
self,
np_module: NumpyModule,
np_module: ModuleType,
exchange_descriptors_x: Sequence[HaloExchangeSpec],
exchange_descriptors_y: Sequence[HaloExchangeSpec] | None = None,
) -> None:
Expand Down Expand Up @@ -237,7 +242,7 @@ def finalize(self) -> None:

@staticmethod
def get(
np_module: NumpyModule,
np_module: ModuleType,
exchange_descriptors_x: Sequence[HaloExchangeSpec],
exchange_descriptors_y: Sequence[HaloExchangeSpec] | None = None,
) -> HaloDataTransformer:
Expand Down Expand Up @@ -308,7 +313,7 @@ def _compile(self) -> None:

# Compute required size
buffer_size = 0
dtype = None
dtype = np.float32 # default that will be overriden or not used
for edge_x in self._infos_x:
buffer_size += edge_x.pack_buffer_size
dtype = edge_x.specification.dtype
Expand All @@ -320,12 +325,12 @@ def _compile(self) -> None:
self._pack_buffer = Buffer.pop_from_cache(
self._np_module.zeros,
(buffer_size,),
dtype, # type: ignore[arg-type]
dtype,
)
self._unpack_buffer = Buffer.pop_from_cache(
self._np_module.zeros,
(buffer_size,),
dtype, # type: ignore[arg-type]
dtype,
)

def ready(self) -> bool:
Expand Down Expand Up @@ -589,7 +594,7 @@ class _CuKernelArgs:

def __init__(
self,
np_module: NumpyModule,
np_module: ModuleType,
exchange_descriptors_x: Sequence[HaloExchangeSpec],
exchange_descriptors_y: Sequence[HaloExchangeSpec] | None = None,
) -> None:
Expand Down
9 changes: 5 additions & 4 deletions ndsl/halo/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections import defaultdict
from collections.abc import Iterable, Mapping
from types import ModuleType
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -14,7 +15,7 @@
from ndsl.halo.rotate import rotate_scalar_data
from ndsl.performance.timer import NullTimer, Timer
from ndsl.quantity import Quantity, QuantityHaloSpec
from ndsl.types import AsyncRequest, NumpyModule
from ndsl.types import AsyncRequest
from ndsl.utils import device_synchronize


Expand Down Expand Up @@ -95,7 +96,7 @@ def __del__(self) -> None:
def from_scalar_specifications(
cls,
comm: Communicator,
numpy_like_module: NumpyModule,
numpy_like_module: ModuleType,
specifications: Iterable[QuantityHaloSpec],
boundaries: Iterable[Boundary],
tag: int,
Expand Down Expand Up @@ -147,7 +148,7 @@ def from_scalar_specifications(
def from_vector_specifications(
cls,
comm: Communicator,
numpy_like_module: NumpyModule,
numpy_like_module: ModuleType,
specifications_x: Iterable[QuantityHaloSpec],
specifications_y: Iterable[QuantityHaloSpec],
boundaries: Iterable[Boundary],
Expand Down Expand Up @@ -475,7 +476,7 @@ def _Isend_vector_shared_boundary(
]
return send_requests

def _maybe_force_cpu(self, module: NumpyModule) -> NumpyModule:
def _maybe_force_cpu(self, module: ModuleType) -> ModuleType:
"""
Get a numpy-like module depending on configuration and
Quantity original allocator.
Expand Down
9 changes: 5 additions & 4 deletions ndsl/quantity/metadata.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import dataclasses
from types import ModuleType
from typing import Any

import numpy as np
import numpy.typing as npt

from ndsl.config.backend import Backend
from ndsl.optional_imports import cupy
from ndsl.types import NumpyModule


if cupy is None:
Expand All @@ -28,7 +29,7 @@ class QuantityMetadata:
"Units of the quantity."
data_type: type
"ndarray-like type used to store the data."
dtype: type
dtype: npt.DTypeLike
"dtype of the data in the ndarray-like object."
backend: Backend
"NDSL backend. Used for performance optimal data allocation."
Expand All @@ -39,7 +40,7 @@ def dim_lengths(self) -> dict[str, int]:
return dict(zip(self.dims, self.extent))

@property
def np(self) -> NumpyModule:
def np(self) -> ModuleType:
"""numpy-like module used to interact with the data."""
if issubclass(self.data_type, cupy.ndarray):
return cupy
Expand Down Expand Up @@ -72,5 +73,5 @@ class QuantityHaloSpec:
origin: tuple[int, ...]
extent: tuple[int, ...]
dims: tuple[str, ...]
numpy_module: NumpyModule
numpy_module: ModuleType
dtype: Any
6 changes: 3 additions & 3 deletions ndsl/quantity/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import warnings
from collections.abc import Iterable, Sequence
from types import ModuleType
from typing import Any, cast

import dace
Expand All @@ -18,7 +19,6 @@
from ndsl.optional_imports import cupy
from ndsl.quantity.bounds import BoundedArrayView
from ndsl.quantity.metadata import QuantityHaloSpec, QuantityMetadata
from ndsl.types import NumpyModule


if cupy is None:
Expand Down Expand Up @@ -326,7 +326,7 @@ def data_as_xarray(self) -> xr.DataArray:
return xr.DataArray(data, dims=self.dims, attrs=self.attrs)

@property
def np(self) -> NumpyModule:
def np(self) -> ModuleType:
return self.metadata.np

@property
Expand Down Expand Up @@ -408,7 +408,7 @@ def transpose(
target_dims = _collapse_dims(target_dims, self.dims)
transpose_order = [self.dims.index(dim) for dim in target_dims]
transposed = Quantity(
self.np.transpose(self.data, transpose_order), # type: ignore[attr-defined]
self.np.transpose(self.data, transpose_order),
dims=_transpose_sequence(self.dims, transpose_order),
units=self.units,
origin=_transpose_sequence(self.origin, transpose_order),
Expand Down
Loading