Skip to content

Commit 0049be0

Browse files
committed
feat!: improve ImageClassifier
We now use Albumentation for image classification tasks. The ImageClassifier now expects a numpy array (uint8) as input, instead of a Pillow Image.
1 parent d16bcc4 commit 0049be0

File tree

3 files changed

+162
-132
lines changed

3 files changed

+162
-132
lines changed

src/openfoodfacts/ml/image_classification.py

Lines changed: 75 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,55 @@
1+
import dataclasses
12
import logging
2-
import math
3-
import time
4-
import typing
5-
from typing import Optional
3+
import warnings
64

5+
import albumentations as A
76
import numpy as np
8-
from PIL import Image, ImageOps
97
from tritonclient.grpc import service_pb2
108

119
from openfoodfacts.ml.triton import (
1210
add_triton_infer_input_tensor,
1311
get_triton_inference_stub,
1412
)
13+
from openfoodfacts.utils import PerfTimer
1514

1615
logger = logging.getLogger(__name__)
1716

1817

19-
def classify_transforms(
20-
img: Image.Image,
21-
size: int = 224,
22-
mean: tuple[float, float, float] = (0.0, 0.0, 0.0),
23-
std: tuple[float, float, float] = (1.0, 1.0, 1.0),
24-
interpolation: Image.Resampling = Image.Resampling.BILINEAR,
25-
crop_fraction: float = 1.0,
26-
) -> np.ndarray:
18+
@dataclasses.dataclass
19+
class ImageClassificationResult:
20+
"""The result of an image classification model.
21+
22+
Attributes:
23+
predictions (list[tuple[str, float]]): The list of label names and their
24+
corresponding confidence scores, ordered by confidence score in
25+
descending order.
26+
metrics (dict[str, float]): The performance metrics of the classification.
27+
Each key is the name of the metric (a step in the inference
28+
process), and the value is the time taken in seconds.
29+
The following metrics are provided:
30+
- preprocess_time: time taken to preprocess the image
31+
- grpc_request_build_time: time taken to build the gRPC request
32+
- triton_inference_time: time taken for Triton inference
33+
- postprocess_time: time taken to postprocess the results
2734
"""
28-
Applies a series of image transformations including resizing, center
29-
cropping, normalization, and conversion to a NumPy array.
30-
31-
Transformation steps is based on the one used in the Ultralytics library:
32-
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/data/augment.py#L2319
33-
34-
:param img: Input Pillow image.
35-
:param size: The target size for the transformed image (shortest edge).
36-
:param mean: Mean values for each RGB channel used in normalization.
37-
:param std: Standard deviation values for each RGB channel used in
38-
normalization.
39-
:param interpolation: Interpolation method from PIL (
40-
Image.Resampling.NEAREST, Image.Resampling.BILINEAR,
41-
Image.Resampling.BICUBIC).
42-
:param crop_fraction: Fraction of the image to be cropped.
43-
:return: The transformed image as a NumPy array.
44-
"""
45-
if img.mode != "RGB":
46-
img = img.convert("RGB")
47-
48-
# Rotate the image based on the EXIF orientation if needed
49-
img = typing.cast(Image.Image, ImageOps.exif_transpose(img))
50-
51-
# Step 1: Resize while preserving the aspect ratio
52-
width, height = img.size
53-
54-
# Calculate scale size while preserving aspect ratio
55-
scale_size = math.floor(size / crop_fraction)
56-
57-
aspect_ratio = width / height
58-
if width < height:
59-
new_width = scale_size
60-
new_height = int(new_width / aspect_ratio)
61-
else:
62-
new_height = scale_size
63-
new_width = int(new_height * aspect_ratio)
64-
65-
img = img.resize((new_width, new_height), interpolation)
66-
67-
# Step 2: Center crop
68-
left = (new_width - size) // 2
69-
top = (new_height - size) // 2
70-
right = left + size
71-
bottom = top + size
72-
img = img.crop((left, top, right, bottom))
7335

74-
# Step 3: Convert the image to a NumPy array and scale pixel values to
75-
# [0, 1]
76-
img_array = np.array(img).astype(np.float32) / 255.0
36+
predictions: list[tuple[str, float]]
37+
metrics: dict[str, float]
7738

78-
# Step 4: Normalize the image
79-
mean_np = np.array(mean, dtype=np.float32).reshape(1, 1, 3)
80-
std_np = np.array(std, dtype=np.float32).reshape(1, 1, 3)
81-
img_array = (img_array - mean_np) / std_np
8239

83-
# Step 5: Change the order of dimensions from (H, W, C) to (C, H, W)
84-
img_array = np.transpose(img_array, (2, 0, 1))
85-
return img_array
40+
def _classify_transform(
41+
max_size: int,
42+
normalize_mean: tuple[float, float, float] = (0.0, 0.0, 0.0),
43+
normalize_std: tuple[float, float, float] = (1.0, 1.0, 1.0),
44+
):
45+
return A.Compose(
46+
[
47+
A.LongestMaxSize(max_size=max_size, p=1.0),
48+
A.PadIfNeeded(min_height=max_size, min_width=max_size, p=1.0),
49+
A.ToRGB(p=1.0),
50+
A.Normalize(mean=normalize_mean, std=normalize_std, p=1.0),
51+
]
52+
)
8653

8754

8855
class ImageClassifier:
@@ -101,49 +68,58 @@ def __init__(self, model_name: str, label_names: list[str], image_size: int = 22
10168

10269
def predict(
10370
self,
104-
image: Image.Image,
71+
image: np.ndarray,
10572
triton_uri: str,
106-
model_version: Optional[str] = None,
107-
) -> list[tuple[str, float]]:
73+
model_version: str | None = None,
74+
) -> ImageClassificationResult:
10875
"""Run an image classification model on an image.
10976
11077
The model is expected to have been trained with Ultralytics library
111-
(Yolov8).
78+
(any Yolo classification model).
11279
113-
:param image: the input Pillow image
80+
:param image: the input NumPy array image
11481
:param triton_uri: URI of the Triton Inference Server, defaults to
11582
None. If not provided, the default value from settings is used.
116-
:return: the prediction results as a list of tuples (label, confidence)
83+
:param model_version: the version of the model to use, defaults to
84+
None. If not provided, the latest version is used.
85+
:return: the prediction results as an ImageClassificationResult
11786
"""
118-
image_array = self.preprocess(image)
119-
120-
grpc_stub = get_triton_inference_stub(triton_uri)
121-
request = service_pb2.ModelInferRequest()
122-
request.model_name = self.model_name
123-
if model_version:
124-
request.model_version = model_version
125-
add_triton_infer_input_tensor(
126-
request, name="images", data=image_array, datatype="FP32"
127-
)
128-
start_time = time.monotonic()
129-
response = grpc_stub.ModelInfer(request)
130-
latency = time.monotonic() - start_time
131-
logger.debug("Inference time for %s: %s", self.model_name, latency)
132-
133-
start_time = time.monotonic()
134-
result = self.postprocess(response)
135-
latency = time.monotonic() - start_time
136-
logger.debug("Post-processing time for %s: %s", self.model_name, latency)
137-
return result
138-
139-
def preprocess(self, image: Image.Image) -> np.ndarray:
87+
metrics: dict[str, float] = {}
88+
89+
with PerfTimer("preprocess_time", metrics):
90+
image_array = self.preprocess(image)
91+
92+
with PerfTimer("grpc_request_build_time", metrics):
93+
request = service_pb2.ModelInferRequest()
94+
request.model_name = self.model_name
95+
if model_version:
96+
request.model_version = model_version
97+
add_triton_infer_input_tensor(
98+
request, name="images", data=image_array, datatype="FP32"
99+
)
100+
101+
with PerfTimer("triton_inference_time", metrics):
102+
grpc_stub = get_triton_inference_stub(triton_uri)
103+
response = grpc_stub.ModelInfer(request)
104+
105+
with PerfTimer("postprocess_time", metrics):
106+
predictions = self.postprocess(response)
107+
108+
return ImageClassificationResult(predictions=predictions, metrics=metrics)
109+
110+
def preprocess(self, image: np.ndarray) -> np.ndarray:
140111
"""Preprocess an image for object detection.
141112
142-
:param image: the input Pillow image
113+
:param image: the input NumPy array image
143114
:return: the preprocessed image as a NumPy array
144115
"""
145-
image_array = classify_transforms(image, size=self.image_size)
146-
return np.expand_dims(image_array, axis=0)
116+
with warnings.catch_warnings():
117+
warnings.filterwarnings("ignore", message="The image is already an RGB")
118+
image_array = _classify_transform(max_size=self.image_size)(image=image)[
119+
"image"
120+
]
121+
image_array = np.transpose(image_array, (2, 0, 1))[np.newaxis, :] # HWC to CHW
122+
return image_array
147123

148124
def postprocess(
149125
self, response: service_pb2.ModelInferResponse

tests/ml/test_image_classification.py

Lines changed: 86 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,91 @@
1+
import warnings
12
from unittest.mock import MagicMock, patch
23

34
import numpy as np
45
from PIL import Image
56

6-
from openfoodfacts.ml.image_classification import ImageClassifier, classify_transforms
7+
from openfoodfacts.ml.image_classification import (
8+
ImageClassificationResult,
9+
ImageClassifier,
10+
classify_transform,
11+
)
712

813

9-
class TestClassifyTransforms:
14+
class TestClassifyTransform:
1015
def test_rgb_image(self):
11-
img = Image.new("RGB", (300, 300), color="red")
12-
transformed_img = classify_transforms(img)
13-
assert transformed_img.shape == (3, 224, 224)
16+
img = np.array(Image.new("RGB", (300, 300), color="red"))
17+
with warnings.catch_warnings():
18+
warnings.filterwarnings("ignore", message="The image is already an RGB")
19+
transformed_img = classify_transform(max_size=224)(image=img)["image"]
20+
assert transformed_img.shape == (224, 224, 3)
1421
assert transformed_img.dtype == np.float32
1522

23+
def test_non_square_image_aspect_ratio_lt_1(self):
24+
# width=150, height=300
25+
img = np.array(Image.new("RGB", (150, 300), color="red"))
26+
with warnings.catch_warnings():
27+
warnings.filterwarnings("ignore", message="The image is already an RGB")
28+
transformed_img = classify_transform(max_size=300)(image=img)["image"]
29+
assert transformed_img.shape == (300, 300, 3)
30+
assert transformed_img.dtype == np.float32
31+
# assert that the green and blue channels are zero
32+
assert np.sum(transformed_img[:, :, 1:3]) == 0.0
33+
# image is in HWC
34+
red_channel = transformed_img[:, :, 0]
35+
assert np.all(red_channel[:, :75] == 0.0)
36+
assert np.all(red_channel[:, 75:150] == 1.0)
37+
assert np.all(red_channel[:, 225:] == 0.0)
38+
39+
def test_non_square_image_aspect_ratio_gt_1(self):
40+
# width=600, height=300
41+
img = np.array(Image.new("RGB", (600, 300), color="red"))
42+
with warnings.catch_warnings():
43+
warnings.filterwarnings("ignore", message="The image is already an RGB")
44+
transformed_img = classify_transform(max_size=300)(image=img)["image"]
45+
assert transformed_img.shape == (300, 300, 3)
46+
assert transformed_img.dtype == np.float32
47+
# assert that the green and blue channels are zero
48+
assert np.sum(transformed_img[:, :, 1:3]) == 0.0
49+
# image is in HWC
50+
red_channel = transformed_img[:, :, 0]
51+
assert np.all(red_channel[:75, :] == 0.0)
52+
assert np.all(red_channel[75:150, :] == 1.0)
53+
assert np.all(red_channel[225:, :] == 0.0)
54+
1655
def test_non_rgb_image(self):
17-
img = Image.new("L", (300, 300), color="red")
18-
transformed_img = classify_transforms(img)
19-
assert transformed_img.shape == (3, 224, 224)
56+
img = np.array(Image.new("L", (300, 300), color="red"))
57+
transformed_img = classify_transform(max_size=224)(image=img)["image"]
58+
assert transformed_img.shape == (224, 224, 3)
2059
assert transformed_img.dtype == np.float32
2160

2261
def test_custom_size(self):
23-
img = Image.new("RGB", (300, 300), color="red")
24-
transformed_img = classify_transforms(img, size=128)
25-
assert transformed_img.shape == (3, 128, 128)
62+
img = np.array(Image.new("RGB", (300, 300), color="red"))
63+
with warnings.catch_warnings():
64+
warnings.filterwarnings("ignore", message="The image is already an RGB")
65+
transformed_img = classify_transform(max_size=128)(image=img)["image"]
66+
assert transformed_img.shape == (128, 128, 3)
2667
assert transformed_img.dtype == np.float32
2768

2869
def test_custom_mean_std(self):
29-
img = Image.new("RGB", (300, 300), color="red")
70+
img = np.array(Image.new("RGB", (300, 300), color="red"))
3071
mean = (0.5, 0.5, 0.5)
3172
std = (0.5, 0.5, 0.5)
32-
transformed_img = classify_transforms(img, mean=mean, std=std)
33-
assert transformed_img.shape == (3, 224, 224)
73+
with warnings.catch_warnings():
74+
warnings.filterwarnings("ignore", message="The image is already an RGB")
75+
transformed_img = classify_transform(max_size=224)(
76+
image=img, normalize_mean=mean, normalize_std=std
77+
)["image"]
78+
assert transformed_img.shape == (224, 224, 3)
3479
assert transformed_img.dtype == np.float32
3580

3681
def test_custom_interpolation(self):
37-
img = Image.new("RGB", (300, 300), color="red")
38-
transformed_img = classify_transforms(
39-
img, interpolation=Image.Resampling.NEAREST
40-
)
41-
assert transformed_img.shape == (3, 224, 224)
42-
assert transformed_img.dtype == np.float32
43-
44-
def test_custom_crop_fraction(self):
45-
img = Image.new("RGB", (300, 300), color="red")
46-
transformed_img = classify_transforms(img, crop_fraction=0.8)
47-
assert transformed_img.shape == (3, 224, 224)
82+
img = np.array(Image.new("RGB", (300, 300), color="red"))
83+
with warnings.catch_warnings():
84+
warnings.filterwarnings("ignore", message="The image is already an RGB")
85+
transformed_img = classify_transform(max_size=224)(
86+
image=img, interpolation=Image.Resampling.NEAREST
87+
)["image"]
88+
assert transformed_img.shape == (224, 224, 3)
4889
assert transformed_img.dtype == np.float32
4990

5091

@@ -55,13 +96,15 @@ def __init__(self, name):
5596

5697
class TestImageClassifier:
5798
def test_preprocess_rgb_image(self):
58-
img = Image.new("RGB", (300, 300), color="red")
99+
img = np.array(Image.new("RGB", (300, 300), color="red"))
59100
classifier = ImageClassifier(
60101
model_name="test_model", label_names=["label1", "label2"]
61102
)
62103
preprocessed_img = classifier.preprocess(img)
63104
assert preprocessed_img.shape == (1, 3, 224, 224)
64105
assert preprocessed_img.dtype == np.float32
106+
assert np.all(preprocessed_img[:, 0, :, :] == 1.0) # red channel
107+
assert np.all(preprocessed_img[:, 1:, :, :] == 0.0) # green and blue channels
65108

66109
def test_postprocess_single_output(self):
67110
classifier = ImageClassifier(
@@ -115,7 +158,7 @@ def test_postprocess_multiple_raw_output_contents(self):
115158
assert str(e) == "expected 1 raw output content, got 2"
116159

117160
def test_predict(self):
118-
img = Image.new("RGB", (300, 300), color="red")
161+
img = np.array(Image.new("RGB", (300, 300), color="red"))
119162
classifier = ImageClassifier(
120163
model_name="test_model", label_names=["label1", "label2"]
121164
)
@@ -144,11 +187,22 @@ def test_predict(self):
144187
):
145188
result = classifier.predict(img, triton_uri)
146189

147-
assert len(result) == 2
148-
assert result[0][0] == "label1"
149-
assert np.isclose(result[0][1], 0.8)
150-
assert result[1][0] == "label2"
151-
assert np.isclose(result[1][1], 0.2)
190+
assert isinstance(result, ImageClassificationResult)
191+
predictions = result.predictions
192+
assert len(predictions) == 2
193+
assert predictions[0][0] == "label1"
194+
assert np.isclose(predictions[0][1], 0.8)
195+
assert predictions[1][0] == "label2"
196+
assert np.isclose(predictions[1][1], 0.2)
197+
198+
assert isinstance(result.metrics, dict)
199+
assert result.metrics.keys() == {
200+
"preprocess_time",
201+
"grpc_request_build_time",
202+
"triton_inference_time",
203+
"postprocess_time",
204+
}
205+
assert all(isinstance(value, float) for value in result.metrics.values())
152206

153207
classifier.preprocess.assert_called_once_with(img)
154208
grpc_stub.ModelInfer.assert_called_once()

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)