Skip to content

Commit 8723c42

Browse files
committed
Address review comments
1 parent e5597da commit 8723c42

File tree

1 file changed

+28
-22
lines changed

1 file changed

+28
-22
lines changed

jax/_src/numpy/reductions.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2375,8 +2375,8 @@ def cumulative_prod(
23752375
@export
23762376
@api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims', 'method'))
23772377
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
2378-
weights: ArrayLike | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear",
2379-
keepdims: bool = False) -> Array:
2378+
out: None = None, overwrite_input: bool = False, method: str = "linear",
2379+
keepdims: bool = False, *, weights: ArrayLike | None = None) -> Array:
23802380
"""Compute the quantile of the data along the specified axis.
23812381
23822382
JAX implementation of :func:`numpy.quantile`.
@@ -2426,6 +2426,8 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
24262426
Array(3., dtype=float32)
24272427
"""
24282428
a, q = ensure_arraylike("quantile", a, q)
2429+
if weights is not None:
2430+
weights = ensure_arraylike("quantile", weights)
24292431
if overwrite_input or out is not None:
24302432
raise ValueError("jax.numpy.quantile does not support overwrite_input=True "
24312433
"or out != None")
@@ -2435,8 +2437,8 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
24352437
@export
24362438
@api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims', 'method'))
24372439
def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
2438-
weights: ArrayLike | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear",
2439-
keepdims: bool = False) -> Array:
2440+
out: None = None, overwrite_input: bool = False, method: str = "linear",
2441+
keepdims: bool = False, *, weights: ArrayLike | None = None) -> Array:
24402442
"""Compute the quantile of the data along the specified axis, ignoring NaNs.
24412443
24422444
JAX implementation of :func:`numpy.nanquantile`.
@@ -2487,18 +2489,19 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
24872489
Array(4.0, dtype=float32)
24882490
"""
24892491
a, q = ensure_arraylike("nanquantile", a, q)
2492+
if weights is not None:
2493+
weights = ensure_arraylike("nanquantile", weights)
24902494
if overwrite_input or out is not None:
24912495
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
24922496
"out != None")
24932497
raise ValueError(msg)
24942498
return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, True, weights)
24952499

