|
| 1 | +# Copyright 2024 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""A collection of different metrics for ranking models.""" |
| 16 | + |
| 17 | +from clu import metrics as clu_metrics |
| 18 | +import flax |
| 19 | +import jax |
| 20 | +import jax.numpy as jnp |
| 21 | + |
| 22 | + |
| 23 | +def _divide_no_nan(x: jax.Array, y: jax.Array) -> jax.Array: |
| 24 | + """Computes a safe divide which returns 0 if the y is zero.""" |
| 25 | + return jnp.where(y != 0, jnp.divide(x, y), 0.0) |
| 26 | + |
| 27 | + |
| 28 | +@flax.struct.dataclass |
| 29 | +class AveragePrecisionAtK(clu_metrics.Average): |
| 30 | + r"""Computes AP@k (average precision at k) metrics in JAX. |
| 31 | +
|
| 32 | + Average precision at k (AP@k) is a metric used to evaluate the performance of |
| 33 | + ranking models. It measures the sum of precision at k where the item at |
| 34 | + the kth rank is relevant, divided by the total number of relevant items. |
| 35 | +
|
| 36 | + Given the top :math:`K` recommendations, AP@K is calculated as: |
| 37 | +
|
| 38 | + .. math:: |
| 39 | + AP@K = frac{1}{r}\sum_{k=1}^{K} \Precision@k * \rel(k) |
| 40 | + rel(k) = |
| 41 | + \begin{cases} |
| 42 | + 1 & \text{if the item at rank } k \text{ is relevant} \\ |
| 43 | + 0 & \text{otherwise} |
| 44 | + \end{cases} |
| 45 | + """ |
| 46 | + |
| 47 | + @classmethod |
| 48 | + def average_precision_at_ks( |
| 49 | + cls, predictions: jax.Array, labels: jax.Array, ks: jax.Array |
| 50 | + ): |
| 51 | + """Computes AP@k (average precision at k) metrics for each of k in ks. |
| 52 | +
|
| 53 | + Args: |
| 54 | + predictions: A floating point 2D vector representing the prediction |
| 55 | + generated from the model. The shape should be (batch_size, vocab_size). |
| 56 | + labels: A multi-hot encoding of the true label. The shape should be |
| 57 | + (batch_size, vocab_size). |
| 58 | + ks: A 1D vector of integers representing the k's to compute the MAP@k |
| 59 | + metrics. The shape should be (|ks|). |
| 60 | +
|
| 61 | + Returns: |
| 62 | + Rank-2 tensor of shape [batch, |ks|] containing AP@k metrics. |
| 63 | + """ |
| 64 | + top_k_indices = jnp.argsort(-predictions, axis=1)[:, : jnp.max(ks)] |
| 65 | + labels = jnp.array(labels >= 1, dtype=jnp.float32) |
| 66 | + total_relevant = labels.sum(axis=1) |
| 67 | + |
| 68 | + def compute_ap_at_k_single(relevant_labels, total_relevant, ks): |
| 69 | + cumulative_precision = jnp.where( |
| 70 | + relevant_labels, |
| 71 | + _divide_no_nan( |
| 72 | + jnp.cumsum(relevant_labels), |
| 73 | + jnp.arange(1, len(relevant_labels) + 1), |
| 74 | + ), |
| 75 | + 0, |
| 76 | + ) |
| 77 | + return jnp.array([ |
| 78 | + _divide_no_nan(jnp.sum(cumulative_precision[:k]), total_relevant) |
| 79 | + for k in ks |
| 80 | + ]) |
| 81 | + |
| 82 | + vmap_compute_ap_at_k = jax.vmap( |
| 83 | + compute_ap_at_k_single, in_axes=(0, 0, None), out_axes=0 |
| 84 | + ) |
| 85 | + |
| 86 | + ap_at_ks = vmap_compute_ap_at_k( |
| 87 | + jnp.take_along_axis(labels, top_k_indices, axis=1), total_relevant, ks |
| 88 | + ) |
| 89 | + return ap_at_ks |
| 90 | + |
| 91 | + @classmethod |
| 92 | + def from_model_output( |
| 93 | + cls, |
| 94 | + predictions: jax.Array, |
| 95 | + labels: jax.Array, |
| 96 | + ks: jax.Array, |
| 97 | + ) -> 'AveragePrecisionAtK': |
| 98 | + """Updates the metric. |
| 99 | +
|
| 100 | + Args: |
| 101 | + predictions: A floating point 2D vector representing the prediction |
| 102 | + generated from the model. The shape should be (batch_size, vocab_size). |
| 103 | + labels: A multi-hot encoding of the true label. The shape should be |
| 104 | + (batch_size, vocab_size). |
| 105 | + ks: A 1D vector of integers representing the k's to compute the MAP@k |
| 106 | + metrics. The shape should be (|ks|). |
| 107 | +
|
| 108 | + Returns: |
| 109 | + The AveragePrecisionAtK metric. The shape should be (|ks|). |
| 110 | +
|
| 111 | + Raises: |
| 112 | + ValueError: If type of `labels` is wrong or the shapes of `predictions` |
| 113 | + and `labels` are incompatible. |
| 114 | + """ |
| 115 | + ap_at_ks = cls.average_precision_at_ks(predictions, labels, ks) |
| 116 | + count = jnp.ones((labels.shape[0], 1), dtype=jnp.float32) |
| 117 | + return cls( |
| 118 | + total=ap_at_ks.sum(axis=0), |
| 119 | + count=count.sum(), |
| 120 | + ) |
0 commit comments