Skip to content

Commit 97c5a22

Browse files
committed
✅ Update type annotations and tests for legendre_p and assoc_legendre_p functions
1 parent 96d6d42 commit 97c5a22

2 files changed

Lines changed: 60 additions & 9 deletions

File tree

scipy-stubs/special/_multiufuncs.pyi

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,7 @@ class _AssocLegendreP(Protocol):
100100

101101
@type_check_only
102102
class _AssocLegendrePAll(Protocol):
103-
@overload # z: 0-d float
104-
def __call__(
105-
self, /, n: int, m: int, z: onp.ToFloat, *, branch_cut: _Branch = 2, norm: bool = False, diff_n: _Dn = 0
106-
) -> _Float3D: ...
107-
@overload # z: >=0-d float
103+
@overload # z: float
108104
def __call__(
109105
self, /, n: int, m: int, z: _ToFloat_D, *, branch_cut: _Branch_D = 2, norm: bool = False, diff_n: _Dn = 0
110106
) -> _Float3_D: ...
@@ -212,6 +208,49 @@ class MultiUFunc(Generic[_UFuncT_co]): # undocumented
212208
force_complex_output: bool = False,
213209
**default_kwargs: object,
214210
) -> None: ...
211+
@overload
212+
def __call__(self: MultiUFunc[_LegendreP], /, n: int, z: onp.ToFloat, *, diff_n: _Dn = 0) -> _Float1D: ...
213+
@overload
214+
def __call__(self: MultiUFunc[_LegendreP], /, n: int, z: onp.ToFloatND, *, diff_n: _Dn = 0) -> _Float2_D: ...
215+
@overload
216+
def __call__(self: MultiUFunc[_SphLegendrePAll], /, n: int, m: int, theta: onp.ToFloat, *, diff_n: _Dn = 0) -> _Float3D: ...
217+
@overload
218+
def __call__(self: MultiUFunc[_SphLegendrePAll], /, n: int, m: int, theta: _ToFloat_D, *, diff_n: _Dn = 0) -> _Float3_D: ...
219+
@overload
220+
def __call__(self: MultiUFunc[_SphLegendrePAll], *, n: int, m: int, theta: onp.ToFloat, diff_n: _Dn = 0) -> _Float3D: ...
221+
@overload
222+
def __call__(self: MultiUFunc[_SphLegendrePAll], *, n: int, m: int, theta: _ToFloat_D, diff_n: _Dn = 0) -> _Float3_D: ...
223+
@overload
224+
def __call__(self: MultiUFunc[_SphLegendreP], /, n: int, m: int, theta: onp.ToFloat, *, diff_n: _Dn = 0) -> _Float1D: ...
225+
@overload
226+
def __call__(self: MultiUFunc[_SphLegendreP], /, n: int, m: int, theta: onp.ToFloatND, *, diff_n: _Dn = 0) -> _Float2_D: ...
227+
@overload
228+
def __call__(self: MultiUFunc[_SphLegendreP], *, n: int, m: int, theta: onp.ToFloat, diff_n: _Dn = 0) -> _Float1D: ...
229+
@overload
230+
def __call__(self: MultiUFunc[_SphLegendreP], *, n: int, m: int, theta: onp.ToFloatND, diff_n: _Dn = 0) -> _Float2_D: ...
231+
@overload
232+
def __call__(
233+
self: MultiUFunc[_SphHarmY],
234+
/,
235+
n: int,
236+
m: int,
237+
theta: _ToFloat_D,
238+
phi: onp.ToFloatND,
239+
*,
240+
diff_n: _D0 = 0,
241+
) -> _Complex1_D: ...
242+
@overload
243+
def __call__(
244+
self: MultiUFunc[_SphHarmYAll],
245+
/,
246+
n: int,
247+
m: int,
248+
theta: _ToFloat_D,
249+
phi: onp.ToFloatND,
250+
*,
251+
diff_n: _D0 = 0,
252+
) -> _Complex3_D: ...
253+
@overload
215254
def __call__(self: MultiUFunc[Callable[_Tss, _RT]], /, *args: _Tss.args, **kwargs: _Tss.kwargs) -> _RT: ...
216255

217256
###

tests/special/test_multiufuncs.pyi

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,29 @@ from scipy.special import (
1515
)
1616

1717
_Float1_D: TypeAlias = onp.Array[onp.AtLeast1D[Any], np.float64]
18+
_Float2_D: TypeAlias = onp.Array[onp.AtLeast2D[Any], np.float64]
1819
_Float3_D: TypeAlias = onp.Array[onp.AtLeast3D[Any], np.float64]
1920
_Complex0D: TypeAlias = onp.Array0D[np.complex128]
2021
_Complex2D: TypeAlias = onp.Array2D[np.complex128]
22+
_Complex1_D: TypeAlias = onp.Array[onp.AtLeast1D[Any], np.complex128]
23+
_Complex3_D: TypeAlias = onp.Array[onp.AtLeast3D[Any], np.complex128]
2124

2225
_i64_1d: onp.Array1D[np.int64]
2326
_f64_1d: onp.Array1D[np.float64]
27+
_f64_2d: onp.Array2D[np.float64]
2428

2529
# legendre_p
2630
assert_type(legendre_p(1, 1.0), onp.Array1D[np.float64])
2731
assert_type(legendre_p(1, np.float32(1.0)), onp.Array1D[np.float64])
32+
assert_type(legendre_p(1, _f64_2d), _Float2_D)
2833
assert_type(legendre_p(1, 1.0, diff_n=True), onp.Array1D[np.float64])
2934
assert_type(legendre_p(1, 1.0, diff_n=1), onp.Array1D[np.float64])
3035

