@@ -54,9 +54,9 @@ def average_precision_at_ks(
5454 metrics. The shape should be (|ks|).
5555
5656 Returns:
57- Rank-2 tensor of shape [ batch, |ks|] containing AP@k metrics.
57+ Rank-2 tensor of shape ( batch, |ks|) containing AP@k metrics.
5858 """
59- sorted_indices = jnp .argsort (- predictions , axis = 1 )
59+ indices_by_rank = jnp .argsort (- predictions , axis = 1 )
6060 labels = jnp .array (labels >= 1 , dtype = jnp .float32 )
6161 total_relevant = labels .sum (axis = 1 )
6262
@@ -88,7 +88,7 @@ def compute_ap_at_k_single(relevant_labels, total_relevant, ks):
8888 )
8989
9090 ap_at_ks = vmap_compute_ap_at_k (
91- jnp .take_along_axis (labels , sorted_indices , axis = 1 ), total_relevant , ks
91+ jnp .take_along_axis (labels , indices_by_rank , axis = 1 ), total_relevant , ks
9292 )
9393 return ap_at_ks
9494
@@ -117,8 +117,79 @@ def from_model_output(
117117 and `labels` are incompatible.
118118 """
119119 ap_at_ks = cls .average_precision_at_ks (predictions , labels , ks )
120- count = jnp .ones (( labels .shape [0 ], 1 ) , dtype = jnp .float32 )
120+ num_examples = jnp .array ( labels .shape [0 ], dtype = jnp .float32 )
121121 return cls (
122122 total = ap_at_ks .sum (axis = 0 ),
123- count = count . sum () ,
123+ count = num_examples ,
124124 )
125+
126+
127+ @flax .struct .dataclass
128+ class PrecisionAtK (base .Average ):
129+ r"""Computes P@k (precision at k) metrics in JAX.
130+
131+ Precision at k (P@k) is a metric that measures the proportion of
132+ relevant items found in the top k recommendations.
133+
134+ Given the top :math:`K` recommendations, P@K is calculated as:
135+
136+ .. math::
137+ Precision@K = \frac{\text{Number of relevant items in top K}}{K}
138+ """
139+
140+ @classmethod
141+ def precision_at_ks (
142+ cls , predictions : jax .Array , labels : jax .Array , ks : jax .Array
143+ ) -> jax .Array :
144+ """Computes P@k (precision at k) metrics for each of k in ks.
145+
146+ Args:
147+ predictions: A floating point 2D array representing the prediction
148+ scores from the model. Higher scores indicate higher relevance. The
149+ shape should be (batch_size, vocab_size).
150+ labels: A multi-hot encoding (0 or 1) of the true labels. The shape should
151+ be (batch_size, vocab_size).
152+ ks: A 1D array of integers representing the k's to compute the P@k
153+ metrics. The shape should be (|ks|).
154+
155+ Returns:
156+ A rank-2 array of shape (batch_size, |ks|) containing P@k metrics.
157+ """
158+ labels = jnp .array (labels >= 1 , dtype = jnp .float32 )
159+ indices_by_rank = jnp .argsort (- predictions , axis = 1 )
160+ labels_by_rank = jnp .take_along_axis (labels , indices_by_rank , axis = 1 )
161+ relevant_by_rank = jnp .cumsum (labels_by_rank , axis = 1 )
162+
163+ vocab_size = predictions .shape [1 ]
164+ relevant_at_k = relevant_by_rank [:, jnp .minimum (ks - 1 , vocab_size - 1 )]
165+ total_at_k = jnp .minimum (ks , vocab_size )
166+ return base .divide_no_nan (relevant_at_k , total_at_k )
167+
168+ @classmethod
169+ def from_model_output (
170+ cls ,
171+ predictions : jax .Array ,
172+ labels : jax .Array ,
173+ ks : jax .Array ,
174+ ) -> 'PrecisionAtK' :
175+ """Creates a PrecisionAtK metric instance from model output.
176+
177+ This computes the P@k for each example in the batch and then aggregates
178+ them (sum of P@k values and count of examples) to be averaged later by
179+ calling .compute() on the returned metric object.
180+
181+ Args:
182+ predictions: A floating point 2D array representing the prediction
183+ scores from the model. The shape should be (batch_size, vocab_size).
184+ labels: A multi-hot encoding (0 or 1) of the true labels. The shape should
185+ be (batch_size, vocab_size).
186+ ks: A 1D array of integers representing the k's to compute the P@k
187+ metrics. The shape should be (|ks|).
188+
189+ Returns:
190+ The PrecisionAtK metric object. The `total` field will have shape (|ks|),
191+ and `count` will be a scalar.
192+ """
193+ p_at_ks = cls .precision_at_ks (predictions , labels , ks )
194+ num_examples = jnp .array (labels .shape [0 ], dtype = jnp .float32 )
195+ return cls (total = p_at_ks .sum (axis = 0 ), count = num_examples )
0 commit comments