From 19c602c3dd9473e2ef58345b352237217f68944f Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Tue, 14 Apr 2026 10:03:17 +0000 Subject: [PATCH 1/7] add generic shape types for multivariate frozen distributions --- scipy-stubs/stats/_multivariate.pyi | 97 ++++++++++++++++------------- 1 file changed, 52 insertions(+), 45 deletions(-) diff --git a/scipy-stubs/stats/_multivariate.pyi b/scipy-stubs/stats/_multivariate.pyi index 58ad233ed..a12494b36 100644 --- a/scipy-stubs/stats/_multivariate.pyi +++ b/scipy-stubs/stats/_multivariate.pyi @@ -31,6 +31,11 @@ __all__ = [ _ScalarT = TypeVar("_ScalarT", bound=np.generic, default=np.float64) _ScalarT_co = TypeVar("_ScalarT_co", bound=np.generic, default=np.float64, covariant=True) +_K = TypeVar("_K", bound=int, default=int) +_M = TypeVar("_M", bound=int, default=int) +_N = TypeVar("_N", bound=int, default=int) +_P = TypeVar("_P", bound=int, default=int) +_D = TypeVar("_D", bound=int, default=int) # TODO(@jorenham): rename as {}T_co _RVG_co = TypeVar("_RVG_co", bound=multi_rv_generic, default=multi_rv_generic, covariant=True) @@ -180,27 +185,24 @@ class multivariate_normal_gen(multi_rv_generic): # TODO(@jorenham): Generic shape-type for mean and cov, so that we can determine whether the methods return scalars or arrays. # https://github.com/scipy/scipy-stubs/issues/406 -class multivariate_normal_frozen(multi_rv_frozen[multivariate_normal_gen]): - # pyrefly: ignore [bad-override] - __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] - +class multivariate_normal_frozen(multi_rv_frozen[multivariate_normal_gen], Generic[_K]): dim: Final[int] allow_singular: Final[bool] maxpts: Final[int] abseps: Final[float] releps: Final[float] cov_object: Final[Covariance] - mean: onp.Array1D[np.float64] + mean: onp.Array[tuple[_K], np.float64] @property - def cov(self, /) -> onp.Array2D[np.float64]: ... + def cov(self, /) -> onp.Array[tuple[_K, _K], np.float64]: ... # def __init__( self, /, - mean: onp.ToFloat1D | None = None, - cov: _AnyCov = 1, + mean: onp.Array[tuple[_K], npc.floating] | None = None, + cov: Covariance | onp.ToFloat | onp.Array[tuple[_K, _K], npc.floating] = 1, allow_singular: bool = False, seed: onp.random.ToRNG | None = None, maxpts: onp.ToJustInt | None = None, @@ -227,7 +229,7 @@ class multivariate_normal_frozen(multi_rv_frozen[multivariate_normal_gen]): # def entropy(self, /) -> np.float64: ... - def marginal(self, dimensions: int | onp.ToInt1D) -> multivariate_normal_frozen: ... + def marginal(self, dimensions: int | onp.ToInt1D) -> multivariate_normal_frozen[int]: ... class matrix_normal_gen(multi_rv_generic): def __call__( @@ -293,19 +295,17 @@ class matrix_normal_gen(multi_rv_generic): # def entropy(self, /, rowcov: _AnyCov = 1, colcov: _AnyCov = 1) -> np.float64: ... -class matrix_normal_frozen(multi_rv_frozen[matrix_normal_gen]): - # pyrefly: ignore [bad-override] - __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] - +class matrix_normal_frozen(multi_rv_frozen[matrix_normal_gen], Generic[_M, _N]): + mean: Final[onp.Array[tuple[_M, _N], np.float64]] rowpsd: Final[_PSD] colpsd: Final[_PSD] def __init__( self, /, - mean: onp.ToFloat2D | None = None, - rowcov: onp.ToFloat2D | onp.ToFloat = 1, - colcov: onp.ToFloat2D | onp.ToFloat = 1, + mean: onp.Array[tuple[_M, _N], npc.floating] | None = None, + rowcov: onp.ToFloat | onp.Array[tuple[_M], npc.floating] | onp.Array[tuple[_M, _M], npc.floating] = 1, + colcov: onp.ToFloat | onp.Array[tuple[_N], npc.floating] | onp.Array[tuple[_N, _N], npc.floating] = 1, seed: onp.random.ToRNG | None = None, ) -> None: ... def logpdf(self, /, X: onp.ToFloatND) -> _ScalarOrArray_f8: ... @@ -597,21 +597,23 @@ class multinomial_gen(multi_rv_generic): ) -> _Array2ND: ... # -class multinomial_frozen(multi_rv_frozen[multinomial_gen]): - def __init__(self, /, n: onp.ToJustIntND, p: onp.ToJustFloatND, seed: onp.random.ToRNG | None = None) -> None: ... +class multinomial_frozen(multi_rv_frozen[multinomial_gen], Generic[_K]): + def __init__( + self, /, n: onp.ToJustInt, p: onp.Array[tuple[_K], npc.floating], seed: onp.random.ToRNG | None = None + ) -> None: ... # def logpmf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ... def pmf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ... # - def mean(self, /) -> _Array1ND: ... - def cov(self, /) -> _Array2ND: ... + def mean(self, /) -> onp.Array[tuple[_K, *tuple[Any, ...]], np.float64]: ... + def cov(self, /) -> onp.Array[tuple[_K, _K, *tuple[Any, ...]], np.float64]: ... # pyright: ignore[reportInvalidTypeForm] def entropy(self, /) -> _ScalarOrArray_f8: ... # @overload - def rvs(self, /, size: tuple[()], random_state: onp.random.ToRNG | None = None) -> _Array1ND: ... + def rvs(self, /, size: tuple[()], random_state: onp.random.ToRNG | None = None) -> onp.Array[tuple[_K], np.float64]: ... @overload def rvs(self, /, size: onp.AtLeast1D | int = 1, random_state: onp.random.ToRNG | None = None) -> _Array2ND: ... @@ -669,16 +671,18 @@ class uniform_direction_gen(multi_rv_generic): @overload def rvs(self, /, dim: int, size: onp.AtLeast2D, random_state: onp.random.ToRNG | None = None) -> _Array3ND[np.float64]: ... -class uniform_direction_frozen(multi_rv_frozen[uniform_direction_gen]): +class uniform_direction_frozen(multi_rv_frozen[uniform_direction_gen], Generic[_D]): dim: Final[int] def __init__(self, /, dim: int | None = None, seed: onp.random.ToRNG | None = None) -> None: ... # @overload - def rvs(self, /, size: None = None, random_state: onp.random.ToRNG | None = None) -> onp.Array1D[np.float64]: ... + def rvs(self, /, size: None = None, random_state: onp.random.ToRNG | None = None) -> onp.Array[tuple[_D], np.float64]: ... @overload - def rvs(self, /, size: int | tuple[int], random_state: onp.random.ToRNG | None = None) -> onp.Array2D[np.float64]: ... + def rvs( + self, /, size: int | tuple[int], random_state: onp.random.ToRNG | None = None + ) -> onp.Array[tuple[int, _D], np.float64]: ... @overload def rvs(self, /, size: onp.AtLeast2D, random_state: onp.random.ToRNG | None = None) -> _Array2ND[np.float64]: ... @@ -827,21 +831,18 @@ class multivariate_t_gen(multi_rv_generic): allow_singular: bool = False, ) -> multivariate_t_frozen: ... -class multivariate_t_frozen(multi_rv_frozen[multivariate_t_gen]): - # pyrefly: ignore [bad-override] - __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] - +class multivariate_t_frozen(multi_rv_frozen[multivariate_t_gen], Generic[_P]): dim: Final[int] df: Final[int] - loc: Final[onp.Array1D[np.float64]] - shape: Final[onp.Array2D[np.float64]] + loc: Final[onp.Array[tuple[_P], np.float64]] + shape: Final[onp.Array[tuple[_P, _P], np.float64]] shape_info: Final[_PSD] def __init__( self, /, - loc: onp.ToFloat1D | None = None, - shape: onp.ToFloat | onp.ToFloat2D = 1, + loc: onp.Array[tuple[_P], npc.floating] | None = None, + shape: onp.ToFloat | onp.Array[tuple[_P, _P], npc.floating] = 1, df: int = 1, allow_singular: bool = False, seed: onp.random.ToRNG | None = None, @@ -872,7 +873,7 @@ class multivariate_t_frozen(multi_rv_frozen[multivariate_t_gen]): def rvs(self, /, size: onp.AtLeast2D, random_state: onp.random.ToRNG | None = None) -> _Array3ND: ... # - def marginal(self, dimensions: int | onp.ToInt1D) -> multivariate_t_frozen: ... + def marginal(self, dimensions: int | onp.ToInt1D) -> multivariate_t_frozen[int]: ... # NOTE: `m` and `n` are broadcastable (but doing so will break `.rvs()` at runtime...) class multivariate_hypergeom_gen(multi_rv_generic): @@ -898,15 +899,17 @@ class multivariate_hypergeom_gen(multi_rv_generic): random_state: onp.random.ToRNG | None = None, ) -> _Array2ND: ... -class multivariate_hypergeom_frozen(multi_rv_frozen[multivariate_hypergeom_gen]): - def __init__(self, /, m: onp.ToIntND, n: onp.ToJustInt | onp.ToJustIntND, seed: onp.random.ToRNG | None = None) -> None: ... +class multivariate_hypergeom_frozen(multi_rv_frozen[multivariate_hypergeom_gen], Generic[_K]): + def __init__( + self, /, m: onp.Array[tuple[_K], npc.integer], n: onp.ToJustInt, seed: onp.random.ToRNG | None = None + ) -> None: ... def logpmf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ... def pmf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ... - def mean(self, /) -> _Array1ND: ... - def var(self, /) -> _Array1ND: ... - def cov(self, /) -> _Array2ND: ... + def mean(self, /) -> onp.Array[tuple[_K, *tuple[Any, ...]], np.float64]: ... + def var(self, /) -> onp.Array[tuple[_K, *tuple[Any, ...]], np.float64]: ... + def cov(self, /) -> onp.Array[tuple[_K, _K, *tuple[Any, ...]], np.float64]: ... # pyright: ignore[reportInvalidTypeForm] @overload - def rvs(self, /, size: tuple[()], random_state: onp.random.ToRNG | None = None) -> _Array1ND: ... + def rvs(self, /, size: tuple[()], random_state: onp.random.ToRNG | None = None) -> onp.Array[tuple[_K], np.float64]: ... @overload def rvs(self, /, size: int | tuple[int] = 1, random_state: onp.random.ToRNG | None = None) -> _Array2ND: ... @@ -996,11 +999,13 @@ class dirichlet_multinomial_gen(multi_rv_generic): def var(self, /, alpha: onp.ToFloatND, n: onp.ToJustIntND) -> _Array1ND: ... def cov(self, /, alpha: onp.ToFloatND, n: onp.ToJustIntND) -> _Array2ND: ... -class dirichlet_multinomial_frozen(multi_rv_frozen[dirichlet_multinomial_gen]): - alpha: _Array1ND - n: _Array1ND[np.int_] # broadcasted against alpha +class dirichlet_multinomial_frozen(multi_rv_frozen[dirichlet_multinomial_gen], Generic[_K]): + alpha: onp.Array[tuple[_K], np.float64] + n: onp.ArrayND[np.int_] - def __init__(self, /, alpha: onp.ToFloatND, n: onp.ToJustIntND, seed: onp.random.ToRNG | None = None) -> None: ... + def __init__( + self, /, alpha: onp.Array[tuple[_K], npc.floating], n: onp.ToJustIntND, seed: onp.random.ToRNG | None = None + ) -> None: ... def logpmf(self, /, x: onp.ToIntND) -> _ScalarOrArray_f8: ... def pmf(self, /, x: onp.ToIntND) -> _ScalarOrArray_f8: ... def mean(self, /) -> _Array1ND: ... @@ -1024,8 +1029,10 @@ class vonmises_fisher_gen(multi_rv_generic): ) -> _Array2ND: ... def fit(self, /, x: onp.ToFloatND) -> tuple[onp.Array1D[np.float64], float]: ... -class vonmises_fisher_frozen(multi_rv_frozen[vonmises_fisher_gen]): - def __init__(self, /, mu: onp.ToFloat1D | None = None, kappa: int = 1, seed: onp.random.ToRNG | None = None) -> None: ... +class vonmises_fisher_frozen(multi_rv_frozen[vonmises_fisher_gen], Generic[_D]): + def __init__( + self, /, mu: onp.Array[tuple[_D], npc.floating] | None = None, kappa: int = 1, seed: onp.random.ToRNG | None = None + ) -> None: ... def logpdf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ... def pdf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ... def entropy(self, /) -> np.float64: ... From 0da08db33265fb97d0c193d8c8a027f503639d90 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Tue, 14 Apr 2026 10:26:48 +0000 Subject: [PATCH 2/7] add generic shape types for multivariate frozen distributions --- scipy-stubs/stats/_multivariate.pyi | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scipy-stubs/stats/_multivariate.pyi b/scipy-stubs/stats/_multivariate.pyi index a12494b36..71fb46cc9 100644 --- a/scipy-stubs/stats/_multivariate.pyi +++ b/scipy-stubs/stats/_multivariate.pyi @@ -296,7 +296,7 @@ class matrix_normal_gen(multi_rv_generic): def entropy(self, /, rowcov: _AnyCov = 1, colcov: _AnyCov = 1) -> np.float64: ... class matrix_normal_frozen(multi_rv_frozen[matrix_normal_gen], Generic[_M, _N]): - mean: Final[onp.Array[tuple[_M, _N], np.float64]] + mean: onp.Array[tuple[_M, _N], np.float64] rowpsd: Final[_PSD] colpsd: Final[_PSD] @@ -834,8 +834,8 @@ class multivariate_t_gen(multi_rv_generic): class multivariate_t_frozen(multi_rv_frozen[multivariate_t_gen], Generic[_P]): dim: Final[int] df: Final[int] - loc: Final[onp.Array[tuple[_P], np.float64]] - shape: Final[onp.Array[tuple[_P, _P], np.float64]] + loc: onp.Array[tuple[_P], np.float64] + shape: onp.Array[tuple[_P, _P], np.float64] shape_info: Final[_PSD] def __init__( From 40502e2d2c230d377a3c9144482ef131348bcd8b Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Tue, 14 Apr 2026 11:13:20 +0000 Subject: [PATCH 3/7] add generic shape types for multivariate frozen distributions --- scipy-stubs/stats/_multivariate.pyi | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/scipy-stubs/stats/_multivariate.pyi b/scipy-stubs/stats/_multivariate.pyi index 71fb46cc9..51d02f4c8 100644 --- a/scipy-stubs/stats/_multivariate.pyi +++ b/scipy-stubs/stats/_multivariate.pyi @@ -1,4 +1,3 @@ -import types from typing import Any, ClassVar, Final, Generic, Literal, SupportsIndex, TypeAlias, overload, type_check_only from typing_extensions import TypeVar, override @@ -68,9 +67,6 @@ class multi_rv_generic(rng_mixin): def _get_random_state(self, /, random_state: onp.random.ToRNG | None) -> onp.random.RNG: ... class multi_rv_frozen(rng_mixin, Generic[_RVG_co]): - @classmethod - def __class_getitem__(cls, arg: object, /) -> types.GenericAlias: ... - _dist: _RVG_co class multivariate_normal_gen(multi_rv_generic): @@ -183,8 +179,6 @@ class multivariate_normal_gen(multi_rv_generic): allow_singular: bool = False, ) -> multivariate_normal_frozen: ... -# TODO(@jorenham): Generic shape-type for mean and cov, so that we can determine whether the methods return scalars or arrays. -# https://github.com/scipy/scipy-stubs/issues/406 class multivariate_normal_frozen(multi_rv_frozen[multivariate_normal_gen], Generic[_K]): dim: Final[int] allow_singular: Final[bool] @@ -438,7 +432,7 @@ class dirichlet_gen(multi_rv_generic): class dirichlet_frozen(multi_rv_frozen[dirichlet_gen]): # pyrefly: ignore [bad-override] - __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] + __class_getitem__: ClassVar[None] = None alpha: Final[onp.Array1D[_Scalar_uif]] @@ -503,7 +497,7 @@ class wishart_gen(multi_rv_generic): class wishart_frozen(multi_rv_frozen[wishart_gen]): # pyrefly: ignore [bad-override] - __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] + __class_getitem__: ClassVar[None] = None dim: Final[int] df: Final[onp.ToFloat] @@ -546,7 +540,7 @@ class invwishart_gen(wishart_gen): class invwishart_frozen(multi_rv_frozen[invwishart_gen]): # pyrefly: ignore [bad-override] - __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] + __class_getitem__: ClassVar[None] = None def __init__(self, /, df: onp.ToFloat, scale: _ToFloatMax2D, seed: onp.random.ToRNG | None = None) -> None: ... @@ -648,15 +642,15 @@ class _group_rv_frozen_mixin(Generic[_ScalarT_co]): class special_ortho_group_gen(_group_rv_gen_mixin[special_ortho_group_frozen], multi_rv_generic): ... # pyrefly: ignore [inconsistent-inheritance] -class special_ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[special_ortho_group_gen]): ... # type: ignore[misc] +class special_ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[special_ortho_group_gen]): ... class ortho_group_gen(_group_rv_gen_mixin[ortho_group_frozen], multi_rv_generic): ... # pyrefly: ignore [inconsistent-inheritance] -class ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[ortho_group_gen]): ... # type: ignore[misc] +class ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[ortho_group_gen]): ... class unitary_group_gen(_group_rv_gen_mixin[unitary_group_frozen, np.complex128], multi_rv_generic): ... # pyrefly: ignore [inconsistent-inheritance] -class unitary_group_frozen(_group_rv_frozen_mixin[np.complex128], multi_rv_frozen[unitary_group_gen]): ... # type: ignore[misc] +class unitary_group_frozen(_group_rv_frozen_mixin[np.complex128], multi_rv_frozen[unitary_group_gen]): ... class uniform_direction_gen(multi_rv_generic): def __call__(self, /, dim: int | None = None, seed: onp.random.ToRNG | None = None) -> uniform_direction_frozen: ... @@ -708,7 +702,7 @@ class random_correlation_gen(multi_rv_generic): class random_correlation_frozen(multi_rv_frozen[random_correlation_gen]): # pyrefly: ignore [bad-override] - __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] + __class_getitem__: ClassVar[None] = None eigs: Final[onp.Array1D[np.float64]] tol: Final[float] @@ -964,7 +958,7 @@ class random_table_gen(multi_rv_generic): class random_table_frozen(multi_rv_frozen[random_table_gen]): # pyrefly: ignore [bad-override] - __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] + __class_getitem__: ClassVar[None] = None def __init__(self, /, row: onp.ToJustIntND, col: onp.ToJustIntND, *, seed: onp.random.ToRNG | None = None) -> None: ... From dfa3a0c905163ef9d19d81add19f96168a952fd4 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Wed, 15 Apr 2026 17:18:43 +0000 Subject: [PATCH 4/7] add generic shape types for multivariate frozen distributions --- scipy-stubs/stats/_multivariate.pyi | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/scipy-stubs/stats/_multivariate.pyi b/scipy-stubs/stats/_multivariate.pyi index 51d02f4c8..fe97d8fd1 100644 --- a/scipy-stubs/stats/_multivariate.pyi +++ b/scipy-stubs/stats/_multivariate.pyi @@ -431,7 +431,6 @@ class dirichlet_gen(multi_rv_generic): ) -> _Array3ND: ... class dirichlet_frozen(multi_rv_frozen[dirichlet_gen]): - # pyrefly: ignore [bad-override] __class_getitem__: ClassVar[None] = None alpha: Final[onp.Array1D[_Scalar_uif]] @@ -496,7 +495,6 @@ class wishart_gen(multi_rv_generic): ) -> _Array2ND[np.float64] | np.float64: ... class wishart_frozen(multi_rv_frozen[wishart_gen]): - # pyrefly: ignore [bad-override] __class_getitem__: ClassVar[None] = None dim: Final[int] @@ -539,7 +537,6 @@ class invwishart_gen(wishart_gen): def var(self, /, df: onp.ToFloat, scale: _ToFloatMax2D) -> np.float64 | None: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] class invwishart_frozen(multi_rv_frozen[invwishart_gen]): - # pyrefly: ignore [bad-override] __class_getitem__: ClassVar[None] = None def __init__(self, /, df: onp.ToFloat, scale: _ToFloatMax2D, seed: onp.random.ToRNG | None = None) -> None: ... @@ -640,16 +637,10 @@ class _group_rv_frozen_mixin(Generic[_ScalarT_co]): def rvs(self, /, size: int, random_state: onp.random.ToRNG | None = None) -> _Array2ND[_ScalarT_co]: ... class special_ortho_group_gen(_group_rv_gen_mixin[special_ortho_group_frozen], multi_rv_generic): ... - -# pyrefly: ignore [inconsistent-inheritance] class special_ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[special_ortho_group_gen]): ... class ortho_group_gen(_group_rv_gen_mixin[ortho_group_frozen], multi_rv_generic): ... - -# pyrefly: ignore [inconsistent-inheritance] class ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[ortho_group_gen]): ... class unitary_group_gen(_group_rv_gen_mixin[unitary_group_frozen, np.complex128], multi_rv_generic): ... - -# pyrefly: ignore [inconsistent-inheritance] class unitary_group_frozen(_group_rv_frozen_mixin[np.complex128], multi_rv_frozen[unitary_group_gen]): ... class uniform_direction_gen(multi_rv_generic): @@ -701,7 +692,6 @@ class random_correlation_gen(multi_rv_generic): ) -> onp.Array2D[np.float64]: ... class random_correlation_frozen(multi_rv_frozen[random_correlation_gen]): - # pyrefly: ignore [bad-override] __class_getitem__: ClassVar[None] = None eigs: Final[onp.Array1D[np.float64]] @@ -957,7 +947,6 @@ class random_table_gen(multi_rv_generic): ) -> _Array3ND[np.float64]: ... class random_table_frozen(multi_rv_frozen[random_table_gen]): - # pyrefly: ignore [bad-override] __class_getitem__: ClassVar[None] = None def __init__(self, /, row: onp.ToJustIntND, col: onp.ToJustIntND, *, seed: onp.random.ToRNG | None = None) -> None: ... From b6cdc14db345cff1b6f1fadb1cc0a1d844933628 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Wed, 15 Apr 2026 17:44:29 +0000 Subject: [PATCH 5/7] revert __class_getitem__ --- scipy-stubs/stats/_multivariate.pyi | 34 ++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/scipy-stubs/stats/_multivariate.pyi b/scipy-stubs/stats/_multivariate.pyi index fe97d8fd1..7b7d22a79 100644 --- a/scipy-stubs/stats/_multivariate.pyi +++ b/scipy-stubs/stats/_multivariate.pyi @@ -1,3 +1,4 @@ +import types from typing import Any, ClassVar, Final, Generic, Literal, SupportsIndex, TypeAlias, overload, type_check_only from typing_extensions import TypeVar, override @@ -67,6 +68,9 @@ class multi_rv_generic(rng_mixin): def _get_random_state(self, /, random_state: onp.random.ToRNG | None) -> onp.random.RNG: ... class multi_rv_frozen(rng_mixin, Generic[_RVG_co]): + @classmethod + def __class_getitem__(cls, arg: object, /) -> types.GenericAlias: ... + _dist: _RVG_co class multivariate_normal_gen(multi_rv_generic): @@ -180,6 +184,9 @@ class multivariate_normal_gen(multi_rv_generic): ) -> multivariate_normal_frozen: ... class multivariate_normal_frozen(multi_rv_frozen[multivariate_normal_gen], Generic[_K]): + # pyrefly: ignore [bad-override] + __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] + dim: Final[int] allow_singular: Final[bool] maxpts: Final[int] @@ -290,6 +297,9 @@ class matrix_normal_gen(multi_rv_generic): def entropy(self, /, rowcov: _AnyCov = 1, colcov: _AnyCov = 1) -> np.float64: ... class matrix_normal_frozen(multi_rv_frozen[matrix_normal_gen], Generic[_M, _N]): + # pyrefly: ignore [bad-override] + __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] + mean: onp.Array[tuple[_M, _N], np.float64] rowpsd: Final[_PSD] colpsd: Final[_PSD] @@ -431,7 +441,8 @@ class dirichlet_gen(multi_rv_generic): ) -> _Array3ND: ... class dirichlet_frozen(multi_rv_frozen[dirichlet_gen]): - __class_getitem__: ClassVar[None] = None + # pyrefly: ignore [bad-override] + __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] alpha: Final[onp.Array1D[_Scalar_uif]] @@ -495,7 +506,8 @@ class wishart_gen(multi_rv_generic): ) -> _Array2ND[np.float64] | np.float64: ... class wishart_frozen(multi_rv_frozen[wishart_gen]): - __class_getitem__: ClassVar[None] = None + # pyrefly: ignore [bad-override] + __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] dim: Final[int] df: Final[onp.ToFloat] @@ -537,7 +549,8 @@ class invwishart_gen(wishart_gen): def var(self, /, df: onp.ToFloat, scale: _ToFloatMax2D) -> np.float64 | None: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] class invwishart_frozen(multi_rv_frozen[invwishart_gen]): - __class_getitem__: ClassVar[None] = None + # pyrefly: ignore [bad-override] + __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] def __init__(self, /, df: onp.ToFloat, scale: _ToFloatMax2D, seed: onp.random.ToRNG | None = None) -> None: ... @@ -637,11 +650,11 @@ class _group_rv_frozen_mixin(Generic[_ScalarT_co]): def rvs(self, /, size: int, random_state: onp.random.ToRNG | None = None) -> _Array2ND[_ScalarT_co]: ... class special_ortho_group_gen(_group_rv_gen_mixin[special_ortho_group_frozen], multi_rv_generic): ... -class special_ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[special_ortho_group_gen]): ... +class special_ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[special_ortho_group_gen]): ... # type: ignore[misc] class ortho_group_gen(_group_rv_gen_mixin[ortho_group_frozen], multi_rv_generic): ... -class ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[ortho_group_gen]): ... +class ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[ortho_group_gen]): ... # type: ignore[misc] class unitary_group_gen(_group_rv_gen_mixin[unitary_group_frozen, np.complex128], multi_rv_generic): ... -class unitary_group_frozen(_group_rv_frozen_mixin[np.complex128], multi_rv_frozen[unitary_group_gen]): ... +class unitary_group_frozen(_group_rv_frozen_mixin[np.complex128], multi_rv_frozen[unitary_group_gen]): ... # type: ignore[misc] class uniform_direction_gen(multi_rv_generic): def __call__(self, /, dim: int | None = None, seed: onp.random.ToRNG | None = None) -> uniform_direction_frozen: ... @@ -692,7 +705,8 @@ class random_correlation_gen(multi_rv_generic): ) -> onp.Array2D[np.float64]: ... class random_correlation_frozen(multi_rv_frozen[random_correlation_gen]): - __class_getitem__: ClassVar[None] = None + # pyrefly: ignore [bad-override] + __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] eigs: Final[onp.Array1D[np.float64]] tol: Final[float] @@ -816,6 +830,9 @@ class multivariate_t_gen(multi_rv_generic): ) -> multivariate_t_frozen: ... class multivariate_t_frozen(multi_rv_frozen[multivariate_t_gen], Generic[_P]): + # pyrefly: ignore [bad-override] + __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] + dim: Final[int] df: Final[int] loc: onp.Array[tuple[_P], np.float64] @@ -947,7 +964,8 @@ class random_table_gen(multi_rv_generic): ) -> _Array3ND[np.float64]: ... class random_table_frozen(multi_rv_frozen[random_table_gen]): - __class_getitem__: ClassVar[None] = None + # pyrefly: ignore [bad-override] + __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] def __init__(self, /, row: onp.ToJustIntND, col: onp.ToJustIntND, *, seed: onp.random.ToRNG | None = None) -> None: ... From 194981a9c78262af674e46af924b70a6045e604c Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Fri, 17 Apr 2026 15:42:36 +0000 Subject: [PATCH 6/7] add generic shape types for multivariate frozen distributions --- scipy-stubs/stats/_multivariate.pyi | 31 +++++++++++++++++------------ 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/scipy-stubs/stats/_multivariate.pyi b/scipy-stubs/stats/_multivariate.pyi index 7b7d22a79..d39cfc05c 100644 --- a/scipy-stubs/stats/_multivariate.pyi +++ b/scipy-stubs/stats/_multivariate.pyi @@ -237,8 +237,8 @@ class matrix_normal_gen(multi_rv_generic): self, /, mean: onp.ToFloat2D | None = None, - rowcov: onp.ToFloat2D | onp.ToFloat = 1, - colcov: onp.ToFloat2D | onp.ToFloat = 1, + rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, + colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, seed: onp.random.ToRNG | None = None, ) -> matrix_normal_frozen: ... @@ -248,16 +248,16 @@ class matrix_normal_gen(multi_rv_generic): /, X: onp.ToFloatND, mean: onp.ToFloat2D | None = None, - rowcov: onp.ToFloat2D | onp.ToFloat = 1, - colcov: onp.ToFloat2D | onp.ToFloat = 1, + rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, + colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, ) -> _ScalarOrArray_f8: ... def pdf( self, /, X: onp.ToFloatND, mean: onp.ToFloat2D | None = None, - rowcov: onp.ToFloat2D | onp.ToFloat = 1, - colcov: onp.ToFloat2D | onp.ToFloat = 1, + rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, + colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, ) -> _ScalarOrArray_f8: ... # If `size > 1` the output is 3-D, otherwise 2-D. @@ -266,8 +266,8 @@ class matrix_normal_gen(multi_rv_generic): self, /, mean: onp.ToFloat2D | None = None, - rowcov: onp.ToFloat2D | onp.ToFloat = 1, - colcov: onp.ToFloat2D | onp.ToFloat = 1, + rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, + colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, size: Literal[1] = 1, random_state: onp.random.ToRNG | None = None, ) -> onp.Array2D[np.float64]: ... @@ -276,8 +276,8 @@ class matrix_normal_gen(multi_rv_generic): self, /, mean: onp.ToFloat2D | None, - rowcov: onp.ToFloat2D | onp.ToFloat, - colcov: onp.ToFloat2D | onp.ToFloat, + rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D, + colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D, size: int, random_state: onp.random.ToRNG | None = None, ) -> _Array2ND[np.float64]: ... @@ -286,15 +286,20 @@ class matrix_normal_gen(multi_rv_generic): self, /, mean: onp.ToFloat2D | None = None, - rowcov: onp.ToFloat2D | onp.ToFloat = 1, - colcov: onp.ToFloat2D | onp.ToFloat = 1, + rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, + colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, *, size: int, random_state: onp.random.ToRNG | None = None, ) -> _Array2ND[np.float64]: ... # - def entropy(self, /, rowcov: _AnyCov = 1, colcov: _AnyCov = 1) -> np.float64: ... + def entropy( + self, + /, + rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, + colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, + ) -> np.float64: ... class matrix_normal_frozen(multi_rv_frozen[matrix_normal_gen], Generic[_M, _N]): # pyrefly: ignore [bad-override] From 44641bf77473636a2d0e5e55555a2a18637fa6fb Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Sat, 18 Apr 2026 11:26:08 +0000 Subject: [PATCH 7/7] remove disable generic and add comments --- scipy-stubs/stats/_multivariate.pyi | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/scipy-stubs/stats/_multivariate.pyi b/scipy-stubs/stats/_multivariate.pyi index d39cfc05c..ec548091d 100644 --- a/scipy-stubs/stats/_multivariate.pyi +++ b/scipy-stubs/stats/_multivariate.pyi @@ -294,12 +294,7 @@ class matrix_normal_gen(multi_rv_generic): ) -> _Array2ND[np.float64]: ... # - def entropy( - self, - /, - rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, - colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, - ) -> np.float64: ... + def entropy(self, /, rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1) -> np.float64: ... class matrix_normal_frozen(multi_rv_frozen[matrix_normal_gen], Generic[_M, _N]): # pyrefly: ignore [bad-override] @@ -641,8 +636,6 @@ class _group_rv_gen_mixin(Generic[_RVF_co, _ScalarT_co]): @type_check_only class _group_rv_frozen_mixin(Generic[_ScalarT_co]): - __class_getitem__: ClassVar[None] = None - dim: Final[int] # NOTE: Contrary to what the `dim` default suggests, it is required. @@ -655,10 +648,13 @@ class _group_rv_frozen_mixin(Generic[_ScalarT_co]): def rvs(self, /, size: int, random_state: onp.random.ToRNG | None = None) -> _Array2ND[_ScalarT_co]: ... class special_ortho_group_gen(_group_rv_gen_mixin[special_ortho_group_frozen], multi_rv_generic): ... +# pyrefly: ignore [inconsistent-inheritance] class special_ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[special_ortho_group_gen]): ... # type: ignore[misc] class ortho_group_gen(_group_rv_gen_mixin[ortho_group_frozen], multi_rv_generic): ... +# pyrefly: ignore [inconsistent-inheritance] class ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[ortho_group_gen]): ... # type: ignore[misc] class unitary_group_gen(_group_rv_gen_mixin[unitary_group_frozen, np.complex128], multi_rv_generic): ... +# pyrefly: ignore [inconsistent-inheritance] class unitary_group_frozen(_group_rv_frozen_mixin[np.complex128], multi_rv_frozen[unitary_group_gen]): ... # type: ignore[misc] class uniform_direction_gen(multi_rv_generic):