5252
5353class RegressionMetricsTest (parameterized .TestCase ):
5454
55- def test_multiple_devices (self ):
56- """Test that metrax metrics work across multiple devices using R2 as an example."""
55+ def test_multiple_devices_jit (self ):
56+ """Test that metrax metrics work across multiple devices using jit and jax.Array."""
57+ # 1. Define the hardware mesh.
58+ devices = jax .devices ()
59+ mesh = jax .sharding .Mesh (devices , ('data' ,))
60+
61+ # 2. Define the sharding strategy (shard the first dimension across 'data'
62+ # axis).
63+ sharding = jax .sharding .NamedSharding (
64+ mesh , jax .sharding .PartitionSpec ('data' , None )
65+ )
66+
67+ # 3. Shard the global data across devices.
68+ y_pred = jax .device_put (OUTPUT_PREDS , sharding )
69+ y_true = jax .device_put (OUTPUT_LABELS , sharding )
70+
71+ # 4. Use jax.jit for the computation.
72+ # In SPMD mode, RSQUARED.from_model_output will perform global
73+ # reductions (sums) automatically across the shards.
74+ @jax .jit
75+ def compute_metric (logits , labels ):
76+ return metrax .RSQUARED .from_model_output (logits , labels )
77+
78+ metric = compute_metric (y_pred , y_true )
5779
80+ # 5. Verify against reference (Keras reference remains the same).
81+ keras_r2 = keras .metrics .R2Score ()
82+ for labels , logits in zip (OUTPUT_LABELS , OUTPUT_PREDS ):
83+ keras_r2 .update_state (
84+ labels [:, jnp .newaxis ],
85+ logits [:, jnp .newaxis ],
86+ )
87+ expected = keras_r2 .result ()
88+
89+ np .testing .assert_allclose (
90+ metric .compute (),
91+ expected ,
92+ rtol = 1e-05 ,
93+ atol = 1e-05 ,
94+ )
95+
96+ def test_multiple_devices_pmap (self ):
97+ """Test that metrax metrics work across multiple devices using R2 as an example."""
5898 def create_r2 (logits , labels ):
5999 """Creates a metrax RSQUARED metric given logits and labels."""
60100 return metrax .RSQUARED .from_model_output (logits , labels )
@@ -72,8 +112,17 @@ def sharded_r2(logits, labels):
72112
73113 y_pred = OUTPUT_PREDS
74114 y_true = OUTPUT_LABELS
75- metric = jax .jit (sharded_r2 )(y_pred , y_true )
76- metric = metric .reduce ()
115+
116+ # Calculate sharded R2 across devices.
117+ metric_sharded = sharded_r2 (y_pred , y_true )
118+
119+ # Move the metric results from devices to the host.
120+ cpu_device = jax .devices ('cpu' )[0 ]
121+ metric_on_host = jax .tree_util .tree_map (
122+ lambda x : jax .device_put (x , cpu_device ),
123+ metric_sharded ,
124+ )
125+ metric = metric_on_host .reduce ()
77126
78127 keras_r2 = keras .metrics .R2Score ()
79128 for labels , logits in zip (y_true , y_pred ):
0 commit comments