Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 75 additions & 99 deletions src/openfoodfacts/ml/image_classification.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,55 @@
import dataclasses
import logging
import math
import time
import typing
from typing import Optional
import warnings

import albumentations as A
import numpy as np
from PIL import Image, ImageOps
from tritonclient.grpc import service_pb2

from openfoodfacts.ml.triton import (
add_triton_infer_input_tensor,
get_triton_inference_stub,
)
from openfoodfacts.utils import PerfTimer

logger = logging.getLogger(__name__)


def classify_transforms(
img: Image.Image,
size: int = 224,
mean: tuple[float, float, float] = (0.0, 0.0, 0.0),
std: tuple[float, float, float] = (1.0, 1.0, 1.0),
interpolation: Image.Resampling = Image.Resampling.BILINEAR,
crop_fraction: float = 1.0,
) -> np.ndarray:
@dataclasses.dataclass
class ImageClassificationResult:
"""The result of an image classification model.

Attributes:
predictions (list[tuple[str, float]]): The list of label names and their
corresponding confidence scores, ordered by confidence score in
descending order.
metrics (dict[str, float]): The performance metrics of the classification.
Each key is the name of the metric (a step in the inference
process), and the value is the time taken in seconds.
The following metrics are provided:
- preprocess_time: time taken to preprocess the image
- grpc_request_build_time: time taken to build the gRPC request
- triton_inference_time: time taken for Triton inference
- postprocess_time: time taken to postprocess the results
"""
Applies a series of image transformations including resizing, center
cropping, normalization, and conversion to a NumPy array.

Transformation steps is based on the one used in the Ultralytics library:
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/data/augment.py#L2319

:param img: Input Pillow image.
:param size: The target size for the transformed image (shortest edge).
:param mean: Mean values for each RGB channel used in normalization.
:param std: Standard deviation values for each RGB channel used in
normalization.
:param interpolation: Interpolation method from PIL (
Image.Resampling.NEAREST, Image.Resampling.BILINEAR,
Image.Resampling.BICUBIC).
:param crop_fraction: Fraction of the image to be cropped.
:return: The transformed image as a NumPy array.
"""
if img.mode != "RGB":
img = img.convert("RGB")

# Rotate the image based on the EXIF orientation if needed
img = typing.cast(Image.Image, ImageOps.exif_transpose(img))

# Step 1: Resize while preserving the aspect ratio
width, height = img.size

# Calculate scale size while preserving aspect ratio
scale_size = math.floor(size / crop_fraction)

aspect_ratio = width / height
if width < height:
new_width = scale_size
new_height = int(new_width / aspect_ratio)
else:
new_height = scale_size
new_width = int(new_height * aspect_ratio)

img = img.resize((new_width, new_height), interpolation)

# Step 2: Center crop
left = (new_width - size) // 2
top = (new_height - size) // 2
right = left + size
bottom = top + size
img = img.crop((left, top, right, bottom))

# Step 3: Convert the image to a NumPy array and scale pixel values to
# [0, 1]
img_array = np.array(img).astype(np.float32) / 255.0
predictions: list[tuple[str, float]]
metrics: dict[str, float]

# Step 4: Normalize the image
mean_np = np.array(mean, dtype=np.float32).reshape(1, 1, 3)
std_np = np.array(std, dtype=np.float32).reshape(1, 1, 3)
img_array = (img_array - mean_np) / std_np

# Step 5: Change the order of dimensions from (H, W, C) to (C, H, W)
img_array = np.transpose(img_array, (2, 0, 1))
return img_array
def _classify_transform(
max_size: int,
normalize_mean: tuple[float, float, float] = (0.0, 0.0, 0.0),
normalize_std: tuple[float, float, float] = (1.0, 1.0, 1.0),
):
return A.Compose(
[
A.LongestMaxSize(max_size=max_size, p=1.0),
A.PadIfNeeded(min_height=max_size, min_width=max_size, p=1.0),
A.ToRGB(p=1.0),
A.Normalize(mean=normalize_mean, std=normalize_std, p=1.0),
]
)


class ImageClassifier:
Expand All @@ -101,49 +68,58 @@ def __init__(self, model_name: str, label_names: list[str], image_size: int = 22

def predict(
self,
image: Image.Image,
image: np.ndarray,
triton_uri: str,
model_version: Optional[str] = None,
) -> list[tuple[str, float]]:
model_version: str | None = None,
) -> ImageClassificationResult:
"""Run an image classification model on an image.

