Skip to content

Commit 6581150

Browse files
committed
add nnx wrappers for metrax metrics
1 parent 8bb69ae commit 6581150

File tree

4 files changed

+161
-5
lines changed

4 files changed

+161
-5
lines changed

src/metrax/nnx/__init__.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,30 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from metrax.nnx.nnx_wrapper import (
16-
NnxWrapper,
15+
from metrax.nnx.nnx_metrics import (
16+
AUCPR,
17+
AUCROC,
18+
Average,
19+
AveragePrecisionAtK,
20+
MSE,
21+
Perplexity,
22+
Precision,
23+
RMSE,
24+
RSQUARED,
25+
Recall,
26+
WER,
1727
)
1828

1929
__all__ = [
20-
"NnxWrapper",
30+
"AUCPR",
31+
"AUCROC",
32+
"Average",
33+
"AveragePrecisionAtK",
34+
"MSE",
35+
"Perplexity",
36+
"Precision",
37+
"Recall",
38+
"RMSE",
39+
"RSQUARED",
40+
"WER",
2141
]

src/metrax/nnx/nnx_metric.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import metrax
16+
from metrax.nnx import nnx_wrapper
17+
18+
NnxWrapper = nnx_wrapper.NnxWrapper
19+
20+
21+
class AUCPR(NnxWrapper):
22+
"""An NNX class for the Metrax metric AUCPR."""
23+
24+
def __init__(self):
25+
super().__init__(metrax.AUCPR)
26+
27+
28+
class AUCROC(NnxWrapper):
29+
"""An NNX class for the Metrax metric AUCROC."""
30+
31+
def __init__(self):
32+
super().__init__(metrax.AUCROC)
33+
34+
35+
class Average(NnxWrapper):
36+
"""An NNX class for the Metrax metric Average."""
37+
38+
def __init__(self):
39+
super().__init__(metrax.Average)
40+
41+
42+
class AveragePrecisionAtK(NnxWrapper):
43+
"""An NNX class for the Metrax metric AveragePrecisionAtK."""
44+
45+
def __init__(self):
46+
super().__init__(metrax.AveragePrecisionAtK)
47+
48+
49+
class MSE(NnxWrapper):
50+
"""An NNX class for the Metrax metric MSE."""
51+
52+
def __init__(self):
53+
super().__init__(metrax.MSE)
54+
55+
56+
class Perplexity(NnxWrapper):
57+
"""An NNX class for the Metrax metric Perplexity."""
58+
59+
def __init__(self):
60+
super().__init__(metrax.Perplexity)
61+
62+
63+
class Precision(NnxWrapper):
64+
"""An NNX class for the Metrax metric Precision."""
65+
66+
def __init__(self):
67+
super().__init__(metrax.Precision)
68+
69+
70+
class Recall(NnxWrapper):
71+
"""An NNX class for the Metrax metric Recall."""
72+
73+
def __init__(self):
74+
super().__init__(metrax.Recall)
75+
76+
77+
class RMSE(NnxWrapper):
78+
"""An NNX class for the Metrax metric RMSE."""
79+
80+
def __init__(self):
81+
super().__init__(metrax.RMSE)
82+
83+
84+
class RSQUARED(NnxWrapper):
85+
"""An NNX class for the Metrax metric RSQUARED."""
86+
87+
def __init__(self):
88+
super().__init__(metrax.RSQUARED)
89+
90+
91+
class WER(NnxWrapper):
92+
"""An NNX class for the Metrax metric WER."""
93+
94+
def __init__(self):
95+
super().__init__(metrax.WER)

src/metrax/nnx/nnx_metric_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for metrax NNX metrics."""
16+
17+
import importlib
18+
import inspect
19+
import pkgutil
20+
from absl.testing import absltest
21+
from absl.testing import parameterized
22+
from clu import metrics as clu_metrics
23+
import metrax
24+
import metrax.nnx
25+
26+
27+
class NnxMetricsTest(parameterized.TestCase):
28+
29+
def test_nnx_metrics_exists(self):
30+
"""Tests that every metrax CLU metric has an NNX counterpart."""
31+
metrax_metrics = metrax.nnx.__all__
32+
for _, module_name, _ in pkgutil.iter_modules(metrax.__path__):
33+
full_module_name = f"{metrax.__name__}.{module_name}"
34+
module = importlib.import_module(full_module_name)
35+
for name, obj in inspect.getmembers(module):
36+
if inspect.isclass(obj) and issubclass(obj, clu_metrics.Metric):
37+
self.assertIn(name, metrax_metrics)
38+
39+
40+
if __name__ == "__main__":
41+
absltest.main()

src/metrax/nnx/nnx_wrapper_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class NnxWrapperTest(parameterized.TestCase):
5050

5151
def test_reset(self):
5252
"""Tests the `reset` method of the `NnxWrapper` class."""
53-
nnx_metric = metrax.nnx.NnxWrapper(metrax.MSE)
53+
nnx_metric = metrax.nnx.MSE()
5454
self.assertEqual(nnx_metric.clu_metric.total, jnp.array(0, jnp.float32))
5555
self.assertEqual(nnx_metric.clu_metric.count, jnp.array(0, jnp.int32))
5656
nnx_metric.update(
@@ -78,7 +78,7 @@ def test_metric_update_and_compute(self, y_true, y_pred, sample_weights):
7878
if sample_weights is None:
7979
sample_weights = np.ones_like(y_true)
8080

81-
nnx_metric = metrax.nnx.NnxWrapper(metrax.MSE)
81+
nnx_metric = metrax.nnx.MSE()
8282
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
8383
nnx_metric.update(
8484
predictions=logits,

0 commit comments

Comments
 (0)