Skip to content

Commit 6efbb36

Browse files
target logits bug fix
1 parent bfb9533 commit 6efbb36

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

speciesnet/classifier.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,9 @@ def batch_predict(
240240
scores = torch.softmax(logits, dim=-1)
241241
scores, indices = torch.topk(scores, k=5, dim=-1)
242242

243-
for filepath, scores_arr, indices_arr in zip(
243+
for file_idx, (filepath, scores_arr, indices_arr) in enumerate(zip(
244244
inference_filepaths, scores.numpy(), indices.numpy()
245-
):
245+
)):
246246

247247
predictions[filepath] = {
248248
"filepath": filepath,
@@ -257,7 +257,7 @@ def batch_predict(
257257
{
258258
"target_classes": self.target_labels,
259259
"target_logits": [
260-
float(logits[0][idx]) for idx in self.target_idx
260+
float(logits[file_idx][idx]) for idx in self.target_idx
261261
],
262262
}
263263
)

0 commit comments

Comments
 (0)