Skip to content
Closed
121 changes: 65 additions & 56 deletions scipy-stubs/stats/_multivariate.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@

_ScalarT = TypeVar("_ScalarT", bound=np.generic, default=np.float64)
_ScalarT_co = TypeVar("_ScalarT_co", bound=np.generic, default=np.float64, covariant=True)
_K = TypeVar("_K", bound=int, default=int)
_M = TypeVar("_M", bound=int, default=int)
_N = TypeVar("_N", bound=int, default=int)
_P = TypeVar("_P", bound=int, default=int)
_D = TypeVar("_D", bound=int, default=int)
Comment on lines +34 to +38

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NumPy does not support using subtypes of int in shape-types, so let's just stick with plain int as well.


# TODO(@jorenham): rename as {}T_co
_RVG_co = TypeVar("_RVG_co", bound=multi_rv_generic, default=multi_rv_generic, covariant=True)
Expand Down Expand Up @@ -178,9 +183,7 @@
allow_singular: bool = False,
) -> multivariate_normal_frozen: ...

# TODO(@jorenham): Generic shape-type for mean and cov, so that we can determine whether the methods return scalars or arrays.
# https://github.com/scipy/scipy-stubs/issues/406
class multivariate_normal_frozen(multi_rv_frozen[multivariate_normal_gen]):
class multivariate_normal_frozen(multi_rv_frozen[multivariate_normal_gen], Generic[_K]):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I just checked, and apparently multivariate_normal only supports 1d mean and 2d covariance, unlike the issue suggest. So there's no reason to make this generic after all.

# pyrefly: ignore [bad-override]
__class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride]

Expand All @@ -190,17 +193,17 @@
abseps: Final[float]
releps: Final[float]
cov_object: Final[Covariance]
mean: onp.Array1D[np.float64]
mean: onp.Array[tuple[_K], np.float64]

@property
def cov(self, /) -> onp.Array2D[np.float64]: ...
def cov(self, /) -> onp.Array[tuple[_K, _K], np.float64]: ...

#
def __init__(
self,
/,
mean: onp.ToFloat1D | None = None,
cov: _AnyCov = 1,
mean: onp.Array[tuple[_K], npc.floating] | None = None,
cov: Covariance | onp.ToFloat | onp.Array[tuple[_K, _K], npc.floating] = 1,
allow_singular: bool = False,
seed: onp.random.ToRNG | None = None,
maxpts: onp.ToJustInt | None = None,
Expand All @@ -227,15 +230,15 @@

#
def entropy(self, /) -> np.float64: ...
def marginal(self, dimensions: int | onp.ToInt1D) -> multivariate_normal_frozen: ...
def marginal(self, dimensions: int | onp.ToInt1D) -> multivariate_normal_frozen[int]: ...

class matrix_normal_gen(multi_rv_generic):
def __call__(
self,
/,
mean: onp.ToFloat2D | None = None,
rowcov: onp.ToFloat2D | onp.ToFloat = 1,
colcov: onp.ToFloat2D | onp.ToFloat = 1,
rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1,
colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1,
Comment on lines +240 to +241

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's separate this fix and the shape-typing improvements. It'd otherwise be confusing if a PR is both a fix and a feature, and would fit in the changelog format.

seed: onp.random.ToRNG | None = None,
) -> matrix_normal_frozen: ...

Expand All @@ -245,16 +248,16 @@
/,
X: onp.ToFloatND,
mean: onp.ToFloat2D | None = None,
rowcov: onp.ToFloat2D | onp.ToFloat = 1,
colcov: onp.ToFloat2D | onp.ToFloat = 1,
rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1,
colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1,
) -> _ScalarOrArray_f8: ...
def pdf(
self,
/,
X: onp.ToFloatND,
mean: onp.ToFloat2D | None = None,
rowcov: onp.ToFloat2D | onp.ToFloat = 1,
colcov: onp.ToFloat2D | onp.ToFloat = 1,
rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1,
colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1,
) -> _ScalarOrArray_f8: ...

