Skip to content

Commit f079045

Browse files
committed
add multi device test
1 parent ec41518 commit f079045

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

src/metrax/metrics_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616

1717
from absl.testing import absltest
1818
from absl.testing import parameterized
19+
import jax
1920
import jax.numpy as jnp
2021
import keras_hub
2122
import metrax
2223
import numpy as np
24+
import os
2325
from sklearn import metrics as sklearn_metrics
2426

2527
np.random.seed(42)
@@ -172,6 +174,40 @@ def test_perplexity_empty(self):
172174
self.assertEqual(m.aggregate_crossentropy, jnp.array(0, jnp.float32))
173175
self.assertEqual(m.num_samples, jnp.array(0, jnp.float32))
174176

177+
def test_multiple_devices(self):
178+
"""Test that metrax metrics work across multiple devices using R2 as an example."""
179+
180+
def create_r2(logits, labels):
181+
"""Creates a metrax R2 metric given logits and labels."""
182+
return metrax.RSQUARED.from_model_output(logits, labels)
183+
184+
def sharded_r2(logits, labels):
185+
"""Calculates sharded MSE across devices."""
186+
num_devices = jax.device_count()
187+
188+
shard_size = logits.shape[0] // num_devices
189+
sharded_logits = logits.reshape(num_devices, shard_size, logits.shape[-1])
190+
sharded_labels = labels.reshape(num_devices, shard_size, labels.shape[-1])
191+
192+
r2_for_devices = jax.pmap(create_r2)(sharded_logits, sharded_labels)
193+
return r2_for_devices
194+
195+
y_pred = OUTPUT_PREDS
196+
y_true = OUTPUT_LABELS
197+
metric = jax.jit(sharded_r2)(y_pred, y_true)
198+
metric = metric.reduce()
199+
200+
expected = sklearn_metrics.r2_score(
201+
y_true.flatten(),
202+
y_pred.flatten(),
203+
)
204+
np.testing.assert_allclose(
205+
metric.compute(),
206+
expected,
207+
rtol=1e-05,
208+
atol=1e-05,
209+
)
210+
175211
@parameterized.named_parameters(
176212
('basic', OUTPUT_LABELS, OUTPUT_PREDS, 0.5),
177213
('high_threshold', OUTPUT_LABELS, OUTPUT_PREDS, 0.7),
@@ -415,4 +451,7 @@ def test_perplexity(self, y_true, y_pred, sample_weights):
415451

416452

417453
if __name__ == '__main__':
454+
os.environ['XLA_FLAGS'] = (
455+
'--xla_force_host_platform_device_count=4' # Use 4 CPU devices
456+
)
418457
absltest.main()

0 commit comments

Comments
 (0)