Skip to content

Add generic shape typing for multivariate frozen distributions in stats._multivariate#1549

Closed
Aniketsy wants to merge 7 commits into
scipy:masterfrom
Aniketsy:fix-406
Closed

Add generic shape typing for multivariate frozen distributions in stats._multivariate#1549
Aniketsy wants to merge 7 commits into
scipy:masterfrom
Aniketsy:fix-406

Conversation

@Aniketsy
Copy link
Copy Markdown
Contributor

@Aniketsy Aniketsy commented Apr 14, 2026

Fixes #406

@Aniketsy Aniketsy marked this pull request as draft April 14, 2026 10:15
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 14, 2026

mypy_primer results

✅ No ecosystem changes detected

@Aniketsy
Copy link
Copy Markdown
Contributor Author

ahhh 😭 , I did run locally

aniket@DESKTOP-074O80J:/mnt/d/scipy-stubs/scipy-stubs$ uv run mypy scipy-stubs
Success: no issues found in 589 source files
aniket@DESKTOP-074O80J:/mnt/d/scipy-stubs/scipy-stubs$ uv run basedpyright
0 errors, 0 warnings, 0 notes

i'll fix these ci fails-- shortly.

@jorenham
Copy link
Copy Markdown
Member

uvx tox p might help prevent some CI failures

@Aniketsy Aniketsy marked this pull request as ready for review April 18, 2026 11:47
@Aniketsy
Copy link
Copy Markdown
Contributor Author

@jorenham could you please review this, when you get chance and share some pointers.
Also i have one small suggestion, can we implement autofix, pre-commit for ruff, as previously i run ruff check it passed locally, but here i got ci-failures for ruff. (not sure if we should implement 🤷 )

I took help from llm and as per suggestion added comments (# pyrefly: ignore [inconsistent-inheritance]), please let me know if these should be reverted as we're getting ci-failures. Thanks!

@jorenham
Copy link
Copy Markdown
Member

@jorenham could you please review this, when you get chance and share some pointers.

I'll have a look soon

Also i have one small suggestion, can we implement autofix, pre-commit for ruff, as previously i run ruff check it passed locally, but here i got ci-failures for ruff. (not sure if we should implement 🤷 )

That's exactly what we use lefthook for :)
https://github.com/scipy/scipy-stubs/blob/master/CONTRIBUTING.md#lefthook

Copy link
Copy Markdown
Member

@jorenham jorenham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this isn't quite what I had in mind. Apologies for not explaining it better in the issue.

It's probably easiest to close this PR and start over, separating the issue you found in #406 (comment), and the shape-typing improvements. And it's probably easiest to do implement the shape-typing improvements one distribution at a time.
How about I open a PR myself for one of those distributions so you understand what I had mind, and leave the rest for you?

Comment on lines +34 to +38
_K = TypeVar("_K", bound=int, default=int)
_M = TypeVar("_M", bound=int, default=int)
_N = TypeVar("_N", bound=int, default=int)
_P = TypeVar("_P", bound=int, default=int)
_D = TypeVar("_D", bound=int, default=int)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NumPy does not support using subtypes of int in shape-types, so let's just stick with plain int as well.

# TODO(@jorenham): Generic shape-type for mean and cov, so that we can determine whether the methods return scalars or arrays.
# https://github.com/scipy/scipy-stubs/issues/406
class multivariate_normal_frozen(multi_rv_frozen[multivariate_normal_gen]):
class multivariate_normal_frozen(multi_rv_frozen[multivariate_normal_gen], Generic[_K]):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I just checked, and apparently multivariate_normal only supports 1d mean and 2d covariance, unlike the issue suggest. So there's no reason to make this generic after all.

Comment on lines +240 to +241
rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1,
colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's separate this fix and the shape-typing improvements. It'd otherwise be confusing if a PR is both a fix and a feature, and would fit in the changelog format.

def entropy(self, /, rowcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1, colcov: onp.ToFloat | onp.ToFloat1D | onp.ToFloat2D = 1) -> np.float64: ...

class matrix_normal_frozen(multi_rv_frozen[matrix_normal_gen]):
class matrix_normal_frozen(multi_rv_frozen[matrix_normal_gen], Generic[_M, _N]):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, apparently I was wrong here too. Maybe something changed, but with the latest scipy it looks like matrix_normal only accepts 2d mean, rowcov and colcov. So sorry for the confusion, but here too is there no need to make the shape-type generic.

#
class multinomial_frozen(multi_rv_frozen[multinomial_gen]):
def __init__(self, /, n: onp.ToJustIntND, p: onp.ToJustFloatND, seed: onp.random.ToRNG | None = None) -> None: ...
class multinomial_frozen(multi_rv_frozen[multinomial_gen], Generic[_K]):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so this should indeed be generic on its shape-type:

In [27]: multinomial(8, [0.3, 0.2, 0.5]).mean()
Out[27]: array([2.4, 1.6, 4. ])

In [28]: multinomial([8,9], [0.3, 0.2, 0.5]).mean()
Out[28]: 
array([[2.4, 1.6, 4. ],
       [2.7, 1.8, 4.5]])

In [29]: multinomial([[8,9],[6,5]], [0.3, 0.2, 0.5]).mean()
Out[29]: 
array([[[2.4, 1.6, 4. ],
        [2.7, 1.8, 4.5]],

       [[1.8, 1.2, 3. ],
        [1.5, 1. , 2.5]]])

However, the generic type parameter should be

_ShapeT_co = TypeVar("`_ShapeT_co", bound=tuple[int, ...], default=tuple[Any, ...], covariant=True)

matching the shape-type of .mean()


@type_check_only
class _group_rv_frozen_mixin(Generic[_ScalarT_co]):
__class_getitem__: ClassVar[None] = None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you remove this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as it sets to None -disables generic , please correct me if i misunderstood here.

@type_check_only
class _group_rv_frozen_mixin(Generic[_ScalarT_co]):
__class_getitem__: ClassVar[None] = None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please avoid unnecessary formatting changes like this one and the ones below

@Aniketsy
Copy link
Copy Markdown
Contributor Author

How about I open a PR myself for one of those distributions so you understand what I had mind, and leave the rest for you?

sure, please go for it, and then i can follow you ..

It's probably easiest to close this PR and start over, separating the issue you found in #406 (comment),

okay , then i'll close this one and open a separate PR for the issue i found and then follow up after your PR.

Thanks! for giving your time in review and apologies for creating extra work .

@jorenham
Copy link
Copy Markdown
Member

sure, please go for it, and then i can follow you ..

I'll get myself something to eat and then start working on it 👌

@Aniketsy
Copy link
Copy Markdown
Contributor Author

I'll get myself something to eat and then start working on it 👌

sure 😊, please ping me also once its done , I'd love to see and understand . thanks ! .... closing this PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

stats._multivariate: generic shape-types for certain *_frozen types

2 participants