Skip to content

Commit 75d6f7e

Browse files
committed
split metrics_test into module tests
1 parent 6fb1bf3 commit 75d6f7e

File tree

4 files changed

+545
-463
lines changed

4 files changed

+545
-463
lines changed
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for metrax classaification metrics."""
16+
17+
from absl.testing import absltest
18+
from absl.testing import parameterized
19+
import jax.numpy as jnp
20+
import keras
21+
import metrax
22+
import numpy as np
23+
24+
np.random.seed(42)
25+
BATCHES = 4
26+
BATCH_SIZE = 8
27+
OUTPUT_LABELS = np.random.randint(
28+
0,
29+
2,
30+
size=(BATCHES, BATCH_SIZE),
31+
).astype(np.float32)
32+
OUTPUT_PREDS = np.random.uniform(size=(BATCHES, BATCH_SIZE))
33+
OUTPUT_PREDS_F16 = OUTPUT_PREDS.astype(jnp.float16)
34+
OUTPUT_PREDS_F32 = OUTPUT_PREDS.astype(jnp.float32)
35+
OUTPUT_PREDS_BF16 = OUTPUT_PREDS.astype(jnp.bfloat16)
36+
OUTPUT_LABELS_BS1 = np.random.randint(
37+
0,
38+
2,
39+
size=(BATCHES, 1),
40+
).astype(np.float32)
41+
OUTPUT_PREDS_BS1 = np.random.uniform(size=(BATCHES, 1)).astype(np.float32)
42+
SAMPLE_WEIGHTS = np.tile(
43+
[0.5, 1, 0, 0, 0, 0, 0, 0],
44+
(BATCHES, 1),
45+
).astype(np.float32)
46+
47+
48+
class ClassificationMetricsTest(parameterized.TestCase):
49+
50+
def test_precision_empty(self):
51+
"""Tests the `empty` method of the `Precision` class."""
52+
m = metrax.Precision.empty()
53+
self.assertEqual(m.true_positives, jnp.array(0, jnp.float32))
54+
self.assertEqual(m.false_positives, jnp.array(0, jnp.float32))
55+
56+
def test_recall_empty(self):
57+
"""Tests the `empty` method of the `Recall` class."""
58+
m = metrax.Recall.empty()
59+
self.assertEqual(m.true_positives, jnp.array(0, jnp.float32))
60+
self.assertEqual(m.false_negatives, jnp.array(0, jnp.float32))
61+
62+
def test_aucpr_empty(self):
63+
"""Tests the `empty` method of the `AUCPR` class."""
64+
m = metrax.AUCPR.empty()
65+
self.assertEqual(m.true_positives, jnp.array(0, jnp.float32))
66+
self.assertEqual(m.false_positives, jnp.array(0, jnp.float32))
67+
self.assertEqual(m.false_negatives, jnp.array(0, jnp.float32))
68+
self.assertEqual(m.num_thresholds, 0)
69+
70+
def test_aucroc_empty(self):
71+
"""Tests the `empty` method of the `AUCROC` class."""
72+
m = metrax.AUCROC.empty()
73+
self.assertEqual(m.true_positives, jnp.array(0, jnp.float32))
74+
self.assertEqual(m.true_negatives, jnp.array(0, jnp.float32))
75+
self.assertEqual(m.false_positives, jnp.array(0, jnp.float32))
76+
self.assertEqual(m.false_negatives, jnp.array(0, jnp.float32))
77+
self.assertEqual(m.num_thresholds, 0)
78+
79+
@parameterized.named_parameters(
80+
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5),
81+
('high_threshold_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.7),
82+
('low_threshold_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.1),
83+
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.5),
84+
('high_threshold_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.7),
85+
('low_threshold_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.1),
86+
('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.5),
87+
('high_threshold_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.7),
88+
('low_threshold_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1),
89+
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5),
90+
)
91+
def test_precision(self, y_true, y_pred, threshold):
92+
"""Test that `Precision` metric computes correct values."""
93+
y_true = y_true.reshape((-1,))
94+
y_pred = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0)
95+
keras_precision = keras.metrics.Precision(thresholds=threshold)
96+
keras_precision.update_state(y_true, y_pred)
97+
expected = keras_precision.result()
98+
99+
metric = None
100+
for logits, labels in zip(y_pred, y_true):
101+
update = metrax.Precision.from_model_output(
102+
predictions=logits,
103+
labels=labels,
104+
threshold=threshold,
105+
)
106+
metric = update if metric is None else metric.merge(update)
107+
108+
# Use lower tolerance for lower precision dtypes.
109+
rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-5
110+
atol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-5
111+
np.testing.assert_allclose(
112+
metric.compute(),
113+
expected,
114+
rtol=rtol,
115+
atol=atol,
116+
)
117+
118+
@parameterized.named_parameters(
119+
('basic', OUTPUT_LABELS, OUTPUT_PREDS, 0.5),
120+
('high_threshold', OUTPUT_LABELS, OUTPUT_PREDS, 0.7),
121+
('low_threshold', OUTPUT_LABELS, OUTPUT_PREDS, 0.1),
122+
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.5),
123+
('high_threshold_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.7),
124+
('low_threshold_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.1),
125+
('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.5),
126+
('high_threshold_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.7),
127+
('low_threshold_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1),
128+
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5),
129+
)
130+
def test_recall(self, y_true, y_pred, threshold):
131+
"""Test that `Recall` metric computes correct values."""
132+
y_true = y_true.reshape((-1,))
133+
y_pred = jnp.where(y_pred.reshape((-1,)) >= threshold, 1, 0)
134+
keras_recall = keras.metrics.Recall(thresholds=threshold)
135+
keras_recall.update_state(y_true, y_pred)
136+
expected = keras_recall.result()
137+
138+
metric = None
139+
for logits, labels in zip(y_pred, y_true):
140+
update = metrax.Recall.from_model_output(
141+
predictions=logits,
142+
labels=labels,
143+
threshold=threshold,
144+
)
145+
metric = update if metric is None else metric.merge(update)
146+
147+
np.testing.assert_allclose(
148+
metric.compute(),
149+
expected,
150+
)
151+
152+
@parameterized.product(
153+
inputs=(
154+
(OUTPUT_LABELS, OUTPUT_PREDS, None),
155+
(OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None),
156+
(OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS),
157+
),
158+
dtype=(
159+
jnp.float16,
160+
jnp.float32,
161+
jnp.bfloat16,
162+
),
163+
)
164+
def test_aucpr(self, inputs, dtype):
165+
"""Test that `AUC-PR` Metric computes correct values."""
166+
y_true, y_pred, sample_weights = inputs
167+
y_true = y_true.astype(dtype)
168+
y_pred = y_pred.astype(dtype)
169+
if sample_weights is None:
170+
sample_weights = np.ones_like(y_true)
171+
172+
metric = None
173+
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
174+
update = metrax.AUCPR.from_model_output(
175+
predictions=logits,
176+
labels=labels,
177+
sample_weights=weights,
178+
)
179+
metric = update if metric is None else metric.merge(update)
180+
181+
keras_aucpr = keras.metrics.AUC(curve='PR')
182+
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
183+
keras_aucpr.update_state(labels, logits, sample_weight=weights)
184+
expected = keras_aucpr.result()
185+
np.testing.assert_allclose(
186+
metric.compute(),
187+
expected,
188+
# Use lower tolerance for lower precision dtypes.
189+
rtol=1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-5,
190+
atol=1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-5,
191+
)
192+
193+
@parameterized.product(
194+
inputs=(
195+
(OUTPUT_LABELS, OUTPUT_PREDS, None),
196+
(OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None),
197+
(OUTPUT_LABELS, OUTPUT_PREDS, SAMPLE_WEIGHTS),
198+
),
199+
dtype=(
200+
jnp.float16,
201+
jnp.float32,
202+
jnp.bfloat16,
203+
),
204+
)
205+
def test_aucroc(self, inputs, dtype):
206+
"""Test that `AUC-ROC` Metric computes correct values."""
207+
y_true, y_pred, sample_weights = inputs
208+
y_true = y_true.astype(dtype)
209+
y_pred = y_pred.astype(dtype)
210+
if sample_weights is None:
211+
sample_weights = np.ones_like(y_true)
212+
213+
metric = None
214+
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
215+
update = metrax.AUCROC.from_model_output(
216+
predictions=logits,
217+
labels=labels,
218+
sample_weights=weights,
219+
)
220+
metric = update if metric is None else metric.merge(update)
221+
222+
keras_aucroc = keras.metrics.AUC(curve='ROC')
223+
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
224+
keras_aucroc.update_state(labels, logits, sample_weight=weights)
225+
expected = keras_aucroc.result()
226+
np.testing.assert_allclose(
227+
metric.compute(),
228+
expected,
229+
# Use lower tolerance for lower precision dtypes.
230+
rtol=1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-7,
231+
atol=1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-7,
232+
)
233+
234+
235+
if __name__ == '__main__':
236+
absltest.main()

0 commit comments

Comments
 (0)