Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 37 additions & 41 deletions scipy-stubs/special/_multiufuncs.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# pyright: reportIncompatibleMethodOverride=false

from collections.abc import Callable, Iterable
from typing import Any, Final, Generic, Literal as L, ParamSpec, Protocol, TypeAlias, overload, type_check_only
from typing import Any, Final, Literal as L, TypeAlias, overload, type_check_only
from typing_extensions import TypeVar, override

import numpy as np
Expand All @@ -18,8 +20,6 @@ __all__ = [
"sph_legendre_p_all",
]

_Tss = ParamSpec("_Tss")
_RT = TypeVar("_RT")
_UFuncT_co = TypeVar("_UFuncT_co", bound=Callable[..., object], default=Callable[..., Any], covariant=True)

_Complex: TypeAlias = np.complex64 | np.complex128 # `clongdouble` isn't supported
Expand Down Expand Up @@ -63,10 +63,27 @@ _D1: TypeAlias = L[True, 1]
_D2: TypeAlias = L[2]
_Dn: TypeAlias = L[0, 1, 2] | bool | np.bool_

###
class MultiUFunc: # undocumented
@property
@override
# pyrefly: ignore [bad-override]
def __doc__(self, /) -> str | None: ... # type: ignore[override] # pyright: ignore[reportIncompatibleVariableOverride]

#
def __init__(
self,
/,
ufunc_or_ufuncs: _UFuncT_co | Iterable[_UFuncT_co],
name: str | None = None,
doc: str | None = None,
*,
force_complex_output: bool = False,
**default_kwargs: object,
) -> None: ...
def __call__(self, /, *args: Any, **kwargs: Any) -> Any: ...

@type_check_only
class _LegendreP(Protocol):
class _LegendreP(MultiUFunc):
@overload # 0-d, 0-d
def __call__(self, /, n: int, z: onp.ToFloat, *, diff_n: _Dn = 0) -> _Float1D: ...
@overload # 0-d, >0-d
Expand All @@ -75,7 +92,7 @@ class _LegendreP(Protocol):
def __call__(self, /, n: onp.ToIntND, z: _ToFloat_D, *, diff_n: _Dn = 0) -> _Float2_D: ...

@type_check_only
class _LegendrePAll(Protocol):
class _LegendrePAll(MultiUFunc):
@overload # float
def __call__(self, /, n: int, z: _ToFloat_D, *, diff_n: _Dn = 0) -> _Float3_D: ...
@overload # complex
Expand All @@ -84,7 +101,7 @@ class _LegendrePAll(Protocol):
def __call__(self, /, n: int, z: _ToComplex_D, *, diff_n: _Dn = 0) -> _Float3_D | _Complex3_D: ...

@type_check_only
class _AssocLegendreP(Protocol):
class _AssocLegendreP(MultiUFunc):
@overload # float
def __call__(
self, /, n: _ToInt_D, m: _ToInt_D, z: _ToFloat_D, *, branch_cut: _Branch_D = 2, norm: bool = False, diff_n: _Dn = 0
Expand All @@ -99,7 +116,7 @@ class _AssocLegendreP(Protocol):
) -> _Float1_D | _Complex1_D: ...

@type_check_only
class _AssocLegendrePAll(Protocol):
class _AssocLegendrePAll(MultiUFunc):
@overload # z: 0-d float
def __call__(
self, /, n: int, m: int, z: onp.ToFloat, *, branch_cut: _Branch = 2, norm: bool = False, diff_n: _Dn = 0
Expand All @@ -122,7 +139,7 @@ class _AssocLegendrePAll(Protocol):
) -> _Float3_D | _Complex3_D: ...

@type_check_only
class _SphLegendreP(Protocol):
class _SphLegendreP(MultiUFunc):
@overload # 0-d, 0-d, 0-d
def __call__(self, /, n: int, m: int, theta: onp.ToFloat, *, diff_n: _Dn = 0) -> _Float1D: ...
@overload # >=0-d, >=0-d, >0-d
Expand All @@ -133,14 +150,14 @@ class _SphLegendreP(Protocol):
def __call__(self, /, n: onp.ToIntND, m: _ToInt_D, theta: _ToFloat_D, *, diff_n: _Dn = 0) -> _Float2_D: ...

@type_check_only
class _SphLegendrePAll(Protocol):
class _SphLegendrePAll(MultiUFunc):
@overload # 0-d, 0-d, 0-d
def __call__(self, /, n: int, m: int, theta: onp.ToFloat, *, diff_n: _Dn = 0) -> _Float3D: ...
@overload # 0-d, 0-d, >=0-d
def __call__(self, /, n: int, m: int, theta: _ToFloat_D, *, diff_n: _Dn = 0) -> _Float3_D: ...

