1616
1717from absl .testing import absltest
1818from absl .testing import parameterized
19+ import jax
1920import jax .numpy as jnp
2021import keras_hub
2122import metrax
2223import numpy as np
24+ import os
2325from sklearn import metrics as sklearn_metrics
2426
2527np .random .seed (42 )
@@ -172,6 +174,40 @@ def test_perplexity_empty(self):
172174 self .assertEqual (m .aggregate_crossentropy , jnp .array (0 , jnp .float32 ))
173175 self .assertEqual (m .num_samples , jnp .array (0 , jnp .float32 ))
174176
177+ def test_multiple_devices (self ):
178+ """Test that metrax metrics work across multiple devices using R2 as an example."""
179+
180+ def create_r2 (logits , labels ):
181+ """Creates a metrax R2 metric given logits and labels."""
182+ return metrax .RSQUARED .from_model_output (logits , labels )
183+
184+ def sharded_r2 (logits , labels ):
185+ """Calculates sharded MSE across devices."""
186+ num_devices = jax .device_count ()
187+
188+ shard_size = logits .shape [0 ] // num_devices
189+ sharded_logits = logits .reshape (num_devices , shard_size , logits .shape [- 1 ])
190+ sharded_labels = labels .reshape (num_devices , shard_size , labels .shape [- 1 ])
191+
192+ r2_for_devices = jax .pmap (create_r2 )(sharded_logits , sharded_labels )
193+ return r2_for_devices
194+
195+ y_pred = OUTPUT_PREDS
196+ y_true = OUTPUT_LABELS
197+ metric = jax .jit (sharded_r2 )(y_pred , y_true )
198+ metric = metric .reduce ()
199+
200+ expected = sklearn_metrics .r2_score (
201+ y_true .flatten (),
202+ y_pred .flatten (),
203+ )
204+ np .testing .assert_allclose (
205+ metric .compute (),
206+ expected ,
207+ rtol = 1e-05 ,
208+ atol = 1e-05 ,
209+ )
210+
175211 @parameterized .named_parameters (
176212 ('basic' , OUTPUT_LABELS , OUTPUT_PREDS , 0.5 ),
177213 ('high_threshold' , OUTPUT_LABELS , OUTPUT_PREDS , 0.7 ),
@@ -415,4 +451,7 @@ def test_perplexity(self, y_true, y_pred, sample_weights):
415451
416452
417453if __name__ == '__main__' :
454+ os .environ ['XLA_FLAGS' ] = (
455+ '--xla_force_host_platform_device_count=4' # Use 4 CPU devices
456+ )
418457 absltest .main ()
0 commit comments