@@ -144,16 +144,16 @@ def precision_at_ks(
144144 """Computes P@k (precision at k) metrics for each of k in ks.
145145
146146 Args:
147- predictions: A floating point 2D JAX array representing the prediction
147+ predictions: A floating point 2D array representing the prediction
148148 scores from the model. Higher scores indicate higher relevance. The
149149 shape should be (batch_size, vocab_size).
150150 labels: A multi-hot encoding (0 or 1) of the true labels. The shape should
151151 be (batch_size, vocab_size).
152- ks: A 1D JAX array of integers representing the k's to compute the P@k
152+ ks: A 1D array of integers representing the k's to compute the P@k
153153 metrics. The shape should be (|ks|).
154154
155155 Returns:
156- A rank-2 JAX array of shape (batch_size, |ks|) containing P@k metrics.
156+ A rank-2 array of shape (batch_size, |ks|) containing P@k metrics.
157157 """
158158 labels = jnp .array (labels >= 1 , dtype = jnp .float32 )
159159 indices_by_rank = jnp .argsort (- predictions , axis = 1 )
@@ -179,11 +179,11 @@ def from_model_output(
179179 calling .compute() on the returned metric object.
180180
181181 Args:
182- predictions: A floating point 2D JAX array representing the prediction
182+ predictions: A floating point 2D array representing the prediction
183183 scores from the model. The shape should be (batch_size, vocab_size).
184184 labels: A multi-hot encoding (0 or 1) of the true labels. The shape should
185185 be (batch_size, vocab_size).
186- ks: A 1D JAX array of integers representing the k's to compute the P@k
186+ ks: A 1D array of integers representing the k's to compute the P@k
187187 metrics. The shape should be (|ks|).
188188
189189 Returns:
0 commit comments