@type_check_only
class _SphHarmY(Protocol):
class _SphHarmY(MultiUFunc):
@overload # 0-d, 0-d, 0-d, 0-d, diff_n == 0
def __call__(self, /, n: int, m: int, theta: onp.ToFloat, phi: onp.ToFloat, *, diff_n: _D0 = 0) -> _Complex0D: ...
@overload # >=0-d, >=0-d, >=0-d, > 0-d, diff_n == 0
Expand Down Expand Up @@ -173,7 +190,7 @@ class _SphHarmY(Protocol):
def __call__(self, /, n: onp.ToIntND, m: _ToInt_D, theta: _ToFloat_D, phi: _ToFloat_D, *, diff_n: _D2) -> _Complex123_D: ...

@type_check_only
class _SphHarmYAll(Protocol):
class _SphHarmYAll(MultiUFunc):
@overload # theta: 0-d, phi: 0-d, diff_n == 0
def __call__(self, /, n: int, m: int, theta: onp.ToFloat, phi: onp.ToFloat, *, diff_n: _D0 = 0) -> _Complex2D: ...
@overload # theta: >=0-d, phi: >0-d, diff_n == 0
Expand All @@ -195,35 +212,14 @@ class _SphHarmYAll(Protocol):

###

class MultiUFunc(Generic[_UFuncT_co]): # undocumented
@property
@override
# pyrefly: ignore [bad-override]
def __doc__(self, /) -> str | None: ... # type: ignore[override] # pyright: ignore[reportIncompatibleVariableOverride]

#
def __init__(
self,
/,
ufunc_or_ufuncs: _UFuncT_co | Iterable[_UFuncT_co],
name: str | None = None,
doc: str | None = None,
*,
force_complex_output: bool = False,
**default_kwargs: object,
) -> None: ...
def __call__(self: MultiUFunc[Callable[_Tss, _RT]], /, *args: _Tss.args, **kwargs: _Tss.kwargs) -> _RT: ...

###

legendre_p: Final[MultiUFunc[_LegendreP]] = ...
legendre_p_all: Final[MultiUFunc[_LegendrePAll]] = ...
legendre_p: Final[_LegendreP] = ...
legendre_p_all: Final[_LegendrePAll] = ...

assoc_legendre_p: Final[MultiUFunc[_AssocLegendreP]] = ...
assoc_legendre_p_all: Final[MultiUFunc[_AssocLegendrePAll]] = ...
assoc_legendre_p: Final[_AssocLegendreP] = ...
assoc_legendre_p_all: Final[_AssocLegendrePAll] = ...

sph_legendre_p: Final[MultiUFunc[_SphLegendreP]] = ...
sph_legendre_p_all: Final[MultiUFunc[_SphLegendrePAll]] = ...
sph_legendre_p: Final[_SphLegendreP] = ...
sph_legendre_p_all: Final[_SphLegendrePAll] = ...

sph_harm_y: Final[MultiUFunc[_SphHarmY]] = ...
sph_harm_y_all: Final[MultiUFunc[_SphHarmYAll]] = ...
sph_harm_y: Final[_SphHarmY] = ...
sph_harm_y_all: Final[_SphHarmYAll] = ...
20 changes: 10 additions & 10 deletions tests/special/test_multiufuncs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ _f64_2d: onp.Array2D[np.float64]
# legendre_p
assert_type(legendre_p(1, 1.0), onp.Array1D[np.float64])
assert_type(legendre_p(1, np.float32(1.0)), onp.Array1D[np.float64])
assert_type(legendre_p(1, _f64_2d), _Float2_D) # type: ignore[assert-type,arg-type] # pyright: ignore[reportAssertTypeFailure, reportArgumentType] # pyrefly: ignore [assert-type, bad-argument-type] # TODO: fix MultiUFunc array overloads
assert_type(legendre_p(1, _f64_2d), _Float2_D)
assert_type(legendre_p(1, 1.0, diff_n=True), onp.Array1D[np.float64])
assert_type(legendre_p(1, 1.0, diff_n=1), onp.Array1D[np.float64])

