Skip to content

Commit ebc75bd

Browse files
committed
created metrax_test which tests metrax metrics are jittable
1 parent c40c5a0 commit ebc75bd

File tree

1 file changed

+128
-0
lines changed

1 file changed

+128
-0
lines changed

src/metrax/metrax_test.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for metrax NNX metrics."""
16+
from absl.testing import absltest
17+
from absl.testing import parameterized
18+
import jax
19+
import metrax
20+
import metrax.nnx
21+
import numpy as np
22+
23+
np.random.seed(42)
24+
BATCHES = 1
25+
BATCH_SIZE = 8
26+
OUTPUT_LABELS = np.random.randint(
27+
0,
28+
2,
29+
size=(BATCHES, BATCH_SIZE),
30+
).astype(np.float32)
31+
OUTPUT_PREDS = np.random.uniform(size=(BATCHES, BATCH_SIZE))
32+
33+
STRING_PREDS = [
34+
'the cat sat on the mat',
35+
'a quick brown fox jumps over the lazy dog',
36+
'hello world',
37+
]
38+
STRING_REFS = [
39+
'the cat sat on the hat',
40+
'the quick brown fox jumps over the lazy dog',
41+
'hello beautiful world',
42+
]
43+
TOKENIZED_PREDS = [sentence.split() for sentence in STRING_PREDS]
44+
TOKENIZED_REFS = [sentence.split() for sentence in STRING_REFS]
45+
46+
47+
class MetraxTest(parameterized.TestCase):
48+
49+
@parameterized.named_parameters(
50+
(
51+
'aucpr',
52+
metrax.AUCPR,
53+
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
54+
),
55+
(
56+
'aucroc',
57+
metrax.AUCROC,
58+
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
59+
),
60+
(
61+
'average',
62+
metrax.Average,
63+
{'values': OUTPUT_PREDS},
64+
),
65+
(
66+
'averageprecisionatk',
67+
metrax.AveragePrecisionAtK,
68+
{
69+
'predictions': OUTPUT_LABELS,
70+
'labels': OUTPUT_PREDS,
71+
'ks': np.array([3]),
72+
},
73+
),
74+
(
75+
'mse',
76+
metrax.MSE,
77+
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
78+
),
79+
(
80+
'perplexity',
81+
metrax.Perplexity,
82+
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
83+
),
84+
(
85+
'precision',
86+
metrax.Precision,
87+
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
88+
),
89+
(
90+
'rmse',
91+
metrax.RMSE,
92+
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
93+
),
94+
(
95+
'rsquared',
96+
metrax.RSQUARED,
97+
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
98+
),
99+
(
100+
'recall',
101+
metrax.Recall,
102+
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
103+
),
104+
)
105+
def test_metrics_jittable(self, metric, kwargs):
106+
"""Tests that jitted metrax metric yields the same result as non-jitted metric."""
107+
computed_metric = metric.from_model_output(**kwargs)
108+
jitted_metric = jax.jit(metric.from_model_output)(**kwargs)
109+
np.testing.assert_allclose(
110+
computed_metric.compute(), jitted_metric.compute()
111+
)
112+
113+
@parameterized.named_parameters(
114+
(
115+
'wer',
116+
metrax.WER,
117+
{'predictions': TOKENIZED_PREDS, 'references': TOKENIZED_REFS},
118+
),
119+
)
120+
def test_metrics_not_jittable(self, metric, kwargs):
121+
"""Tests that attempting to jit and call a known non-jittable metric raises an error."""
122+
np.testing.assert_raises(
123+
TypeError, lambda: jax.jit(metric.from_model_output)(**kwargs)
124+
)
125+
126+
127+
if __name__ == '__main__':
128+
absltest.main()

0 commit comments

Comments
 (0)