1+ import warnings
12from unittest .mock import MagicMock , patch
23
34import numpy as np
45from PIL import Image
56
6- from openfoodfacts .ml .image_classification import ImageClassifier , classify_transforms
7+ from openfoodfacts .ml .image_classification import (
8+ ImageClassificationResult ,
9+ ImageClassifier ,
10+ _classify_transform ,
11+ )
712
813
9- class TestClassifyTransforms :
14+ class TestClassifyTransform :
1015 def test_rgb_image (self ):
11- img = Image .new ("RGB" , (300 , 300 ), color = "red" )
12- transformed_img = classify_transforms (img )
13- assert transformed_img .shape == (3 , 224 , 224 )
16+ img = np .array (Image .new ("RGB" , (300 , 300 ), color = "red" ))
17+ with warnings .catch_warnings ():
18+ warnings .filterwarnings ("ignore" , message = "The image is already an RGB" )
19+ transformed_img = _classify_transform (max_size = 224 )(image = img )["image" ]
20+ assert transformed_img .shape == (224 , 224 , 3 )
1421 assert transformed_img .dtype == np .float32
1522
23+ def test_non_square_image_aspect_ratio_lt_1 (self ):
24+ # width=150, height=300
25+ img = np .array (Image .new ("RGB" , (150 , 300 ), color = "red" ))
26+ with warnings .catch_warnings ():
27+ warnings .filterwarnings ("ignore" , message = "The image is already an RGB" )
28+ transformed_img = _classify_transform (max_size = 300 )(image = img )["image" ]
29+ assert transformed_img .shape == (300 , 300 , 3 )
30+ assert transformed_img .dtype == np .float32
31+ # assert that the green and blue channels are zero
32+ assert np .sum (transformed_img [:, :, 1 :3 ]) == 0.0
33+ # image is in HWC
34+ red_channel = transformed_img [:, :, 0 ]
35+ assert np .all (red_channel [:, :75 ] == 0.0 )
36+ assert np .all (red_channel [:, 75 :150 ] == 1.0 )
37+ assert np .all (red_channel [:, 225 :] == 0.0 )
38+
39+ def test_non_square_image_aspect_ratio_gt_1 (self ):
40+ # width=600, height=300
41+ img = np .array (Image .new ("RGB" , (600 , 300 ), color = "red" ))
42+ with warnings .catch_warnings ():
43+ warnings .filterwarnings ("ignore" , message = "The image is already an RGB" )
44+ transformed_img = _classify_transform (max_size = 300 )(image = img )["image" ]
45+ assert transformed_img .shape == (300 , 300 , 3 )
46+ assert transformed_img .dtype == np .float32
47+ # assert that the green and blue channels are zero
48+ assert np .sum (transformed_img [:, :, 1 :3 ]) == 0.0
49+ # image is in HWC
50+ red_channel = transformed_img [:, :, 0 ]
51+ assert np .all (red_channel [:75 , :] == 0.0 )
52+ assert np .all (red_channel [75 :150 , :] == 1.0 )
53+ assert np .all (red_channel [225 :, :] == 0.0 )
54+
1655 def test_non_rgb_image (self ):
17- img = Image .new ("L" , (300 , 300 ), color = "red" )
18- transformed_img = classify_transforms ( img )
19- assert transformed_img .shape == (3 , 224 , 224 )
56+ img = np . array ( Image .new ("L" , (300 , 300 ), color = "red" ) )
57+ transformed_img = _classify_transform ( max_size = 224 )( image = img )[ "image" ]
58+ assert transformed_img .shape == (224 , 224 , 3 )
2059 assert transformed_img .dtype == np .float32
2160
2261 def test_custom_size (self ):
23- img = Image .new ("RGB" , (300 , 300 ), color = "red" )
24- transformed_img = classify_transforms (img , size = 128 )
25- assert transformed_img .shape == (3 , 128 , 128 )
62+ img = np .array (Image .new ("RGB" , (300 , 300 ), color = "red" ))
63+ with warnings .catch_warnings ():
64+ warnings .filterwarnings ("ignore" , message = "The image is already an RGB" )
65+ transformed_img = _classify_transform (max_size = 128 )(image = img )["image" ]
66+ assert transformed_img .shape == (128 , 128 , 3 )
2667 assert transformed_img .dtype == np .float32
2768
2869 def test_custom_mean_std (self ):
29- img = Image .new ("RGB" , (300 , 300 ), color = "red" )
70+ img = np . array ( Image .new ("RGB" , (300 , 300 ), color = "red" ) )
3071 mean = (0.5 , 0.5 , 0.5 )
3172 std = (0.5 , 0.5 , 0.5 )
32- transformed_img = classify_transforms (img , mean = mean , std = std )
33- assert transformed_img .shape == (3 , 224 , 224 )
73+ with warnings .catch_warnings ():
74+ warnings .filterwarnings ("ignore" , message = "The image is already an RGB" )
75+ transformed_img = _classify_transform (max_size = 224 )(
76+ image = img , normalize_mean = mean , normalize_std = std
77+ )["image" ]
78+ assert transformed_img .shape == (224 , 224 , 3 )
3479 assert transformed_img .dtype == np .float32
3580
3681 def test_custom_interpolation (self ):
37- img = Image .new ("RGB" , (300 , 300 ), color = "red" )
38- transformed_img = classify_transforms (
39- img , interpolation = Image .Resampling .NEAREST
40- )
41- assert transformed_img .shape == (3 , 224 , 224 )
42- assert transformed_img .dtype == np .float32
43-
44- def test_custom_crop_fraction (self ):
45- img = Image .new ("RGB" , (300 , 300 ), color = "red" )
46- transformed_img = classify_transforms (img , crop_fraction = 0.8 )
47- assert transformed_img .shape == (3 , 224 , 224 )
82+ img = np .array (Image .new ("RGB" , (300 , 300 ), color = "red" ))
83+ with warnings .catch_warnings ():
84+ warnings .filterwarnings ("ignore" , message = "The image is already an RGB" )
85+ transformed_img = _classify_transform (max_size = 224 )(
86+ image = img , interpolation = Image .Resampling .NEAREST
87+ )["image" ]
88+ assert transformed_img .shape == (224 , 224 , 3 )
4889 assert transformed_img .dtype == np .float32
4990
5091
@@ -55,13 +96,15 @@ def __init__(self, name):
5596
5697class TestImageClassifier :
5798 def test_preprocess_rgb_image (self ):
58- img = Image .new ("RGB" , (300 , 300 ), color = "red" )
99+ img = np . array ( Image .new ("RGB" , (300 , 300 ), color = "red" ) )
59100 classifier = ImageClassifier (
60101 model_name = "test_model" , label_names = ["label1" , "label2" ]
61102 )
62103 preprocessed_img = classifier .preprocess (img )
63104 assert preprocessed_img .shape == (1 , 3 , 224 , 224 )
64105 assert preprocessed_img .dtype == np .float32
106+ assert np .all (preprocessed_img [:, 0 , :, :] == 1.0 ) # red channel
107+ assert np .all (preprocessed_img [:, 1 :, :, :] == 0.0 ) # green and blue channels
65108
66109 def test_postprocess_single_output (self ):
67110 classifier = ImageClassifier (
@@ -115,7 +158,7 @@ def test_postprocess_multiple_raw_output_contents(self):
115158 assert str (e ) == "expected 1 raw output content, got 2"
116159
117160 def test_predict (self ):
118- img = Image .new ("RGB" , (300 , 300 ), color = "red" )
161+ img = np . array ( Image .new ("RGB" , (300 , 300 ), color = "red" ) )
119162 classifier = ImageClassifier (
120163 model_name = "test_model" , label_names = ["label1" , "label2" ]
121164 )
@@ -144,11 +187,22 @@ def test_predict(self):
144187 ):
145188 result = classifier .predict (img , triton_uri )
146189
147- assert len (result ) == 2
148- assert result [0 ][0 ] == "label1"
149- assert np .isclose (result [0 ][1 ], 0.8 )
150- assert result [1 ][0 ] == "label2"
151- assert np .isclose (result [1 ][1 ], 0.2 )
190+ assert isinstance (result , ImageClassificationResult )
191+ predictions = result .predictions
192+ assert len (predictions ) == 2
193+ assert predictions [0 ][0 ] == "label1"
194+ assert np .isclose (predictions [0 ][1 ], 0.8 )
195+ assert predictions [1 ][0 ] == "label2"
196+ assert np .isclose (predictions [1 ][1 ], 0.2 )
197+
198+ assert isinstance (result .metrics , dict )
199+ assert result .metrics .keys () == {
200+ "preprocess_time" ,
201+ "grpc_request_build_time" ,
202+ "triton_inference_time" ,
203+ "postprocess_time" ,
204+ }
205+ assert all (isinstance (value , float ) for value in result .metrics .values ())
152206
153207 classifier .preprocess .assert_called_once_with (img )
154208 grpc_stub .ModelInfer .assert_called_once ()
0 commit comments