The model is expected to have been trained with Ultralytics library
(Yolov8).
(any Yolo classification model).

:param image: the input Pillow image
:param image: the input NumPy array image
:param triton_uri: URI of the Triton Inference Server, defaults to
None. If not provided, the default value from settings is used.
:return: the prediction results as a list of tuples (label, confidence)
:param model_version: the version of the model to use, defaults to
None. If not provided, the latest version is used.
:return: the prediction results as an ImageClassificationResult
"""
image_array = self.preprocess(image)

grpc_stub = get_triton_inference_stub(triton_uri)
request = service_pb2.ModelInferRequest()
request.model_name = self.model_name
if model_version:
request.model_version = model_version
add_triton_infer_input_tensor(
request, name="images", data=image_array, datatype="FP32"
)
start_time = time.monotonic()
response = grpc_stub.ModelInfer(request)
latency = time.monotonic() - start_time
logger.debug("Inference time for %s: %s", self.model_name, latency)

start_time = time.monotonic()
result = self.postprocess(response)
latency = time.monotonic() - start_time
logger.debug("Post-processing time for %s: %s", self.model_name, latency)
return result

def preprocess(self, image: Image.Image) -> np.ndarray:
metrics: dict[str, float] = {}

with PerfTimer("preprocess_time", metrics):
image_array = self.preprocess(image)

with PerfTimer("grpc_request_build_time", metrics):
request = service_pb2.ModelInferRequest()
request.model_name = self.model_name
if model_version:
request.model_version = model_version
add_triton_infer_input_tensor(
request, name="images", data=image_array, datatype="FP32"
)

with PerfTimer("triton_inference_time", metrics):
grpc_stub = get_triton_inference_stub(triton_uri)
response = grpc_stub.ModelInfer(request)

with PerfTimer("postprocess_time", metrics):
predictions = self.postprocess(response)

return ImageClassificationResult(predictions=predictions, metrics=metrics)

def preprocess(self, image: np.ndarray) -> np.ndarray:
"""Preprocess an image for object detection.

