Add generic shape typing for multivariate frozen distributions in stats._multivariate#1549
Add generic shape typing for multivariate frozen distributions in stats._multivariate#1549Aniketsy wants to merge 7 commits into
Conversation
|
|
ahhh 😭 , I did run locally i'll fix these ci fails-- shortly. |
|
|
|
@jorenham could you please review this, when you get chance and share some pointers. I took help from llm and as per suggestion added comments ( |
I'll have a look soon
That's exactly what we use lefthook for :) |
jorenham
left a comment
There was a problem hiding this comment.
Hmm, this isn't quite what I had in mind. Apologies for not explaining it better in the issue.
It's probably easiest to close this PR and start over, separating the issue you found in #406 (comment), and the shape-typing improvements. And it's probably easiest to do implement the shape-typing improvements one distribution at a time.
How about I open a PR myself for one of those distributions so you understand what I had mind, and leave the rest for you?
| _K = TypeVar("_K", bound=int, default=int) | ||
| _M = TypeVar("_M", bound=int, default=int) | ||
| _N = TypeVar("_N", bound=int, default=int) | ||
| _P = TypeVar("_P", bound=int, default=int) | ||
| _D = TypeVar("_D", bound=int, default=int) |
There was a problem hiding this comment.
NumPy does not support using subtypes of int in shape-types, so let's just stick with plain int as well.
| # TODO(@jorenham): Generic shape-type for mean and cov, so that we can determine whether the methods return scalars or arrays. | ||
| # https://github.com/scipy/scipy-stubs/issues/406 | ||
| class multivariate_normal_frozen(multi_rv_frozen[multivariate_normal_gen]): | ||
| class multivariate_normal_frozen(multi_rv_frozen[multivariate_normal_gen], Generic[_K]): |
There was a problem hiding this comment.
Hmm I just checked, and apparently multivariate_normal only supports 1d mean and 2d covariance, unlike the issue suggest. So there's no reason to make this generic after all.
| rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, | ||
| colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, |
There was a problem hiding this comment.
Let's separate this fix and the shape-typing improvements. It'd otherwise be confusing if a PR is both a fix and a feature, and would fit in the changelog format.
| def entropy(self, /, rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1) -> np.float64: ... | ||
|
|
||
| class matrix_normal_frozen(multi_rv_frozen[matrix_normal_gen]): | ||
| class matrix_normal_frozen(multi_rv_frozen[matrix_normal_gen], Generic[_M, _N]): |
There was a problem hiding this comment.
Well, apparently I was wrong here too. Maybe something changed, but with the latest scipy it looks like matrix_normal only accepts 2d mean, rowcov and colcov. So sorry for the confusion, but here too is there no need to make the shape-type generic.
| # | ||
| class multinomial_frozen(multi_rv_frozen[multinomial_gen]): | ||
| def __init__(self, /, n: onp.ToJustIntND, p: onp.ToJustFloatND, seed: onp.random.ToRNG | None = None) -> None: ... | ||
| class multinomial_frozen(multi_rv_frozen[multinomial_gen], Generic[_K]): |
There was a problem hiding this comment.
Ok so this should indeed be generic on its shape-type:
In [27]: multinomial(8, [0.3, 0.2, 0.5]).mean()
Out[27]: array([2.4, 1.6, 4. ])
In [28]: multinomial([8,9], [0.3, 0.2, 0.5]).mean()
Out[28]:
array([[2.4, 1.6, 4. ],
[2.7, 1.8, 4.5]])
In [29]: multinomial([[8,9],[6,5]], [0.3, 0.2, 0.5]).mean()
Out[29]:
array([[[2.4, 1.6, 4. ],
[2.7, 1.8, 4.5]],
[[1.8, 1.2, 3. ],
[1.5, 1. , 2.5]]])
However, the generic type parameter should be
_ShapeT_co = TypeVar("`_ShapeT_co", bound=tuple[int, ...], default=tuple[Any, ...], covariant=True)matching the shape-type of .mean()
|
|
||
| @type_check_only | ||
| class _group_rv_frozen_mixin(Generic[_ScalarT_co]): | ||
| __class_getitem__: ClassVar[None] = None |
There was a problem hiding this comment.
as it sets to None -disables generic , please correct me if i misunderstood here.
| @type_check_only | ||
| class _group_rv_frozen_mixin(Generic[_ScalarT_co]): | ||
| __class_getitem__: ClassVar[None] = None | ||
|
|
There was a problem hiding this comment.
please avoid unnecessary formatting changes like this one and the ones below
sure, please go for it, and then i can follow you ..
okay , then i'll close this one and open a separate PR for the issue i found and then follow up after your PR. Thanks! for giving your time in review and apologies for creating extra work . |
I'll get myself something to eat and then start working on it 👌 |
sure 😊, please ping me also once its done , I'd love to see and understand . thanks ! .... closing this PR |
Fixes #406