@@ -322,3 +322,105 @@ def test_classifications(self, predicted_vs_expected) -> None:
322322 assert classifications ["scores" ] == sorted (
323323 classifications ["scores" ], reverse = True
324324 )
325+
326+ def test_target_species_batched_vs_non_batched (
327+ self , model_name : str , tmp_path
328+ ) -> None :
329+ """Test that target_species_txt works consistently
330+ with batch and non-batch predict."""
331+
332+ # Create a temporary target species file with a subset of species
333+ target_species_file = tmp_path / "target_species.txt"
334+ target_species = [
335+ AFRICAN_ELEPHANT ,
336+ DOMESTIC_DOG ,
337+ HUMAN ,
338+ BLANK ,
339+ ]
340+ target_species_file .write_text ("\n " .join (target_species ) + "\n " )
341+
342+ # Create a classifier with target_species_txt
343+ classifier_with_targets = SpeciesNetClassifier (
344+ model_name ,
345+ target_species_txt = str (target_species_file ),
346+ )
347+
348+ # Test images with various species
349+ test_cases = [
350+ ("test_data/african_elephants.jpg" , [BBox (0.7041 , 0.4765 , 0.1108 , 0.125 )]),
351+ ("test_data/domestic_dog.jpg" , [BBox (0.2377 , 0.08398 , 0.5161 , 0.6497 )]),
352+ ("test_data/human.jpg" , [BBox (0.7115 , 0.4976 , 0.0664 , 0.2424 )]),
353+ ("test_data/blank.jpg" , []),
354+ ]
355+
356+ # Preprocess all images
357+ filepaths = []
358+ preprocessed_imgs = []
359+ for filepath , bboxes in test_cases :
360+ img = classifier_with_targets .preprocess (
361+ load_rgb_image (filepath ), bboxes = bboxes
362+ )
363+ filepaths .append (filepath )
364+ preprocessed_imgs .append (img )
365+
366+ # Test 1: Non-batched prediction (batch_size=1)
367+ non_batched_predictions = []
368+ for filepath , img in zip (filepaths , preprocessed_imgs ):
369+ prediction = classifier_with_targets .predict (filepath , img )
370+ non_batched_predictions .append (prediction )
371+
372+ # Test 2: Batched prediction (batch_size>1)
373+ batched_predictions = classifier_with_targets .batch_predict (
374+ filepaths , preprocessed_imgs
375+ )
376+
377+ # Verify that both approaches produce identical results
378+ assert len (non_batched_predictions ) == len (batched_predictions )
379+
380+ for i , (non_batched , batched ) in enumerate (
381+ zip (non_batched_predictions , batched_predictions )
382+ ):
383+ # Check that both have target_logits
384+ assert "target_logits" in non_batched ["classifications" ]
385+ assert "target_logits" in batched ["classifications" ]
386+
387+ # Check that target_classes are present and identical
388+ assert "target_classes" in non_batched ["classifications" ]
389+ assert "target_classes" in batched ["classifications" ]
390+ assert (
391+ non_batched ["classifications" ]["target_classes" ]
392+ == batched ["classifications" ]["target_classes" ]
393+ )
394+
395+ # Check that target_logits are identical
396+ non_batched_logits = non_batched ["classifications" ]["target_logits" ]
397+ batched_logits = batched ["classifications" ]["target_logits" ]
398+
399+ assert len (non_batched_logits ) == len (batched_logits )
400+ assert len (non_batched_logits ) == len (target_species )
401+
402+ # Use np.allclose for floating point comparison
403+ # Note: Using relaxed tolerances to account for minor numerical differences
404+ # in batched vs non-batched processing (e.g., from fp32 operations).
405+ # If this test fails with larger differences, it indicates a bug where
406+ # batched and non-batched predictions produce different results.
407+ np .testing .assert_allclose (
408+ non_batched_logits ,
409+ batched_logits ,
410+ rtol = 1e-3 , # 0.1% relative tolerance
411+ atol = 1e-3 , # 0.001 absolute tolerance
412+ err_msg = f"Target logits mismatch for image { i } ({ filepaths [i ]} )" ,
413+ )
414+
415+ # Also verify that regular classifications match
416+ assert (
417+ non_batched ["classifications" ]["classes" ]
418+ == batched ["classifications" ]["classes" ]
419+ )
420+ np .testing .assert_allclose (
421+ non_batched ["classifications" ]["scores" ],
422+ batched ["classifications" ]["scores" ],
423+ rtol = 1e-3 , # 0.1% relative tolerance
424+ atol = 1e-5 , # 0.00001 absolute tolerance
425+ err_msg = f"Scores mismatch for image { i } ({ filepaths [i ]} )" ,
426+ )
0 commit comments