Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/metrax/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import keras_hub
import metrax
import numpy as np
import os
from sklearn import metrics as sklearn_metrics

np.random.seed(42)
Expand Down Expand Up @@ -172,6 +174,40 @@ def test_perplexity_empty(self):
self.assertEqual(m.aggregate_crossentropy, jnp.array(0, jnp.float32))
self.assertEqual(m.num_samples, jnp.array(0, jnp.float32))

def test_multiple_devices(self):
"""Test that metrax metrics work across multiple devices using R2 as an example."""

def create_r2(logits, labels):
"""Creates a metrax R2 metric given logits and labels."""
return metrax.RSQUARED.from_model_output(logits, labels)

def sharded_r2(logits, labels):
"""Calculates sharded MSE across devices."""
num_devices = jax.device_count()

shard_size = logits.shape[0] // num_devices
sharded_logits = logits.reshape(num_devices, shard_size, logits.shape[-1])
sharded_labels = labels.reshape(num_devices, shard_size, labels.shape[-1])

r2_for_devices = jax.pmap(create_r2)(sharded_logits, sharded_labels)
return r2_for_devices

y_pred = OUTPUT_PREDS
y_true = OUTPUT_LABELS
metric = jax.jit(sharded_r2)(y_pred, y_true)
metric = metric.reduce()

expected = sklearn_metrics.r2_score(
y_true.flatten(),
y_pred.flatten(),
)
np.testing.assert_allclose(
metric.compute(),
expected,
rtol=1e-05,
atol=1e-05,
)

@parameterized.named_parameters(
('basic', OUTPUT_LABELS, OUTPUT_PREDS, 0.5),
('high_threshold', OUTPUT_LABELS, OUTPUT_PREDS, 0.7),
Expand Down Expand Up @@ -415,4 +451,7 @@ def test_perplexity(self, y_true, y_pred, sample_weights):


if __name__ == '__main__':
os.environ['XLA_FLAGS'] = (
'--xla_force_host_platform_device_count=4' # Use 4 CPU devices
)
absltest.main()