Skip to content

Commit 8da455b

Browse files
authored
add precisionAtK to metrax (#72)
* add precisionAtK to metrax * add back deleted lines by accident * add changes for ranking_metrics_test * match keras behavior for invalid ks * drop JAX from JAX arrays
1 parent 66396c1 commit 8da455b

File tree

7 files changed

+124
-6
lines changed

7 files changed

+124
-6
lines changed

src/metrax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
MSE = regression_metrics.MSE
2727
Perplexity = nlp_metrics.Perplexity
2828
Precision = classification_metrics.Precision
29+
PrecisionAtK = ranking_metrics.PrecisionAtK
2930
RMSE = regression_metrics.RMSE
3031
RSQUARED = regression_metrics.RSQUARED
3132
Recall = classification_metrics.Recall
@@ -43,6 +44,7 @@
4344
"MSE",
4445
"Perplexity",
4546
"Precision",
47+
"PrecisionAtK",
4648
"RMSE",
4749
"RSQUARED",
4850
"Recall",

src/metrax/metrax_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@ class MetraxTest(parameterized.TestCase):
8484
metrax.Precision,
8585
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
8686
),
87+
(
88+
'precisionAtK',
89+
metrax.PrecisionAtK,
90+
{
91+
'predictions': OUTPUT_LABELS,
92+
'labels': OUTPUT_PREDS,
93+
'ks': np.array([3]),
94+
},
95+
),
8796
(
8897
'rmse',
8998
metrax.RMSE,

src/metrax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
MSE = nnx_metrics.MSE
2323
Perplexity = nnx_metrics.Perplexity
2424
Precision = nnx_metrics.Precision
25+
PrecisionAtK = nnx_metrics.PrecisionAtK
2526
RMSE = nnx_metrics.RMSE
2627
RSQUARED = nnx_metrics.RSQUARED
2728
Recall = nnx_metrics.Recall
@@ -39,6 +40,7 @@
3940
"MSE",
4041
"Perplexity",
4142
"Precision",
43+
"PrecisionAtK",
4244
"RMSE",
4345
"RSQUARED",
4446
"Recall",

src/metrax/nnx/nnx_metrics.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@ def __init__(self):
7474
super().__init__(metrax.Precision)
7575

7676

77+
class PrecisionAtK(NnxWrapper):
78+
"""An NNX class for the Metrax metric PrecisionAtK."""
79+
80+
def __init__(self):
81+
super().__init__(metrax.PrecisionAtK)
82+
83+
7784
class Recall(NnxWrapper):
7885
"""An NNX class for the Metrax metric Recall."""
7986

src/metrax/nnx/nnx_metrics_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_nnx_metrics_exists(self):
3636
key for key, metric in inspect.getmembers(metrax.nnx)
3737
if inspect.isclass(metric) and issubclass(metric, nnx.Metric)
3838
]
39-
self.assertGreater(len(metrax_metric_keys), 0)
39+
self.assertNotEmpty(metrax_metric_keys)
4040
self.assertSameElements(metrax_metric_keys, metrax_nnx_metric_keys)
4141

4242

src/metrax/ranking_metrics.py