# If `size > 1` the output is 3-D, otherwise 2-D.
Expand All @@ -263,8 +266,8 @@
self,
/,
mean: onp.ToFloat2D | None = None,
rowcov: onp.ToFloat2D | onp.ToFloat = 1,
colcov: onp.ToFloat2D | onp.ToFloat = 1,
rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1,
colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1,
size: Literal[1] = 1,
random_state: onp.random.ToRNG | None = None,
) -> onp.Array2D[np.float64]: ...
Expand All @@ -273,8 +276,8 @@
self,
/,
mean: onp.ToFloat2D | None,
rowcov: onp.ToFloat2D | onp.ToFloat,
colcov: onp.ToFloat2D | onp.ToFloat,
rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D,
colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D,
size: int,
random_state: onp.random.ToRNG | None = None,
) -> _Array2ND[np.float64]: ...
Expand All @@ -283,29 +286,30 @@
self,
/,
mean: onp.ToFloat2D | None = None,
rowcov: onp.ToFloat2D | onp.ToFloat = 1,
colcov: onp.ToFloat2D | onp.ToFloat = 1,
rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1,
colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1,
*,
size: int,
random_state: onp.random.ToRNG | None = None,
) -> _Array2ND[np.float64]: ...

#

Check failure on line 296 in scipy-stubs/stats/_multivariate.pyi

View workflow job for this annotation

GitHub Actions / lint

ruff (E501)

scipy-stubs/stats/_multivariate.pyi:296:131: E501 Line too long (157 > 130)
def entropy(self, /, rowcov: _AnyCov = 1, colcov: _AnyCov = 1) -> np.float64: ...
def entropy(self, /, rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1) -> np.float64: ...

class matrix_normal_frozen(multi_rv_frozen[matrix_normal_gen]):
class matrix_normal_frozen(multi_rv_frozen[matrix_normal_gen], Generic[_M, _N]):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, apparently I was wrong here too. Maybe something changed, but with the latest scipy it looks like matrix_normal only accepts 2d mean, rowcov and colcov. So sorry for the confusion, but here too is there no need to make the shape-type generic.

# pyrefly: ignore [bad-override]
__class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride]

mean: onp.Array[tuple[_M, _N], np.float64]
rowpsd: Final[_PSD]
colpsd: Final[_PSD]

def __init__(
self,
/,
mean: onp.ToFloat2D | None = None,
rowcov: onp.ToFloat2D | onp.ToFloat = 1,
colcov: onp.ToFloat2D | onp.ToFloat = 1,
mean: onp.Array[tuple[_M, _N], npc.floating] | None = None,
rowcov: onp.ToFloat | onp.Array[tuple[_M], npc.floating] | onp.Array[tuple[_M, _M], npc.floating] = 1,
colcov: onp.ToFloat | onp.Array[tuple[_N], npc.floating] | onp.Array[tuple[_N, _N], npc.floating] = 1,
seed: onp.random.ToRNG | None = None,
) -> None: ...
def logpdf(self, /, X: onp.ToFloatND) -> _ScalarOrArray_f8: ...
Expand Down Expand Up @@ -597,21 +601,23 @@
) -> _Array2ND: ...

#
class multinomial_frozen(multi_rv_frozen[multinomial_gen]):
def __init__(self, /, n: onp.ToJustIntND, p: onp.ToJustFloatND, seed: onp.random.ToRNG | None = None) -> None: ...
class multinomial_frozen(multi_rv_frozen[multinomial_gen], Generic[_K]):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so this should indeed be generic on its shape-type:

In [27]: multinomial(8, [0.3, 0.2, 0.5]).mean()
Out[27]: array([2.4, 1.6, 4. ])

In [28]: multinomial([8,9], [0.3, 0.2, 0.5]).mean()
Out[28]: 
array([[2.4, 1.6, 4. ],
       [2.7, 1.8, 4.5]])

In [29]: multinomial([[8,9],[6,5]], [0.3, 0.2, 0.5]).mean()
Out[29]: 
array([[[2.4, 1.6, 4. ],
        [2.7, 1.8, 4.5]],

       [[1.8, 1.2, 3. ],
        [1.5, 1. , 2.5]]])

However, the generic type parameter should be

_ShapeT_co = TypeVar("`_ShapeT_co", bound=tuple[int, ...], default=tuple[Any, ...], covariant=True)

matching the shape-type of .mean()

def __init__(
self, /, n: onp.ToJustInt, p: onp.Array[tuple[_K], npc.floating], seed: onp.random.ToRNG | None = None
) -> None: ...

#
def logpmf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ...
def pmf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ...

#
def mean(self, /) -> _Array1ND: ...
def cov(self, /) -> _Array2ND: ...
def mean(self, /) -> onp.Array[tuple[_K, *tuple[Any, ...]], np.float64]: ...
def cov(self, /) -> onp.Array[tuple[_K, _K, *tuple[Any, ...]], np.float64]: ... # pyright: ignore[reportInvalidTypeForm]
def entropy(self, /) -> _ScalarOrArray_f8: ...

