-
-
Notifications
You must be signed in to change notification settings - Fork 37
Add generic shape typing for multivariate frozen distributions in stats._multivariate #1549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
19c602c
0da08db
40502e2
dfa3a0c
b6cdc14
194981a
44641bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,11 @@ | |
|
|
||
| _ScalarT = TypeVar("_ScalarT", bound=np.generic, default=np.float64) | ||
| _ScalarT_co = TypeVar("_ScalarT_co", bound=np.generic, default=np.float64, covariant=True) | ||
| _K = TypeVar("_K", bound=int, default=int) | ||
| _M = TypeVar("_M", bound=int, default=int) | ||
| _N = TypeVar("_N", bound=int, default=int) | ||
| _P = TypeVar("_P", bound=int, default=int) | ||
| _D = TypeVar("_D", bound=int, default=int) | ||
|
|
||
| # TODO(@jorenham): rename as {}T_co | ||
| _RVG_co = TypeVar("_RVG_co", bound=multi_rv_generic, default=multi_rv_generic, covariant=True) | ||
|
|
@@ -178,9 +183,7 @@ | |
| allow_singular: bool = False, | ||
| ) -> multivariate_normal_frozen: ... | ||
|
|
||
| # TODO(@jorenham): Generic shape-type for mean and cov, so that we can determine whether the methods return scalars or arrays. | ||
| # https://github.com/scipy/scipy-stubs/issues/406 | ||
| class multivariate_normal_frozen(multi_rv_frozen[multivariate_normal_gen]): | ||
| class multivariate_normal_frozen(multi_rv_frozen[multivariate_normal_gen], Generic[_K]): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm I just checked, and apparently |
||
| # pyrefly: ignore [bad-override] | ||
| __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] | ||
|
|
||
|
|
@@ -190,17 +193,17 @@ | |
| abseps: Final[float] | ||
| releps: Final[float] | ||
| cov_object: Final[Covariance] | ||
| mean: onp.Array1D[np.float64] | ||
| mean: onp.Array[tuple[_K], np.float64] | ||
|
|
||
| @property | ||
| def cov(self, /) -> onp.Array2D[np.float64]: ... | ||
| def cov(self, /) -> onp.Array[tuple[_K, _K], np.float64]: ... | ||
|
|
||
| # | ||
| def __init__( | ||
| self, | ||
| /, | ||
| mean: onp.ToFloat1D | None = None, | ||
| cov: _AnyCov = 1, | ||
| mean: onp.Array[tuple[_K], npc.floating] | None = None, | ||
| cov: Covariance | onp.ToFloat | onp.Array[tuple[_K, _K], npc.floating] = 1, | ||
| allow_singular: bool = False, | ||
| seed: onp.random.ToRNG | None = None, | ||
| maxpts: onp.ToJustInt | None = None, | ||
|
|
@@ -227,15 +230,15 @@ | |
|
|
||
| # | ||
| def entropy(self, /) -> np.float64: ... | ||
| def marginal(self, dimensions: int | onp.ToInt1D) -> multivariate_normal_frozen: ... | ||
| def marginal(self, dimensions: int | onp.ToInt1D) -> multivariate_normal_frozen[int]: ... | ||
|
|
||
| class matrix_normal_gen(multi_rv_generic): | ||
| def __call__( | ||
| self, | ||
| /, | ||
| mean: onp.ToFloat2D | None = None, | ||
| rowcov: onp.ToFloat2D | onp.ToFloat = 1, | ||
| colcov: onp.ToFloat2D | onp.ToFloat = 1, | ||
| rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, | ||
| colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, | ||
|
Comment on lines
+240
to
+241
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's separate this fix and the shape-typing improvements. It'd otherwise be confusing if a PR is both a fix and a feature, and would fit in the changelog format. |
||
| seed: onp.random.ToRNG | None = None, | ||
| ) -> matrix_normal_frozen: ... | ||
|
|
||
|
|
@@ -245,16 +248,16 @@ | |
| /, | ||
| X: onp.ToFloatND, | ||
| mean: onp.ToFloat2D | None = None, | ||
| rowcov: onp.ToFloat2D | onp.ToFloat = 1, | ||
| colcov: onp.ToFloat2D | onp.ToFloat = 1, | ||
| rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, | ||
| colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, | ||
| ) -> _ScalarOrArray_f8: ... | ||
| def pdf( | ||
| self, | ||
| /, | ||
| X: onp.ToFloatND, | ||
| mean: onp.ToFloat2D | None = None, | ||
| rowcov: onp.ToFloat2D | onp.ToFloat = 1, | ||
| colcov: onp.ToFloat2D | onp.ToFloat = 1, | ||
| rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, | ||
| colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, | ||
| ) -> _ScalarOrArray_f8: ... | ||
|
|
||
| # If `size > 1` the output is 3-D, otherwise 2-D. | ||
|
|
@@ -263,8 +266,8 @@ | |
| self, | ||
| /, | ||
| mean: onp.ToFloat2D | None = None, | ||
| rowcov: onp.ToFloat2D | onp.ToFloat = 1, | ||
| colcov: onp.ToFloat2D | onp.ToFloat = 1, | ||
| rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, | ||
| colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, | ||
| size: Literal[1] = 1, | ||
| random_state: onp.random.ToRNG | None = None, | ||
| ) -> onp.Array2D[np.float64]: ... | ||
|
|
@@ -273,8 +276,8 @@ | |
| self, | ||
| /, | ||
| mean: onp.ToFloat2D | None, | ||
| rowcov: onp.ToFloat2D | onp.ToFloat, | ||
| colcov: onp.ToFloat2D | onp.ToFloat, | ||
| rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D, | ||
| colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D, | ||
| size: int, | ||
| random_state: onp.random.ToRNG | None = None, | ||
| ) -> _Array2ND[np.float64]: ... | ||
|
|
@@ -283,29 +286,30 @@ | |
| self, | ||
| /, | ||
| mean: onp.ToFloat2D | None = None, | ||
| rowcov: onp.ToFloat2D | onp.ToFloat = 1, | ||
| colcov: onp.ToFloat2D | onp.ToFloat = 1, | ||
| rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, | ||
| colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, | ||
| *, | ||
| size: int, | ||
| random_state: onp.random.ToRNG | None = None, | ||
| ) -> _Array2ND[np.float64]: ... | ||
|
|
||
| # | ||
| def entropy(self, /, rowcov: _AnyCov = 1, colcov: _AnyCov = 1) -> np.float64: ... | ||
| def entropy(self, /, rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1) -> np.float64: ... | ||
|
|
||
| class matrix_normal_frozen(multi_rv_frozen[matrix_normal_gen]): | ||
| class matrix_normal_frozen(multi_rv_frozen[matrix_normal_gen], Generic[_M, _N]): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, apparently I was wrong here too. Maybe something changed, but with the latest scipy it looks like |
||
| # pyrefly: ignore [bad-override] | ||
| __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] | ||
|
|
||
| mean: onp.Array[tuple[_M, _N], np.float64] | ||
| rowpsd: Final[_PSD] | ||
| colpsd: Final[_PSD] | ||
|
|
||
| def __init__( | ||
| self, | ||
| /, | ||
| mean: onp.ToFloat2D | None = None, | ||
| rowcov: onp.ToFloat2D | onp.ToFloat = 1, | ||
| colcov: onp.ToFloat2D | onp.ToFloat = 1, | ||
| mean: onp.Array[tuple[_M, _N], npc.floating] | None = None, | ||
| rowcov: onp.ToFloat | onp.Array[tuple[_M], npc.floating] | onp.Array[tuple[_M, _M], npc.floating] = 1, | ||
| colcov: onp.ToFloat | onp.Array[tuple[_N], npc.floating] | onp.Array[tuple[_N, _N], npc.floating] = 1, | ||
| seed: onp.random.ToRNG | None = None, | ||
| ) -> None: ... | ||
| def logpdf(self, /, X: onp.ToFloatND) -> _ScalarOrArray_f8: ... | ||
|
|
@@ -597,21 +601,23 @@ | |
| ) -> _Array2ND: ... | ||
|
|
||
| # | ||
| class multinomial_frozen(multi_rv_frozen[multinomial_gen]): | ||
| def __init__(self, /, n: onp.ToJustIntND, p: onp.ToJustFloatND, seed: onp.random.ToRNG | None = None) -> None: ... | ||
| class multinomial_frozen(multi_rv_frozen[multinomial_gen], Generic[_K]): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok so this should indeed be generic on its shape-type: However, the generic type parameter should be _ShapeT_co = TypeVar("`_ShapeT_co", bound=tuple[int, ...], default=tuple[Any, ...], covariant=True)matching the shape-type of |
||
| def __init__( | ||
| self, /, n: onp.ToJustInt, p: onp.Array[tuple[_K], npc.floating], seed: onp.random.ToRNG | None = None | ||
| ) -> None: ... | ||
|
|
||
| # | ||
| def logpmf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ... | ||
| def pmf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ... | ||
|
|
||
| # | ||
| def mean(self, /) -> _Array1ND: ... | ||
| def cov(self, /) -> _Array2ND: ... | ||
| def mean(self, /) -> onp.Array[tuple[_K, *tuple[Any, ...]], np.float64]: ... | ||
| def cov(self, /) -> onp.Array[tuple[_K, _K, *tuple[Any, ...]], np.float64]: ... # pyright: ignore[reportInvalidTypeForm] | ||
| def entropy(self, /) -> _ScalarOrArray_f8: ... | ||
|
|
||
| # | ||
| @overload | ||
| def rvs(self, /, size: tuple[()], random_state: onp.random.ToRNG | None = None) -> _Array1ND: ... | ||
| def rvs(self, /, size: tuple[()], random_state: onp.random.ToRNG | None = None) -> onp.Array[tuple[_K], np.float64]: ... | ||
| @overload | ||
| def rvs(self, /, size: onp.AtLeast1D | int = 1, random_state: onp.random.ToRNG | None = None) -> _Array2ND: ... | ||
|
|
||
|
|
@@ -630,8 +636,6 @@ | |
|
|
||
| @type_check_only | ||
| class _group_rv_frozen_mixin(Generic[_ScalarT_co]): | ||
| __class_getitem__: ClassVar[None] = None | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why did you remove this?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as it sets to None -disables generic , please correct me if i misunderstood here. |
||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please avoid unnecessary formatting changes like this one and the ones below |
||
| dim: Final[int] | ||
|
|
||
| # NOTE: Contrary to what the `dim` default suggests, it is required. | ||
|
|
@@ -644,15 +648,12 @@ | |
| def rvs(self, /, size: int, random_state: onp.random.ToRNG | None = None) -> _Array2ND[_ScalarT_co]: ... | ||
|
|
||
| class special_ortho_group_gen(_group_rv_gen_mixin[special_ortho_group_frozen], multi_rv_generic): ... | ||
|
|
||
| # pyrefly: ignore [inconsistent-inheritance] | ||
| class special_ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[special_ortho_group_gen]): ... # type: ignore[misc] | ||
| class ortho_group_gen(_group_rv_gen_mixin[ortho_group_frozen], multi_rv_generic): ... | ||
|
|
||
| # pyrefly: ignore [inconsistent-inheritance] | ||
| class ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[ortho_group_gen]): ... # type: ignore[misc] | ||
| class unitary_group_gen(_group_rv_gen_mixin[unitary_group_frozen, np.complex128], multi_rv_generic): ... | ||
|
|
||
| # pyrefly: ignore [inconsistent-inheritance] | ||
| class unitary_group_frozen(_group_rv_frozen_mixin[np.complex128], multi_rv_frozen[unitary_group_gen]): ... # type: ignore[misc] | ||
|
|
||
|
|
@@ -669,16 +670,18 @@ | |
| @overload | ||
| def rvs(self, /, dim: int, size: onp.AtLeast2D, random_state: onp.random.ToRNG | None = None) -> _Array3ND[np.float64]: ... | ||
|
|
||
| class uniform_direction_frozen(multi_rv_frozen[uniform_direction_gen]): | ||
| class uniform_direction_frozen(multi_rv_frozen[uniform_direction_gen], Generic[_D]): | ||
| dim: Final[int] | ||
|
|
||
| def __init__(self, /, dim: int | None = None, seed: onp.random.ToRNG | None = None) -> None: ... | ||
|
|
||
| # | ||
| @overload | ||
| def rvs(self, /, size: None = None, random_state: onp.random.ToRNG | None = None) -> onp.Array1D[np.float64]: ... | ||
| def rvs(self, /, size: None = None, random_state: onp.random.ToRNG | None = None) -> onp.Array[tuple[_D], np.float64]: ... | ||
| @overload | ||
| def rvs(self, /, size: int | tuple[int], random_state: onp.random.ToRNG | None = None) -> onp.Array2D[np.float64]: ... | ||
| def rvs( | ||
| self, /, size: int | tuple[int], random_state: onp.random.ToRNG | None = None | ||
| ) -> onp.Array[tuple[int, _D], np.float64]: ... | ||
| @overload | ||
| def rvs(self, /, size: onp.AtLeast2D, random_state: onp.random.ToRNG | None = None) -> _Array2ND[np.float64]: ... | ||
|
|
||
|
|
@@ -827,21 +830,21 @@ | |
| allow_singular: bool = False, | ||
| ) -> multivariate_t_frozen: ... | ||
|
|
||
| class multivariate_t_frozen(multi_rv_frozen[multivariate_t_gen]): | ||
| class multivariate_t_frozen(multi_rv_frozen[multivariate_t_gen], Generic[_P]): | ||
| # pyrefly: ignore [bad-override] | ||
| __class_getitem__: ClassVar[None] = None # type:ignore[assignment] # pyright:ignore[reportIncompatibleMethodOverride] | ||
|
|
||
| dim: Final[int] | ||
| df: Final[int] | ||
| loc: Final[onp.Array1D[np.float64]] | ||
| shape: Final[onp.Array2D[np.float64]] | ||
| loc: onp.Array[tuple[_P], np.float64] | ||
| shape: onp.Array[tuple[_P, _P], np.float64] | ||
| shape_info: Final[_PSD] | ||
|
|
||
| def __init__( | ||
| self, | ||
| /, | ||
| loc: onp.ToFloat1D | None = None, | ||
| shape: onp.ToFloat | onp.ToFloat2D = 1, | ||
| loc: onp.Array[tuple[_P], npc.floating] | None = None, | ||
| shape: onp.ToFloat | onp.Array[tuple[_P, _P], npc.floating] = 1, | ||
| df: int = 1, | ||
| allow_singular: bool = False, | ||
| seed: onp.random.ToRNG | None = None, | ||
|
|
@@ -872,7 +875,7 @@ | |
| def rvs(self, /, size: onp.AtLeast2D, random_state: onp.random.ToRNG | None = None) -> _Array3ND: ... | ||
|
|
||
| # | ||
| def marginal(self, dimensions: int | onp.ToInt1D) -> multivariate_t_frozen: ... | ||
| def marginal(self, dimensions: int | onp.ToInt1D) -> multivariate_t_frozen[int]: ... | ||
|
|
||
| # NOTE: `m` and `n` are broadcastable (but doing so will break `.rvs()` at runtime...) | ||
| class multivariate_hypergeom_gen(multi_rv_generic): | ||
|
|
@@ -898,15 +901,17 @@ | |
| random_state: onp.random.ToRNG | None = None, | ||
| ) -> _Array2ND: ... | ||
|
|
||
| class multivariate_hypergeom_frozen(multi_rv_frozen[multivariate_hypergeom_gen]): | ||
| def __init__(self, /, m: onp.ToIntND, n: onp.ToJustInt | onp.ToJustIntND, seed: onp.random.ToRNG | None = None) -> None: ... | ||
| class multivariate_hypergeom_frozen(multi_rv_frozen[multivariate_hypergeom_gen], Generic[_K]): | ||
| def __init__( | ||
| self, /, m: onp.Array[tuple[_K], npc.integer], n: onp.ToJustInt, seed: onp.random.ToRNG | None = None | ||
| ) -> None: ... | ||
| def logpmf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ... | ||
| def pmf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ... | ||
| def mean(self, /) -> _Array1ND: ... | ||
| def var(self, /) -> _Array1ND: ... | ||
| def cov(self, /) -> _Array2ND: ... | ||
| def mean(self, /) -> onp.Array[tuple[_K, *tuple[Any, ...]], np.float64]: ... | ||
| def var(self, /) -> onp.Array[tuple[_K, *tuple[Any, ...]], np.float64]: ... | ||
| def cov(self, /) -> onp.Array[tuple[_K, _K, *tuple[Any, ...]], np.float64]: ... # pyright: ignore[reportInvalidTypeForm] | ||
| @overload | ||
| def rvs(self, /, size: tuple[()], random_state: onp.random.ToRNG | None = None) -> _Array1ND: ... | ||
| def rvs(self, /, size: tuple[()], random_state: onp.random.ToRNG | None = None) -> onp.Array[tuple[_K], np.float64]: ... | ||
| @overload | ||
| def rvs(self, /, size: int | tuple[int] = 1, random_state: onp.random.ToRNG | None = None) -> _Array2ND: ... | ||
|
|
||
|
|
@@ -996,11 +1001,13 @@ | |
| def var(self, /, alpha: onp.ToFloatND, n: onp.ToJustIntND) -> _Array1ND: ... | ||
| def cov(self, /, alpha: onp.ToFloatND, n: onp.ToJustIntND) -> _Array2ND: ... | ||
|
|
||
| class dirichlet_multinomial_frozen(multi_rv_frozen[dirichlet_multinomial_gen]): | ||
| alpha: _Array1ND | ||
| n: _Array1ND[np.int_] # broadcasted against alpha | ||
| class dirichlet_multinomial_frozen(multi_rv_frozen[dirichlet_multinomial_gen], Generic[_K]): | ||
| alpha: onp.Array[tuple[_K], np.float64] | ||
| n: onp.ArrayND[np.int_] | ||
|
|
||
| def __init__(self, /, alpha: onp.ToFloatND, n: onp.ToJustIntND, seed: onp.random.ToRNG | None = None) -> None: ... | ||
| def __init__( | ||
| self, /, alpha: onp.Array[tuple[_K], npc.floating], n: onp.ToJustIntND, seed: onp.random.ToRNG | None = None | ||
| ) -> None: ... | ||
| def logpmf(self, /, x: onp.ToIntND) -> _ScalarOrArray_f8: ... | ||
| def pmf(self, /, x: onp.ToIntND) -> _ScalarOrArray_f8: ... | ||
| def mean(self, /) -> _Array1ND: ... | ||
|
|
@@ -1024,8 +1031,10 @@ | |
| ) -> _Array2ND: ... | ||
| def fit(self, /, x: onp.ToFloatND) -> tuple[onp.Array1D[np.float64], float]: ... | ||
|
|
||
| class vonmises_fisher_frozen(multi_rv_frozen[vonmises_fisher_gen]): | ||
| def __init__(self, /, mu: onp.ToFloat1D | None = None, kappa: int = 1, seed: onp.random.ToRNG | None = None) -> None: ... | ||
| class vonmises_fisher_frozen(multi_rv_frozen[vonmises_fisher_gen], Generic[_D]): | ||
| def __init__( | ||
| self, /, mu: onp.Array[tuple[_D], npc.floating] | None = None, kappa: int = 1, seed: onp.random.ToRNG | None = None | ||
| ) -> None: ... | ||
| def logpdf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ... | ||
| def pdf(self, /, x: onp.ToFloatND) -> _ScalarOrArray_f8: ... | ||
| def entropy(self, /) -> np.float64: ... | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NumPy does not support using subtypes of
intin shape-types, so let's just stick with plainintas well.