|
14 | 14 |
|
15 | 15 | """A collection of different metrics for ranking models.""" |
16 | 16 |
|
| 17 | +import abc |
17 | 18 | import flax |
18 | 19 | import jax |
19 | 20 | import jax.numpy as jnp |
|
22 | 23 |
|
23 | 24 | @flax.struct.dataclass |
24 | 25 | class AveragePrecisionAtK(base.Average): |
25 | | - r"""Computes AP@k (average precision at k) metrics in JAX. |
| 26 | + r"""Computes AP@k (average precision at k) metrics. |
26 | 27 |
|
27 | 28 | Average precision at k (AP@k) is a metric used to evaluate the performance of |
28 | 29 | ranking models. It measures the sum of precision at k where the item at |
@@ -125,71 +126,200 @@ def from_model_output( |
125 | 126 |
|
126 | 127 |
|
127 | 128 | @flax.struct.dataclass |
128 | | -class PrecisionAtK(base.Average): |
129 | | - r"""Computes P@k (precision at k) metrics in JAX. |
| 129 | +class TopKRankingMetric(base.Average, abc.ABC): |
| 130 | + """Abstract base class for Top-K ranking metrics like Precision@k and Recall@k. |
130 | 131 |
|
131 | | - Precision at k (P@k) is a metric that measures the proportion of |
132 | | - relevant items found in the top k recommendations. |
133 | | -
|
134 | | - Given the top :math:`K` recommendations, P@K is calculated as: |
| 132 | + This class provides common functionality for calculating metrics that evaluate |
| 133 | + the quality of the top k items in a ranked list. Subclasses must implement |
| 134 | + the `_calculate_metric_at_ks` method to define the specific metric |
| 135 | + computation (e.g., precision, recall). |
135 | 136 |
|
136 | | - .. math:: |
137 | | - Precision@K = \frac{\text{Number of relevant items in top K}}{K} |
| 137 | + The `from_model_output` method is a factory method that computes the metric |
| 138 | + values for a batch of predictions and labels, and aggregates them. |
138 | 139 | """ |
139 | 140 |
|
140 | | - @classmethod |
141 | | - def precision_at_ks( |
142 | | - cls, predictions: jax.Array, labels: jax.Array, ks: jax.Array |
| 141 | + @staticmethod |
| 142 | + def _get_relevant_at_k( |
| 143 | + predictions: jax.Array, labels: jax.Array, ks: jax.Array |
143 | 144 | ) -> jax.Array: |
144 | | - """Computes P@k (precision at k) metrics for each of k in ks. |
| 145 | + """Computes the number of relevant items at each k. |
| 146 | +
|
| 147 | + This static method processes predictions and labels to determine the |
| 148 | + number of relevant items at specified k-values. |
145 | 149 |
|
146 | 150 | Args: |
147 | | - predictions: A floating point 2D array representing the prediction |
148 | | - scores from the model. Higher scores indicate higher relevance. The |
| 151 | + predictions: A floating point 2D array representing the prediction scores |
| 152 | + from the model. Higher scores indicate higher relevance. The shape |
| 153 | + should be (batch_size, vocab_size). |
| 154 | + labels: A multi-hot encoding (0 or 1, or counts) of the true labels. The |
149 | 155 | shape should be (batch_size, vocab_size). |
150 | | - labels: A multi-hot encoding (0 or 1) of the true labels. The shape should |
151 | | - be (batch_size, vocab_size). |
152 | | - ks: A 1D array of integers representing the k's to compute the P@k |
153 | | - metrics. The shape should be (|ks|). |
| 156 | + ks: A 1D array of integers representing the k's (cut-off points) for which |
| 157 | + to compute metrics. The shape should be (|ks|). |
154 | 158 |
|
155 | 159 | Returns: |
156 | | - A rank-2 array of shape (batch_size, |ks|) containing P@k metrics. |
| 160 | + relevant_at_k: A 2D array of shape (batch_size, |ks|). Each element [i, j] |
| 161 | + is the number of relevant items among the top ks[j] recommendations for |
| 162 | + the i-th example in the batch. |
157 | 163 | """ |
158 | 164 | labels = jnp.array(labels >= 1, dtype=jnp.float32) |
159 | 165 | indices_by_rank = jnp.argsort(-predictions, axis=1) |
160 | 166 | labels_by_rank = jnp.take_along_axis(labels, indices_by_rank, axis=1) |
161 | 167 | relevant_by_rank = jnp.cumsum(labels_by_rank, axis=1) |
162 | 168 |
|
163 | 169 | vocab_size = predictions.shape[1] |
164 | | - relevant_at_k = relevant_by_rank[:, jnp.minimum(ks - 1, vocab_size - 1)] |
165 | | - total_at_k = jnp.minimum(ks, vocab_size) |
166 | | - return base.divide_no_nan(relevant_at_k, total_at_k) |
| 170 | + k_indices = jnp.minimum(ks - 1, vocab_size - 1) |
| 171 | + relevant_at_k = relevant_by_rank[:, k_indices] |
| 172 | + |
| 173 | + return relevant_at_k |
| 174 | + |
| 175 | + @classmethod |
| 176 | + @abc.abstractmethod |
| 177 | + def _calculate_metric_at_ks( |
| 178 | + cls, predictions: jax.Array, labels: jax.Array, ks: jax.Array |
| 179 | + ) -> jax.Array: |
| 180 | + """Computes the specific metric (e.g., P@k, R@k) values per example for each k. |
| 181 | +
|
| 182 | + This method must be implemented by concrete subclasses (e.g., PrecisionAtK, |
| 183 | + RecallAtK) to define the actual calculation of the metric based on |
| 184 | + predictions, labels, and k-values. |
| 185 | +
|
| 186 | + Args: |
| 187 | + predictions: A floating point 2D array representing the prediction scores |
| 188 | + from the model. |
| 189 | + labels: A multi-hot encoding of the true labels. |
| 190 | + ks: A 1D array of integers representing the k's. |
| 191 | +
|
| 192 | + Returns: |
| 193 | + A rank-2 array of shape (batch_size, |ks|) containing the metric |
| 194 | + values for each example in the batch and each specified k. |
| 195 | + """ |
| 196 | + raise NotImplementedError('Subclasses must implement this method.') |
167 | 197 |
|
168 | 198 | @classmethod |
169 | 199 | def from_model_output( |
170 | 200 | cls, |
171 | 201 | predictions: jax.Array, |
172 | 202 | labels: jax.Array, |
173 | 203 | ks: jax.Array, |
174 | | - ) -> 'PrecisionAtK': |
175 | | - """Creates a PrecisionAtK metric instance from model output. |
| 204 | + ) -> 'TopKRankingMetric': |
| 205 | + """Creates a metric instance from model output. |
176 | 206 |
|
177 | | - This computes the P@k for each example in the batch and then aggregates |
178 | | - them (sum of P@k values and count of examples) to be averaged later by |
179 | | - calling .compute() on the returned metric object. |
| 207 | + This class method computes the specific ranking metric (defined by the |
| 208 | + subclass's implementation of `_calculate_metric_at_ks`) for each example |
| 209 | + in the batch. It then aggregates these values (sum of metric values and |
| 210 | + count of examples) into a metric object. This object can later be used |
| 211 | + to compute the mean metric value (e.g., Mean Precision@k) by calling |
| 212 | + its `.compute()` method (inherited from `base.Average`). |
180 | 213 |
|
181 | 214 | Args: |
182 | | - predictions: A floating point 2D array representing the prediction |
183 | | - scores from the model. The shape should be (batch_size, vocab_size). |
184 | | - labels: A multi-hot encoding (0 or 1) of the true labels. The shape should |
185 | | - be (batch_size, vocab_size). |
| 215 | + predictions: A floating point 2D array representing the prediction scores |
| 216 | + from the model. The shape should be (batch_size, vocab_size). |
| 217 | + labels: A multi-hot encoding (0 or 1, or counts) of the true labels. The |
| 218 | + shape should be (batch_size, vocab_size). |
| 219 | + ks: A 1D array of integers representing the k's to compute the metrics. |
| 220 | + The shape should be (|ks|). |
| 221 | +
|
| 222 | + Returns: |
| 223 | + An instance of the calling class (e.g., PrecisionAtK, RecallAtK) |
| 224 | + with `total` and `count` fields populated. The `total` field will |
| 225 | + have shape (|ks|), representing the sum of metric values for each k |
| 226 | + across the batch, and `count` will be a scalar representing the |
| 227 | + number of examples in the batch. |
| 228 | + """ |
| 229 | + metric_at_ks = cls._calculate_metric_at_ks(predictions, labels, ks) |
| 230 | + num_examples = jnp.array(labels.shape[0], dtype=jnp.float32) |
| 231 | + return cls( |
| 232 | + total=metric_at_ks.sum(axis=0), |
| 233 | + count=num_examples, |
| 234 | + ) |
| 235 | + |
| 236 | + |
| 237 | +@flax.struct.dataclass |
| 238 | +class PrecisionAtK(TopKRankingMetric): |
| 239 | + r"""Computes P@k (precision at k) metrics. |
| 240 | +
|
| 241 | + Precision at k (P@k) is a metric that measures the proportion of |
| 242 | + relevant items found in the top k recommendations. It answers the question: |
| 243 | + "Out of the K items recommended, how many are actually relevant?" |
| 244 | +
|
| 245 | + Given the top :math:`K` recommendations, P@K is calculated as: |
| 246 | +
|
| 247 | + .. math:: |
| 248 | + Precision@K = \frac{\text{Number of relevant items in top K}}{K} |
| 249 | + """ |
| 250 | + |
| 251 | + @classmethod |
| 252 | + def _calculate_metric_at_ks( |
| 253 | + cls, predictions: jax.Array, labels: jax.Array, ks: jax.Array |
| 254 | + ) -> jax.Array: |
| 255 | + """Computes P@k (precision at k) metrics for each of k in ks for each example. |
| 256 | +
|
| 257 | + This method implements the core logic for calculating Precision@k. |
| 258 | + It utilizes the `_get_relevant_at_k` helper from the base |
| 259 | + class to get the number of relevant items at each k, and then divides |
| 260 | + by k (clamped by vocabulary size) to get the precision. |
| 261 | +
|
| 262 | + Args: |
| 263 | + predictions: A floating point 2D array representing the prediction scores |
| 264 | + from the model. The shape should be (batch_size, vocab_size). |
| 265 | + labels: A multi-hot encoding (0 or 1, or counts) of the true labels. The |
| 266 | + shape should be (batch_size, vocab_size). |
186 | 267 | ks: A 1D array of integers representing the k's to compute the P@k |
187 | 268 | metrics. The shape should be (|ks|). |
188 | 269 |
|
189 | 270 | Returns: |
190 | | - The PrecisionAtK metric object. The `total` field will have shape (|ks|), |
191 | | - and `count` will be a scalar. |
| 271 | + A rank-2 array of shape (batch_size, |ks|) containing P@k metrics |
| 272 | + for each example and each k. |
192 | 273 | """ |
193 | | - p_at_ks = cls.precision_at_ks(predictions, labels, ks) |
194 | | - num_examples = jnp.array(labels.shape[0], dtype=jnp.float32) |
195 | | - return cls(total=p_at_ks.sum(axis=0), count=num_examples) |
| 274 | + relevant_at_k = cls._get_relevant_at_k(predictions, labels, ks) |
| 275 | + vocab_size = labels.shape[1] |
| 276 | + denominator_p_at_k = jnp.minimum(ks.astype(jnp.float32), vocab_size) |
| 277 | + return base.divide_no_nan(relevant_at_k, denominator_p_at_k[jnp.newaxis, :]) |
| 278 | + |
| 279 | + |
| 280 | +@flax.struct.dataclass |
| 281 | +class RecallAtK(TopKRankingMetric): |
| 282 | + r"""Computes R@k (recall at k) metrics in JAX. |
| 283 | +
|
| 284 | + Recall at k (R@k) is a metric that measures the proportion of |
| 285 | + relevant items that are found in the top k recommendations, out of the |
| 286 | + total number of relevant items for a given user/query. It answers the |
| 287 | + question: |
| 288 | + "Out of all the items that are truly relevant, how many did we find in the top |
| 289 | + K?" |
| 290 | +
|
| 291 | + Given the top :math:`K` recommendations, R@K is calculated as: |
| 292 | +
|
| 293 | + .. math:: |
| 294 | + Recall@K = \frac{\text{Number of relevant items in top K}}{\text{Total |
| 295 | + number of relevant items}} |
| 296 | + """ |
| 297 | + |
| 298 | + @classmethod |
| 299 | + def _calculate_metric_at_ks( |
| 300 | + cls, predictions: jax.Array, labels: jax.Array, ks: jax.Array |
| 301 | + ) -> jax.Array: |
| 302 | + """Computes R@k (recall at k) metrics for each of k in ks for each example. |
| 303 | +
|
| 304 | + This method implements the core logic for calculating Recall@k. |
| 305 | + It utilizes the `_get_relevant_at_k` helper from the base |
| 306 | + class to get the number of relevant items at each k and the binarized |
| 307 | + labels. |
| 308 | + The number of relevant items at k is then divided by the total number of |
| 309 | + relevant items for that example to get the recall. |
| 310 | +
|
| 311 | + Args: |
| 312 | + predictions: A floating point 2D array representing the prediction scores |
| 313 | + from the model. The shape should be (batch_size, vocab_size). |
| 314 | + labels: A multi-hot encoding (0 or 1, or counts) of the true labels. The |
| 315 | + shape should be (batch_size, vocab_size). |
| 316 | + ks: A 1D array of integers representing the k's to compute the R@k |
| 317 | + metrics. The shape should be (|ks|). |
| 318 | +
|
| 319 | + Returns: |
| 320 | + A rank-2 array of shape (batch_size, |ks|) containing R@k metrics |
| 321 | + for each example and each k. |
| 322 | + """ |
| 323 | + relevant_at_k = cls._get_relevant_at_k(predictions, labels, ks) |
| 324 | + total_relevant = jnp.sum(jnp.array(labels >= 1, dtype=jnp.float32), axis=1) |
| 325 | + return base.divide_no_nan(relevant_at_k, total_relevant[:, jnp.newaxis]) |
0 commit comments