Skip to content

Conversation

@C-you-know
Copy link

Adds weights parameter to quantile, percentile, nanquantile, and nanpercentile functions for weighted calculations.
Added tests to verify the implementation and ensured existing functionality remains unchanged.
Documentation will be added pending code review.
Happy to make changes and iterate based on feedback!
Fixes #32647

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @C-you-know, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the statistical capabilities of JAX's NumPy API by introducing a weights parameter to the quantile, percentile, nanquantile, and nanpercentile functions. This allows users to perform weighted calculations, which is essential for scenarios where data points have varying importance. The implementation includes a new inverted_cdf method designed to handle these weighted computations, alongside robust validation for weights and comprehensive test coverage.

Highlights

  • Weighted Quantile/Percentile Support: The quantile, percentile, nanquantile, and nanpercentile functions now accept an optional weights parameter, enabling weighted calculations.
  • New Calculation Method: A new method called inverted_cdf has been introduced, which is specifically designed to handle weighted quantile calculations. Other methods do not currently support weights.
  • Robust Weight Handling: The implementation includes comprehensive logic for validating weights (e.g., non-complex types, shape matching), promoting dtypes, and correctly managing NaNs and invalid weights during computation.
  • Test Coverage: New tests have been added to verify the correctness of the weighted quantile and percentile functions across various scenarios, ensuring existing functionality remains stable.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a weights parameter to the quantile, percentile, nanquantile, and nanpercentile functions, enabling weighted calculations. The implementation adds a new inverted_cdf method to support this functionality and includes corresponding tests. While the changes are generally well-implemented, there's a critical issue with mismatched arguments in the internal _quantile function calls and a minor issue with an f-string in an error message. Addressing these will improve the correctness and robustness of the new feature.

@C-you-know
Copy link
Author

Screenshot 2025-12-22 at 4 41 56 PM

Passing all test cases!
Should I add more test cases for edge cases? Let me know if there are specific scenarios you'd like me to cover!

@C-you-know
Copy link
Author

Gentle ping @jakevdp 🙂

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really great on a first read-through, thanks!

A couple small comments below. I'll trigger the tests to see if they reveal any issues with the implementation.

Comment on lines 2664 to 2673
curr_out_axis = q_ndim
for i in range(a.ndim):
if i == axis:
index.append(idx)
if keepdims:
curr_out_axis += 1
else:
index.append(lax.broadcasted_iota(np.int32, out_shape, curr_out_axis))
curr_out_axis += 1
result = a[tuple(index)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a very ineficient way to compute the result. I think you could use empty slices in place of the broadcasted iotas to do this more simply.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using empty slices here but it does not work cleanly because idx already encodes the non-axis dimensions. Combining it with slice(None) causes those dimensions to be included twice under advanced indexing. I used take_along_axis to keep the indexing explicit and avoid that duplication

@jakevdp jakevdp self-assigned this Jan 5, 2026
@C-you-know
Copy link
Author

Hi! I looked into the failing CI cases.

The failures in testWeightedQuantile10 happen during the NumPy reference computation rather than in the JAX implementation itself.

np.nanpercentile emits a RuntimeWarning: All-NaN slice encountered on some randomized inputs; JAX returns the same NaN result but doesn’t emit a warning.

_CheckAgainstNumpy treats the NumPy warning as a test error, so the comparison never reaches the output check.
I believe the implementation is correct, and the issue is how the test handles NumPy warnings. Would you prefer that I update the test to suppress NumPy warnings in the reference path (consistent with other nan* reducer tests), or would you like me to handle this differently?

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 7, 2026

JAX tests are deliberately configured to error when there are unexpected warnings. If you want to explicitly ignore a warning within a test, you can use the jtu.ignore_warning utility.

@C-you-know C-you-know force-pushed the feature-weighted-quantile branch from 088e9c2 to fe49c09 Compare January 26, 2026 18:18
@C-you-know
Copy link
Author

Hi @jakevdp,

I have rebased the PR onto the latest upstream, fixed the failing test cases and made a small naming cleanup in the implementation. All tests pass locally now.

Thanks for taking another look.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! A few more changes below.

@api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims', 'method'))
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
weights: ArrayLike | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weights should be a keyword-only parameter, as in np.quantile. Same for the other exported functions below.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Updated.

raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest' or 'inverted_cdf'")
if weights is not None:
weights = ensure_arraylike("_quantile", weights)
weights = lax.asarray(weights)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove this; ensure_arraylike already takes care of array conversion. If you move ensure_arraylike to the calling function, then you could make the annotation Array | None and then this wouldn't be required.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to clarify this point based on the current structure of quantile and the related nanquantile, percentile and nanpercentile functions. The calling functions already apply ensure_arraylike to an and q but still pass both through lax.asarray before calling _quantile. since, ensure_arraylike also performs array conversion I was not sure whether this duplication is intentional for consistency.

if method not in ["linear", "lower", "higher", "midpoint", "nearest", "inverted_cdf"]:
raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest' or 'inverted_cdf'")
if weights is not None:
weights = ensure_arraylike("_quantile", weights)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name here is used to generate a user-visible message, and _quantile is an internal name.

Instead, we should pass weights to ensure_arraylike in the calling functions (alternatively we could modify _quantile to take the name of the calling function as a parameter, but that seems less clean).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated accordingly

if weights.shape != a.shape:
if axis is None:
raise ValueError("Weights shape must match 'a' shape when axis is None.")
ax_tuple = (axis,) if isinstance(axis, int) else tuple(axis)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use util.canonicalize_axis_tuple for this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated accordingly

keepdim = [1] * a.ndim
a = a.ravel()
if weights is not None:
weights = weights.ravel()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two-space indentation please. Also in a few places below.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry. Updated the code to follow two space indentation.

result = lax.mul(lax.add(low_value, high_value), lax._const(low_value, 0.5))
elif method == "inverted_cdf":
if weights is None:
weights = lax.full(a.shape, 1.0, dtype=a.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lax.full_like(a, 1.0) is a bit more concise, and will propagate sharding information when relevant.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

else:
valid_counts = lax.full_like(total_weight, a_shape[axis], dtype=q.dtype)
limit = lax.sub(valid_counts, lax._const(valid_counts, 1))
max_idx = lax.convert_element_type(limit, np.int32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use dtypes.default_int_dtype() here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Updated.

max_idx = lax.convert_element_type(limit, np.int32)
max_idx_f = lax.expand_dims(max_idx, tuple(range(q_ndim)))
max_idx_f = lax.convert_element_type(max_idx_f, idx.dtype)
idx = lax.max(lax._const(idx, 0), lax.min(idx, max_idx_f))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use lax.clamp here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered using lax.clamp here but it would require dynamically broadcasting max_idx_f to match the shape of idx. I kept the current approach to avoid introducing that additional broadcasting logic. I am happy to switch to lax.clamp if you would prefer that for consistency

result = indexing.take_along_axis(a, idx_transposed, axis=axis)
result = lax.squeeze(result, (axis,))
else:
perm = (list(range(q_ndim, q_ndim + axis)) +
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generalized unpacking would be a bit clearer here I think:

perm = [*range(q_ndim, q_ndim + axis),
        *range(q_ndim),
        *range(q_ndim + axis, idx_take.ndim)]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Updated.

@C-you-know
Copy link
Author

I have addressed the review comments and left inline replies where relevant. Please let me know if I missed anything

@C-you-know C-you-know force-pushed the feature-weighted-quantile branch from a9742ce to 8723c42 Compare January 27, 2026 17:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support weighted quantile & percentile

2 participants