Skip to content

Commit d451f0a

Browse files
committed
drop JAX from JAX arrays
1 parent af38603 commit d451f0a

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/metrax/ranking_metrics.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,16 @@ def precision_at_ks(
144144
"""Computes P@k (precision at k) metrics for each of k in ks.
145145
146146
Args:
147-
predictions: A floating point 2D JAX array representing the prediction
147+
predictions: A floating point 2D array representing the prediction
148148
scores from the model. Higher scores indicate higher relevance. The
149149
shape should be (batch_size, vocab_size).
150150
labels: A multi-hot encoding (0 or 1) of the true labels. The shape should
151151
be (batch_size, vocab_size).
152-
ks: A 1D JAX array of integers representing the k's to compute the P@k
152+
ks: A 1D array of integers representing the k's to compute the P@k
153153
metrics. The shape should be (|ks|).
154154
155155
Returns:
156-
A rank-2 JAX array of shape (batch_size, |ks|) containing P@k metrics.
156+
A rank-2 array of shape (batch_size, |ks|) containing P@k metrics.
157157
"""
158158
labels = jnp.array(labels >= 1, dtype=jnp.float32)
159159
indices_by_rank = jnp.argsort(-predictions, axis=1)
@@ -179,11 +179,11 @@ def from_model_output(
179179
calling .compute() on the returned metric object.
180180
181181
Args:
182-
predictions: A floating point 2D JAX array representing the prediction
182+
predictions: A floating point 2D array representing the prediction
183183
scores from the model. The shape should be (batch_size, vocab_size).
184184
labels: A multi-hot encoding (0 or 1) of the true labels. The shape should
185185
be (batch_size, vocab_size).
186-
ks: A 1D JAX array of integers representing the k's to compute the P@k
186+
ks: A 1D array of integers representing the k's to compute the P@k
187187
metrics. The shape should be (|ks|).
188188
189189
Returns:

0 commit comments

Comments
 (0)