Skip to content

Commit 6506745

Browse files
Merge pull request #139 from eamag:main
PiperOrigin-RevId: 859822315
2 parents 65c5fc5 + 991d58e commit 6506745

File tree

6 files changed

+125
-3
lines changed

6 files changed

+125
-3
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ jobs:
44
ruff:
55
runs-on: ubuntu-latest
66
steps:
7-
- uses: actions/checkout@v4
7+
- uses: actions/checkout@v6
88
- name: Lint
99
uses: astral-sh/ruff-action@v2
1010
test:
1111
runs-on: ubuntu-latest
1212
steps:
13-
- uses: actions/checkout@v4
13+
- uses: actions/checkout@v6
1414
- name: Set up Python 3.12
15-
uses: actions/setup-python@v5
15+
uses: actions/setup-python@v6
1616
with:
1717
python-version: 3.12
1818
- name: Install dependencies

src/metrax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
RougeL = nlp_metrics.RougeL
4646
RougeN = nlp_metrics.RougeN
4747
SNR = audio_metrics.SNR
48+
SpearmanRankCorrelation = regression_metrics.SpearmanRankCorrelation
4849
SSIM = image_metrics.SSIM
4950
WER = nlp_metrics.WER
5051

@@ -70,6 +71,7 @@
7071
"PSNR",
7172
"RMSE",
7273
"RSQUARED",
74+
"SpearmanRankCorrelation",
7375
"Recall",
7476
"RecallAtK",
7577
"RougeL",

src/metrax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
RougeL = nnx_metrics.RougeL
4040
RougeN = nnx_metrics.RougeN
4141
SNR = nnx_metrics.SNR
42+
SpearmanRankCorrelation = nnx_metrics.SpearmanRankCorrelation
4243
SSIM = nnx_metrics.SSIM
4344
WER = nnx_metrics.WER
4445

@@ -68,6 +69,7 @@
6869
"RougeL",
6970
"RougeN",
7071
"SNR",
72+
"SpearmanRankCorrelation",
7173
"SSIM",
7274
"WER",
7375
]

src/metrax/nnx/nnx_metrics.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,13 @@ def __init__(self):
191191
super().__init__(metrax.SNR)
192192

193193

194+
class SpearmanRankCorrelation(NnxWrapper):
195+
"""An NNX class for the Metrax metric SpearmanRankCorrelation."""
196+
197+
def __init__(self):
198+
super().__init__(metrax.SpearmanRankCorrelation)
199+
200+
194201
class SSIM(NnxWrapper):
195202
"""An NNX class for the Metrax metric SSIM."""
196203

src/metrax/regression_metrics.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,78 @@ def compute(self) -> jax.Array:
275275
mean = base.divide_no_nan(self.total, self.count)
276276
sst = self.sum_of_squared_label - self.count * jnp.power(mean, 2)
277277
return 1 - base.divide_no_nan(self.sum_of_squared_error, sst)
278+
279+
280+
@flax.struct.dataclass
281+
class SpearmanRankCorrelation(clu_metrics.Metric):
282+
r"""Computes the Spearman rank correlation coefficient.
283+
284+
The Spearman rank correlation coefficient measures the monotonic relationship
285+
between two variables. It is defined as the Pearson correlation coefficient
286+
between the ranked variables.
287+
288+
.. math::
289+
\rho = 1 - \frac{6 \sum d_i^2}{n(n^2 - 1)}
290+
291+
where:
292+
- :math:`d_i` is the difference between the ranks of each observation
293+
- :math:`n` is the number of observations
294+
295+
This implementation accumulates all `predictions` and `labels` to compute the
296+
exact ranks upon calling `compute()`.
297+
298+
.. warning::
299+
For very large datasets, this may lead to Out-of-Memory (OOM) errors.
300+
301+
Attributes:
302+
predictions: Accumulated predictions.
303+
labels: Accumulated labels.
304+
"""
305+
306+
predictions: jax.Array
307+
labels: jax.Array
308+
309+
@classmethod
310+
def empty(cls) -> 'SpearmanRankCorrelation':
311+
return cls(
312+
predictions=jnp.array([], jnp.float32),
313+
labels=jnp.array([], jnp.float32),
314+
)
315+
316+
@classmethod
317+
def from_model_output(
318+
cls,
319+
predictions: jax.Array,
320+
labels: jax.Array,
321+
**kwargs,
322+
) -> 'SpearmanRankCorrelation':
323+
del kwargs
324+
return cls(
325+
predictions=predictions.flatten(),
326+
labels=labels.flatten(),
327+
)
328+
329+
def merge(
330+
self, other: 'SpearmanRankCorrelation'
331+
) -> 'SpearmanRankCorrelation':
332+
return type(self)(
333+
predictions=jnp.concatenate([self.predictions, other.predictions]),
334+
labels=jnp.concatenate([self.labels, other.labels]),
335+
)
336+
337+
def compute(self) -> jax.Array:
338+
if self.predictions.size == 0:
339+
return jnp.array(jnp.nan, jnp.float32)
340+
341+
rank_preds = jax.scipy.stats.rankdata(self.predictions)
342+
rank_labels = jax.scipy.stats.rankdata(self.labels)
343+
344+
def pearson_correlation(x, y):
345+
mu_x = jnp.mean(x)
346+
mu_y = jnp.mean(y)
347+
xm, ym = x - mu_x, y - mu_y
348+
r_num = jnp.sum(xm * ym)
349+
r_den = jnp.sqrt(jnp.sum(xm**2) * jnp.sum(ym**2))
350+
return base.divide_no_nan(r_num, r_den)
351+
352+
return pearson_correlation(rank_preds, rank_labels)

src/metrax/regression_metrics_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@
2424
import keras
2525
import metrax
2626
import numpy as np
27+
import scipy
2728
from sklearn import metrics as sklearn_metrics
2829

30+
spearmanr = scipy.stats.spearmanr
31+
2932
np.random.seed(42)
3033
BATCHES = 4
3134
BATCH_SIZE = 8
@@ -321,6 +324,39 @@ def test_rsquared(self, y_true, y_pred, sample_weights):
321324
atol=atol,
322325
)
323326

327+
@parameterized.named_parameters(
328+
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16),
329+
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32),
330+
('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16),
331+
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1),
332+
)
333+
def test_spearman(self, y_true, y_pred):
334+
"""Test that `SpearmanRankCorrelation` Metric computes correct values."""
335+
y_true = y_true.astype(y_pred.dtype)
336+
y_pred = y_pred.astype(y_true.dtype)
337+
338+
metric = None
339+
for labels, logits in zip(y_true, y_pred):
340+
update = metrax.SpearmanRankCorrelation.from_model_output(
341+
predictions=logits,
342+
labels=labels,
343+
)
344+
metric = update if metric is None else metric.merge(update)
345+
346+
expected, _ = spearmanr(
347+
y_true.astype('float32').flatten(),
348+
y_pred.astype('float32').flatten(),
349+
)
350+
# Use lower tolerance for lower precision dtypes.
351+
rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
352+
atol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
353+
np.testing.assert_allclose(
354+
metric.compute(),
355+
expected,
356+
rtol=rtol,
357+
atol=atol,
358+
)
359+
324360

325361
if __name__ == '__main__':
326362
os.environ['XLA_FLAGS'] = (

0 commit comments

Comments
 (0)