diff --git a/src/metrax/metrics_test.py b/src/metrax/metrics_test.py index 85f6d08..fb056dc 100644 --- a/src/metrax/metrics_test.py +++ b/src/metrax/metrics_test.py @@ -16,10 +16,12 @@ from absl.testing import absltest from absl.testing import parameterized +import jax import jax.numpy as jnp import keras_hub import metrax import numpy as np +import os from sklearn import metrics as sklearn_metrics np.random.seed(42) @@ -172,6 +174,40 @@ def test_perplexity_empty(self): self.assertEqual(m.aggregate_crossentropy, jnp.array(0, jnp.float32)) self.assertEqual(m.num_samples, jnp.array(0, jnp.float32)) + def test_multiple_devices(self): + """Test that metrax metrics work across multiple devices using R2 as an example.""" + + def create_r2(logits, labels): + """Creates a metrax RSQUARED metric given logits and labels.""" + return metrax.RSQUARED.from_model_output(logits, labels) + + def sharded_r2(logits, labels): + """Calculates sharded R2 across devices.""" + num_devices = jax.device_count() + + shard_size = logits.shape[0] // num_devices + sharded_logits = logits.reshape(num_devices, shard_size, logits.shape[-1]) + sharded_labels = labels.reshape(num_devices, shard_size, labels.shape[-1]) + + r2_for_devices = jax.pmap(create_r2)(sharded_logits, sharded_labels) + return r2_for_devices + + y_pred = OUTPUT_PREDS + y_true = OUTPUT_LABELS + metric = jax.jit(sharded_r2)(y_pred, y_true) + metric = metric.reduce() + + expected = sklearn_metrics.r2_score( + y_true.flatten(), + y_pred.flatten(), + ) + np.testing.assert_allclose( + metric.compute(), + expected, + rtol=1e-05, + atol=1e-05, + ) + @parameterized.named_parameters( ('basic', OUTPUT_LABELS, OUTPUT_PREDS, 0.5), ('high_threshold', OUTPUT_LABELS, OUTPUT_PREDS, 0.7), @@ -415,4 +451,7 @@ def test_perplexity(self, y_true, y_pred, sample_weights): if __name__ == '__main__': + os.environ['XLA_FLAGS'] = ( + '--xla_force_host_platform_device_count=4' # Use 4 CPU devices + ) absltest.main() \ No newline at end of file