@@ -2375,8 +2375,8 @@ def cumulative_prod(
23752375@export
23762376@api .jit (static_argnames = ('axis' , 'overwrite_input' , 'keepdims' , 'method' ))
23772377def quantile (a : ArrayLike , q : ArrayLike , axis : int | tuple [int , ...] | None = None ,
2378- weights : ArrayLike | None = None , out : None = None , overwrite_input : bool = False , method : str = "linear" ,
2379- keepdims : bool = False ) -> Array :
2378+ out : None = None , overwrite_input : bool = False , method : str = "linear" ,
2379+ keepdims : bool = False , * , weights : ArrayLike | None = None ) -> Array :
23802380 """Compute the quantile of the data along the specified axis.
23812381
23822382 JAX implementation of :func:`numpy.quantile`.
@@ -2426,6 +2426,8 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
24262426 Array(3., dtype=float32)
24272427 """
24282428 a , q = ensure_arraylike ("quantile" , a , q )
2429+ if weights is not None :
2430+ weights = ensure_arraylike ("quantile" , weights )
24292431 if overwrite_input or out is not None :
24302432 raise ValueError ("jax.numpy.quantile does not support overwrite_input=True "
24312433 "or out != None" )
@@ -2435,8 +2437,8 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
24352437@export
24362438@api .jit (static_argnames = ('axis' , 'overwrite_input' , 'keepdims' , 'method' ))
24372439def nanquantile (a : ArrayLike , q : ArrayLike , axis : int | tuple [int , ...] | None = None ,
2438- weights : ArrayLike | None = None , out : None = None , overwrite_input : bool = False , method : str = "linear" ,
2439- keepdims : bool = False ) -> Array :
2440+ out : None = None , overwrite_input : bool = False , method : str = "linear" ,
2441+ keepdims : bool = False , * , weights : ArrayLike | None = None ) -> Array :
24402442 """Compute the quantile of the data along the specified axis, ignoring NaNs.
24412443
24422444 JAX implementation of :func:`numpy.nanquantile`.
@@ -2487,18 +2489,19 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
24872489 Array(4.0, dtype=float32)
24882490 """
24892491 a , q = ensure_arraylike ("nanquantile" , a , q )
2492+ if weights is not None :
2493+ weights = ensure_arraylike ("nanquantile" , weights )
24902494 if overwrite_input or out is not None :
24912495 msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
24922496 "out != None" )
24932497 raise ValueError (msg )
24942498 return _quantile (lax .asarray (a ), lax .asarray (q ), axis , method , keepdims , True , weights )
24952499
24962500def _quantile (a : Array , q : Array , axis : int | tuple [int , ...] | None ,
2497- method : str , keepdims : bool , squash_nans : bool , weights : ArrayLike | None = None ) -> Array :
2501+ method : str , keepdims : bool , squash_nans : bool , weights : Array | None = None ) -> Array :
24982502 if method not in ["linear" , "lower" , "higher" , "midpoint" , "nearest" , "inverted_cdf" ]:
24992503 raise ValueError ("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest' or 'inverted_cdf'" )
25002504 if weights is not None :
2501- weights = ensure_arraylike ("_quantile" , weights )
25022505 weights = lax .asarray (weights )
25032506 if dtypes .issubdtype (weights .dtype , np .complexfloating ):
25042507 raise ValueError ("Weights cannot be complex types." )
@@ -2508,8 +2511,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25082511 if weights .shape != a .shape :
25092512 if axis is None :
25102513 raise ValueError ("Weights shape must match 'a' shape when axis is None." )
2511- ax_tuple = (axis ,) if isinstance (axis , int ) else tuple (axis )
2512- ax_tuple = tuple (canonicalize_axis (ax , a .ndim ) for ax in ax_tuple )
2514+ ax_tuple = canonicalize_axis_tuple (axis , a .ndim )
25132515 if weights .shape != tuple (a .shape [ax ] for ax in ax_tuple ):
25142516 raise ValueError (f"Weights shape { weights .shape } must match reduction axes "
25152517 f"{ tuple (a .shape [ax ] for ax in ax_tuple )} " )
@@ -2524,7 +2526,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25242526 keepdim = [1 ] * a .ndim
25252527 a = a .ravel ()
25262528 if weights is not None :
2527- weights = weights .ravel ()
2529+ weights = weights .ravel ()
25282530 axis = 0
25292531 elif isinstance (axis , tuple ):
25302532 keepdim = list (a .shape )
@@ -2544,7 +2546,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25442546 touch_shape = tuple (x for idx ,x in enumerate (a .shape ) if idx in axis )
25452547 a = lax .reshape (a , do_not_touch_shape + (math .prod (touch_shape ),), dimensions )
25462548 if weights is not None :
2547- weights = lax .reshape (weights , do_not_touch_shape + (math .prod (touch_shape ),), dimensions )
2549+ weights = lax .reshape (weights , do_not_touch_shape + (math .prod (touch_shape ),), dimensions )
25482550 axis = canonicalize_axis (- 1 , a .ndim )
25492551 else :
25502552 axis = canonicalize_axis (axis , a .ndim )
@@ -2559,9 +2561,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25592561 if squash_nans :
25602562 a = _where (lax ._isnan (a ), np .nan , a ) # Ensure nans are positive so they sort to the end.
25612563 if weights is not None :
2562- a , weights = lax .sort_key_val (a , weights , dimension = axis )
2564+ a , weights = lax .sort_key_val (a , weights , dimension = axis )
25632565 else :
2564- a = lax .sort (a , dimension = axis )
2566+ a = lax .sort (a , dimension = axis )
25652567 counts = sum (lax .bitwise_not (lax ._isnan (a )), axis = axis , dtype = q .dtype , keepdims = keepdims )
25662568 shape_after_reduction = counts .shape
25672569 q = lax .expand_dims (
@@ -2591,9 +2593,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25912593 with config .debug_nans (False ):
25922594 a = _where (any (lax ._isnan (a ), axis = axis , keepdims = True ), np .nan , a )
25932595 if weights is not None :
2594- a , weights = lax .sort_key_val (a , weights , dimension = axis )
2596+ a , weights = lax .sort_key_val (a , weights , dimension = axis )
25952597 else :
2596- a = lax .sort (a , dimension = axis )
2598+ a = lax .sort (a , dimension = axis )
25972599 n = lax .convert_element_type (a_shape [axis ], lax ._dtype (q ))
25982600 q = lax .mul (q , n - 1 )
25992601 low = lax .floor (q )
@@ -2638,7 +2640,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
26382640 result = lax .mul (lax .add (low_value , high_value ), lax ._const (low_value , 0.5 ))
26392641 elif method == "inverted_cdf" :
26402642 if weights is None :
2641- weights = lax .full ( a . shape , 1.0 , dtype = a . dtype )
2643+ weights = lax .full_like ( a , 1.0 )
26422644 zeros = lax .full_like (weights , 0 )
26432645 bad_weights = lax .bitwise_or (lax .lt (weights , zeros ), lax ._isnan (weights ))
26442646 nan_data = lax ._isnan (a )
@@ -2651,13 +2653,13 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
26512653 target_w_aligned = target_w if keepdims else lax .expand_dims (target_w , (axis + q_ndim ,))
26522654 cw_f = lax .expand_dims (cum_weights , tuple (range (q_ndim )))
26532655 is_less = lax .lt (cw_f , target_w_aligned )
2654- idx = sum (lax .convert_element_type (is_less , np . int32 ), axis = axis + q_ndim , keepdims = keepdims )
2656+ idx = sum (lax .convert_element_type (is_less , dtypes . default_int_dtype () ), axis = axis + q_ndim , keepdims = keepdims )
26552657 if squash_nans :
26562658 valid_counts = sum (lax .bitwise_not (nan_data ), axis = axis , dtype = q .dtype , keepdims = keepdims )
26572659 else :
26582660 valid_counts = lax .full_like (total_weight , a_shape [axis ], dtype = q .dtype )
26592661 limit = lax .sub (valid_counts , lax ._const (valid_counts , 1 ))
2660- max_idx = lax .convert_element_type (limit , np . int32 )
2662+ max_idx = lax .convert_element_type (limit , dtypes . default_int_dtype () )
26612663 max_idx_f = lax .expand_dims (max_idx , tuple (range (q_ndim )))
26622664 max_idx_f = lax .convert_element_type (max_idx_f , idx .dtype )
26632665 idx = lax .max (lax ._const (idx , 0 ), lax .min (idx , max_idx_f ))
@@ -2670,9 +2672,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
26702672 result = indexing .take_along_axis (a , idx_transposed , axis = axis )
26712673 result = lax .squeeze (result , (axis ,))
26722674 else :
2673- perm = ( list ( range (q_ndim , q_ndim + axis )) +
2674- list ( range (q_ndim )) +
2675- list ( range (q_ndim + axis , idx_take .ndim )))
2675+ perm = [ * range (q_ndim , q_ndim + axis ),
2676+ * range (q_ndim ),
2677+ * range (q_ndim + axis , idx_take .ndim )]
26762678 idx_transposed = lax .transpose (idx_take , perm )
26772679 result = indexing .take_along_axis (a , idx_transposed , axis = axis )
26782680 inv_perm = [perm .index (i ) for i in range (len (perm ))]
@@ -2701,7 +2703,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
27012703def percentile (a : ArrayLike , q : ArrayLike ,
27022704 axis : int | tuple [int , ...] | None = None ,
27032705 out : None = None , overwrite_input : bool = False , method : str = "linear" ,
2704- weights : ArrayLike | None = None , keepdims : bool = False ) -> Array :
2706+ keepdims : bool = False , * , weights : ArrayLike | None = None ) -> Array :
27052707 """Compute the percentile of the data along the specified axis.
27062708
27072709 JAX implementation of :func:`numpy.percentile`.
@@ -2751,6 +2753,8 @@ def percentile(a: ArrayLike, q: ArrayLike,
27512753 Array(3., dtype=float32)
27522754 """
27532755 a , q = ensure_arraylike ("percentile" , a , q )
2756+ if weights is not None :
2757+ weights = ensure_arraylike ("percentile" , weights )
27542758 q , = promote_dtypes_inexact (q )
27552759 return quantile (a , q / 100 , axis = axis , out = out , overwrite_input = overwrite_input ,
27562760 method = method , weights = weights , keepdims = keepdims )
@@ -2761,7 +2765,7 @@ def percentile(a: ArrayLike, q: ArrayLike,
27612765def nanpercentile (a : ArrayLike , q : ArrayLike ,
27622766 axis : int | tuple [int , ...] | None = None ,
27632767 out : None = None , overwrite_input : bool = False , method : str = "linear" ,
2764- weights : ArrayLike | None = None , keepdims : bool = False ) -> Array :
2768+ keepdims : bool = False , * , weights : ArrayLike | None = None ) -> Array :
27652769 """Compute the percentile of the data along the specified axis, ignoring NaN values.
27662770
27672771 JAX implementation of :func:`numpy.nanpercentile`.
@@ -2813,6 +2817,8 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
28132817 Array(4.0, dtype=float32)
28142818 """
28152819 a , q = ensure_arraylike ("nanpercentile" , a , q )
2820+ if weights is not None :
2821+ weights = ensure_arraylike ("nanpercentile" , weights )
28162822 q , = promote_dtypes_inexact (q )
28172823 q = q / 100
28182824 return nanquantile (a , q , axis = axis , out = out , overwrite_input = overwrite_input ,
0 commit comments