Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ dependencies = [
"httpx>=0.28.1",
"pydantic>=2.10.5",
"pydantic-settings>=2.7.1",
"shapely>=2.0.6",
"colour-science>=0.4.6",
]

[build-system]
Expand Down
31 changes: 31 additions & 0 deletions src/color_correction_asdfghjkl/constant/color_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np

# in BGR format
reference_color_d50 = np.array(
[
[68, 82, 115], # 1. Dark skin
[128, 149, 195], # 2. Light skin
[157, 123, 93], # 3. Blue sky
[65, 108, 91], # 4. Foliage
[175, 129, 130], # 5. Blue flower
[171, 191, 99], # 6. Bluish green
[46, 123, 220], # 7. Orange
[168, 92, 72], # 8. Purplish blue
[97, 84, 194], # 9. Moderate red
[104, 59, 91], # 10. Purple
[62, 189, 161], # 11. Yellow green
[40, 161, 229], # 12. Orange yellow
[147, 63, 42], # 13. Blue
[72, 149, 72], # 14. Green
[57, 50, 175], # 15. Red
[22, 200, 238], # 16. Yellow
[150, 84, 188], # 17. Magenta
[166, 137, 0], # 18. Cyan
[240, 245, 245], # 19. White 9.5
[201, 202, 201], # 20. Neutral 8
[162, 162, 161], # 21. Neutral 6.5
[121, 121, 120], # 22. Neutral 5
[85, 85, 83], # 23. Neutral 3.5
[51, 50, 50], # 24. Black 2
],
)
8 changes: 8 additions & 0 deletions src/color_correction_asdfghjkl/core/correction/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from abc import ABC, abstractmethod

import numpy as np


class BaseComputeCorrection(ABC):
@abstractmethod
def fit(self, image: np.ndarray) -> np.ndarray: ...
41 changes: 41 additions & 0 deletions src/color_correction_asdfghjkl/core/correction/least_squares.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import time

import numpy as np

from color_correction_asdfghjkl.core.correction.base import BaseComputeCorrection


class LeastSquaresRegression(BaseComputeCorrection):
def __init__(self) -> None:
self.model = None

def fit(
self,
input_patches: np.ndarray,
reference_patches: np.ndarray,
) -> np.ndarray:
start_time = time.perf_counter()

self.model = np.linalg.lstsq(
a=input_patches,
b=reference_patches,
rcond=None,
)[0] # get only matrix of coefficients

exc_time = time.perf_counter() - start_time
print(f"Least Squares Regression: {exc_time} seconds")
return self.model

def compute_correction(self, input_image: np.ndarray) -> np.ndarray:
if self.model is None:
raise ValueError("Model is not fitted yet. Please call fit() method first.")

# Reshape
h, w, c = input_image.shape
image = input_image.reshape(-1, 3).astype(np.float32)

image = np.dot(input_image, self.model)

# Clip dan convert kembali ke uint8
corrected_image = np.clip(image, 0, 255).astype(np.uint8).reshape(h, w, c)
return corrected_image
48 changes: 48 additions & 0 deletions src/color_correction_asdfghjkl/schemas/yolov8_det.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import numpy as np
from pydantic import BaseModel

from color_correction_asdfghjkl.utils.geometry_processing import (
extract_intersecting_patches,
generate_expected_patches,
suggest_missing_patch_coordinates,
)
from color_correction_asdfghjkl.utils.image_processing import (
calculate_mean_rgb,
crop_region_with_margin,
)
from color_correction_asdfghjkl.utils.yolo_utils import draw_detections

box_tuple = tuple[int, int, int, int]
Expand Down Expand Up @@ -41,3 +50,42 @@ def draw_detections(self, image: np.ndarray, mask_alpha: float = 0.2) -> np.ndar
class_ids=self.class_ids,
mask_alpha=mask_alpha,
)

def get_list_patches(self, input_image: np.ndarray) -> list[np.ndarray]:
ls_cards, ls_patches = self.get_each_class_box()

if len(ls_cards) == 0:
raise ValueError("No cards detected")

if len(ls_patches) == 0:
raise ValueError("No patches detected")

# Extract card coordinates
card_box = ls_cards[0]
ls_grid_card = generate_expected_patches(card_box)

# get ls grid card
ls_ordered_patch_bbox = extract_intersecting_patches(
ls_patches=ls_patches,
ls_grid_card=ls_grid_card,
)

if None in ls_ordered_patch_bbox:
# Auto fill missing patches ----------------
print("Auto fill missing patch...", ls_ordered_patch_bbox)
d_suggest = suggest_missing_patch_coordinates(ls_ordered_patch_bbox)
for idx, patch in d_suggest.items():
ls_ordered_patch_bbox[idx] = patch
print(f"result len = {len(ls_ordered_patch_bbox)}")

ls_rgb_mean_patch = []
for coord_patch in ls_ordered_patch_bbox:
cropped_patch = crop_region_with_margin(
image=input_image,
coordinates=coord_patch,
margin_ratio=0.2,
)
rgb_mean_patch = calculate_mean_rgb(cropped_patch)
ls_rgb_mean_patch.append(rgb_mean_patch)

return ls_rgb_mean_patch
185 changes: 185 additions & 0 deletions src/color_correction_asdfghjkl/services/color_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from typing import Literal

import colour as cl
import cv2
import numpy as np
from numpy.typing import NDArray