3136
# legendre_p_all
3237
assert_type(legendre_p_all(3, 1.0), _Float3_D)
3338
assert_type(legendre_p_all(n=3, z=np.float32(1.0)), _Float3_D)
3439
assert_type(legendre_p_all(3, _f64_1d), _Float3_D)
40+
assert_type(legendre_p_all(3, _f64_2d), _Float3_D)
3541
assert_type(legendre_p_all(3, 1.0, diff_n=True), _Float3_D)
3642
assert_type(legendre_p_all(3, 1.0, diff_n=2), _Float3_D)
3743

@@ -40,35 +46,41 @@ assert_type(assoc_legendre_p(3, 2, 1.0), _Float1_D)
4046
assert_type(assoc_legendre_p(n=3, m=2, z=np.float32(1.0)), _Float1_D)
4147
assert_type(assoc_legendre_p(_i64_1d, 2, 1.0), _Float1_D)
4248
assert_type(assoc_legendre_p(3, _i64_1d, _f64_1d), _Float1_D)
49+
assert_type(assoc_legendre_p(3, 2, _f64_2d), _Float1_D)
4350
assert_type(assoc_legendre_p(3, 2, 1.0, branch_cut=3, norm=True, diff_n=1), _Float1_D)
4451
assert_type(assoc_legendre_p(3, 2, _f64_1d, branch_cut=_i64_1d, diff_n=2), _Float1_D)
4552

4653
# assoc_legendre_p_all
47-
assert_type(assoc_legendre_p_all(3, 2, 1.0), onp.Array3D[np.float64])
48-
assert_type(assoc_legendre_p_all(n=3, m=2, z=np.float32(1.0)), onp.Array3D[np.float64])
49-
assert_type(assoc_legendre_p_all(3, 2, 1.0, branch_cut=3, norm=True, diff_n=1), onp.Array3D[np.float64])
50-
assert_type(assoc_legendre_p_all(3, 2, np.float64(1.0), branch_cut=2, diff_n=2), onp.Array3D[np.float64])
54+
assert_type(assoc_legendre_p_all(3, 2, 1.0), _Float3_D)
55+
assert_type(assoc_legendre_p_all(n=3, m=2, z=np.float32(1.0)), _Float3_D)
56+
assert_type(assoc_legendre_p_all(n=3, m=2, z=_f64_1d), _Float3_D)
57+
assert_type(assoc_legendre_p_all(3, 2, 1.0, branch_cut=3, norm=True, diff_n=1), _Float3_D)
58+
assert_type(assoc_legendre_p_all(3, 2, np.float64(1.0), branch_cut=2, diff_n=2), _Float3_D)
5159

5260
# sph_legendre_p
5361
assert_type(sph_legendre_p(3, 2, 1.0), onp.Array1D[np.float64])
5462
assert_type(sph_legendre_p(n=3, m=2, theta=np.float32(1.0)), onp.Array1D[np.float64])
63+
assert_type(sph_legendre_p(n=3, m=2, theta=_f64_1d), _Float2_D)
5564
assert_type(sph_legendre_p(3, 2, 1.0, diff_n=True), onp.Array1D[np.float64])
5665
assert_type(sph_legendre_p(3, 2, 1.0, diff_n=2), onp.Array1D[np.float64])
5766

5867
# sph_legendre_p_all
5968
assert_type(sph_legendre_p_all(3, 2, 1.0), onp.Array3D[np.float64])
6069
assert_type(sph_legendre_p_all(n=3, m=2, theta=np.float32(1.0)), onp.Array3D[np.float64])
70+
assert_type(sph_legendre_p_all(n=3, m=2, theta=_f64_1d), _Float3_D)
6171
assert_type(sph_legendre_p_all(3, 2, 1.0, diff_n=True), onp.Array3D[np.float64])
6272
assert_type(sph_legendre_p_all(3, 2, 1.0, diff_n=2), onp.Array3D[np.float64])
6373

6474
# sph_harm_y
6575
assert_type(sph_harm_y(3, 2, 1.0, 2.0), _Complex0D)
6676
assert_type(sph_harm_y(n=3, m=2, theta=np.float32(1.0), phi=np.float32(2.0)), _Complex0D)
77+
assert_type(sph_harm_y(3, 2, 1.0, _f64_1d), _Complex1_D)
6778
assert_type(sph_harm_y(3, 2, 1.0, 2.0, diff_n=False), _Complex0D)
6879
assert_type(sph_harm_y(3, 2, 1.0, 2.0, diff_n=0), _Complex0D)
6980

7081
# sph_harm_y_all
7182
assert_type(sph_harm_y_all(3, 2, 1.0, 2.0), _Complex2D)
7283
assert_type(sph_harm_y_all(n=3, m=2, theta=np.float32(1.0), phi=np.float32(2.0)), _Complex2D)
84+
assert_type(sph_harm_y_all(3, 2, 1.0, _f64_1d), _Complex3_D)
7385
assert_type(sph_harm_y_all(3, 2, 1.0, 2.0, diff_n=False), _Complex2D)
7486
assert_type(sph_harm_y_all(3, 2, 1.0, 2.0, diff_n=0), _Complex2D)

0 commit comments

Comments
 (0)