-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Description
Description
jax.scipy.stats.rankdata supports nan_policy='propagate', but the behavior is not the same as that of scipy.stats.rankdata. Specifically, JAX's implementation adopts a convention that each NaNs is a distinct value greater than infinity, whereas SciPy's implementation treats NaNs as values for which a correct rank cannot be determined, resulting in all ranks being NaN12.
import numpy as np
import jax.numpy as jnp
from jax.scipy.stats import rankdata as jax_rankdata
from scipy.stats import rankdata as scipy_rankdata
x = np.asarray([1, 2, 3, 4, 4, np.nan, np.inf, np.nan])
res = jax_rankdata(jnp.asarray(x), nan_policy='propagate') # [1. 2. 3. 4.5 4.5 7. 6. 8. ]
ref = scipy_rankdata(x, nan_policy='propagate') # [nan nan nan nan nan nan nan nan]System info (python version, jaxlib version, accelerator, etc.)
jax: 0.7.2
jaxlib: 0.7.2
numpy: 2.0.2
python: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='13af8e90dccd', release='6.6.105+', version='#1 SMP Thu Oct 2 10:42:05 UTC 2025', machine='x86_64')
BTW, thanks for addressing the mode issue gh-34486 (and hopefully this one). I'm in the process of getting JAX w/ JIT working throughout scipy.stats (see e.g. scipy/scipy#24405, scipy/scipy#24411), and rankdata and mode are two functions for which our vectorized, array API implementations don't readily lend themselves to JIT. rankdata is particularly important, since it's used in many hypothesis tests. With those functions delegated to JAX, I think most of scipy.stats that we translate to the array API will eventually work with JAX JIT. (Exposing beta, binomial, F, t, and jax.scipy.special would be helpful. If adding them there is undesirable, we could special-case the delegations from scipy.special to the jax.scipy.stats distributions, but I thought adding them to jax.scipy.special might be useful in its own right.)
Footnotes
-
The "correct" behavior of
nan_policy="propagate"may be debatable, but this behavior was discussed in ENH: stats: Addnan_policyoptional argument forstats.rankdatascipy/scipy#16140. It was deemed to be a case in which "Propagate the nan value to the output" is not the same as "just execute the function without checking fornan" (see A Design Specification fornan_policy). ↩ -
If the user wishes to treat NaNs as values that cannot be compared with numbers and omit them from the rank calculation, there is
nan_policy='omit'. ↩