:param image: the input Pillow image
:param image: the input NumPy array image
:return: the preprocessed image as a NumPy array
"""
image_array = classify_transforms(image, size=self.image_size)
return np.expand_dims(image_array, axis=0)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="The image is already an RGB")
image_array = _classify_transform(max_size=self.image_size)(image=image)[
"image"
]
image_array = np.transpose(image_array, (2, 0, 1))[np.newaxis, :] # HWC to CHW
return image_array

def postprocess(
self, response: service_pb2.ModelInferResponse
Expand Down
118 changes: 86 additions & 32 deletions tests/ml/test_image_classification.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,91 @@
import warnings
from unittest.mock import MagicMock, patch

import numpy as np
from PIL import Image

from openfoodfacts.ml.image_classification import ImageClassifier, classify_transforms
from openfoodfacts.ml.image_classification import (
ImageClassificationResult,
ImageClassifier,
_classify_transform,
)


class TestClassifyTransforms:
class TestClassifyTransform:
def test_rgb_image(self):
img = Image.new("RGB", (300, 300), color="red")
transformed_img = classify_transforms(img)
assert transformed_img.shape == (3, 224, 224)
img = np.array(Image.new("RGB", (300, 300), color="red"))
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="The image is already an RGB")
transformed_img = _classify_transform(max_size=224)(image=img)["image"]
assert transformed_img.shape == (224, 224, 3)
assert transformed_img.dtype == np.float32

def test_non_square_image_aspect_ratio_lt_1(self):
# width=150, height=300

Check warning on line 24 in tests/ml/test_image_classification.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Remove this commented out code.

See more on https://sonarcloud.io/project/issues?id=openfoodfacts_openfoodfacts-python&issues=AZ0B86tfKlOOTc7jeBXK&open=AZ0B86tfKlOOTc7jeBXK&pullRequest=460
img = np.array(Image.new("RGB", (150, 300), color="red"))
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="The image is already an RGB")
transformed_img = _classify_transform(max_size=300)(image=img)["image"]
assert transformed_img.shape == (300, 300, 3)
assert transformed_img.dtype == np.float32
# assert that the green and blue channels are zero
assert np.sum(transformed_img[:, :, 1:3]) == 0.0

Check warning on line 32 in tests/ml/test_image_classification.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Do not perform equality checks with floating point values.

See more on https://sonarcloud.io/project/issues?id=openfoodfacts_openfoodfacts-python&issues=AZ0B86tfKlOOTc7jeBXL&open=AZ0B86tfKlOOTc7jeBXL&pullRequest=460
# image is in HWC
red_channel = transformed_img[:, :, 0]
assert np.all(red_channel[:, :75] == 0.0)

Check warning on line 35 in tests/ml/test_image_classification.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Do not perform equality checks with floating point values.

See more on https://sonarcloud.io/project/issues?id=openfoodfacts_openfoodfacts-python&issues=AZ0B86tfKlOOTc7jeBXM&open=AZ0B86tfKlOOTc7jeBXM&pullRequest=460
assert np.all(red_channel[:, 75:150] == 1.0)

Check warning on line 36 in tests/ml/test_image_classification.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Do not perform equality checks with floating point values.

See more on https://sonarcloud.io/project/issues?id=openfoodfacts_openfoodfacts-python&issues=AZ0B86tfKlOOTc7jeBXN&open=AZ0B86tfKlOOTc7jeBXN&pullRequest=460
assert np.all(red_channel[:, 225:] == 0.0)

Check warning on line 37 in tests/ml/test_image_classification.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Do not perform equality checks with floating point values.

See more on https://sonarcloud.io/project/issues?id=openfoodfacts_openfoodfacts-python&issues=AZ0B86tfKlOOTc7jeBXO&open=AZ0B86tfKlOOTc7jeBXO&pullRequest=460

def test_non_square_image_aspect_ratio_gt_1(self):
# width=600, height=300

Check warning on line 40 in tests/ml/test_image_classification.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Remove this commented out code.

See more on https://sonarcloud.io/project/issues?id=openfoodfacts_openfoodfacts-python&issues=AZ0B86tfKlOOTc7jeBXP&open=AZ0B86tfKlOOTc7jeBXP&pullRequest=460
img = np.array(Image.new("RGB", (600, 300), color="red"))
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="The image is already an RGB")
transformed_img = _classify_transform(max_size=300)(image=img)["image"]
assert transformed_img.shape == (300, 300, 3)
assert transformed_img.dtype == np.float32
# assert that the green and blue channels are zero
assert np.sum(transformed_img[:, :, 1:3]) == 0.0

Check warning on line 48 in tests/ml/test_image_classification.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Do not perform equality checks with floating point values.

See more on https://sonarcloud.io/project/issues?id=openfoodfacts_openfoodfacts-python&issues=AZ0B86tfKlOOTc7jeBXQ&open=AZ0B86tfKlOOTc7jeBXQ&pullRequest=460
# image is in HWC
red_channel = transformed_img[:, :, 0]
assert np.all(red_channel[:75, :] == 0.0)

Check warning on line 51 in tests/ml/test_image_classification.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Do not perform equality checks with floating point values.

See more on https://sonarcloud.io/project/issues?id=openfoodfacts_openfoodfacts-python&issues=AZ0B86tfKlOOTc7jeBXR&open=AZ0B86tfKlOOTc7jeBXR&pullRequest=460
assert np.all(red_channel[75:150, :] == 1.0)

Check warning on line 52 in tests/ml/test_image_classification.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Do not perform equality checks with floating point values.

See more on https://sonarcloud.io/project/issues?id=openfoodfacts_openfoodfacts-python&issues=AZ0B86tfKlOOTc7jeBXS&open=AZ0B86tfKlOOTc7jeBXS&pullRequest=460
assert np.all(red_channel[225:, :] == 0.0)

Check warning on line 53 in tests/ml/test_image_classification.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Do not perform equality checks with floating point values.

See more on https://sonarcloud.io/project/issues?id=openfoodfacts_openfoodfacts-python&issues=AZ0B86tfKlOOTc7jeBXT&open=AZ0B86tfKlOOTc7jeBXT&pullRequest=460

def test_non_rgb_image(self):
img = Image.new("L", (300, 300), color="red")
transformed_img = classify_transforms(img)
assert transformed_img.shape == (3, 224, 224)
img = np.array(Image.new("L", (300, 300), color="red"))
transformed_img = _classify_transform(max_size=224)(image=img)["image"]
assert transformed_img.shape == (224, 224, 3)
assert transformed_img.dtype == np.float32

def test_custom_size(self):
img = Image.new("RGB", (300, 300), color="red")
transformed_img = classify_transforms(img, size=128)
assert transformed_img.shape == (3, 128, 128)
img = np.array(Image.new("RGB", (300, 300), color="red"))
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="The image is already an RGB")
transformed_img = _classify_transform(max_size=128)(image=img)["image"]
assert transformed_img.shape == (128, 128, 3)
assert transformed_img.dtype == np.float32

def test_custom_mean_std(self):
img = Image.new("RGB", (300, 300), color="red")
img = np.array(Image.new("RGB", (300, 300), color="red"))
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
transformed_img = classify_transforms(img, mean=mean, std=std)
assert transformed_img.shape == (3, 224, 224)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="The image is already an RGB")
transformed_img = _classify_transform(max_size=224)(
image=img, normalize_mean=mean, normalize_std=std
)["image"]
assert transformed_img.shape == (224, 224, 3)
assert transformed_img.dtype == np.float32

def test_custom_interpolation(self):
img = Image.new("RGB", (300, 300), color="red")
transformed_img = classify_transforms(
img, interpolation=Image.Resampling.NEAREST
)
assert transformed_img.shape == (3, 224, 224)
assert transformed_img.dtype == np.float32

def test_custom_crop_fraction(self):
img = Image.new("RGB", (300, 300), color="red")
transformed_img = classify_transforms(img, crop_fraction=0.8)
assert transformed_img.shape == (3, 224, 224)
img = np.array(Image.new("RGB", (300, 300), color="red"))
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="The image is already an RGB")
transformed_img = _classify_transform(max_size=224)(
image=img, interpolation=Image.Resampling.NEAREST
)["image"]
assert transformed_img.shape == (224, 224, 3)
assert transformed_img.dtype == np.float32


Expand All @@ -55,13 +96,15 @@

class TestImageClassifier:
def test_preprocess_rgb_image(self):
img = Image.new("RGB", (300, 300), color="red")
img = np.array(Image.new("RGB", (300, 300), color="red"))
classifier = ImageClassifier(
model_name="test_model", label_names=["label1", "label2"]
)
preprocessed_img = classifier.preprocess(img)
assert preprocessed_img.shape == (1, 3, 224, 224)
assert preprocessed_img.dtype == np.float32
assert np.all(preprocessed_img[:, 0, :, :] == 1.0) # red channel

Check warning on line 106 in tests/ml/test_image_classification.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Do not perform equality checks with floating point values.

See more on https://sonarcloud.io/project/issues?id=openfoodfacts_openfoodfacts-python&issues=AZ0B86tfKlOOTc7jeBXU&open=AZ0B86tfKlOOTc7jeBXU&pullRequest=460
assert np.all(preprocessed_img[:, 1:, :, :] == 0.0) # green and blue channels

Check warning on line 107 in tests/ml/test_image_classification.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Do not perform equality checks with floating point values.

See more on https://sonarcloud.io/project/issues?id=openfoodfacts_openfoodfacts-python&issues=AZ0B86tfKlOOTc7jeBXV&open=AZ0B86tfKlOOTc7jeBXV&pullRequest=460

def test_postprocess_single_output(self):
classifier = ImageClassifier(
Expand Down Expand Up @@ -115,7 +158,7 @@
assert str(e) == "expected 1 raw output content, got 2"

def test_predict(self):
img = Image.new("RGB", (300, 300), color="red")
img = np.array(Image.new("RGB", (300, 300), color="red"))
classifier = ImageClassifier(
model_name="test_model", label_names=["label1", "label2"]
)
Expand Down Expand Up @@ -144,11 +187,22 @@
):
result = classifier.predict(img, triton_uri)

assert len(result) == 2
assert result[0][0] == "label1"
assert np.isclose(result[0][1], 0.8)
assert result[1][0] == "label2"
assert np.isclose(result[1][1], 0.2)
assert isinstance(result, ImageClassificationResult)
predictions = result.predictions
assert len(predictions) == 2
assert predictions[0][0] == "label1"
assert np.isclose(predictions[0][1], 0.8)
assert predictions[1][0] == "label2"
assert np.isclose(predictions[1][1], 0.2)

assert isinstance(result.metrics, dict)
assert result.metrics.keys() == {
"preprocess_time",
"grpc_request_build_time",
"triton_inference_time",
"postprocess_time",
}
assert all(isinstance(value, float) for value in result.metrics.values())

classifier.preprocess.assert_called_once_with(img)
grpc_stub.ModelInfer.assert_called_once()
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading