Skip to content

Commit cae0c2a

Browse files
authored
add recallAtK to metrax (#73)
* add recallAtK to metrax * add back test files * remove in JAX from docstring
1 parent 8da455b commit cae0c2a

File tree

6 files changed

+224
-39
lines changed

6 files changed

+224
-39
lines changed

src/metrax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
RMSE = regression_metrics.RMSE
3131
RSQUARED = regression_metrics.RSQUARED
3232
Recall = classification_metrics.Recall
33+
RecallAtK = ranking_metrics.RecallAtK
3334
RougeL = nlp_metrics.RougeL
3435
RougeN = nlp_metrics.RougeN
3536
WER = nlp_metrics.WER
@@ -48,6 +49,7 @@
4849
"RMSE",
4950
"RSQUARED",
5051
"Recall",
52+
"RecallAtK",
5153
"RougeL",
5254
"RougeN",
5355
"WER",

src/metrax/metrax_test.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
size=(BATCHES, BATCH_SIZE),
3030
).astype(np.float32)
3131
OUTPUT_PREDS = np.random.uniform(size=(BATCHES, BATCH_SIZE))
32+
KS = np.array([3])
3233

3334
STRING_PREDS = [
3435
'the cat sat on the mat',
@@ -66,7 +67,7 @@ class MetraxTest(parameterized.TestCase):
6667
{
6768
'predictions': OUTPUT_LABELS,
6869
'labels': OUTPUT_PREDS,
69-
'ks': np.array([3]),
70+
'ks': KS,
7071
},
7172
),
7273
(
@@ -90,7 +91,7 @@ class MetraxTest(parameterized.TestCase):
9091
{
9192
'predictions': OUTPUT_LABELS,
9293
'labels': OUTPUT_PREDS,
93-
'ks': np.array([3]),
94+
'ks': KS,
9495
},
9596
),
9697
(
@@ -108,6 +109,15 @@ class MetraxTest(parameterized.TestCase):
108109
metrax.Recall,
109110
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
110111
),
112+
(
113+
'recallAtK',
114+
metrax.RecallAtK,
115+
{
116+
'predictions': OUTPUT_LABELS,
117+
'labels': OUTPUT_PREDS,
118+
'ks': KS,
119+
},
120+
),
111121
)
112122
def test_metrics_jittable(self, metric, kwargs):
113123
"""Tests that jitted metrax metric yields the same result as non-jitted metric."""

src/metrax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
RMSE = nnx_metrics.RMSE
2727
RSQUARED = nnx_metrics.RSQUARED
2828
Recall = nnx_metrics.Recall
29+
RecallAtK = nnx_metrics.RecallAtK
2930
RougeL = nnx_metrics.RougeL
3031
RougeN = nnx_metrics.RougeN
3132
WER = nnx_metrics.WER
@@ -44,6 +45,7 @@
4445
"RMSE",
4546
"RSQUARED",
4647
"Recall",
48+
"RecallAtK",
4749
"RougeL",
4850
"RougeN",
4951
"WER",

src/metrax/nnx/nnx_metrics.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ def __init__(self):
8888
super().__init__(metrax.Recall)
8989

9090

91+
class RecallAtK(NnxWrapper):
92+
"""An NNX class for the Metrax metric RecallAtK."""
93+
94+
def __init__(self):
95+
super().__init__(metrax.RecallAtK)
96+
97+
9198
class RMSE(NnxWrapper):
9299
"""An NNX class for the Metrax metric RMSE."""
93100

src/metrax/ranking_metrics.py

Lines changed: 167 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""A collection of different metrics for ranking models."""
1616

17+
import abc
1718
import flax
1819
import jax
1920
import jax.numpy as jnp
@@ -22,7 +23,7 @@
2223

2324
@flax.struct.dataclass
2425
class AveragePrecisionAtK(base.Average):
25-
r"""Computes AP@k (average precision at k) metrics in JAX.
26+
r"""Computes AP@k (average precision at k) metrics.
2627
2728
Average precision at k (AP@k) is a metric used to evaluate the performance of
2829
ranking models. It measures the sum of precision at k where the item at
@@ -125,71 +126,200 @@ def from_model_output(
125126

126127

127128
@flax.struct.dataclass
128-
class PrecisionAtK(base.Average):
129-
r"""Computes P@k (precision at k) metrics in JAX.
129+
class TopKRankingMetric(base.Average, abc.ABC):
130+
"""Abstract base class for Top-K ranking metrics like Precision@k and Recall@k.
130131
131-
Precision at k (P@k) is a metric that measures the proportion of
132-
relevant items found in the top k recommendations.
133-
134-
Given the top :math:`K` recommendations, P@K is calculated as:
132+
This class provides common functionality for calculating metrics that evaluate
133+
the quality of the top k items in a ranked list. Subclasses must implement
134+
the `_calculate_metric_at_ks` method to define the specific metric
135+
computation (e.g., precision, recall).
135136
136-
.. math::
137-
Precision@K = \frac{\text{Number of relevant items in top K}}{K}
137+
The `from_model_output` method is a factory method that computes the metric
138+
values for a batch of predictions and labels, and aggregates them.
138139
"""
139140

140-
@classmethod
141-
def precision_at_ks(
142-
cls, predictions: jax.Array, labels: jax.Array, ks: jax.Array
141+
@staticmethod
142+
def _get_relevant_at_k(
143+
predictions: jax.Array, labels: jax.Array, ks: jax.Array
143144
) -> jax.Array:
144-
"""Computes P@k (precision at k) metrics for each of k in ks.
145+
"""Computes the number of relevant items at each k.
146+
147+
This static method processes predictions and labels to determine the
148+
number of relevant items at specified k-values.
145149
146150
Args:
147-
predictions: A floating point 2D array representing the prediction
148-
scores from the model. Higher scores indicate higher relevance. The
151+
predictions: A floating point 2D array representing the prediction scores
152+
from the model. Higher scores indicate higher relevance. The shape
153+
should be (batch_size, vocab_size).
154+
labels: A multi-hot encoding (0 or 1, or counts) of the true labels. The
149155
shape should be (batch_size, vocab_size).
150-
labels: A multi-hot encoding (0 or 1) of the true labels. The shape should
151-
be (batch_size, vocab_size).
152-
ks: A 1D array of integers representing the k's to compute the P@k
153-
metrics. The shape should be (|ks|).
156+
ks: A 1D array of integers representing the k's (cut-off points) for which
157+
to compute metrics. The shape should be (|ks|).
154158
155159
Returns:
156-
A rank-2 array of shape (batch_size, |ks|) containing P@k metrics.
160+
relevant_at_k: A 2D array of shape (batch_size, |ks|). Each element [i, j]
161+
is the number of relevant items among the top ks[j] recommendations for
162+
the i-th example in the batch.
157163
"""
158164
labels = jnp.array(labels >= 1, dtype=jnp.float32)
159165
indices_by_rank = jnp.argsort(-predictions, axis=1)
160166
labels_by_rank = jnp.take_along_axis(labels, indices_by_rank, axis=1)
161167
relevant_by_rank = jnp.cumsum(labels_by_rank, axis=1)
162168

163169
vocab_size = predictions.shape[1]
164-
relevant_at_k = relevant_by_rank[:, jnp.minimum(ks - 1, vocab_size - 1)]
165-
total_at_k = jnp.minimum(ks, vocab_size)
166-
return base.divide_no_nan(relevant_at_k, total_at_k)
170+
k_indices = jnp.minimum(ks - 1, vocab_size - 1)
171+
relevant_at_k = relevant_by_rank[:, k_indices]
172+
173+
return relevant_at_k
174+
175+
@classmethod
176+
@abc.abstractmethod
177+
def _calculate_metric_at_ks(
178+
cls, predictions: jax.Array, labels: jax.Array, ks: jax.Array
179+
) -> jax.Array:
180+
"""Computes the specific metric (e.g., P@k, R@k) values per example for each k.
181+
182+
This method must be implemented by concrete subclasses (e.g., PrecisionAtK,
183+
RecallAtK) to define the actual calculation of the metric based on
184+
predictions, labels, and k-values.
185+
186+
Args:
187+
predictions: A floating point 2D array representing the prediction scores
188+
from the model.
189+
labels: A multi-hot encoding of the true labels.
190+
ks: A 1D array of integers representing the k's.
191+
192+
Returns:
193+
A rank-2 array of shape (batch_size, |ks|) containing the metric
194+
values for each example in the batch and each specified k.
195+
"""
196+
raise NotImplementedError('Subclasses must implement this method.')
167197

168198
@classmethod
169199
def from_model_output(
170200
cls,
171201
predictions: jax.Array,
172202
labels: jax.Array,
173203
ks: jax.Array,
174-
) -> 'PrecisionAtK':
175-
"""Creates a PrecisionAtK metric instance from model output.
204+
) -> 'TopKRankingMetric':
205+
"""Creates a metric instance from model output.
176206
177-
This computes the P@k for each example in the batch and then aggregates
178-
them (sum of P@k values and count of examples) to be averaged later by
179-
calling .compute() on the returned metric object.
207+
This class method computes the specific ranking metric (defined by the
208+
subclass's implementation of `_calculate_metric_at_ks`) for each example
209+
in the batch. It then aggregates these values (sum of metric values and
210+
count of examples) into a metric object. This object can later be used
211+
to compute the mean metric value (e.g., Mean Precision@k) by calling
212+
its `.compute()` method (inherited from `base.Average`).
180213
181214
Args:
182-
predictions: A floating point 2D array representing the prediction
183-
scores from the model. The shape should be (batch_size, vocab_size).
184-
labels: A multi-hot encoding (0 or 1) of the true labels. The shape should
185-
be (batch_size, vocab_size).
215+
predictions: A floating point 2D array representing the prediction scores
216+
from the model. The shape should be (batch_size, vocab_size).
217+
labels: A multi-hot encoding (0 or 1, or counts) of the true labels. The
218+
shape should be (batch_size, vocab_size).
219+
ks: A 1D array of integers representing the k's to compute the metrics.
220+
The shape should be (|ks|).
221+
222+
Returns:
223+
An instance of the calling class (e.g., PrecisionAtK, RecallAtK)
224+
with `total` and `count` fields populated. The `total` field will
225+
have shape (|ks|), representing the sum of metric values for each k
226+
across the batch, and `count` will be a scalar representing the
227+
number of examples in the batch.
228+
"""
229+
metric_at_ks = cls._calculate_metric_at_ks(predictions, labels, ks)
230+
num_examples = jnp.array(labels.shape[0], dtype=jnp.float32)
231+
return cls(
232+
total=metric_at_ks.sum(axis=0),
233+
count=num_examples,
234+
)
235+
236+
237+
@flax.struct.dataclass
238+
class PrecisionAtK(TopKRankingMetric):
239+
r"""Computes P@k (precision at k) metrics.
240+
241+
Precision at k (P@k) is a metric that measures the proportion of
242+
relevant items found in the top k recommendations. It answers the question:
243+
"Out of the K items recommended, how many are actually relevant?"
244+
245+
Given the top :math:`K` recommendations, P@K is calculated as:
246+
247+
.. math::
248+
Precision@K = \frac{\text{Number of relevant items in top K}}{K}
249+
"""
250+
251+
@classmethod
252+
def _calculate_metric_at_ks(
253+
cls, predictions: jax.Array, labels: jax.Array, ks: jax.Array
254+
) -> jax.Array:
255+
"""Computes P@k (precision at k) metrics for each of k in ks for each example.
256+
257+
This method implements the core logic for calculating Precision@k.
258+
It utilizes the `_get_relevant_at_k` helper from the base
259+
class to get the number of relevant items at each k, and then divides
260+
by k (clamped by vocabulary size) to get the precision.
261+
262+
Args:
263+
predictions: A floating point 2D array representing the prediction scores
264+
from the model. The shape should be (batch_size, vocab_size).
265+
labels: A multi-hot encoding (0 or 1, or counts) of the true labels. The
266+
shape should be (batch_size, vocab_size).
186267
ks: A 1D array of integers representing the k's to compute the P@k
187268
metrics. The shape should be (|ks|).
188269
189270
Returns:
190-
The PrecisionAtK metric object. The `total` field will have shape (|ks|),
191-
and `count` will be a scalar.
271+
A rank-2 array of shape (batch_size, |ks|) containing P@k metrics
272+
for each example and each k.
192273
"""
193-
p_at_ks = cls.precision_at_ks(predictions, labels, ks)
194-
num_examples = jnp.array(labels.shape[0], dtype=jnp.float32)
195-
return cls(total=p_at_ks.sum(axis=0), count=num_examples)
274+
relevant_at_k = cls._get_relevant_at_k(predictions, labels, ks)
275+
vocab_size = labels.shape[1]
276+
denominator_p_at_k = jnp.minimum(ks.astype(jnp.float32), vocab_size)
277+
return base.divide_no_nan(relevant_at_k, denominator_p_at_k[jnp.newaxis, :])
278+
279+
280+
@flax.struct.dataclass
281+
class RecallAtK(TopKRankingMetric):
282+
r"""Computes R@k (recall at k) metrics in JAX.
283+
284+
Recall at k (R@k) is a metric that measures the proportion of
285+
relevant items that are found in the top k recommendations, out of the
286+
total number of relevant items for a given user/query. It answers the
287+
question:
288+
"Out of all the items that are truly relevant, how many did we find in the top
289+
K?"
290+
291+
Given the top :math:`K` recommendations, R@K is calculated as:
292+
293+
.. math::
294+
Recall@K = \frac{\text{Number of relevant items in top K}}{\text{Total
295+
number of relevant items}}
296+
"""
297+
298+
@classmethod
299+
def _calculate_metric_at_ks(
300+
cls, predictions: jax.Array, labels: jax.Array, ks: jax.Array
301+
) -> jax.Array:
302+
"""Computes R@k (recall at k) metrics for each of k in ks for each example.
303+
304+
This method implements the core logic for calculating Recall@k.
305+
It utilizes the `_get_relevant_at_k` helper from the base
306+
class to get the number of relevant items at each k and the binarized
307+
labels.
308+
The number of relevant items at k is then divided by the total number of
309+
relevant items for that example to get the recall.
310+
311+
Args:
312+
predictions: A floating point 2D array representing the prediction scores
313+
from the model. The shape should be (batch_size, vocab_size).
314+
labels: A multi-hot encoding (0 or 1, or counts) of the true labels. The
315+
shape should be (batch_size, vocab_size).
316+
ks: A 1D array of integers representing the k's to compute the R@k
317+
metrics. The shape should be (|ks|).
318+
319+
Returns:
320+
A rank-2 array of shape (batch_size, |ks|) containing R@k metrics
321+
for each example and each k.
322+
"""
323+
relevant_at_k = cls._get_relevant_at_k(predictions, labels, ks)
324+
total_relevant = jnp.sum(jnp.array(labels >= 1, dtype=jnp.float32), axis=1)
325+
return base.divide_no_nan(relevant_at_k, total_relevant[:, jnp.newaxis])