Lines changed: 76 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ def average_precision_at_ks(
5454
metrics. The shape should be (|ks|).
5555
5656
Returns:
57-
Rank-2 tensor of shape [batch, |ks|] containing AP@k metrics.
57+
Rank-2 tensor of shape (batch, |ks|) containing AP@k metrics.
5858
"""
59-
sorted_indices = jnp.argsort(-predictions, axis=1)
59+
indices_by_rank = jnp.argsort(-predictions, axis=1)
6060
labels = jnp.array(labels >= 1, dtype=jnp.float32)
6161
total_relevant = labels.sum(axis=1)
6262

@@ -88,7 +88,7 @@ def compute_ap_at_k_single(relevant_labels, total_relevant, ks):
8888
)
8989

9090
ap_at_ks = vmap_compute_ap_at_k(
91-
jnp.take_along_axis(labels, sorted_indices, axis=1), total_relevant, ks
91+
jnp.take_along_axis(labels, indices_by_rank, axis=1), total_relevant, ks
9292
)
9393
return ap_at_ks
9494

@@ -117,8 +117,79 @@ def from_model_output(
117117
and `labels` are incompatible.
118118
"""
119119
ap_at_ks = cls.average_precision_at_ks(predictions, labels, ks)
120-
count = jnp.ones((labels.shape[0], 1), dtype=jnp.float32)
120+
num_examples = jnp.array(labels.shape[0], dtype=jnp.float32)
121121
return cls(
122122
total=ap_at_ks.sum(axis=0),
123-
count=count.sum(),
123+
count=num_examples,
124124
)
125+
126+
127+
@flax.struct.dataclass
128+
class PrecisionAtK(base.Average):
129+
r"""Computes P@k (precision at k) metrics in JAX.
130+
131+
Precision at k (P@k) is a metric that measures the proportion of
132+
relevant items found in the top k recommendations.
133+
134+
Given the top :math:`K` recommendations, P@K is calculated as:
135+
136+
.. math::
137+
Precision@K = \frac{\text{Number of relevant items in top K}}{K}
138+
"""
139+
140+
@classmethod
141+
def precision_at_ks(
142+
cls, predictions: jax.Array, labels: jax.Array, ks: jax.Array
143+
) -> jax.Array:
144+
"""Computes P@k (precision at k) metrics for each of k in ks.
145+
146+
Args:
147+
predictions: A floating point 2D array representing the prediction
148+
scores from the model. Higher scores indicate higher relevance. The
149+
shape should be (batch_size, vocab_size).
150+
labels: A multi-hot encoding (0 or 1) of the true labels. The shape should
151+
be (batch_size, vocab_size).
152+
ks: A 1D array of integers representing the k's to compute the P@k
153+
metrics. The shape should be (|ks|).
154+
155+
Returns:
156+
A rank-2 array of shape (batch_size, |ks|) containing P@k metrics.
157+
"""
158+
labels = jnp.array(labels >= 1, dtype=jnp.float32)
159+
indices_by_rank = jnp.argsort(-predictions, axis=1)
160+
labels_by_rank = jnp.take_along_axis(labels, indices_by_rank, axis=1)
161+
relevant_by_rank = jnp.cumsum(labels_by_rank, axis=1)
162+
163+
vocab_size = predictions.shape[1]
164+
relevant_at_k = relevant_by_rank[:, jnp.minimum(ks - 1, vocab_size - 1)]
165+
total_at_k = jnp.minimum(ks, vocab_size)
166+
return base.divide_no_nan(relevant_at_k, total_at_k)
167+
168+
@classmethod
169+
def from_model_output(
170+
cls,
171+
predictions: jax.Array,
172+
labels: jax.Array,
173+
ks: jax.Array,
174+
) -> 'PrecisionAtK':
175+
"""Creates a PrecisionAtK metric instance from model output.
176+
177+
This computes the P@k for each example in the batch and then aggregates
178+
them (sum of P@k values and count of examples) to be averaged later by
179+
calling .compute() on the returned metric object.
180+
181+
Args:
182+
predictions: A floating point 2D array representing the prediction
183+
scores from the model. The shape should be (batch_size, vocab_size).
184+
labels: A multi-hot encoding (0 or 1) of the true labels. The shape should
185+
be (batch_size, vocab_size).
186+
ks: A 1D array of integers representing the k's to compute the P@k
187+
metrics. The shape should be (|ks|).
188+
189+
Returns:
190+
The PrecisionAtK metric object. The `total` field will have shape (|ks|),
191+
and `count` will be a scalar.
192+
"""
193+
p_at_ks = cls.precision_at_ks(predictions, labels, ks)
194+
num_examples = jnp.array(labels.shape[0], dtype=jnp.float32)
195+
return cls(total=p_at_ks.sum(axis=0), count=num_examples)

src/metrax/ranking_metrics_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
0.637499988079071,
4949
])
5050
MAP_FROM_KERAS_VS1 = np.array([0.75, 0.75, 0.75, 0.75, 0.75, 0.75])
51+
P_FROM_KERAS = np.array([0.75, 0.875, 0.58333337306976320, 0.5625, 0.5, 0.5])
52+
P_FROM_KERAS_VS1 = np.array([0.75, 0.75, 0.75, 0.75, 0.75, 0.75])
5153

5254

5355
class RankingMetricsTest(parameterized.TestCase):
@@ -89,6 +91,31 @@ def test_averageprecisionatk(self, y_true, y_pred, map_from_keras, jitted):
8991
atol=1e-05,
9092
)
9193

94+
@parameterized.named_parameters(
95+
('basic', OUTPUT_LABELS, OUTPUT_PREDS, P_FROM_KERAS),
96+
(
97+
'vocab_size_one',
98+
OUTPUT_LABELS_VS1,
99+
OUTPUT_PREDS_VS1,
100+
P_FROM_KERAS_VS1,
101+
),
102+
)
103+
def test_precisionatk(self, y_true, y_pred, map_from_keras):
104+
"""Test that `PrecisionAtK` Metric computes correct values."""
105+
ks = jnp.array([1, 2, 3, 4, 5, 6])
106+
metric = metrax.PrecisionAtK.from_model_output(
107+
predictions=y_pred,
108+
labels=y_true,
109+
ks=ks,
110+
)
111+
112+
np.testing.assert_allclose(
113+
metric.compute(),
114+
map_from_keras,
115+
rtol=1e-05,
116+
atol=1e-05,
117+
)
118+
92119

93120
if __name__ == '__main__':
94121
absltest.main()

0 commit comments

Comments
 (0)