Skip to content

Commit f8ea337

Browse files
Target logits bug fix (#45)
* target logits bug fix * updated formatting * Added test to evaluate batched vs non-batched outputs
1 parent bfb9533 commit f8ea337

File tree

2 files changed

+105
-3
lines changed

2 files changed

+105
-3
lines changed

speciesnet/classifier.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,8 @@ def batch_predict(
240240
scores = torch.softmax(logits, dim=-1)
241241
scores, indices = torch.topk(scores, k=5, dim=-1)
242242

243-
for filepath, scores_arr, indices_arr in zip(
244-
inference_filepaths, scores.numpy(), indices.numpy()
243+
for file_idx, (filepath, scores_arr, indices_arr) in enumerate(
244+
zip(inference_filepaths, scores.numpy(), indices.numpy())
245245
):
246246

247247
predictions[filepath] = {
@@ -257,7 +257,7 @@ def batch_predict(
257257
{
258258
"target_classes": self.target_labels,
259259
"target_logits": [
260-
float(logits[0][idx]) for idx in self.target_idx
260+
float(logits[file_idx][idx]) for idx in self.target_idx
261261
],
262262
}
263263
)

speciesnet/classifier_test.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,105 @@ def test_classifications(self, predicted_vs_expected) -> None:
322322
assert classifications["scores"] == sorted(
323323
classifications["scores"], reverse=True
324324
)
325+
326+
def test_target_species_batched_vs_non_batched(
327+
self, model_name: str, tmp_path
328+
) -> None:
329+
"""Test that target_species_txt works consistently
330+
with batch and non-batch predict."""
331+
332+
# Create a temporary target species file with a subset of species
333+
target_species_file = tmp_path / "target_species.txt"
334+
target_species = [
335+
AFRICAN_ELEPHANT,
336+
DOMESTIC_DOG,
337+
HUMAN,
338+
BLANK,
339+
]
340+
target_species_file.write_text("\n".join(target_species) + "\n")
341+
342+
# Create a classifier with target_species_txt
343+
classifier_with_targets = SpeciesNetClassifier(
344+
model_name,
345+
target_species_txt=str(target_species_file),
346+
)
347+
348+
# Test images with various species
349+
test_cases = [
350+
("test_data/african_elephants.jpg", [BBox(0.7041, 0.4765, 0.1108, 0.125)]),
351+
("test_data/domestic_dog.jpg", [BBox(0.2377, 0.08398, 0.5161, 0.6497)]),
352+
("test_data/human.jpg", [BBox(0.7115, 0.4976, 0.0664, 0.2424)]),
353+
("test_data/blank.jpg", []),
354+
]
355+
356+
# Preprocess all images
357+
filepaths = []
358+
preprocessed_imgs = []
359+
for filepath, bboxes in test_cases:
360+
img = classifier_with_targets.preprocess(
361+
load_rgb_image(filepath), bboxes=bboxes
362+
)
363+
filepaths.append(filepath)
364+
preprocessed_imgs.append(img)
365+
366+
# Test 1: Non-batched prediction (batch_size=1)
367+
non_batched_predictions = []
368+
for filepath, img in zip(filepaths, preprocessed_imgs):
369+
prediction = classifier_with_targets.predict(filepath, img)
370+
non_batched_predictions.append(prediction)
371+
372+
# Test 2: Batched prediction (batch_size>1)
373+
batched_predictions = classifier_with_targets.batch_predict(
374+
filepaths, preprocessed_imgs
375+
)
376+
377+
# Verify that both approaches produce identical results
378+
assert len(non_batched_predictions) == len(batched_predictions)
379+
380+
for i, (non_batched, batched) in enumerate(
381+
zip(non_batched_predictions, batched_predictions)
382+
):
383+
# Check that both have target_logits
384+
assert "target_logits" in non_batched["classifications"]
385+
assert "target_logits" in batched["classifications"]
386+
387+
# Check that target_classes are present and identical
388+
assert "target_classes" in non_batched["classifications"]
389+
assert "target_classes" in batched["classifications"]
390+
assert (
391+
non_batched["classifications"]["target_classes"]
392+
== batched["classifications"]["target_classes"]
393+
)
394+
395+
# Check that target_logits are identical
396+
non_batched_logits = non_batched["classifications"]["target_logits"]
397+
batched_logits = batched["classifications"]["target_logits"]
398+
399+
assert len(non_batched_logits) == len(batched_logits)
400+
assert len(non_batched_logits) == len(target_species)
401+
402+
# Use np.allclose for floating point comparison
403+
# Note: Using relaxed tolerances to account for minor numerical differences
404+
# in batched vs non-batched processing (e.g., from fp32 operations).
405+
# If this test fails with larger differences, it indicates a bug where
406+
# batched and non-batched predictions produce different results.
407+
np.testing.assert_allclose(
408+
non_batched_logits,
409+
batched_logits,
410+
rtol=1e-3, # 0.1% relative tolerance
411+
atol=1e-3, # 0.001 absolute tolerance
412+
err_msg=f"Target logits mismatch for image {i} ({filepaths[i]})",
413+
)
414+
415+
# Also verify that regular classifications match
416+
assert (
417+
non_batched["classifications"]["classes"]
418+
== batched["classifications"]["classes"]
419+
)
420+
np.testing.assert_allclose(
421+
non_batched["classifications"]["scores"],
422+
batched["classifications"]["scores"],
423+
rtol=1e-3, # 0.1% relative tolerance
424+
atol=1e-5, # 0.00001 absolute tolerance
425+
err_msg=f"Scores mismatch for image {i} ({filepaths[i]})",
426+
)

0 commit comments

Comments
 (0)