src/metrax/ranking_metrics_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@
5050
MAP_FROM_KERAS_VS1 = np.array([0.75, 0.75, 0.75, 0.75, 0.75, 0.75])
5151
P_FROM_KERAS = np.array([0.75, 0.875, 0.58333337306976320, 0.5625, 0.5, 0.5])
5252
P_FROM_KERAS_VS1 = np.array([0.75, 0.75, 0.75, 0.75, 0.75, 0.75])
53+
R_FROM_KERAS = np.array([
54+
0.2083333432674408,
55+
0.5416666865348816,
56+
0.5416666865348816,
57+
0.625,
58+
0.6666666865348816,
59+
0.75,
60+
])
61+
R_FROM_KERAS_VS1 = np.array([0.75, 0.75, 0.75, 0.75, 0.75, 0.75])
5362

5463

5564
class RankingMetricsTest(parameterized.TestCase):
@@ -116,6 +125,31 @@ def test_precisionatk(self, y_true, y_pred, map_from_keras):
116125
atol=1e-05,
117126
)
118127

128+
@parameterized.named_parameters(
129+
('basic', OUTPUT_LABELS, OUTPUT_PREDS, R_FROM_KERAS),
130+
(
131+
'vocab_size_one',
132+
OUTPUT_LABELS_VS1,
133+
OUTPUT_PREDS_VS1,
134+
R_FROM_KERAS_VS1,
135+
),
136+
)
137+
def test_recallatk(self, y_true, y_pred, map_from_keras):
138+
"""Test that `RecallAtK` Metric computes correct values."""
139+
ks = jnp.array([1, 2, 3, 4, 5, 6])
140+
metric = metrax.RecallAtK.from_model_output(
141+
predictions=y_pred,
142+
labels=y_true,
143+
ks=ks,
144+
)
145+
146+
np.testing.assert_allclose(
147+
metric.compute(),
148+
map_from_keras,
149+
rtol=1e-05,
150+
atol=1e-05,
151+
)
152+
119153

120154
if __name__ == '__main__':
121155
absltest.main()

0 commit comments

Comments
 (0)