Skip to content

Commit a5799f5

Browse files
authored
migrate metrics_test.py from sklearn to keras (#16)
1 parent f80bf7d commit a5799f5

File tree

1 file changed

+87
-145
lines changed

1 file changed

+87
-145
lines changed

src/metrax/metrics_test.py

Lines changed: 87 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Tests for metrax.metrax."""
15+
"""Tests for metrax metrics."""
1616

1717
from absl.testing import absltest
1818
from absl.testing import parameterized
1919
import jax
2020
import jax.numpy as jnp
21+
import keras
2122
import keras_hub
2223
import metrax
2324
import numpy as np
@@ -47,78 +48,6 @@
4748

4849
class MetricsTest(parameterized.TestCase):
4950

50-
def setUp(self):
51-
super().setUp()
52-
53-
# TODO(jeffcarp): Merge these into generated fixtures.
54-
self.model_outputs = (
55-
dict(
56-
logits=jnp.array(
57-
[0.34, 0.89, 0.12, 0.67, 0.98, 0.23, 0.56, 0.71, 0.45, 0.08]
58-
),
59-
labels=jnp.array([1, 0, 1, 1, 0, 0, 1, 0, 1, 1]),
60-
),
61-
dict(
62-
logits=jnp.array(
63-
[0.23, 0.89, 0.57, 0.11, 0.99, 0.38, 0.76, 0.05, 0.62, 0.44]
64-
),
65-
labels=jnp.array([0, 0, 1, 0, 1, 1, 0, 1, 0, 0]),
66-
),
67-
dict(
68-
logits=jnp.array(
69-
[0.67, 0.21, 0.95, 0.03, 0.88, 0.51, 0.34, 0.79, 0.15, 0.42]
70-
),
71-
labels=jnp.array([1, 1, 0, 1, 0, 1, 1, 0, 0, 1]),
72-
),
73-
dict(
74-
logits=jnp.array(
75-
[0.91, 0.37, 0.18, 0.75, 0.59, 0.02, 0.83, 0.26, 0.64, 0.48]
76-
),
77-
labels=jnp.array([0, 1, 1, 0, 0, 1, 0, 1, 1, 0]),
78-
),
79-
)
80-
self.model_outputs_batch_size_one = (
81-
dict(
82-
logits=jnp.array([[0.32]]),
83-
labels=jnp.array([1]),
84-
),
85-
dict(
86-
logits=jnp.array([[0.74]]),
87-
labels=jnp.array([1]),
88-
),
89-
dict(
90-
logits=jnp.array([[0.86]]),
91-
labels=jnp.array([1]),
92-
),
93-
dict(
94-
logits=jnp.array([[0.21]]),
95-
labels=jnp.array([1]),
96-
),
97-
)
98-
self.sample_weights = jnp.array([0.5, 1, 0, 0, 0, 0, 0, 0, 0, 0])
99-
100-
def compute_aucpr(self, model_outputs, sample_weights=None):
101-
metric = None
102-
for model_output in model_outputs:
103-
update = metrax.AUCPR.from_model_output(
104-
predictions=model_output.get('logits'),
105-
labels=model_output.get('labels'),
106-
sample_weights=sample_weights,
107-
)
108-
metric = update if metric is None else metric.merge(update)
109-
return metric.compute()
110-
111-
def compute_aucroc(self, model_outputs, sample_weights=None):
112-
metric = None
113-
for model_output in model_outputs:
114-
update = metrax.AUCROC.from_model_output(
115-
predictions=model_output.get('logits'),
116-
labels=model_output.get('labels'),
117-
sample_weights=sample_weights,
118-
)
119-
metric = update if metric is None else metric.merge(update)
120-
return metric.compute()
121-
12251
def test_mse_empty(self):
12352
"""Tests the `empty` method of the `MSE` class."""
12453
m = metrax.MSE.empty()
@@ -197,10 +126,13 @@ def sharded_r2(logits, labels):
197126
metric = jax.jit(sharded_r2)(y_pred, y_true)
198127
metric = metric.reduce()
199128

200-
expected = sklearn_metrics.r2_score(
201-
y_true.flatten(),
202-
y_pred.flatten(),
203-
)
129+
keras_r2 = keras.metrics.R2Score()
130+
for labels, logits in zip(y_true, y_pred):
131+
keras_r2.update_state(
132+
labels[:, jnp.newaxis],
133+
logits[:, jnp.newaxis],
134+
)
135+
expected = keras_r2.result()
204136
np.testing.assert_allclose(
205137
metric.compute(),
206138
expected,
@@ -215,10 +147,12 @@ def sharded_r2(logits, labels):
215147
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5),
216148
)
217149
def test_precision(self, y_true, y_pred, threshold):
218-
"""Test that Precision metric computes correct values."""
150+
"""Test that `Precision` metric computes correct values."""
219151
y_true = y_true.reshape((-1,))
220152
y_pred = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0)
221-
expected = sklearn_metrics.precision_score(y_true, y_pred)
153+
keras_precision = keras.metrics.Precision(thresholds=threshold)
154+
keras_precision.update_state(y_true, y_pred)
155+
expected = keras_precision.result()
222156

223157
metric = None
224158
for logits, labels in zip(y_pred, y_true):
@@ -241,10 +175,12 @@ def test_precision(self, y_true, y_pred, threshold):
241175
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5),
242176
)
243177
def test_recall(self, y_true, y_pred, threshold):
244-
"""Test that Recall metric computes correct values."""
178+
"""Test that `Recall` metric computes correct values."""
245179
y_true = y_true.reshape((-1,))
246180
y_pred = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0)
247-
expected = sklearn_metrics.recall_score(y_true, y_pred)
181+
keras_recall = keras.metrics.Recall(thresholds=threshold)
182+
keras_recall.update_state(y_true, y_pred)
183+
expected = keras_recall.result()
248184

249185
metric = None
250186
for logits, labels in zip(y_pred, y_true):
@@ -260,64 +196,64 @@ def test_recall(self, y_true, y_pred, threshold):
260196
expected,
261197
)
262198

263-
def test_aucpr(self):
264-
"""Test that AUC-PR Metric computes correct values."""
265-
np.testing.assert_allclose(
266-
self.compute_aucpr(self.model_outputs),
267-
jnp.array(0.41513795, dtype=jnp.float32),
268-
)
199+
@parameterized.named_parameters(
200+
('basic', OUTPUT_LABELS, OUTPUT_PREDS, None),
201+
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None),
202+
('weighted', OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS),
203+
)
204+
def test_aucpr(self, y_true, y_pred, sample_weights):
205+
"""Test that `AUC-PR` Metric computes correct values."""
206+
if sample_weights is None:
207+
sample_weights = np.ones_like(y_true)
269208

270-
def test_aucpr_with_sample_weight(self):
271-
"""Test that AUC-PR Metric computes correct values when using sample weights."""
272-
np.testing.assert_allclose(
273-
self.compute_aucpr(self.model_outputs, self.sample_weights),
274-
jnp.array(0.32785615, dtype=jnp.float32),
275-
)
209+
metric = None
210+
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
211+
update = metrax.AUCPR.from_model_output(
212+
predictions=logits,
213+
labels=labels,
214+
sample_weights=weights,
215+
)
216+
metric = update if metric is None else metric.merge(update)
276217

277-
def test_aucpr_with_batch_size_one(self):
278-
"""Test that AUC-PR Metric computes correct values with batch size one."""
218+
keras_aucpr = keras.metrics.AUC(curve='PR')
219+
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
220+
keras_aucpr.update_state(labels, logits, sample_weight=weights)
221+
expected = keras_aucpr.result()
279222
np.testing.assert_allclose(
280-
self.compute_aucpr(self.model_outputs_batch_size_one),
281-
jnp.array(1.0, dtype=jnp.float32),
223+
metric.compute(),
224+
expected,
225+
rtol=1e-07,
226+
atol=1e-07,
282227
)
283228

284-
def test_aucroc(self):
285-
"""Test that AUC-ROC Metric computes correct values."""
286-
# Concatenate logits and labels
287-
all_logits = jnp.concatenate(
288-
[model_output['logits'] for model_output in self.model_outputs]
289-
)
290-
all_labels = jnp.concatenate(
291-
[model_output['labels'] for model_output in self.model_outputs]
292-
)
293-
np.testing.assert_allclose(
294-
self.compute_aucroc(self.model_outputs),
295-
sklearn_metrics.roc_auc_score(all_labels, all_logits),
296-
)
229+
@parameterized.named_parameters(
230+
('basic', OUTPUT_LABELS, OUTPUT_PREDS, None),
231+
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None),
232+
('weighted', OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS),
233+
)
234+
def test_aucroc(self, y_true, y_pred, sample_weights):
235+
"""Test that `AUC-ROC` Metric computes correct values."""
236+
if sample_weights is None:
237+
sample_weights = np.ones_like(y_true)
297238

298-
def test_aucroc_with_sample_weight(self):
299-
"""Test that AUC-ROC Metric computes correct values when using sample weights."""
300-
# Concatenate logits and labels
301-
all_logits = jnp.concatenate(
302-
[model_output['logits'] for model_output in self.model_outputs]
303-
)
304-
all_labels = jnp.concatenate(
305-
[model_output['labels'] for model_output in self.model_outputs]
306-
)
307-
sample_weights = jnp.concatenate(
308-
[self.sample_weights] * len(self.model_outputs)
309-
)
239+
metric = None
240+
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
241+
update = metrax.AUCROC.from_model_output(
242+
predictions=logits,
243+
labels=labels,
244+
sample_weights=weights,
245+
)
246+
metric = update if metric is None else metric.merge(update)
247+
248+
keras_aucroc = keras.metrics.AUC(curve='ROC')
249+
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
250+
keras_aucroc.update_state(labels, logits, sample_weight=weights)
251+
expected = keras_aucroc.result()
310252
np.testing.assert_allclose(
311-
jnp.array(
312-
self.compute_aucroc(self.model_outputs, self.sample_weights),
313-
dtype=jnp.float16,
314-
),
315-
jnp.array(
316-
sklearn_metrics.roc_auc_score(
317-
all_labels, all_logits, sample_weight=sample_weights
318-
),
319-
dtype=jnp.float16,
320-
),
253+
metric.compute(),
254+
expected,
255+
rtol=1e-07,
256+
atol=1e-07,
321257
)
322258

323259
@parameterized.named_parameters(
@@ -326,6 +262,7 @@ def test_aucroc_with_sample_weight(self):
326262
('weighted', OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS),
327263
)
328264
def test_mse(self, y_true, y_pred, sample_weights):
265+
"""Test that `MSE` Metric computes correct values."""
329266
if sample_weights is None:
330267
sample_weights = np.ones_like(y_true)
331268

@@ -338,6 +275,8 @@ def test_mse(self, y_true, y_pred, sample_weights):
338275
)
339276
metric = update if metric is None else metric.merge(update)
340277

278+
# TODO(jiwonshin): Use `keras.metrics.MeanSquaredError` once it supports
279+
# sample weights.
341280
expected = sklearn_metrics.mean_squared_error(
342281
y_true.flatten(),
343282
y_pred.flatten(),
@@ -356,6 +295,7 @@ def test_mse(self, y_true, y_pred, sample_weights):
356295
('weighted', OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS),
357296
)
358297
def test_rmse(self, y_true, y_pred, sample_weights):
298+
"""Test that `RMSE` Metric computes correct values."""
359299
if sample_weights is None:
360300
sample_weights = np.ones_like(y_true)
361301

@@ -368,13 +308,10 @@ def test_rmse(self, y_true, y_pred, sample_weights):
368308
)
369309
metric = update if metric is None else metric.merge(update)
370310

371-
# `sklearn_metrics.root_mean_squared_error` is not available.
372-
expected = jnp.sqrt(
373-
jnp.average(
374-
jnp.square(y_pred.flatten() - y_true.flatten()),
375-
weights=sample_weights.flatten(),
376-
),
377-
)
311+
keras_rmse = keras.metrics.RootMeanSquaredError()
312+
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
313+
keras_rmse.update_state(labels, logits, sample_weight=weights)
314+
expected = keras_rmse.result()
378315
np.testing.assert_allclose(
379316
metric.compute(),
380317
expected,
@@ -388,6 +325,7 @@ def test_rmse(self, y_true, y_pred, sample_weights):
388325
('weighted', OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS),
389326
)
390327
def test_rsquared(self, y_true, y_pred, sample_weights):
328+
"""Test that `RSQUARED` Metric computes correct values."""
391329
if sample_weights is None:
392330
sample_weights = np.ones_like(y_true)
393331

@@ -400,11 +338,14 @@ def test_rsquared(self, y_true, y_pred, sample_weights):
400338
)
401339
metric = update if metric is None else metric.merge(update)
402340

403-
expected = sklearn_metrics.r2_score(
404-
y_true.flatten(),
405-
y_pred.flatten(),
406-
sample_weight=sample_weights.flatten(),
407-
)
341+
keras_r2 = keras.metrics.R2Score()
342+
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
343+
keras_r2.update_state(
344+
labels[:, jnp.newaxis],
345+
logits[:, jnp.newaxis],
346+
sample_weight=weights[:, jnp.newaxis],
347+
)
348+
expected = keras_r2.result()
408349
np.testing.assert_allclose(
409350
metric.compute(),
410351
expected,
@@ -427,6 +368,7 @@ def test_rsquared(self, y_true, y_pred, sample_weights):
427368
),
428369
)
429370
def test_perplexity(self, y_true, y_pred, sample_weights):
371+
"""Test that `Perplexity` Metric computes correct values."""
430372
keras_metric = keras_hub.metrics.Perplexity()
431373
metrax_metric = None
432374
for index, (labels, logits) in enumerate(zip(y_true, y_pred)):
@@ -454,4 +396,4 @@ def test_perplexity(self, y_true, y_pred, sample_weights):
454396
os.environ['XLA_FLAGS'] = (
455397
'--xla_force_host_platform_device_count=4' # Use 4 CPU devices
456398
)
457-
absltest.main()
399+
absltest.main()

0 commit comments

Comments
 (0)