from color_correction_asdfghjkl.constant.color_checker import reference_color_d50
from color_correction_asdfghjkl.core.card_detection.yolov8_det_onnx import (
YOLOv8CardDetector,
)
from color_correction_asdfghjkl.core.correction.least_squares import (
LeastSquaresRegression,
)
from color_correction_asdfghjkl.utils.image_processing import generate_image_patches

ColorPatchType = NDArray[np.uint8]
ImageType = NDArray[np.uint8]


class ColorCorrection:
"""Color correction handler using color card detection and correction models.

Parameters
----------
detection_model : {'yolov8'}
The model to use for color card detection.
correction_model : {'least_squares'}
The model to use for color correction.
reference_color_card : str, optional
Path to the reference color card image.
use_gpu : bool, default=True
Whether to use GPU for card detection.
"""

def __init__(
self,
detection_model: Literal["yolov8"] = "yolov8",
correction_model: Literal["least_squares"] = "least_squares",
reference_color_card: str | None = None,
use_gpu: bool = True,
) -> None:
self.reference_color_card = reference_color_card or reference_color_d50
self.correction_model = self._initialize_correction_model(correction_model)
self.card_detector = self._initialize_detector(detection_model, use_gpu)
self.correction_weights: NDArray | None = None

def _initialize_correction_model(self, model_name: str) -> LeastSquaresRegression:
if model_name == "least_squares":
return LeastSquaresRegression()
raise ValueError(f"Unsupported correction model: {model_name}")

def _initialize_detector(
self,
model_name: str,
use_gpu: bool,
) -> YOLOv8CardDetector:
if model_name == "yolov8":
return YOLOv8CardDetector(use_gpu=use_gpu)
raise ValueError(f"Unsupported detection model: {model_name}")

def extract_color_patches(self, input_image: ImageType) -> list[ColorPatchType]:
"""Extract color patches from input image using card detection.

Parameters
----------
input_image : NDArray
Input image from which to extract color patches.

Returns
-------
list[NDArray]
List of BGR mean values for each detected patch.
Each element is an array of shape (3,) containing [B, G, R] values.
"""
detection_result = self.card_detector.detect(image=input_image.copy())
return detection_result.get_list_patches(input_image=input_image.copy())

def fit(
self,
input_image: ImageType,
reference_image: ImageType | None = None,
) -> tuple[NDArray, list[ColorPatchType], list[ColorPatchType]]:
"""Fit color correction model using input and reference images.

Parameters
----------
input_image : NDArray
Image BGR to be corrected that contains color checker classic 24 patches.
reference_image : NDArray, optional
Image BGR to be reference that contains color checker classic 24 patches.

Returns
-------
Tuple[NDArray, List[NDArray], List[NDArray]]
Correction weights, input patches, and reference patches.
"""
input_patches = self.extract_color_patches(input_image=input_image)
reference_patches = (
reference_color_d50
if reference_image is None
else self.extract_color_patches(reference_image)
)

self.correction_weights = self.correction_model.fit(
input_patches=input_patches,
reference_patches=reference_patches,
)
return self.correction_weights, input_patches, reference_patches

def correct_image(self, input_image: ImageType) -> ImageType:
"""Apply color correction to input image.

Parameters
----------
input_image : NDArray
Image to be color corrected.

Returns
-------
NDArray
Color corrected image.
"""
if self.correction_weights is None:
raise RuntimeError("Model must be fitted before correction")

return self.correction_model.compute_correction(
input_image=input_image.copy(),
)

def calculate_color_difference(
self,
image1: ImageType,
image2: ImageType,
) -> tuple[float, float, float, float]:
"""Calculate color difference metrics between two images.

Parameters
----------
image1, image2 : NDArray
Images to compare in BGR format.

Returns
-------
Tuple[float, float, float, float]
Minimum, maximum, mean, and standard deviation of delta E values.
"""
rgb1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
rgb2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)

lab1 = cl.XYZ_to_Lab(cl.sRGB_to_XYZ(rgb1 / 255))
lab2 = cl.XYZ_to_Lab(cl.sRGB_to_XYZ(rgb2 / 255))

delta_e = cl.difference.delta_E(lab1, lab2, method="CIE 2000")

return (
float(np.min(delta_e)),
float(np.max(delta_e)),
float(np.mean(delta_e)),
float(np.std(delta_e)),
)


if __name__ == "__main__":
import os

image_path = "color_correction_asdfghjkl/asset/images/cc-1.jpg"
image_path = "color_correction_asdfghjkl/asset/images/cc-19.png"
filename = os.path.basename(image_path)
cc = ColorCorrection(detection_model="yolov8", correction_model="least_squares")
input_image = cv2.imread(image_path)
_, ls_input_patches, ls_reference_patches = cc.fit(input_image=input_image)
corrected_image = cc.correct_image(input_image=input_image)

in_img_patch = generate_image_patches(ls_input_patches)
ref_img_patch = generate_image_patches(ls_reference_patches)
cv2.imwrite(f"input_image_patches-{filename}.jpg", in_img_patch)
cv2.imwrite(f"reference_image_patches-{filename}.jpg", ref_img_patch)
cc.calculate_color_difference(in_img_patch, ref_img_patch)

ls_correct_patch = cc.extract_color_patches(input_image=corrected_image)
corrected_img_patch = generate_image_patches(ls_correct_patch)
cv2.imwrite(f"corrected_image_patches-{filename}.jpg", corrected_img_patch)
cc.calculate_color_difference(corrected_img_patch, ref_img_patch)
cv2.imwrite(f"corrected_image-{filename}.jpg", corrected_image)
Loading
Loading