24962500
def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
2497-
method: str, keepdims: bool, squash_nans: bool, weights: ArrayLike | None = None) -> Array:
2501+
method: str, keepdims: bool, squash_nans: bool, weights: Array | None = None) -> Array:
24982502
if method not in ["linear", "lower", "higher", "midpoint", "nearest", "inverted_cdf"]:
24992503
raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest' or 'inverted_cdf'")
25002504
if weights is not None:
2501-
weights = ensure_arraylike("_quantile", weights)
25022505
weights = lax.asarray(weights)
25032506
if dtypes.issubdtype(weights.dtype, np.complexfloating):
25042507
raise ValueError("Weights cannot be complex types.")
@@ -2508,8 +2511,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25082511
if weights.shape != a.shape:
25092512
if axis is None:
25102513
raise ValueError("Weights shape must match 'a' shape when axis is None.")
2511-
ax_tuple = (axis,) if isinstance(axis, int) else tuple(axis)
2512-
ax_tuple = tuple(canonicalize_axis(ax, a.ndim) for ax in ax_tuple)
2514+
ax_tuple = canonicalize_axis_tuple(axis, a.ndim)
25132515
if weights.shape != tuple(a.shape[ax] for ax in ax_tuple):
25142516
raise ValueError(f"Weights shape {weights.shape} must match reduction axes "
25152517
f"{tuple(a.shape[ax] for ax in ax_tuple)}")
@@ -2524,7 +2526,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25242526
keepdim = [1] * a.ndim
25252527
a = a.ravel()
25262528
if weights is not None:
2527-
weights = weights.ravel()
2529+
weights = weights.ravel()
25282530
axis = 0
25292531
elif isinstance(axis, tuple):
25302532
keepdim = list(a.shape)
@@ -2544,7 +2546,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25442546
touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx in axis)
25452547
a = lax.reshape(a, do_not_touch_shape + (math.prod(touch_shape),), dimensions)
25462548
if weights is not None:
2547-
weights = lax.reshape(weights, do_not_touch_shape + (math.prod(touch_shape),), dimensions)
2549+
weights = lax.reshape(weights, do_not_touch_shape + (math.prod(touch_shape),), dimensions)
25482550
axis = canonicalize_axis(-1, a.ndim)
25492551
else:
25502552
axis = canonicalize_axis(axis, a.ndim)
@@ -2559,9 +2561,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25592561
if squash_nans:
25602562
a = _where(lax._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
25612563
if weights is not None:
2562-
a, weights = lax.sort_key_val(a, weights, dimension=axis)
2564+
a, weights = lax.sort_key_val(a, weights, dimension=axis)
25632565
else:
2564-
a = lax.sort(a, dimension=axis)
2566+
a = lax.sort(a, dimension=axis)
25652567
counts = sum(lax.bitwise_not(lax._isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims)
25662568
shape_after_reduction = counts.shape
25672569
q = lax.expand_dims(
@@ -2591,9 +2593,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25912593
with config.debug_nans(False):
25922594
a = _where(any(lax._isnan(a), axis=axis, keepdims=True), np.nan, a)
25932595
if weights is not None:
2594-
a, weights = lax.sort_key_val(a, weights, dimension=axis)
2596+
a, weights = lax.sort_key_val(a, weights, dimension=axis)
25952597
else:
2596-
a = lax.sort(a, dimension=axis)
2598+
a = lax.sort(a, dimension=axis)
25972599
n = lax.convert_element_type(a_shape[axis], lax._dtype(q))
25982600
q = lax.mul(q, n - 1)
25992601
low = lax.floor(q)
@@ -2638,7 +2640,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
26382640
result = lax.mul(lax.add(low_value, high_value), lax._const(low_value, 0.5))
26392641
elif method == "inverted_cdf":
26402642
if weights is None:
2641-
weights = lax.full(a.shape, 1.0, dtype=a.dtype)
2643+
weights = lax.full_like(a, 1.0)
26422644
zeros = lax.full_like(weights, 0)
26432645
bad_weights = lax.bitwise_or(lax.lt(weights, zeros), lax._isnan(weights))
26442646
nan_data = lax._isnan(a)
@@ -2651,13 +2653,13 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
26512653
target_w_aligned = target_w if keepdims else lax.expand_dims(target_w, (axis + q_ndim,))
26522654
cw_f = lax.expand_dims(cum_weights, tuple(range(q_ndim)))
26532655
is_less = lax.lt(cw_f, target_w_aligned)
2654-
idx = sum(lax.convert_element_type(is_less, np.int32), axis=axis + q_ndim, keepdims=keepdims)
2656+
idx = sum(lax.convert_element_type(is_less, dtypes.default_int_dtype()), axis=axis + q_ndim, keepdims=keepdims)
26552657
if squash_nans:
26562658
valid_counts = sum(lax.bitwise_not(nan_data), axis=axis, dtype=q.dtype, keepdims=keepdims)
26572659
else:
26582660
valid_counts = lax.full_like(total_weight, a_shape[axis], dtype=q.dtype)
26592661
limit = lax.sub(valid_counts, lax._const(valid_counts, 1))
2660-
max_idx = lax.convert_element_type(limit, np.int32)
2662+
max_idx = lax.convert_element_type(limit, dtypes.default_int_dtype())
26612663
max_idx_f = lax.expand_dims(max_idx, tuple(range(q_ndim)))
26622664
max_idx_f = lax.convert_element_type(max_idx_f, idx.dtype)
26632665
idx = lax.max(lax._const(idx, 0), lax.min(idx, max_idx_f))
@@ -2670,9 +2672,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
26702672
result = indexing.take_along_axis(a, idx_transposed, axis=axis)
26712673
result = lax.squeeze(result, (axis,))
26722674
else:
2673-
perm = (list(range(q_ndim, q_ndim + axis)) +
2674-
list(range(q_ndim)) +
2675-
list(range(q_ndim + axis, idx_take.ndim)))
2675+
perm = [*range(q_ndim, q_ndim + axis),
2676+
*range(q_ndim),
2677+
*range(q_ndim + axis, idx_take.ndim)]
26762678
idx_transposed = lax.transpose(idx_take, perm)
26772679
result = indexing.take_along_axis(a, idx_transposed, axis=axis)
26782680
inv_perm = [perm.index(i) for i in range(len(perm))]
@@ -2701,7 +2703,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
27012703
def percentile(a: ArrayLike, q: ArrayLike,
27022704
axis: int | tuple[int, ...] | None = None,
27032705
out: None = None, overwrite_input: bool = False, method: str = "linear",
2704-
weights: ArrayLike | None = None, keepdims: bool = False) -> Array:
2706+
keepdims: bool = False, *, weights: ArrayLike | None = None) -> Array:
27052707
"""Compute the percentile of the data along the specified axis.
27062708
27072709
JAX implementation of :func:`numpy.percentile`.
@@ -2751,6 +2753,8 @@ def percentile(a: ArrayLike, q: ArrayLike,
27512753
Array(3., dtype=float32)
27522754
"""
27532755
a, q = ensure_arraylike("percentile", a, q)
2756+
if weights is not None:
2757+
weights = ensure_arraylike("percentile", weights)
27542758
q, = promote_dtypes_inexact(q)
27552759
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
27562760
method=method, weights=weights, keepdims=keepdims)
@@ -2761,7 +2765,7 @@ def percentile(a: ArrayLike, q: ArrayLike,
27612765
def nanpercentile(a: ArrayLike, q: ArrayLike,
27622766
axis: int | tuple[int, ...] | None = None,
27632767
out: None = None, overwrite_input: bool = False, method: str = "linear",
2764-
weights: ArrayLike | None = None, keepdims: bool = False) -> Array:
2768+
keepdims: bool = False, *, weights: ArrayLike | None = None) -> Array:
27652769
"""Compute the percentile of the data along the specified axis, ignoring NaN values.
27662770
27672771
JAX implementation of :func:`numpy.nanpercentile`.
@@ -2813,6 +2817,8 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
28132817
Array(4.0, dtype=float32)
28142818
"""
28152819
a, q = ensure_arraylike("nanpercentile", a, q)
2820+
if weights is not None:
2821+
weights = ensure_arraylike("nanpercentile", weights)
28162822
q, = promote_dtypes_inexact(q)
28172823
q = q / 100
28182824
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,

0 commit comments

Comments
 (0)