Skip to content

Commit 24687eb

Browse files
committed
add ranking_metrics to metrax
1 parent 75d6f7e commit 24687eb

File tree

3 files changed

+209
-4
lines changed

3 files changed

+209
-4
lines changed

src/metrax/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,23 @@
2121
from metrax.nlp_metrics import (
2222
Perplexity,
2323
)
24+
from metrax.ranking_metrics import (
25+
AveragePrecisionAtK,
26+
)
2427
from metrax.regression_metrics import (
2528
MSE,
2629
RMSE,
2730
RSQUARED,
2831
)
2932

3033
__all__ = [
34+
"AUCPR",
35+
"AUCROC",
36+
"AveragePrecisionAtK",
3137
"MSE",
32-
"RMSE",
33-
"RSQUARED",
3438
"Perplexity",
3539
"Precision",
3640
"Recall",
37-
"AUCPR",
38-
"AUCROC",
41+
"RMSE",
42+
"RSQUARED",
3943
]

src/metrax/ranking_metrics.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A collection of different metrics for ranking models."""
16+
17+
from clu import metrics as clu_metrics
18+
import flax
19+
import jax
20+
import jax.numpy as jnp
21+
22+
23+
def _divide_no_nan(x: jax.Array, y: jax.Array) -> jax.Array:
24+
"""Computes a safe divide which returns 0 if the y is zero."""
25+
return jnp.where(y != 0, jnp.divide(x, y), 0.0)
26+
27+
28+
@flax.struct.dataclass
29+
class AveragePrecisionAtK(clu_metrics.Average):
30+
r"""Computes AP@k (average precision at k) metrics in JAX.
31+
32+
Average precision at k (AP@k) is a metric used to evaluate the performance of
33+
ranking models. It measures the sum of precision at k where the item at
34+
the kth rank is relevant, divided by the total number of relevant items.
35+
36+
Given the top :math:`K` recommendations, AP@K is calculated as:
37+
38+
.. math::
39+
AP@K = frac{1}{r}\sum_{k=1}^{K} \Precision@k * \rel(k)
40+
rel(k) =
41+
\begin{cases}
42+
1 & \text{if the item at rank } k \text{ is relevant} \\
43+
0 & \text{otherwise}
44+
\end{cases}
45+
"""
46+
47+
@classmethod
48+
def average_precision_at_ks(
49+
cls, predictions: jax.Array, labels: jax.Array, ks: jax.Array
50+
):
51+
"""Computes AP@k (average precision at k) metrics for each of k in ks.
52+
53+
Args:
54+
predictions: A floating point 2D vector representing the prediction
55+
generated from the model. The shape should be (batch_size, vocab_size).
56+
labels: A multi-hot encoding of the true label. The shape should be
57+
(batch_size, vocab_size).
58+
ks: A 1D vector of integers representing the k's to compute the MAP@k
59+
metrics. The shape should be (|ks|).
60+
61+
Returns:
62+
Rank-2 tensor of shape [batch, |ks|] containing AP@k metrics.
63+
"""
64+
top_k_indices = jnp.argsort(-predictions, axis=1)[:, : jnp.max(ks)]
65+
labels = jnp.array(labels >= 1, dtype=jnp.float32)
66+
total_relevant = labels.sum(axis=1)
67+
68+
def compute_ap_at_k_single(relevant_labels, total_relevant, ks):
69+
cumulative_precision = jnp.where(
70+
relevant_labels,
71+
_divide_no_nan(
72+
jnp.cumsum(relevant_labels),
73+
jnp.arange(1, len(relevant_labels) + 1),
74+
),
75+
0,
76+
)
77+
return jnp.array([
78+
_divide_no_nan(jnp.sum(cumulative_precision[:k]), total_relevant)
79+
for k in ks
80+
])
81+
82+
vmap_compute_ap_at_k = jax.vmap(
83+
compute_ap_at_k_single, in_axes=(0, 0, None), out_axes=0
84+
)
85+
86+
ap_at_ks = vmap_compute_ap_at_k(
87+
jnp.take_along_axis(labels, top_k_indices, axis=1), total_relevant, ks
88+
)
89+
return ap_at_ks
90+
91+
@classmethod
92+
def from_model_output(
93+
cls,
94+
predictions: jax.Array,
95+
labels: jax.Array,
96+
ks: jax.Array,
97+
) -> 'AveragePrecisionAtK':
98+
"""Updates the metric.
99+
100+
Args:
101+
predictions: A floating point 2D vector representing the prediction
102+
generated from the model. The shape should be (batch_size, vocab_size).
103+
labels: A multi-hot encoding of the true label. The shape should be
104+
(batch_size, vocab_size).
105+
ks: A 1D vector of integers representing the k's to compute the MAP@k
106+
metrics. The shape should be (|ks|).
107+
108+
Returns:
109+
The AveragePrecisionAtK metric. The shape should be (|ks|).
110+
111+
Raises:
112+
ValueError: If type of `labels` is wrong or the shapes of `predictions`
113+
and `labels` are incompatible.
114+
"""
115+
ap_at_ks = cls.average_precision_at_ks(predictions, labels, ks)
116+
count = jnp.ones((labels.shape[0], 1), dtype=jnp.float32)
117+
return cls(
118+
total=ap_at_ks.sum(axis=0),
119+
count=count.sum(),
120+
)

src/metrax/ranking_metrics_test.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for metrax ranking metrics."""
16+
17+
from absl.testing import absltest
18+
from absl.testing import parameterized
19+
import jax.numpy as jnp
20+
import metrax
21+
import numpy as np
22+
23+
np.random.seed(42)
24+
BATCH_SIZE = 4
25+
VOCAB_SIZE = 8
26+
OUTPUT_LABELS = np.random.randint(
27+
0,
28+
2,
29+
size=(BATCH_SIZE, VOCAB_SIZE),
30+
).astype(np.float32)
31+
OUTPUT_PREDS = np.random.uniform(size=(BATCH_SIZE, VOCAB_SIZE)).astype(
32+
np.float32
33+
)
34+
OUTPUT_LABELS_VS1 = np.random.randint(
35+
0,
36+
2,
37+
size=(BATCH_SIZE, 1),
38+
).astype(np.float32)
39+
OUTPUT_PREDS_VS1 = np.random.uniform(size=(BATCH_SIZE, 1)).astype(np.float32)
40+
# TODO(jiwonshin): Replace with keras metric once it is available in OSS.
41+
MAP_FROM_KERAS = np.array([
42+
0.2083333432674408,
43+
0.4791666865348816,
44+
0.4791666865348816,
45+
0.5416666865348816,
46+
0.574999988079071,
47+
0.637499988079071,
48+
])
49+
MAP_FROM_KERAS_VS1 = np.array([0.75, 0.75, 0.75, 0.75, 0.75, 0.75])
50+
51+
52+
class RankingMetricsTest(parameterized.TestCase):
53+
54+
@parameterized.named_parameters(
55+
('basic', OUTPUT_LABELS, OUTPUT_PREDS, MAP_FROM_KERAS),
56+
(
57+
'vocab_size_one',
58+
OUTPUT_LABELS_VS1,
59+
OUTPUT_PREDS_VS1,
60+
MAP_FROM_KERAS_VS1,
61+
),
62+
)
63+
def test_averageprecisionatk(self, y_true, y_pred, map_from_keras):
64+
"""Test that `AveragePrecisionAtK` Metric computes correct values."""
65+
ks = jnp.array([1, 2, 3, 4, 5, 6])
66+
metric = metrax.AveragePrecisionAtK.from_model_output(
67+
predictions=y_pred,
68+
labels=y_true,
69+
ks=ks,
70+
)
71+
72+
np.testing.assert_allclose(
73+
metric.compute(),
74+
map_from_keras,
75+
rtol=1e-05,
76+
atol=1e-05,
77+
)
78+
79+
80+
if __name__ == '__main__':
81+
absltest.main()

0 commit comments

Comments
 (0)