Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions yt/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from functools import wraps
from importlib.util import find_spec
from shutil import which
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeVar
from unittest import SkipTest

import matplotlib
import numpy as np
import numpy.typing as npt
from more_itertools import always_iterable
from numpy.random import RandomState
from unyt.exceptions import UnitOperationError
Expand Down Expand Up @@ -91,8 +92,8 @@ def assert_rel_equal(a1, a2, decimals, err_msg="", verbose=True):

# tested: volume integral is 1.
def cubicspline_python(
x: float | np.ndarray,
) -> np.ndarray:
x: float | npt.NDArray[np.floating],
) -> npt.NDArray[np.floating]:
"""
cubic spline SPH kernel function for testing against more
effiecient cython methods
Expand All @@ -118,8 +119,12 @@ def cubicspline_python(


def integrate_kernel(
kernelfunc: Callable[[float], float], b: float, hsml: float
) -> float:
kernelfunc: Callable[
[float | npt.NDArray[np.floating]], float | npt.NDArray[np.floating]
],
b: float | npt.NDArray[np.floating],
hsml: float | npt.NDArray[np.floating],
) -> float | npt.NDArray[np.floating]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's often much preferable to be strict on the return type. Here I don't see any reason not to be

Suggested change
) -> float | npt.NDArray[np.floating]:
) -> float:

"""
integrates a kernel function over a line passing entirely
through it
Expand Down Expand Up @@ -147,18 +152,24 @@ def integrate_kernel(
dx = np.diff(xe, axis=0)
spv = kernelfunc(np.sqrt(xc**2 + x**2))
integral = np.sum(spv * dx, axis=0)
return pre * integral
result = pre * integral
if isinstance(result, np.floating):
return result.item()
return result
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
result = pre * integral
if isinstance(result, np.floating):
return result.item()
return result
return float(pre * integral)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is a good idea. This function can be called on an arrays of b and hsml values and float(array) will not convert an array to the np.float type. Being strict on return type is fine, but then we should either revert to always returning an array, and changing the dtype, or do that type cast before picking out the single element for a 0d-array.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... e.g. np.float32(pre * integral) would work for arrays though. We'd need to pick which floating point precision to specify in that case though.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused. Why would np.float32(...) work but not float(...) ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, but that's what worked and didn't work when I tried it out:

Python 3.13.2 (main, Feb  4 2025, 14:51:09) [Clang 16.0.0 (clang-1600.0.26.6)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import numpy as np
>>> a = np.arange(3)
>>> float(a)
Traceback (most recent call last):
  File "<python-input-2>", line 1, in <module>
    float(a)
    ~~~~~^^^
TypeError: only length-1 arrays can be converted to Python scalars
>>> np.float32(a)
array([0., 1., 2.], dtype=float32)
>>> np.float(a)
Traceback (most recent call last):
  File "<python-input-4>", line 1, in <module>
    np.float(a)
    ^^^^^^^^
  File "/Users/nastasha/code/venvs/ytdev_pixav/lib/python3.13/site-packages/numpy/__init__.py", line 397, in __getattr__
    raise AttributeError(__former_attrs__[attr], name=None)
AttributeError: module 'numpy' has no attribute 'float'.
`np.float` was a deprecated alias for the builtin `float`. To avoid this error in existing code, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.
The aliases was originally deprecated in NumPy 1.20; for more details and guidance see the original release note at:
    https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations



_zeroperiods = np.array([0.0, 0.0, 0.0])


_FloatingT = TypeVar("_FloatingT", bound=np.floating)


def distancematrix(
pos3_i0: np.ndarray,
pos3_i1: np.ndarray,
pos3_i0: npt.NDArray[_FloatingT],
pos3_i1: npt.NDArray[_FloatingT],
periodic: tuple[bool, bool, bool] = (True,) * 3,
periods: np.ndarray = _zeroperiods,
) -> np.ndarray:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually here the output dtype is bound to be the same as pos3_i3, so, instead of np.floating, we should use _FloatingT = TypeVar("_FloatingT", bound=np.floating) here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, i think i understand this -- using TypeVar("_FloatingT", bound=np.floating) would ensure the input/output precisions match? in that case though, should I use npt.NBitBase for this (https://numpy.org/doc/stable/reference/typing.html#numpy.typing.NBitBase) ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nevermind about the NBitBase. i'm learning. np.floating is the way i think.

periods: npt.NDArray[_FloatingT] = _zeroperiods,
) -> npt.NDArray[_FloatingT]:
"""
Calculates the distances between two arrays of points.

Expand Down