Expand All @@ -51,36 +51,36 @@ assert_type(assoc_legendre_p(3, 2, 1.0, branch_cut=3, norm=True, diff_n=1), _Flo
assert_type(assoc_legendre_p(3, 2, _f64_1d, branch_cut=_i64_1d, diff_n=2), _Float1_D)

# assoc_legendre_p_all
assert_type(assoc_legendre_p_all(3, 2, 1.0), _Float3_D) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure] # pyrefly: ignore [assert-type] # TODO: fix MultiUFunc array overloads
assert_type(assoc_legendre_p_all(n=3, m=2, z=np.float32(1.0)), _Float3_D) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure] # pyrefly: ignore [assert-type] # TODO: fix MultiUFunc array overloads
assert_type(assoc_legendre_p_all(n=3, m=2, z=_f64_1d), _Float3_D) # type: ignore[assert-type,arg-type] # pyright: ignore[reportAssertTypeFailure, reportArgumentType] # pyrefly: ignore [assert-type, bad-argument-type] # TODO: fix MultiUFunc array overloads
assert_type(assoc_legendre_p_all(3, 2, 1.0, branch_cut=3, norm=True, diff_n=1), _Float3_D) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure] # pyrefly: ignore [assert-type] # TODO: fix MultiUFunc array overloads
assert_type(assoc_legendre_p_all(3, 2, np.float64(1.0), branch_cut=2, diff_n=2), _Float3_D) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure] # pyrefly: ignore [assert-type] # TODO: fix MultiUFunc array overloads
assert_type(assoc_legendre_p_all(3, 2, 1.0), onp.Array3D[np.float64])
assert_type(assoc_legendre_p_all(n=3, m=2, z=np.float32(1.0)), onp.Array3D[np.float64])
assert_type(assoc_legendre_p_all(n=3, m=2, z=_f64_1d), _Float3_D)
assert_type(assoc_legendre_p_all(3, 2, 1.0, branch_cut=3, norm=True, diff_n=1), onp.Array3D[np.float64])
assert_type(assoc_legendre_p_all(3, 2, np.float64(1.0), branch_cut=2, diff_n=2), onp.Array3D[np.float64])

# sph_legendre_p
assert_type(sph_legendre_p(3, 2, 1.0), onp.Array1D[np.float64])
assert_type(sph_legendre_p(n=3, m=2, theta=np.float32(1.0)), onp.Array1D[np.float64])
assert_type(sph_legendre_p(n=3, m=2, theta=_f64_1d), _Float2_D) # type: ignore[assert-type,arg-type] # pyright: ignore[reportAssertTypeFailure, reportArgumentType] # pyrefly: ignore [assert-type, bad-argument-type] # TODO: fix MultiUFunc array overloads
assert_type(sph_legendre_p(n=3, m=2, theta=_f64_1d), _Float2_D)
assert_type(sph_legendre_p(3, 2, 1.0, diff_n=True), onp.Array1D[np.float64])
assert_type(sph_legendre_p(3, 2, 1.0, diff_n=2), onp.Array1D[np.float64])

# sph_legendre_p_all
assert_type(sph_legendre_p_all(3, 2, 1.0), onp.Array3D[np.float64])
assert_type(sph_legendre_p_all(n=3, m=2, theta=np.float32(1.0)), onp.Array3D[np.float64])
assert_type(sph_legendre_p_all(n=3, m=2, theta=_f64_1d), _Float3_D) # type: ignore[assert-type,arg-type] # pyright: ignore[reportAssertTypeFailure, reportArgumentType] # pyrefly: ignore [assert-type, bad-argument-type] # TODO: fix MultiUFunc array overloads
assert_type(sph_legendre_p_all(n=3, m=2, theta=_f64_1d), _Float3_D)
assert_type(sph_legendre_p_all(3, 2, 1.0, diff_n=True), onp.Array3D[np.float64])
assert_type(sph_legendre_p_all(3, 2, 1.0, diff_n=2), onp.Array3D[np.float64])

# sph_harm_y
assert_type(sph_harm_y(3, 2, 1.0, 2.0), _Complex0D)
assert_type(sph_harm_y(n=3, m=2, theta=np.float32(1.0), phi=np.float32(2.0)), _Complex0D)
assert_type(sph_harm_y(3, 2, 1.0, _f64_1d), _Complex1_D) # type: ignore[assert-type,arg-type] # pyright: ignore[reportAssertTypeFailure, reportArgumentType] # pyrefly: ignore [assert-type, bad-argument-type] # TODO: fix MultiUFunc array overloads
assert_type(sph_harm_y(3, 2, 1.0, _f64_1d), _Complex1_D)
assert_type(sph_harm_y(3, 2, 1.0, 2.0, diff_n=False), _Complex0D)
assert_type(sph_harm_y(3, 2, 1.0, 2.0, diff_n=0), _Complex0D)

# sph_harm_y_all
assert_type(sph_harm_y_all(3, 2, 1.0, 2.0), _Complex2D)
assert_type(sph_harm_y_all(n=3, m=2, theta=np.float32(1.0), phi=np.float32(2.0)), _Complex2D)
assert_type(sph_harm_y_all(3, 2, 1.0, _f64_1d), _Complex3_D) # type: ignore[assert-type,arg-type] # pyright: ignore[reportAssertTypeFailure, reportArgumentType] # pyrefly: ignore [assert-type, bad-argument-type] # TODO: fix MultiUFunc array overloads
assert_type(sph_harm_y_all(3, 2, 1.0, _f64_1d), _Complex3_D)
assert_type(sph_harm_y_all(3, 2, 1.0, 2.0, diff_n=False), _Complex2D)
assert_type(sph_harm_y_all(3, 2, 1.0, 2.0, diff_n=0), _Complex2D)