#
@overload
def rvs(self, /, size: tuple[()], random_state: onp.random.ToRNG | None = None) -> _Array1ND: ...
def rvs(self, /, size: tuple[()], random_state: onp.random.ToRNG | None = None) -> onp.Array[tuple[_K], np.float64]: ...
@overload
def rvs(self, /, size: onp.AtLeast1D | int = 1, random_state: onp.random.ToRNG | None = None) -> _Array2ND: ...

Expand All @@ -630,8 +636,6 @@

@type_check_only
class _group_rv_frozen_mixin(Generic[_ScalarT_co]):
__class_getitem__: ClassVar[None] = None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you remove this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as it sets to None -disables generic , please correct me if i misunderstood here.


Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please avoid unnecessary formatting changes like this one and the ones below

dim: Final[int]

# NOTE: Contrary to what the `dim` default suggests, it is required.
Expand All @@ -644,15 +648,12 @@
def rvs(self, /, size: int, random_state: onp.random.ToRNG | None = None) -> _Array2ND[_ScalarT_co]: ...

class special_ortho_group_gen(_group_rv_gen_mixin[special_ortho_group_frozen], multi_rv_generic): ...

# pyrefly: ignore [inconsistent-inheritance]
class special_ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[special_ortho_group_gen]): ... # type: ignore[misc]
class ortho_group_gen(_group_rv_gen_mixin[ortho_group_frozen], multi_rv_generic): ...

# pyrefly: ignore [inconsistent-inheritance]
class ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[ortho_group_gen]): ... # type: ignore[misc]
class unitary_group_gen(_group_rv_gen_mixin[unitary_group_frozen, np.complex128], multi_rv_generic): ...

# pyrefly: ignore [inconsistent-inheritance]
class unitary_group_frozen(_group_rv_frozen_mixin[np.complex128], multi_rv_frozen[unitary_group_gen]): ... # type: ignore[misc]

Expand All @@ -669,16 +670,18 @@
@overload
def rvs(self, /, dim: int, size: onp.AtLeast2D, random_state: onp.random.ToRNG | None = None) -> _Array3ND[np.float64]: ...

class uniform_direction_frozen(multi_rv_frozen[uniform_direction_gen]):
class uniform_direction_frozen(multi_rv_frozen[uniform_direction_gen], Generic[_D]):
dim: Final[int]

def __init__(self, /, dim: int | None = None, seed: onp.random.ToRNG | None = None) -> None: ...

#
@overload
def rvs(self, /, size: None = None, random_state: onp.random.ToRNG | None = None) -> onp.Array1D[np.float64]: ...
def rvs(self, /, size: None = None, random_state: onp.random.ToRNG | None = None) -> onp.Array[tuple[_D], np.float64]: ...
@overload
def rvs(self, /, size: int | tuple[int], random_state: onp.random.ToRNG | None = None) -> onp.Array2D[np.float64]: ...
def rvs(
self, /, size: int | tuple[int], random_state: onp.random.ToRNG | None = None
) -> onp.Array[tuple[int, _D], np.float64]: ...
@overload
def rvs(self, /, size: onp.AtLeast2D, random_state: onp.random.ToRNG | None = None) -> _Array2ND[np.float64]: ...

Expand Down Expand Up @@ -827,21 +830,21 @@
allow_singular: bool = False,
) -> multivariate_t_frozen: ...

class multivariate_t_frozen(multi_rv_frozen[multivariate_t_gen]):
class multivariate_t_frozen(multi_rv_frozen[multivariate_t_gen], Generic[_P]):
# pyrefly: ignore [bad-override]
__class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride]

dim: Final[int]
df: Final[int]
loc: Final[onp.Array1D[np.float64]]
shape: Final[onp.Array2D[np.float64]]
loc: onp.Array[tuple[_P], np.float64]
shape: onp.Array[tuple[_P, _P], np.float64]
shape_info: Final[_PSD]

