Skip to content

Commit aa9431f

Browse files
committed
add base_metrics to metrax
1 parent 7c875d7 commit aa9431f

File tree

7 files changed

+329
-121
lines changed

7 files changed

+329
-121
lines changed

src/metrax/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from metrax.base_metrics import (
16+
Average,
17+
)
1518
from metrax.classification_metrics import (
1619
AUCPR,
1720
AUCROC,
@@ -34,6 +37,7 @@
3437
__all__ = [
3538
"AUCPR",
3639
"AUCROC",
40+
"Average",
3741
"AveragePrecisionAtK",
3842
"MSE",
3943
"Perplexity",

src/metrax/base_metrics.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A collection of base metrics for metrax."""
16+
17+
from clu import metrics as clu_metrics
18+
import flax
19+
import jax
20+
import jax.numpy as jnp
21+
22+
23+
def divide_no_nan(x: jax.Array, y: jax.Array) -> jax.Array:
24+
"""Computes a safe divide which returns 0 if the y is zero."""
25+
return jnp.where(y != 0, jnp.divide(x, y), 0.0)
26+
27+
28+
@flax.struct.dataclass
29+
class Average(clu_metrics.Average):
30+
r"""Average Metric inherits clu.metrics.Average and performs safe division."""
31+
32+
@classmethod
33+
def from_model_output(
34+
cls,
35+
values: jax.Array,
36+
sample_weights: jax.Array | None = None,
37+
) -> 'Average':
38+
"""Updates the metric.
39+
40+
Args:
41+
values: A floating point 1D vector representing the values. The shape
42+
should be (batch_size,).
43+
sample_weights: An optional floating point 1D vector representing the
44+
weight of each sample. The shape should be (batch_size,).
45+
46+
Returns:
47+
Updated Average metric.
48+
"""
49+
total = values
50+
count = jnp.ones_like(values, dtype=values.dtype)
51+
if sample_weights is not None:
52+
total = values * sample_weights
53+
count = count * sample_weights
54+
return cls(
55+
total=total.sum(),
56+
count=count.sum(),
57+
)
58+
59+
def compute(self) -> jax.Array:
60+
return divide_no_nan(self.total, self.count)

src/metrax/base_metrics_test.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for metrax base metrics."""
16+
17+
from absl.testing import absltest
18+
from absl.testing import parameterized
19+
import jax.numpy as jnp
20+
import keras
21+
import metrax
22+
from metrax import base_metrics
23+
import numpy as np
24+
25+
np.random.seed(42)
26+
BATCHES = 4
27+
BATCH_SIZE = 8
28+
OUTPUT = np.random.uniform(size=(BATCHES, BATCH_SIZE))
29+
OUTPUT_F16 = OUTPUT.astype(jnp.float16)
30+
OUTPUT_F32 = OUTPUT.astype(jnp.float32)
31+
OUTPUT_BF16 = OUTPUT.astype(jnp.bfloat16)
32+
OUTPUT_BS1 = np.random.uniform(size=(BATCHES, 1)).astype(jnp.float32)
33+
SAMPLE_WEIGHTS = np.tile(
34+
[0.5, 1, 0, 0, 0, 0, 0, 0],
35+
(BATCHES, 1),
36+
)
37+
38+
39+
class BaseMetricsTest(parameterized.TestCase):
40+
41+
def test_basic_division(self):
42+
x = jnp.array([10.0, 20.0, 30.0])
43+
y = jnp.array([2.0, 4.0, 5.0])
44+
expected = jnp.array([5.0, 5.0, 6.0])
45+
result = base_metrics.divide_no_nan(x, y)
46+
self.assertTrue(jnp.array_equal(result, expected))
47+
48+
def test_division_by_zero(self):
49+
x = jnp.array([10.0, 20.0, 30.0])
50+
y = jnp.array([2.0, 0.0, 5.0])
51+
expected = jnp.array([5.0, 0.0, 6.0])
52+
result = base_metrics.divide_no_nan(x, y)
53+
self.assertTrue(jnp.array_equal(result, expected))
54+
55+
def test_all_zeros_denominator(self):
56+
x = jnp.array([10.0, 20.0, 30.0])
57+
y = jnp.array([0.0, 0.0, 0.0])
58+
expected = jnp.array([0.0, 0.0, 0.0])
59+
result = base_metrics.divide_no_nan(x, y)
60+
self.assertTrue(jnp.array_equal(result, expected))
61+
62+
def test_all_zeros_numerator(self):
63+
x = jnp.array([0.0, 0.0, 0.0])
64+
y = jnp.array([2.0, 4.0, 5.0])
65+
expected = jnp.array([0.0, 0.0, 0.0])
66+
result = base_metrics.divide_no_nan(x, y)
67+
self.assertTrue(jnp.array_equal(result, expected))
68+
69+
def test_mixed_zeros(self):
70+
x = jnp.array([10.0, 0.0, 30.0, 0.0])
71+
y = jnp.array([2.0, 0.0, 5.0, 4.0])
72+
expected = jnp.array([5.0, 0.0, 6.0, 0.0])
73+
result = base_metrics.divide_no_nan(x, y)
74+
self.assertTrue(jnp.array_equal(result, expected))
75+
76+
def test_scalar_inputs(self):
77+
x = jnp.array(10.0)
78+
y = jnp.array(2.0)
79+
expected = jnp.array(5.0)
80+
result = base_metrics.divide_no_nan(x, y)
81+
self.assertTrue(jnp.array_equal(result, expected))
82+
83+
def test_scalar_denominator_zero(self):
84+
x = jnp.array(10.0)
85+
y = jnp.array(0.0)
86+
expected = jnp.array(0.0)
87+
result = base_metrics.divide_no_nan(x, y)
88+
self.assertTrue(jnp.array_equal(result, expected))
89+
90+
def test_negative_values(self):
91+
x = jnp.array([-10.0, 20.0, -30.0])
92+
y = jnp.array([2.0, -4.0, 5.0])
93+
expected = jnp.array([-5.0, -5.0, -6.0])
94+
result = base_metrics.divide_no_nan(x, y)
95+
self.assertTrue(jnp.array_equal(result, expected))
96+
97+
def test_negative_and_zero_values(self):
98+
x = jnp.array([-10.0, 20.0, -30.0, 10.0])
99+
y = jnp.array([2.0, -4.0, 0.0, 0.0])
100+
expected = jnp.array([-5.0, -5.0, 0.0, 0.0])
101+
result = base_metrics.divide_no_nan(x, y)
102+
self.assertTrue(jnp.array_equal(result, expected))
103+
104+
@parameterized.named_parameters(
105+
('basic_f16', OUTPUT_F16, None),
106+
('basic_f32', OUTPUT_F32, None),
107+
('basic_bf16', OUTPUT_BF16, None),
108+
('batch_size_one', OUTPUT_BS1, None),
109+
('weighted_f16', OUTPUT_F16, SAMPLE_WEIGHTS),
110+
('weighted_f32', OUTPUT_F32, SAMPLE_WEIGHTS),
111+
('weighted_bf16', OUTPUT_BF16, SAMPLE_WEIGHTS),
112+
)
113+
def test_average(self, values, sample_weights):
114+
"""Test that `Average` metric computes correct values."""
115+
if sample_weights is None:
116+
sample_weights = jnp.ones_like(values)
117+
sample_weights = jnp.array(sample_weights, dtype=values.dtype)
118+
metric = metrax.Average.from_model_output(
119+
values=values,
120+
sample_weights=sample_weights,
121+
)
122+
123+
keras_mean = keras.metrics.Mean(dtype=values.dtype)
124+
keras_mean.update_state(values, sample_weights)
125+
keras_metrics = keras_mean.result()
126+
keras_metrics = jnp.array(keras_metrics, dtype=values.dtype)
127+
128+
# Use lower tolerance for lower precision dtypes.
129+
rtol = 1e-2 if values.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
130+
atol = 1e-2 if values.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
131+
np.testing.assert_allclose(
132+
metric.compute(),
133+
keras_metrics,
134+
rtol=rtol,
135+
atol=atol,
136+
)
137+
138+
139+
if __name__ == '__main__':
140+
absltest.main()

src/metrax/classification_metrics.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import flax
1919
import jax
2020
import jax.numpy as jnp
21+
from metrax import base_metrics
2122

2223

2324
def _default_threshold(num_thresholds: int) -> jax.Array:
@@ -42,11 +43,6 @@ def _default_threshold(num_thresholds: int) -> jax.Array:
4243
return thresholds
4344

4445

45-
def _divide_no_nan(x: jax.Array, y: jax.Array) -> jax.Array:
46-
"""Computes a safe divide which returns 0 if the y is zero."""
47-
return jnp.where(y != 0, jnp.divide(x, y), 0.0)
48-
49-
5046
@flax.struct.dataclass
5147
class Precision(clu_metrics.Metric):
5248
r"""Computes precision for binary classification given `predictions` and `labels`.
@@ -116,7 +112,7 @@ def merge(self, other: 'Precision') -> 'Precision':
116112
)
117113

118114
def compute(self) -> jax.Array:
119-
return _divide_no_nan(
115+
return base_metrics.divide_no_nan(
120116
self.true_positives, (self.true_positives + self.false_positives)
121117
)
122118

@@ -187,7 +183,7 @@ def merge(self, other: 'Recall') -> 'Recall':
187183
)
188184

189185
def compute(self) -> jax.Array:
190-
return _divide_no_nan(
186+
return base_metrics.divide_no_nan(
191187
self.true_positives, (self.true_positives + self.false_negatives)
192188
)
193189

@@ -365,20 +361,20 @@ def interpolate_pr_auc(self) -> jax.Array:
365361
)
366362
p = self.true_positives + self.false_positives
367363
dp = p[: self.num_thresholds - 1] - p[1:]
368-
prec_slope = _divide_no_nan(dtp, jnp.maximum(dp, 0))
364+
prec_slope = base_metrics.divide_no_nan(dtp, jnp.maximum(dp, 0))
369365
intercept = self.true_positives[1:] - prec_slope * p[1:]
370366

371367
# recall_relative_ratio
372368
safe_p_ratio = jnp.where(
373369
jnp.multiply(p[: self.num_thresholds - 1] > 0, p[1:] > 0),
374-
_divide_no_nan(
370+
base_metrics.divide_no_nan(
375371
p[: self.num_thresholds - 1],
376372
jnp.maximum(p[1:], 0),
377373
),
378374
jnp.ones_like(p[1:]),
379375
)
380376
# pr_auc_increment
381-
pr_auc_increment = _divide_no_nan(
377+
pr_auc_increment = base_metrics.divide_no_nan(
382378
prec_slope * (dtp + intercept * jnp.log(safe_p_ratio)),
383379
jnp.maximum(self.true_positives[1:] + self.false_negatives[1:], 0),
384380
)
@@ -506,10 +502,10 @@ def merge(self, other: 'AUCROC') -> 'AUCROC':
506502
)
507503

508504
def compute(self) -> jax.Array:
509-
tp_rate = _divide_no_nan(
505+
tp_rate = base_metrics.divide_no_nan(
510506
self.true_positives, self.true_positives + self.false_negatives
511507
)
512-
fp_rate = _divide_no_nan(
508+
fp_rate = base_metrics.divide_no_nan(
513509
self.false_positives, self.false_positives + self.true_negatives
514510
)
515511
# Threshold goes from 0 to 1, so trapezoid is negative.

0 commit comments

Comments
 (0)