Skip to content

Commit 688469b

Browse files
committed
ENH: ensure yt.utils.funcs.ensure_numpy_array consistently returns a copy
1 parent e962538 commit 688469b

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

yt/_maintenance/numpy2_compat.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,40 @@
11
# avoid deprecation warnings in numpy >= 2.0
2+
from dataclasses import dataclass
3+
from importlib.metadata import version
4+
from typing import Literal, TypeAlias
25

36
import numpy as np
7+
from packaging.version import Version
48

59
if hasattr(np, "trapezoid"):
610
# np.trapz is deprecated in numpy 2.0 in favor of np.trapezoid
711
trapezoid = np.trapezoid
812
else:
913
trapezoid = np.trapz # type: ignore # noqa: NPY201
14+
15+
16+
NUMPY_VERSION = Version(version("numpy"))
17+
18+
# list numpy functions that gained a `copy` keyword argument at
19+
# some point after our oldest supported numpy version.
20+
CopyFunc: TypeAlias = Literal["asarray"]
21+
22+
MIN_NUMPY_VERSION: dict[CopyFunc, Version] = {
23+
"asarray": Version("2.0.0"),
24+
}
25+
26+
27+
@dataclass(frozen=True, slots=True)
28+
class CopyKwarg:
29+
value: bool | None
30+
31+
def get(self, f: CopyFunc, /) -> dict[Literal["copy"], bool | None]:
32+
if NUMPY_VERSION >= MIN_NUMPY_VERSION[f]:
33+
return {"copy": self.value}
34+
else:
35+
return {}
36+
37+
38+
COPY_NONE = CopyKwarg(None)
39+
COPY_TRUE = CopyKwarg(True)
40+
COPY_FALSE = CopyKwarg(False)

yt/funcs.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from yt._maintenance.deprecation import issue_deprecation_warning
2626
from yt._maintenance.ipython_compat import IS_IPYTHON
27+
from yt._maintenance.numpy2_compat import COPY_TRUE
2728
from yt.config import ytcfg
2829
from yt.units import YTArray, YTQuantity
2930
from yt.utilities.exceptions import YTFieldNotFound, YTInvalidWidthError
@@ -86,15 +87,7 @@ def ensure_numpy_array(obj):
8687
This function ensures that *obj* is a numpy array. Typically used to
8788
convert scalar, list or tuple argument passed to functions using Cython.
8889
"""
89-
if isinstance(obj, np.ndarray):
90-
if obj.shape == ():
91-
return np.array([obj])
92-
# We cast to ndarray to catch ndarray subclasses
93-
return np.array(obj)
94-
elif isinstance(obj, (list, tuple)):
95-
return np.asarray(obj)
96-
else:
97-
return np.asarray([obj])
90+
return np.atleast_1d(np.asarray(obj, **COPY_TRUE.get("asarray")))
9891

9992

10093
def read_struct(f, fmt):

0 commit comments

Comments
 (0)