def __init__(
self,
/,
loc: onp.ToFloat1D | None = None,
shape: onp.ToFloat | onp.ToFloat2D = 1,
loc: onp.Array[tuple[_P], npc.floating] | None = None,
shape: onp.ToFloat | onp.Array[tuple[_P, _P], npc.floating] = 1,
df: int = 1,
allow_singular: bool = False,
seed: onp.random.ToRNG | None = None,
Expand Down Expand Up @@ -872,7 +875,7 @@
def rvs(self, /, size: onp.AtLeast2D, random_state: onp.random.ToRNG | None = None) -> _Array3ND: ...

#
def marginal(self, dimensions: int | onp.ToInt1D) -> multivariate_t_frozen: ...
def marginal(self, dimensions: int | onp.ToInt1D) -> multivariate_t_frozen[int]: ...

# NOTE: `m` and `n` are broadcastable (but doing so will break `.rvs()` at runtime...)
class multivariate_hypergeom_gen(multi_rv_generic):
Expand All @@ -898,15 +901,17 @@
random_state: onp.random.ToRNG | None = None,
) -> _Array2ND: ...

class multivariate_hypergeom_frozen(multi_rv_frozen[multivariate_hypergeom_gen]):
def __init__(self, /, m: onp.ToIntND, n: onp.ToJustInt | onp.ToJustIntND, seed: onp.random.ToRNG | None = None) -> None: ...
class multivariate_hypergeom_frozen(multi_rv_frozen[multivariate_hypergeom_gen], Generic[_K]):
def __init__(
self, /, m: onp.Array[tuple[_K], npc.integer], n: onp.ToJustInt, seed: onp.random.ToRNG | None = None
) -> None: ...
def logpmf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ...
def pmf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ...
def mean(self, /) -> _Array1ND: ...
def var(self, /) -> _Array1ND: ...
def cov(self, /) -> _Array2ND: ...
def mean(self, /) -> onp.Array[tuple[_K, *tuple[Any, ...]], np.float64]: ...
def var(self, /) -> onp.Array[tuple[_K, *tuple[Any, ...]], np.float64]: ...
def cov(self, /) -> onp.Array[tuple[_K, _K, *tuple[Any, ...]], np.float64]: ... # pyright: ignore[reportInvalidTypeForm]
@overload
def rvs(self, /, size: tuple[()], random_state: onp.random.ToRNG | None = None) -> _Array1ND: ...
def rvs(self, /, size: tuple[()], random_state: onp.random.ToRNG | None = None) -> onp.Array[tuple[_K], np.float64]: ...
@overload
def rvs(self, /, size: int | tuple[int] = 1, random_state: onp.random.ToRNG | None = None) -> _Array2ND: ...

Expand Down Expand Up @@ -996,11 +1001,13 @@
def var(self, /, alpha: onp.ToFloatND, n: onp.ToJustIntND) -> _Array1ND: ...
def cov(self, /, alpha: onp.ToFloatND, n: onp.ToJustIntND) -> _Array2ND: ...

class dirichlet_multinomial_frozen(multi_rv_frozen[dirichlet_multinomial_gen]):
alpha: _Array1ND
n: _Array1ND[np.int_] # broadcasted against alpha
class dirichlet_multinomial_frozen(multi_rv_frozen[dirichlet_multinomial_gen], Generic[_K]):
alpha: onp.Array[tuple[_K], np.float64]
n: onp.ArrayND[np.int_]

def __init__(self, /, alpha: onp.ToFloatND, n: onp.ToJustIntND, seed: onp.random.ToRNG | None = None) -> None: ...
def __init__(
self, /, alpha: onp.Array[tuple[_K], npc.floating], n: onp.ToJustIntND, seed: onp.random.ToRNG | None = None
) -> None: ...
def logpmf(self, /, x: onp.ToIntND) -> _ScalarOrArray_f8: ...
def pmf(self, /, x: onp.ToIntND) -> _ScalarOrArray_f8: ...
def mean(self, /) -> _Array1ND: ...
Expand All @@ -1024,8 +1031,10 @@
) -> _Array2ND: ...
def fit(self, /, x: onp.ToFloatND) -> tuple[onp.Array1D[np.float64], float]: ...

class vonmises_fisher_frozen(multi_rv_frozen[vonmises_fisher_gen]):
def __init__(self, /, mu: onp.ToFloat1D | None = None, kappa: int = 1, seed: onp.random.ToRNG | None = None) -> None: ...
class vonmises_fisher_frozen(multi_rv_frozen[vonmises_fisher_gen], Generic[_D]):
def __init__(
self, /, mu: onp.Array[tuple[_D], npc.floating] | None = None, kappa: int = 1, seed: onp.random.ToRNG | None = None
) -> None: ...
def logpdf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ...
def pdf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ...
def entropy(self, /) -> np.float64: ...
Expand Down
Loading