-
Notifications
You must be signed in to change notification settings - Fork 3.4k
feat: Add weights parameter to quantile and percentile functions (#32647) #34084
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @C-you-know, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the statistical capabilities of JAX's NumPy API by introducing a Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a weights parameter to the quantile, percentile, nanquantile, and nanpercentile functions, enabling weighted calculations. The implementation adds a new inverted_cdf method to support this functionality and includes corresponding tests. While the changes are generally well-implemented, there's a critical issue with mismatched arguments in the internal _quantile function calls and a minor issue with an f-string in an error message. Addressing these will improve the correctness and robustness of the new feature.
|
Gentle ping @jakevdp 🙂 |
jakevdp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really great on a first read-through, thanks!
A couple small comments below. I'll trigger the tests to see if they reveal any issues with the implementation.
jax/_src/numpy/reductions.py
Outdated
| curr_out_axis = q_ndim | ||
| for i in range(a.ndim): | ||
| if i == axis: | ||
| index.append(idx) | ||
| if keepdims: | ||
| curr_out_axis += 1 | ||
| else: | ||
| index.append(lax.broadcasted_iota(np.int32, out_shape, curr_out_axis)) | ||
| curr_out_axis += 1 | ||
| result = a[tuple(index)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a very ineficient way to compute the result. I think you could use empty slices in place of the broadcasted iotas to do this more simply.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried using empty slices here but it does not work cleanly because idx already encodes the non-axis dimensions. Combining it with slice(None) causes those dimensions to be included twice under advanced indexing. I used take_along_axis to keep the indexing explicit and avoid that duplication
|
Hi! I looked into the failing CI cases. The failures in testWeightedQuantile10 happen during the NumPy reference computation rather than in the JAX implementation itself. np.nanpercentile emits a RuntimeWarning: All-NaN slice encountered on some randomized inputs; JAX returns the same NaN result but doesn’t emit a warning. _CheckAgainstNumpy treats the NumPy warning as a test error, so the comparison never reaches the output check. |
|
JAX tests are deliberately configured to error when there are unexpected warnings. If you want to explicitly ignore a warning within a test, you can use the |
088e9c2 to
fe49c09
Compare
|
Hi @jakevdp, I have rebased the PR onto the latest upstream, fixed the failing test cases and made a small naming cleanup in the implementation. All tests pass locally now. Thanks for taking another look. |
jakevdp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! A few more changes below.
jax/_src/numpy/reductions.py
Outdated
| @api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims', 'method')) | ||
| def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, | ||
| out: None = None, overwrite_input: bool = False, method: str = "linear", | ||
| weights: ArrayLike | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weights should be a keyword-only parameter, as in np.quantile. Same for the other exported functions below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. Updated.
| raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest' or 'inverted_cdf'") | ||
| if weights is not None: | ||
| weights = ensure_arraylike("_quantile", weights) | ||
| weights = lax.asarray(weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove this; ensure_arraylike already takes care of array conversion. If you move ensure_arraylike to the calling function, then you could make the annotation Array | None and then this wouldn't be required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to clarify this point based on the current structure of quantile and the related nanquantile, percentile and nanpercentile functions. The calling functions already apply ensure_arraylike to an and q but still pass both through lax.asarray before calling _quantile. since, ensure_arraylike also performs array conversion I was not sure whether this duplication is intentional for consistency.
jax/_src/numpy/reductions.py
Outdated
| if method not in ["linear", "lower", "higher", "midpoint", "nearest", "inverted_cdf"]: | ||
| raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest' or 'inverted_cdf'") | ||
| if weights is not None: | ||
| weights = ensure_arraylike("_quantile", weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name here is used to generate a user-visible message, and _quantile is an internal name.
Instead, we should pass weights to ensure_arraylike in the calling functions (alternatively we could modify _quantile to take the name of the calling function as a parameter, but that seems less clean).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated accordingly
jax/_src/numpy/reductions.py
Outdated
| if weights.shape != a.shape: | ||
| if axis is None: | ||
| raise ValueError("Weights shape must match 'a' shape when axis is None.") | ||
| ax_tuple = (axis,) if isinstance(axis, int) else tuple(axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use util.canonicalize_axis_tuple for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated accordingly
jax/_src/numpy/reductions.py
Outdated
| keepdim = [1] * a.ndim | ||
| a = a.ravel() | ||
| if weights is not None: | ||
| weights = weights.ravel() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
two-space indentation please. Also in a few places below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry. Updated the code to follow two space indentation.
jax/_src/numpy/reductions.py
Outdated
| result = lax.mul(lax.add(low_value, high_value), lax._const(low_value, 0.5)) | ||
| elif method == "inverted_cdf": | ||
| if weights is None: | ||
| weights = lax.full(a.shape, 1.0, dtype=a.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lax.full_like(a, 1.0) is a bit more concise, and will propagate sharding information when relevant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
jax/_src/numpy/reductions.py
Outdated
| else: | ||
| valid_counts = lax.full_like(total_weight, a_shape[axis], dtype=q.dtype) | ||
| limit = lax.sub(valid_counts, lax._const(valid_counts, 1)) | ||
| max_idx = lax.convert_element_type(limit, np.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use dtypes.default_int_dtype() here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. Updated.
| max_idx = lax.convert_element_type(limit, np.int32) | ||
| max_idx_f = lax.expand_dims(max_idx, tuple(range(q_ndim))) | ||
| max_idx_f = lax.convert_element_type(max_idx_f, idx.dtype) | ||
| idx = lax.max(lax._const(idx, 0), lax.min(idx, max_idx_f)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could use lax.clamp here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I considered using lax.clamp here but it would require dynamically broadcasting max_idx_f to match the shape of idx. I kept the current approach to avoid introducing that additional broadcasting logic. I am happy to switch to lax.clamp if you would prefer that for consistency
jax/_src/numpy/reductions.py
Outdated
| result = indexing.take_along_axis(a, idx_transposed, axis=axis) | ||
| result = lax.squeeze(result, (axis,)) | ||
| else: | ||
| perm = (list(range(q_ndim, q_ndim + axis)) + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generalized unpacking would be a bit clearer here I think:
perm = [*range(q_ndim, q_ndim + axis),
*range(q_ndim),
*range(q_ndim + axis, idx_take.ndim)]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. Updated.
|
I have addressed the review comments and left inline replies where relevant. Please let me know if I missed anything |
a9742ce to
8723c42
Compare

Adds weights parameter to quantile, percentile, nanquantile, and nanpercentile functions for weighted calculations.
Added tests to verify the implementation and ensured existing functionality remains unchanged.
Documentation will be added pending code review.
Happy to make changes and iterate based on feedback!
Fixes #32647