Skip to content

jax.scipy.stats.rankdata nan_policy='propagate' inconsistent with SciPy's #34490

@mdhaber

Description

@mdhaber

Description

jax.scipy.stats.rankdata supports nan_policy='propagate', but the behavior is not the same as that of scipy.stats.rankdata. Specifically, JAX's implementation adopts a convention that each NaNs is a distinct value greater than infinity, whereas SciPy's implementation treats NaNs as values for which a correct rank cannot be determined, resulting in all ranks being NaN12.

import numpy as np
import jax.numpy as jnp
from jax.scipy.stats import rankdata as jax_rankdata
from scipy.stats import rankdata as scipy_rankdata

x = np.asarray([1, 2, 3, 4, 4, np.nan, np.inf, np.nan])
res = jax_rankdata(jnp.asarray(x), nan_policy='propagate')  # [1.  2.  3.  4.5 4.5 7.  6.  8. ]
ref = scipy_rankdata(x, nan_policy='propagate')  # [nan nan nan nan nan nan nan nan]

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.7.2
jaxlib: 0.7.2
numpy:  2.0.2
python: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='13af8e90dccd', release='6.6.105+', version='#1 SMP Thu Oct  2 10:42:05 UTC 2025', machine='x86_64')

BTW, thanks for addressing the mode issue gh-34486 (and hopefully this one). I'm in the process of getting JAX w/ JIT working throughout scipy.stats (see e.g. scipy/scipy#24405, scipy/scipy#24411), and rankdata and mode are two functions for which our vectorized, array API implementations don't readily lend themselves to JIT. rankdata is particularly important, since it's used in many hypothesis tests. With those functions delegated to JAX, I think most of scipy.stats that we translate to the array API will eventually work with JAX JIT. (Exposing beta, binomial, F, t, and $\chi^2$ distribution functions and complements via jax.scipy.special would be helpful. If adding them there is undesirable, we could special-case the delegations from scipy.special to the jax.scipy.stats distributions, but I thought adding them to jax.scipy.special might be useful in its own right.)

Footnotes

  1. The "correct" behavior of nan_policy="propagate" may be debatable, but this behavior was discussed in ENH: stats: Add nan_policy optional argument for stats.rankdata  scipy/scipy#16140. It was deemed to be a case in which "Propagate the nan value to the output" is not the same as "just execute the function without checking for nan" (see A Design Specification for nan_policy).

  2. If the user wishes to treat NaNs as values that cannot be compared with numbers and omit them from the rank calculation, there is nan_policy='omit'.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions