@@ -275,3 +275,78 @@ def compute(self) -> jax.Array:
275275 mean = base .divide_no_nan (self .total , self .count )
276276 sst = self .sum_of_squared_label - self .count * jnp .power (mean , 2 )
277277 return 1 - base .divide_no_nan (self .sum_of_squared_error , sst )
278+
279+
280+ @flax .struct .dataclass
281+ class SpearmanRankCorrelation (clu_metrics .Metric ):
282+ r"""Computes the Spearman rank correlation coefficient.
283+
284+ The Spearman rank correlation coefficient measures the monotonic relationship
285+ between two variables. It is defined as the Pearson correlation coefficient
286+ between the ranked variables.
287+
288+ .. math::
289+ \rho = 1 - \frac{6 \sum d_i^2}{n(n^2 - 1)}
290+
291+ where:
292+ - :math:`d_i` is the difference between the ranks of each observation
293+ - :math:`n` is the number of observations
294+
295+ This implementation accumulates all `predictions` and `labels` to compute the
296+ exact ranks upon calling `compute()`.
297+
298+ .. warning::
299+ For very large datasets, this may lead to Out-of-Memory (OOM) errors.
300+
301+ Attributes:
302+ predictions: Accumulated predictions.
303+ labels: Accumulated labels.
304+ """
305+
306+ predictions : jax .Array
307+ labels : jax .Array
308+
309+ @classmethod
310+ def empty (cls ) -> 'SpearmanRankCorrelation' :
311+ return cls (
312+ predictions = jnp .array ([], jnp .float32 ),
313+ labels = jnp .array ([], jnp .float32 ),
314+ )
315+
316+ @classmethod
317+ def from_model_output (
318+ cls ,
319+ predictions : jax .Array ,
320+ labels : jax .Array ,
321+ ** kwargs ,
322+ ) -> 'SpearmanRankCorrelation' :
323+ del kwargs
324+ return cls (
325+ predictions = predictions .flatten (),
326+ labels = labels .flatten (),
327+ )
328+
329+ def merge (
330+ self , other : 'SpearmanRankCorrelation'
331+ ) -> 'SpearmanRankCorrelation' :
332+ return type (self )(
333+ predictions = jnp .concatenate ([self .predictions , other .predictions ]),
334+ labels = jnp .concatenate ([self .labels , other .labels ]),
335+ )
336+
337+ def compute (self ) -> jax .Array :
338+ if self .predictions .size == 0 :
339+ return jnp .array (jnp .nan , jnp .float32 )
340+
341+ rank_preds = jax .scipy .stats .rankdata (self .predictions )
342+ rank_labels = jax .scipy .stats .rankdata (self .labels )
343+
344+ def pearson_correlation (x , y ):
345+ mu_x = jnp .mean (x )
346+ mu_y = jnp .mean (y )
347+ xm , ym = x - mu_x , y - mu_y
348+ r_num = jnp .sum (xm * ym )
349+ r_den = jnp .sqrt (jnp .sum (xm ** 2 ) * jnp .sum (ym ** 2 ))
350+ return base .divide_no_nan (r_num , r_den )
351+
352+ return pearson_correlation (rank_preds , rank_labels )
0 commit comments