|
21 | 21 | from metrax import base |
22 | 22 |
|
23 | 23 |
|
| 24 | +@flax.struct.dataclass |
| 25 | +class DCGAtK(base.Average): |
| 26 | + r"""Computes DCG@k (Discounted Cumulative Gain at k) metrics. |
| 27 | +
|
| 28 | + This implementation calculates DCG@k based on the principle: |
| 29 | + $DCG@k(y, s) = \sum_{i | \text{rank}(s_i) \le k} \text{gain}(y_i) \times |
| 30 | + \text{rank\_discount}(\text{rank}(s_i))$ |
| 31 | + where $y_i$ is the label of item $i$, $s_i$ is its score, |
| 32 | + and $\text{rank}(s_i)$ is the 1-based rank of item $i$ based on its score. |
| 33 | +
|
| 34 | + The gain is $gain(y_i) = 2^{y_i} - 1$. |
| 35 | + The rank_discount is $1 / \log_2(\text{rank} + 1)$. |
| 36 | + """ |
| 37 | + |
| 38 | + @classmethod |
| 39 | + def _calculate_dcg_at_ks( |
| 40 | + cls, |
| 41 | + predictions: jax.Array, |
| 42 | + labels: jax.Array, |
| 43 | + ks: jax.Array, |
| 44 | + ) -> jax.Array: |
| 45 | + """Computes DCG@k for each example and for each k, using 'exp2' gain. |
| 46 | +
|
| 47 | + This function is JIT-compiled. The gain calculation is fixed to 'exp2'. |
| 48 | + It uses jax.vmap to compute DCG for multiple k values efficiently. |
| 49 | +
|
| 50 | + Args: |
| 51 | + predictions: A floating point 2D array (batch_size, vocab_size) |
| 52 | + representing prediction scores. Higher scores mean higher rank. |
| 53 | + labels: A 2D array (batch_size, vocab_size) of graded relevance scores. |
| 54 | + ks: A 1D array of integers representing the k values for which DCG is |
| 55 | + computed (e.g., jnp.array([1, 5, 10])). Shape: (num_ks,). |
| 56 | +
|
| 57 | + Returns: |
| 58 | + A 2D array (batch_size, num_ks) containing DCG@k values. |
| 59 | + """ |
| 60 | + gains = jnp.power(2.0, labels.astype(jnp.float32)) - 1.0 |
| 61 | + score_ranks = jnp.argsort(jnp.argsort(-predictions, axis=1), axis=1) + 1 |
| 62 | + score_rank_discounts = 1.0 / jnp.log2(score_ranks.astype(jnp.float32) + 1.0) |
| 63 | + item_contributions = gains * score_rank_discounts |
| 64 | + |
| 65 | + def _compute_dcg_at_k(k, current_item_contributions, current_score_ranks): |
| 66 | + """Computes DCG for a single k value across all examples in a batch. |
| 67 | +
|
| 68 | + Args: |
| 69 | + k: A scalar JAX array representing the single 'k' (top-k) value for |
| 70 | + which DCG is to be computed. |
| 71 | + current_item_contributions: A 2D JAX array containing the pre-calculated |
| 72 | + contribution (gain * discount) for each item in each example of the |
| 73 | + batch. The shape should be (batch_size, vocab_size). |
| 74 | + current_score_ranks: A 2D JAX array containing the 1-based rank for each |
| 75 | + item in each example of the batch. The shape should be (batch_size, |
| 76 | + vocab_size). |
| 77 | +
|
| 78 | + Returns: |
| 79 | + A 1D JAX array containing the DCG@k for each example in the batch. |
| 80 | + The shape should be (batch_size, ). |
| 81 | + """ |
| 82 | + mask_for_k = current_score_ranks <= k |
| 83 | + dcg_at_k = jnp.sum(current_item_contributions * mask_for_k, axis=1) |
| 84 | + return dcg_at_k |
| 85 | + |
| 86 | + dcg_at_ks = jax.vmap( |
| 87 | + _compute_dcg_at_k, |
| 88 | + in_axes=(0, None, None), |
| 89 | + out_axes=1, # Place the mapped axis(from ks) as the second axis |
| 90 | + )(ks, item_contributions, score_ranks) |
| 91 | + |
| 92 | + return dcg_at_ks |
| 93 | + |
| 94 | + @classmethod |
| 95 | + def from_model_output( |
| 96 | + cls, |
| 97 | + predictions: jax.Array, |
| 98 | + labels: jax.Array, |
| 99 | + ks: jax.Array, |
| 100 | + ) -> 'DCGAtK': |
| 101 | + """Creates a DCGAtK metric instance from model output.""" |
| 102 | + dcg_at_ks = cls._calculate_dcg_at_ks(predictions, labels, ks) |
| 103 | + num_examples = jnp.array(labels.shape[0], dtype=jnp.float32) |
| 104 | + return cls( |
| 105 | + total=jnp.sum(dcg_at_ks, axis=0), |
| 106 | + count=num_examples, |
| 107 | + ) |
| 108 | + |
| 109 | + |
24 | 110 | @flax.struct.dataclass |
25 | 111 | class AveragePrecisionAtK(base.Average): |
26 | 112 | r"""Computes AP@k (average precision at k) metrics. |
@@ -151,10 +237,10 @@ def _get_relevant_at_k( |
151 | 237 | predictions: A floating point 2D array representing the prediction scores |
152 | 238 | from the model. Higher scores indicate higher relevance. The shape |
153 | 239 | should be (batch_size, vocab_size). |
154 | | - labels: A multi-hot encoding (0 or 1, or counts) of the true labels. The |
155 | | - shape should be (batch_size, vocab_size). |
156 | | - ks: A 1D array of integers representing the k's (cut-off points) for which |
157 | | - to compute metrics. The shape should be (|ks|). |
| 240 | + labels: A multi-hot encoding (0 or 1) of the true labels. The shape should |
| 241 | + be (batch_size, vocab_size). |
| 242 | + ks: A 1D array of integers representing the k's to compute the P@k |
| 243 | + metrics. The shape should be (|ks|). |
158 | 244 |
|
159 | 245 | Returns: |
160 | 246 | relevant_at_k: A 2D array of shape (batch_size, |ks|). Each element [i, j] |
@@ -279,7 +365,7 @@ class to get the number of relevant items at each k, and then divides |
279 | 365 |
|
280 | 366 | @flax.struct.dataclass |
281 | 367 | class RecallAtK(TopKRankingMetric): |
282 | | - r"""Computes R@k (recall at k) metrics in JAX. |
| 368 | + r"""Computes R@k (recall at k) metrics. |
283 | 369 |
|
284 | 370 | Recall at k (R@k) is a metric that measures the proportion of |
285 | 371 | relevant items that are found in the top k recommendations, out of the |
|
0 commit comments