Skip to content

Commit 65c5fc5

Browse files
hyunn9973copybara-github
authored andcommitted
Fix a bug on regression_metrics_test.
PiperOrigin-RevId: 858065972
1 parent a579122 commit 65c5fc5

File tree

1 file changed

+53
-4
lines changed

1 file changed

+53
-4
lines changed

src/metrax/regression_metrics_test.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,49 @@
5252

5353
class RegressionMetricsTest(parameterized.TestCase):
5454

55-
def test_multiple_devices(self):
56-
"""Test that metrax metrics work across multiple devices using R2 as an example."""
55+
def test_multiple_devices_jit(self):
56+
"""Test that metrax metrics work across multiple devices using jit and jax.Array."""
57+
# 1. Define the hardware mesh.
58+
devices = jax.devices()
59+
mesh = jax.sharding.Mesh(devices, ('data',))
60+
61+
# 2. Define the sharding strategy (shard the first dimension across 'data'
62+
# axis).
63+
sharding = jax.sharding.NamedSharding(
64+
mesh, jax.sharding.PartitionSpec('data', None)
65+
)
66+
67+
# 3. Shard the global data across devices.
68+
y_pred = jax.device_put(OUTPUT_PREDS, sharding)
69+
y_true = jax.device_put(OUTPUT_LABELS, sharding)
70+
71+
# 4. Use jax.jit for the computation.
72+
# In SPMD mode, RSQUARED.from_model_output will perform global
73+
# reductions (sums) automatically across the shards.
74+
@jax.jit
75+
def compute_metric(logits, labels):
76+
return metrax.RSQUARED.from_model_output(logits, labels)
77+
78+
metric = compute_metric(y_pred, y_true)
5779

80+
# 5. Verify against reference (Keras reference remains the same).
81+
keras_r2 = keras.metrics.R2Score()
82+
for labels, logits in zip(OUTPUT_LABELS, OUTPUT_PREDS):
83+
keras_r2.update_state(
84+
labels[:, jnp.newaxis],
85+
logits[:, jnp.newaxis],
86+
)
87+
expected = keras_r2.result()
88+
89+
np.testing.assert_allclose(
90+
metric.compute(),
91+
expected,
92+
rtol=1e-05,
93+
atol=1e-05,
94+
)
95+
96+
def test_multiple_devices_pmap(self):
97+
"""Test that metrax metrics work across multiple devices using R2 as an example."""
5898
def create_r2(logits, labels):
5999
"""Creates a metrax RSQUARED metric given logits and labels."""
60100
return metrax.RSQUARED.from_model_output(logits, labels)
@@ -72,8 +112,17 @@ def sharded_r2(logits, labels):
72112

73113
y_pred = OUTPUT_PREDS
74114
y_true = OUTPUT_LABELS
75-
metric = jax.jit(sharded_r2)(y_pred, y_true)
76-
metric = metric.reduce()
115+
116+
# Calculate sharded R2 across devices.
117+
metric_sharded = sharded_r2(y_pred, y_true)
118+
119+
# Move the metric results from devices to the host.
120+
cpu_device = jax.devices('cpu')[0]
121+
metric_on_host = jax.tree_util.tree_map(
122+
lambda x: jax.device_put(x, cpu_device),
123+
metric_sharded,
124+
)
125+
metric = metric_on_host.reduce()
77126

78127
keras_r2 = keras.metrics.R2Score()
79128
for labels, logits in zip(y_true, y_pred):

0 commit comments

Comments
 (0)