diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index d22799d..2f78fc3 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -23,9 +23,6 @@ jobs: - name: Extract version from Git tag id: get_version run: echo "VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_ENV - - # - name: Install the project - # run: uv sync --all-groups --no-group dev-model - name: Clear directory run: rm -rf dist diff --git a/.gitignore b/.gitignore index 6f4255f..6267f05 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,8 @@ wheels/ .pytest_cache/ .ruff_cache/ -*.pyc \ No newline at end of file +*.pyc + +tmp/ +zzz/ +color_correction_asdfghjkl/asset/images/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 49cff5c..180e211 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ repos: - id: trailing-whitespace # python code formatting - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.9.2 + rev: v0.9.4 hooks: - id: ruff types_or: [python, pyi, jupyter] diff --git a/.vscode/settings.json b/.vscode/settings.json index 3a0710d..366f08c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,8 +1,24 @@ { "ruff.lint.args": [ - "--config=${workspaceFolder}/pyproject.toml" + "--config=${workspaceFolder}/pyproject.toml" ], "ruff.format.args": [ - "--config=${workspaceFolder}/pyproject.toml" - ] - } \ No newline at end of file + "--config=${workspaceFolder}/pyproject.toml" + ], + "python.testing.pytestArgs": [ + "tests -sv" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, + + // Auto activate virtual environment + "python.terminal.activateEnvironment": true, + "python.defaultInterpreterPath": "${workspaceFolder}/.venv/bin/python", + + // Configure terminal to open in workspace + "terminal.integrated.cwd": "${workspaceFolder}", + + // Git settings for pre-commit + "git.postCommitCommand": "push", + "git.hooks.enable": true +} diff --git a/CHANGELOG.md b/CHANGELOG.md index 49aa8ff..61bbe0d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,45 @@ # Changelog +## [v0.0.1b0] - 2025-02-03 + +### πŸ”§ Improvements +- **Color Correction Core** + - Added new correction models: polynomial, linear regression, and affine regression + - Improved patch detection and processing pipeline + - Added support for debug visualization outputs + - Enhanced color patch extraction with better error handling + +### 🎨 Features +- **Reference Colors** + - Added RGB format reference colors alongside BGR + - Improved color patch visualization and comparison tools + - Added support for custom reference images + +### πŸ“ Documentation +- **README Updates** + - Simplified usage documentation with clearer examples + - Added visual explanation of color correction workflow + - Updated installation and usage instructions + +### πŸ› οΈ Development +- **Project Structure** + - Reorganized core modules for better maintainability + - Added new utility modules for image processing + - Updated VSCode settings for better development experience + +### πŸ”¨ Build +- **Dependencies** + - Added scikit-learn for advanced correction models + - Updated ruff to v0.9.4 + - Added pre-commit hooks configuration + +### πŸ§ͺ Testing +- **Test Coverage** + - Added new test cases for image processing utilities + - Improved test organization and structure + + + ## [v0.0.1a2] - 2025-01-27 ### πŸš€ New Features @@ -51,4 +91,4 @@ ### πŸ“ Initial Setup - **Initialize project with Python version, .gitignore, VSCode settings, pre-commit configuration, and pyproject.toml** (71a8c74) -- **Add README.md for Color Correction package documentation** (2b35650) \ No newline at end of file +- **Add README.md for Color Correction package documentation** (2b35650) diff --git a/README.md b/README.md index 1108600..7affa76 100644 --- a/README.md +++ b/README.md @@ -1,38 +1,10 @@ -# 🎨 Color Correction Toolkit +# 🎨 Color Correction > **Note:** The "asdfghjkl" is just a placeholder due to some naming difficulties. This package is designed to perform color correction on images using the Color Checker Classic 24 Patch card. It provides a robust solution for ensuring accurate color representation in your images. -## πŸ› οΈ How It Works - -1. **Input Requirements**: - - An image containing the Color Checker Classic 24 Patch card (referred to as the color checker). - - A reference image of the color checker. - -2. **Process**: - - The package computes a color correction matrix based on the input image and reference image. - - This matrix can then be used to correct the colors of the input image or other images with the same color profile. - - **Workflow**: - ``` - Input Image + Reference Image -> Color Correction Matrix -> Corrected Image - ``` - -3. **Advanced Usage**: - - For multiple images, it’s recommended to preprocess the images using white balance before calculating the color correction matrix. This ensures that all images are in the same color profile, leading to a more accurate correction matrix. - - **Enhanced Workflow**: - ``` - Input Images -> White Balance -> Compute (Input Image with Color Checker, Reference Image) -> Color Correction Matrix -> Corrected Images - ``` - -## πŸ“ˆ Benefits -- **Consistency**: Ensure uniform color correction across multiple images. -- **Accuracy**: Leverage the color correction matrix for precise color adjustments. -- **Flexibility**: Adaptable for various image sets with different color profiles. - ## Installation ```bash @@ -41,24 +13,42 @@ pip install color-correction-asdfghjkl ## Usage ```python -import cv2 +# Step 1: Define the path to the input image +image_path = "asset/images/cc-19.png" -from color_correction_asdfghjkl import ColorCorrection +# Step 2: Load the input image +input_image = cv2.imread(image_path) -cc = ColorCorrection( +# Step 3: Initialize the color correction model with specified parameters +color_corrector = ColorCorrection( detection_model="yolov8", + detection_conf_th=0.25, correction_model="least_squares", - use_gpu=False, + degree=2, # for polynomial correction model + use_gpu=True, +) + +# Step 4: Extract color patches from the input image +color_corrector.set_input_patches(image=input_image, debug=True) +color_corrector.fit() +corrected_image = color_corrector.predict( + input_image=input_image, + debug=True, + debug_output_dir="zzz", ) -input_image = cv2.imread("cc-19.png") -cc.fit(input_image=input_image) -corrected_image = cc.correct_image(input_image=input_image) -cv2.imwrite("corrected_image.png", corrected_image) ``` Sample output: ![Sample Output](assets/sample-output-usage.png) +## πŸ“ˆ Benefits +- **Consistency**: Ensure uniform color correction across multiple images. +- **Accuracy**: Leverage the color correction matrix for precise color adjustments. +- **Flexibility**: Adaptable for various image sets with different color profiles. + +![How it works](assets/color-correction-how-it-works.png) + + ## πŸ“š References - [Color Checker Classic 24 Patch Card](https://www.xrite.com/categories/calibration-profiling/colorchecker-classic) @@ -68,3 +58,4 @@ Sample output: - [Automatic color correction with OpenCV and Python (PyImageSearch)](https://pyimagesearch.com/2021/02/15/automatic-color-correction-with-opencv-and-python/) - [ONNX-YOLOv8-Object-Detection](https://github.com/ibaiGorordo/ONNX-YOLOv8-Object-Detection) - [yolov8-triton](https://github.com/omarabid59/yolov8-triton/tree/main) +- [Streamlined Data Science Development: Organizing, Developing and Documenting Your Code](https://medium.com/henkel-data-and-analytics/streamlined-data-science-development-organizing-developing-and-documenting-your-code-bfd69e3ef4fb) diff --git a/assets/color-correction-how-it-works.png b/assets/color-correction-how-it-works.png new file mode 100644 index 0000000..1993883 Binary files /dev/null and b/assets/color-correction-how-it-works.png differ diff --git a/color_correction_asdfghjkl/__init__.py b/color_correction_asdfghjkl/__init__.py index 4279555..f74d97b 100644 --- a/color_correction_asdfghjkl/__init__.py +++ b/color_correction_asdfghjkl/__init__.py @@ -1,16 +1,18 @@ __version__ = "0.0.1a1" # fmt: off -from .constant.color_checker import reference_color_d50 as REFERENCE_COLOR_D50 # noqa: N812, I001 -from .core.card_detection.yolov8_det_onnx import YOLOv8CardDetector -from .schemas.yolov8_det import DetectionResult as YOLOv8DetectionResult +from .constant.color_checker import reference_color_d50_bgr as REFERENCE_COLOR_D50_BGR # noqa: N812, I001 +from .constant.color_checker import reference_color_d50_rgb as REFERENCE_COLOR_D50_RGB # noqa: N812, I001 +from .core.card_detection.det_yv8_onnx import YOLOv8CardDetector +from .schemas.det_yv8 import DetectionResult as YOLOv8DetectionResult from .services.color_correction import ColorCorrection # fmt: on __all__ = [ "__version__", - "REFERENCE_COLOR_D50", + "REFERENCE_COLOR_D50_BGR", + "REFERENCE_COLOR_D50_RGB", "ColorCorrection", "YOLOv8CardDetector", "YOLOv8DetectionResult", diff --git a/color_correction_asdfghjkl/constant/color_checker.py b/color_correction_asdfghjkl/constant/color_checker.py index c0ef633..eb8969c 100644 --- a/color_correction_asdfghjkl/constant/color_checker.py +++ b/color_correction_asdfghjkl/constant/color_checker.py @@ -1,7 +1,7 @@ import numpy as np # in BGR format -reference_color_d50 = np.array( +reference_color_d50_bgr = np.array( [ [68, 82, 115], # 1. Dark skin [128, 149, 195], # 2. Light skin @@ -29,3 +29,32 @@ [51, 50, 50], # 24. Black 2 ], ) + +reference_color_d50_rgb = np.array( + [ + [115, 82, 68], # 1. Dark skin + [195, 149, 128], # 2. Light skin + [93, 123, 157], # 3. Blue sky + [91, 108, 65], # 4. Foliage + [130, 129, 175], # 5. Blue flower + [99, 191, 171], # 6. Bluish green + [220, 123, 46], # 7. Orange + [72, 92, 168], # 8. Purplish blue + [194, 84, 97], # 9. Moderate red + [91, 59, 104], # 10. Purple + [161, 189, 62], # 11. Yellow green + [229, 161, 40], # 12. Orange yellow + [42, 63, 147], # 13. Blue + [72, 149, 72], # 14. Green + [175, 50, 57], # 15. Red + [238, 200, 22], # 16. Yellow + [188, 84, 150], # 17. Magenta + [0, 137, 166], # 18. Cyan + [245, 245, 240], # 19. White 9.5 + [201, 202, 201], # 20. Neutral 8 + [161, 162, 162], # 21. Neutral 6.5 + [120, 121, 121], # 22. Neutral 5 + [83, 85, 85], # 23. Neutral 3.5 + [50, 50, 51], # 24. Black 2 + ], +) diff --git a/color_correction_asdfghjkl/core/card_detection/base.py b/color_correction_asdfghjkl/core/card_detection/base.py index 4240e6b..a050e85 100644 --- a/color_correction_asdfghjkl/core/card_detection/base.py +++ b/color_correction_asdfghjkl/core/card_detection/base.py @@ -2,7 +2,7 @@ import numpy as np -from color_correction_asdfghjkl.schemas.yolov8_det import DetectionResult +from color_correction_asdfghjkl.schemas.det_yv8 import DetectionResult class BaseCardDetector(ABC): diff --git a/color_correction_asdfghjkl/core/card_detection/yolov8_det_onnx.py b/color_correction_asdfghjkl/core/card_detection/det_yv8_onnx.py similarity index 97% rename from color_correction_asdfghjkl/core/card_detection/yolov8_det_onnx.py rename to color_correction_asdfghjkl/core/card_detection/det_yv8_onnx.py index 15748d4..ebc5cf7 100644 --- a/color_correction_asdfghjkl/core/card_detection/yolov8_det_onnx.py +++ b/color_correction_asdfghjkl/core/card_detection/det_yv8_onnx.py @@ -5,7 +5,7 @@ import onnxruntime from color_correction_asdfghjkl.core.card_detection.base import BaseCardDetector -from color_correction_asdfghjkl.schemas.yolov8_det import DetectionResult +from color_correction_asdfghjkl.schemas.det_yv8 import DetectionResult from color_correction_asdfghjkl.utils.downloader import downloader_model_yolov8 from color_correction_asdfghjkl.utils.yolo_utils import ( multiclass_nms, @@ -34,7 +34,6 @@ def __init__( self.iou_threshold = iou_th self.use_gpu = use_gpu if path is None: - print("Auto downloading YOLOv8 model...") path = downloader_model_yolov8(use_gpu) self.__initialize_model(path) @@ -53,7 +52,7 @@ def detect(self, image: np.ndarray) -> DetectionResult: A dataclass containing detected bounding boxes, confidence scores, and class IDs. """ - input_tensor = self.__prepare_input(image) + input_tensor = self.__prepare_input(image.copy()) outputs = self.__inference(input_tensor) boxes, scores, class_ids = self.__process_output(outputs) diff --git a/color_correction_asdfghjkl/core/correction/__init__.py b/color_correction_asdfghjkl/core/correction/__init__.py new file mode 100644 index 0000000..cf1004a --- /dev/null +++ b/color_correction_asdfghjkl/core/correction/__init__.py @@ -0,0 +1,17 @@ +from color_correction_asdfghjkl.core.correction._factory import ( + CorrectionModelFactory, +) +from color_correction_asdfghjkl.core.correction.affine_reg import AffineReg +from color_correction_asdfghjkl.core.correction.least_squares import ( + LeastSquaresRegression, +) +from color_correction_asdfghjkl.core.correction.linear_reg import LinearReg +from color_correction_asdfghjkl.core.correction.polynomial import Polynomial + +__all__ = [ + "CorrectionModelFactory", + "LeastSquaresRegression", + "Polynomial", + "LinearReg", + "AffineReg", +] diff --git a/color_correction_asdfghjkl/core/correction/_factory.py b/color_correction_asdfghjkl/core/correction/_factory.py new file mode 100644 index 0000000..c7d6290 --- /dev/null +++ b/color_correction_asdfghjkl/core/correction/_factory.py @@ -0,0 +1,18 @@ +from color_correction_asdfghjkl.core.correction.affine_reg import AffineReg +from color_correction_asdfghjkl.core.correction.least_squares import ( + LeastSquaresRegression, +) +from color_correction_asdfghjkl.core.correction.linear_reg import LinearReg +from color_correction_asdfghjkl.core.correction.polynomial import Polynomial + + +class CorrectionModelFactory: + @staticmethod + def create(model_name: str, **kwargs: dict) -> ...: + model_registry = { + "least_squares": LeastSquaresRegression(), + "polynomial": Polynomial(**kwargs), + "linear_reg": LinearReg(), + "affine_reg": AffineReg(), + } + return model_registry.get(model_name) diff --git a/color_correction_asdfghjkl/core/correction/affine_reg.py b/color_correction_asdfghjkl/core/correction/affine_reg.py new file mode 100644 index 0000000..e6f7f32 --- /dev/null +++ b/color_correction_asdfghjkl/core/correction/affine_reg.py @@ -0,0 +1,41 @@ +import time + +import numpy as np +from sklearn.linear_model import LinearRegression + +from color_correction_asdfghjkl.core.correction.base import BaseComputeCorrection +from color_correction_asdfghjkl.utils.correction import ( + postprocessing_compute, + preprocessing_compute, +) + + +class AffineReg(BaseComputeCorrection): + def __init__(self) -> None: + self.model = None + + def fit( + self, + x_patches: np.ndarray, # input patches + y_patches: np.ndarray, # reference patches + ) -> np.ndarray: + start_time = time.perf_counter() + x_patches = np.array(x_patches) + print("x_patches.shape", x_patches.shape) + x_patches = np.hstack([x_patches, np.ones((x_patches.shape[0], 1))]) + self.model = LinearRegression(fit_intercept=False).fit(x_patches, y_patches) + + exc_time = time.perf_counter() - start_time + print(f"{self.__class__.__name__} Fit: {exc_time} seconds") + return self.model + + def compute_correction(self, input_image: np.ndarray) -> np.ndarray: + if self.model is None: + raise ValueError("Model is not fitted yet. Please call fit() method first.") + + org_input_shape = input_image.shape + input_image = preprocessing_compute(input_image) + input_image = np.hstack([input_image, np.ones((input_image.shape[0], 1))]) + image = self.model.predict(input_image) + corrected_image = postprocessing_compute(org_input_shape, image) + return corrected_image diff --git a/color_correction_asdfghjkl/core/correction/least_squares.py b/color_correction_asdfghjkl/core/correction/least_squares.py index 6dbdf03..1375d8e 100644 --- a/color_correction_asdfghjkl/core/correction/least_squares.py +++ b/color_correction_asdfghjkl/core/correction/least_squares.py @@ -3,6 +3,10 @@ import numpy as np from color_correction_asdfghjkl.core.correction.base import BaseComputeCorrection +from color_correction_asdfghjkl.utils.correction import ( + postprocessing_compute, + preprocessing_compute, +) class LeastSquaresRegression(BaseComputeCorrection): @@ -11,31 +15,28 @@ def __init__(self) -> None: def fit( self, - input_patches: np.ndarray, - reference_patches: np.ndarray, + x_patches: np.ndarray, # input patches + y_patches: np.ndarray, # reference patches ) -> np.ndarray: start_time = time.perf_counter() self.model = np.linalg.lstsq( - a=input_patches, - b=reference_patches, + a=x_patches, + b=y_patches, rcond=None, )[0] # get only matrix of coefficients exc_time = time.perf_counter() - start_time - print(f"Least Squares Regression: {exc_time} seconds") + print(f"{self.__class__.__name__} Fit: {exc_time} seconds") return self.model def compute_correction(self, input_image: np.ndarray) -> np.ndarray: if self.model is None: raise ValueError("Model is not fitted yet. Please call fit() method first.") - # Reshape - h, w, c = input_image.shape - image = input_image.reshape(-1, 3).astype(np.float32) - + # Input adalah array (N,3) dari nilai warna patches + org_input_shape = input_image.shape + input_image = preprocessing_compute(input_image) image = np.dot(input_image, self.model) - - # Clip dan convert kembali ke uint8 - corrected_image = np.clip(image, 0, 255).astype(np.uint8).reshape(h, w, c) + corrected_image = postprocessing_compute(org_input_shape, image) return corrected_image diff --git a/color_correction_asdfghjkl/core/correction/linear_reg.py b/color_correction_asdfghjkl/core/correction/linear_reg.py new file mode 100644 index 0000000..37db8bc --- /dev/null +++ b/color_correction_asdfghjkl/core/correction/linear_reg.py @@ -0,0 +1,38 @@ +import time + +import numpy as np +from sklearn.linear_model import LinearRegression + +from color_correction_asdfghjkl.core.correction.base import BaseComputeCorrection +from color_correction_asdfghjkl.utils.correction import ( + postprocessing_compute, + preprocessing_compute, +) + + +class LinearReg(BaseComputeCorrection): + def __init__(self) -> None: + self.model = None + + def fit( + self, + x_patches: np.ndarray, # input patches + y_patches: np.ndarray, # reference patches + ) -> np.ndarray: + start_time = time.perf_counter() + + self.model = LinearRegression(fit_intercept=False).fit(x_patches, y_patches) + + exc_time = time.perf_counter() - start_time + print(f"{self.__class__.__name__} Fit: {exc_time} seconds") + return self.model + + def compute_correction(self, input_image: np.ndarray) -> np.ndarray: + if self.model is None: + raise ValueError("Model is not fitted yet. Please call fit() method first.") + + org_input_shape = input_image.shape + input_image = preprocessing_compute(input_image) + image = self.model.predict(input_image) + corrected_image = postprocessing_compute(org_input_shape, image) + return corrected_image diff --git a/color_correction_asdfghjkl/core/correction/polynomial.py b/color_correction_asdfghjkl/core/correction/polynomial.py new file mode 100644 index 0000000..50080bc --- /dev/null +++ b/color_correction_asdfghjkl/core/correction/polynomial.py @@ -0,0 +1,46 @@ +import time + +import numpy as np +from sklearn.linear_model import LinearRegression +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import PolynomialFeatures + +from color_correction_asdfghjkl.core.correction.base import BaseComputeCorrection +from color_correction_asdfghjkl.utils.correction import ( + postprocessing_compute, + preprocessing_compute, +) + + +class Polynomial(BaseComputeCorrection): + def __init__(self, **kwargs: dict) -> None: + self.model = None + self.degree = kwargs.get("degree", 2) + + def fit( + self, + x_patches: np.ndarray, # input patches + y_patches: np.ndarray, # reference patches + **kwargs: dict, + ) -> np.ndarray: + start_time = time.perf_counter() + + degree = kwargs.get("degree", self.degree) + self.model = make_pipeline( + PolynomialFeatures(degree), + LinearRegression(), + ).fit(x_patches, y_patches) + + exc_time = time.perf_counter() - start_time + print(f"{self.__class__.__name__} Fit: {exc_time} seconds") + return self.model + + def compute_correction(self, input_image: np.ndarray) -> np.ndarray: + if self.model is None: + raise ValueError("Model is not fitted yet. Please call fit() method first.") + + org_input_shape = input_image.shape + input_image = preprocessing_compute(input_image) + image = self.model.predict(input_image) + corrected_image = postprocessing_compute(org_input_shape, image) + return corrected_image diff --git a/color_correction_asdfghjkl/processor/det_yv8.py b/color_correction_asdfghjkl/processor/det_yv8.py new file mode 100644 index 0000000..bfd0bda --- /dev/null +++ b/color_correction_asdfghjkl/processor/det_yv8.py @@ -0,0 +1,304 @@ +import cv2 +import numpy as np + +from color_correction_asdfghjkl.schemas.det_yv8 import DetectionResult +from color_correction_asdfghjkl.utils.geometry_processing import ( + extract_intersecting_patches, + generate_expected_patches, + suggest_missing_patch_coordinates, +) +from color_correction_asdfghjkl.utils.image_processing import ( + calc_mean_color_patch, + crop_region_with_margin, +) + +# Type aliases for better readability +BoundingBox = tuple[int, int, int, int] +RGB = tuple[float, float, float] +BGR = tuple[float, float, float] + + +class DetectionProcessor: + """ + A class to process color calibration card detections and extract color patches. + + This class handles the detection and processing of color calibration cards and their + individual color patches, including visualization and RGB value extraction. + """ + + @staticmethod + def get_each_class_box( + prediction: DetectionResult, + ) -> tuple[list[BoundingBox], list[BoundingBox]]: + """ + Separate detection boxes by class (cards and patches). + + Parameters + ---------- + prediction : DetectionResult + Detection results containing boxes and class IDs + + Returns + ------- + Tuple[List[BoundingBox], List[BoundingBox]] + Two lists containing card boxes and patch boxes respectively + """ + ls_cards = [ + box + for box, class_id in zip( + prediction.boxes, + prediction.class_ids, + strict=False, + ) + if class_id == 1 + ] + ls_patches = [ + box + for box, class_id in zip( + prediction.boxes, + prediction.class_ids, + strict=False, + ) + if class_id == 0 + ] + return ls_cards, ls_patches + + @staticmethod + def print_summary(prediction: DetectionResult) -> None: + """ + Print a summary of detected objects. + + Parameters + ---------- + prediction : DetectionResult + Detection results to summarize + """ + ls_cards, ls_patches = DetectionProcessor.get_each_class_box(prediction) + print(f"Number of cards detected: {len(ls_cards)}") + print(f"Number of patches detected: {len(ls_patches)}") + + @staticmethod + def process_patches( + input_image: np.ndarray, + ordered_patches: list[tuple[BoundingBox, tuple[int, int]] | None], + ) -> tuple[list[RGB], np.ndarray]: + """ + Process detected patches to extract RGB values and create a visualization. + + Parameters + ---------- + input_image : np.ndarray + Input image containing the patches + ordered_patches : List[Optional[BoundingBox]] + List of ordered patch coordinates + + Returns + ------- + Tuple[List[RGB], np.ndarray] + List of RGB values and visualization image + """ + patch_size = (50, 50, 1) + ls_bgr_mean_patch = [] + ls_horizontal_patch = [] + ls_vertical_patch = [] + + for idx, coord_patch in enumerate(ordered_patches, start=1): + if coord_patch is None: + continue + + bbox_patch, _ = coord_patch + + # Extract and process each patch + cropped_patch = crop_region_with_margin( + image=input_image, + coordinates=bbox_patch, + margin_ratio=0.2, + ) + bgr_mean_patch = calc_mean_color_patch(cropped_patch) + ls_bgr_mean_patch.append(bgr_mean_patch) + + # Build visualization + patch_viz = np.tile(bgr_mean_patch, patch_size) + ls_horizontal_patch.append(patch_viz) + if idx % 6 == 0: + ls_vertical_patch.append(np.hstack(ls_horizontal_patch)) + ls_horizontal_patch = [] + + patches_image = np.vstack(ls_vertical_patch) + return ls_bgr_mean_patch, patches_image + + @staticmethod + def extract_color_patches( + input_image: np.ndarray, + prediction: DetectionResult, + draw_processed_image: bool = False, + ) -> tuple[list[BGR], np.ndarray, np.ndarray | None]: + """ + Extract and process color patches from the detected calibration card. + + Parameters + ---------- + input_image : np.ndarray + Input image containing the color calibration card + prediction : DetectionResult + Detection results from YOLOv8 + draw_processed_image : bool, optional + Whether to create a visualization image, by default False + + Returns + ------- + Tuple[List[BGR], np.ndarray, Optional[np.ndarray]] + BGR values, patch visualization, and optional detection visualization + + Raises + ------ + ValueError + If no cards or patches are detected + """ + ls_cards, ls_patches = DetectionProcessor.get_each_class_box(prediction) + + if not ls_cards: + raise ValueError("No cards detected") + if not ls_patches: + raise ValueError("No patches detected") + + # Generate expected patch grid + card_box = ls_cards[0] + ls_grid_card = generate_expected_patches(card_box) + + # Match detected patches with grid + ls_ordered_patch_bbox = extract_intersecting_patches( + ls_patches=ls_patches, + ls_grid_card=ls_grid_card, + ) + + # Handle missing patches + d_suggest = None + if None in ls_ordered_patch_bbox: + print("Auto filling missing patches...") + ls_ordered_bbox_only = [ + patch[0] if patch is not None else None + for patch in ls_ordered_patch_bbox + ] + d_suggest = suggest_missing_patch_coordinates(ls_ordered_bbox_only) + for idx, patch in d_suggest.items(): + cxpatch = (patch[0] + patch[2]) // 2 + cypatch = (patch[1] + patch[3]) // 2 + ls_ordered_patch_bbox[idx] = (patch, (cxpatch, cypatch)) + + # Process patches and create visualizations + ls_bgr_mean_patch, grid_patch_img = DetectionProcessor.process_patches( + input_image=input_image, + ordered_patches=ls_ordered_patch_bbox, + ) + + detection_viz = None + if draw_processed_image: + detection_viz = DetectionProcessor.draw_preprocess( + image=input_image, + expected_boxes=ls_grid_card, + prediction=prediction, + ls_ordered_patch_bbox=ls_ordered_patch_bbox, + suggested_patches=d_suggest, + ) + + return ls_bgr_mean_patch, grid_patch_img, detection_viz + + @staticmethod + def draw_preprocess( + image: np.ndarray, + expected_boxes: list[BoundingBox], + prediction: DetectionResult, + ls_ordered_patch_bbox: list[BoundingBox | None], + suggested_patches: dict[int, BoundingBox] | None = None, + ) -> np.ndarray: + """ + Draw detection visualizations on the image. + + Parameters + ---------- + image : np.ndarray + Input image to draw on + boxes : List[BoundingBox] + List of bounding boxes to draw + patch_indices : Optional[List[int]] + Indices to label the patches + suggested_patches : Optional[Dict[int, BoundingBox]] + Additional suggested patch locations to draw + + Returns + ------- + np.ndarray + Image with visualizations + """ + color_green = (0, 255, 0) + color_cyan = (255, 255, 10) + color_violet = (255, 0, 255) + color_red = (0, 0, 255) + color_blue = (255, 0, 0) + + result_image = image.copy() + + # Draw all expected boxes + for idx_b, box in enumerate(expected_boxes): + cv2.rectangle( + img=result_image, + pt1=(box[0], box[1]), + pt2=(box[2], box[3]), + color=color_green, + thickness=2, + ) + + # draw connection lines between expected and intersecting patches + patch = ls_ordered_patch_bbox[idx_b] + if patch is None: + continue + cx, cy = patch[1] + crefx, crefy = (box[0] + box[2]) // 2, (box[1] + box[3]) // 2 + cv2.line( + img=result_image, + pt1=(cx, cy), + pt2=(crefx, crefy), + color=color_blue, + thickness=1, + ) + + # draw all predicted boxes + for pbox, pids, pscore in zip( + prediction.boxes, + prediction.class_ids, + prediction.scores, + strict=False, + ): + if pids == 1: + continue + cv2.rectangle( + img=result_image, + pt1=(pbox[0], pbox[1]), + pt2=(pbox[2], pbox[3]), + color=color_cyan, + thickness=2, + ) + cv2.putText( + img=result_image, + text=f"{pids} {pscore:.2f}", + org=(pbox[0] + 3, pbox[1] + 12), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=0.4, + color=color_red, + thickness=1, + lineType=cv2.LINE_AA, + ) + + # Draw suggested patches if provided + if suggested_patches: + for box in suggested_patches.values(): + cv2.rectangle( + img=result_image, + pt1=(box[0], box[1]), + pt2=(box[2], box[3]), + color=color_violet, + thickness=2, + ) + + return result_image diff --git a/color_correction_asdfghjkl/schemas/det_yv8.py b/color_correction_asdfghjkl/schemas/det_yv8.py new file mode 100644 index 0000000..f6e8724 --- /dev/null +++ b/color_correction_asdfghjkl/schemas/det_yv8.py @@ -0,0 +1,19 @@ +import numpy as np +from pydantic import BaseModel + +from color_correction_asdfghjkl.utils.yolo_utils import draw_detections + +BoundingBox = tuple[int, int, int, int] + + +class DetectionResult(BaseModel): + boxes: list[BoundingBox] + scores: list[float] + class_ids: list[int] + + def draw_detections( + self, + image: np.ndarray, + ) -> np.ndarray: + """Draw detection boxes on image.""" + return draw_detections(image, self.boxes, self.scores, self.class_ids) diff --git a/color_correction_asdfghjkl/schemas/yolov8_det.py b/color_correction_asdfghjkl/schemas/yolov8_det.py deleted file mode 100644 index 75c7bca..0000000 --- a/color_correction_asdfghjkl/schemas/yolov8_det.py +++ /dev/null @@ -1,91 +0,0 @@ -import numpy as np -from pydantic import BaseModel - -from color_correction_asdfghjkl.utils.geometry_processing import ( - extract_intersecting_patches, - generate_expected_patches, - suggest_missing_patch_coordinates, -) -from color_correction_asdfghjkl.utils.image_processing import ( - calculate_mean_rgb, - crop_region_with_margin, -) -from color_correction_asdfghjkl.utils.yolo_utils import draw_detections - -box_tuple = tuple[int, int, int, int] - - -class DetectionResult(BaseModel): - boxes: list[box_tuple] - scores: list[float] - class_ids: list[int] - - def get_each_class_box(self) -> tuple[list[box_tuple], list[box_tuple]]: - """ - Return - ------ - Tuple[list[box_tuple], list[box_tuple]] - A tuple of two lists, where the first list contains the bounding boxes - of the cards and the second list contains the bounding boxes of the patches. - """ - ls_cards = [] - ls_patches = [] - for box, class_id in zip(self.boxes, self.class_ids, strict=False): - if class_id == 0: - ls_patches.append(box) - if class_id == 1: - ls_cards.append(box) - return ls_cards, ls_patches - - def print_summary(self) -> None: - ls_cards, ls_patches = self.get_each_class_box() - print(f"Number of cards detected: {len(ls_cards)}") - print(f"Number of patches detected: {len(ls_patches)}") - - def draw_detections(self, image: np.ndarray, mask_alpha: float = 0.2) -> np.ndarray: - return draw_detections( - image=image, - boxes=self.boxes, - scores=self.scores, - class_ids=self.class_ids, - mask_alpha=mask_alpha, - ) - - def get_list_patches(self, input_image: np.ndarray) -> list[np.ndarray]: - ls_cards, ls_patches = self.get_each_class_box() - - if len(ls_cards) == 0: - raise ValueError("No cards detected") - - if len(ls_patches) == 0: - raise ValueError("No patches detected") - - # Extract card coordinates - card_box = ls_cards[0] - ls_grid_card = generate_expected_patches(card_box) - - # get ls grid card - ls_ordered_patch_bbox = extract_intersecting_patches( - ls_patches=ls_patches, - ls_grid_card=ls_grid_card, - ) - - if None in ls_ordered_patch_bbox: - # Auto fill missing patches ---------------- - print("Auto fill missing patch...", ls_ordered_patch_bbox) - d_suggest = suggest_missing_patch_coordinates(ls_ordered_patch_bbox) - for idx, patch in d_suggest.items(): - ls_ordered_patch_bbox[idx] = patch - print(f"result len = {len(ls_ordered_patch_bbox)}") - - ls_rgb_mean_patch = [] - for coord_patch in ls_ordered_patch_bbox: - cropped_patch = crop_region_with_margin( - image=input_image, - coordinates=coord_patch, - margin_ratio=0.2, - ) - rgb_mean_patch = calculate_mean_rgb(cropped_patch) - ls_rgb_mean_patch.append(rgb_mean_patch) - - return ls_rgb_mean_patch diff --git a/color_correction_asdfghjkl/services/color_correction.py b/color_correction_asdfghjkl/services/color_correction.py index cf056c2..3a12c2d 100644 --- a/color_correction_asdfghjkl/services/color_correction.py +++ b/color_correction_asdfghjkl/services/color_correction.py @@ -1,3 +1,4 @@ +import os from typing import Literal import colour as cl @@ -5,82 +6,272 @@ import numpy as np from numpy.typing import NDArray -from color_correction_asdfghjkl.constant.color_checker import reference_color_d50 -from color_correction_asdfghjkl.core.card_detection.yolov8_det_onnx import ( +from color_correction_asdfghjkl.constant.color_checker import reference_color_d50_bgr +from color_correction_asdfghjkl.core.card_detection.det_yv8_onnx import ( YOLOv8CardDetector, ) -from color_correction_asdfghjkl.core.correction.least_squares import ( - LeastSquaresRegression, +from color_correction_asdfghjkl.core.correction import CorrectionModelFactory +from color_correction_asdfghjkl.processor.det_yv8 import DetectionProcessor +from color_correction_asdfghjkl.utils.image_patch import ( + create_patch_tiled_image, + visualize_patch_comparison, +) +from color_correction_asdfghjkl.utils.visualization_utils import ( + create_image_grid_visualization, ) -from color_correction_asdfghjkl.utils.image_processing import generate_image_patches ColorPatchType = NDArray[np.uint8] ImageType = NDArray[np.uint8] +LiteralModelCorrection = Literal[ + "least_squares", + "polynomial", + "linear_reg", + "affine_reg", +] class ColorCorrection: """Color correction handler using color card detection and correction models. + This class handles the complete workflow of color correction, including: + - Color card detection in images + - Color patch extraction + - Color correction model training + - Image correction application + Parameters ---------- detection_model : {'yolov8'} The model to use for color card detection. - correction_model : {'least_squares'} + detection_conf_th : float, optional + Confidence threshold for card detection. + correction_model : {'least_squares', 'polynomial', 'linear_reg', 'affine_reg'} The model to use for color correction. - reference_color_card : str, optional - Path to the reference color card image. + reference_image : NDArray[np.uint8] | None, optional + Reference image containing color checker card. + If None, uses standard D50 values. use_gpu : bool, default=True Whether to use GPU for card detection. + **kwargs : dict + Additional parameters for the correction model. + + Attributes + ---------- + reference_patches : List[ColorPatchType] | None + Extracted color patches from reference image. + reference_grid_image : ImageType | None + Visualization of reference color patches in grid format. + reference_debug_image : ImageType | None + Debug visualization of reference image preprocessing. """ def __init__( self, detection_model: Literal["yolov8"] = "yolov8", - correction_model: Literal["least_squares"] = "least_squares", - reference_color_card: str | None = None, + detection_conf_th: float = 0.25, + correction_model: LiteralModelCorrection = "least_squares", + reference_image: ImageType | None = None, use_gpu: bool = True, + **kwargs: dict, ) -> None: - self.reference_color_card = reference_color_card or reference_color_d50 - self.correction_model = self._initialize_correction_model(correction_model) - self.card_detector = self._initialize_detector(detection_model, use_gpu) - self.correction_weights: NDArray | None = None + # Initialize reference image attributes + self.reference_patches = None + self.reference_grid_image = None + self.reference_debug_image = None + + # Initialize input image attributes + self.input_patches = None + self.input_grid_image = None + self.input_debug_image = None + + # Initialize model attributes + self.trained_model = None + self.correction_model = CorrectionModelFactory.create( + model_name=correction_model, + **kwargs, + ) + self.card_detector = self._create_detector( + model_name=detection_model, + conf_th=detection_conf_th, + use_gpu=use_gpu, + ) - def _initialize_correction_model(self, model_name: str) -> LeastSquaresRegression: - if model_name == "least_squares": - return LeastSquaresRegression() - raise ValueError(f"Unsupported correction model: {model_name}") + # Set reference patches + self.set_reference_patches(image=reference_image) - def _initialize_detector( + def _create_detector( self, model_name: str, - use_gpu: bool, + conf_th: float = 0.25, + use_gpu: bool = False, ) -> YOLOv8CardDetector: - if model_name == "yolov8": - return YOLOv8CardDetector(use_gpu=use_gpu) - raise ValueError(f"Unsupported detection model: {model_name}") + """Create a card detector instance. + + Parameters + ---------- + model_name : str + Name of the detector model to create. + conf_th : float, optional + Confidence threshold for card detection. Default is 0.25. + use_gpu : bool, optional + Whether to use GPU for detection. Default is False. - def extract_color_patches(self, input_image: ImageType) -> list[ColorPatchType]: - """Extract color patches from input image using card detection. + Returns + ------- + YOLOv8CardDetector + Initialized detector instance. + + Raises + ------ + ValueError + If the model name is not supported. + """ + if model_name != "yolov8": + raise ValueError(f"Unsupported detection model: {model_name}") + return YOLOv8CardDetector(use_gpu=use_gpu, conf_th=conf_th) + + def _extract_color_patches( + self, + image: ImageType, + debug: bool = False, + ) -> tuple[list[ColorPatchType], ImageType, ImageType | None]: + """Extract color patches from an image using card detection. Parameters ---------- - input_image : NDArray - Input image from which to extract color patches. + image : ImageType + Input image in BGR format. + debug : bool, optional + Whether to generate debug visualizations. Returns ------- - list[NDArray] - List of BGR mean values for each detected patch. - Each element is an array of shape (3,) containing [B, G, R] values. + tuple[list[ColorPatchType], ImageType, ImageType | None] + - List of BGR mean values for each detected patch + - Grid visualization of detected patches + - Debug visualization (if debug=True) """ - detection_result = self.card_detector.detect(image=input_image.copy()) - return detection_result.get_list_patches(input_image=input_image.copy()) + prediction = self.card_detector.detect(image=image) + ls_bgr_mean_patch, grid_patch_img, debug_detection_viz = ( + DetectionProcessor.extract_color_patches( + input_image=image, + prediction=prediction, + draw_processed_image=debug, + ) + ) + return ls_bgr_mean_patch, grid_patch_img, debug_detection_viz - def fit( + def _save_debug_output( self, input_image: ImageType, - reference_image: ImageType | None = None, - ) -> tuple[NDArray, list[ColorPatchType], list[ColorPatchType]]: + corrected_image: ImageType, + output_directory: str, + ) -> None: + """Save debug visualizations to disk. + + Parameters + ---------- + input_image : ImageType + The input image. + corrected_image : ImageType + The color-corrected image. + output_directory : str + Directory to save debug outputs. + """ + predicted_patches = self.correction_model.compute_correction( + input_image=np.array(self.input_patches), + ) + predicted_grid = create_patch_tiled_image(predicted_patches) + + before_comparison = visualize_patch_comparison( + ls_mean_in=self.input_patches, + ls_mean_ref=self.reference_patches, + ) + after_comparison = visualize_patch_comparison( + ls_mean_in=predicted_patches, + ls_mean_ref=self.reference_patches, + ) + + # Create output directories + run_dir = self._create_debug_directory(output_directory) + + # Prepare debug image grid + image_collection = [ + ("Input Image", input_image), + ("Corrected Image", corrected_image), + ("Debug Preprocess", self.input_debug_image), + ("Reference vs Input", before_comparison), + ("Reference vs Corrected", after_comparison), + ("[Free Space]", None), + ("Patch Input", self.input_grid_image), + ("Patch Corrected", predicted_grid), + ("Patch Reference", self.reference_grid_image), + ] + + # Save debug grid + save_path = os.path.join(run_dir, "debug.jpg") + create_image_grid_visualization( + images=image_collection, + grid_size=((len(image_collection) // 3) + 1, 3), + figsize=(15, ((len(image_collection) // 3) + 1) * 4), + save_path=save_path, + ) + print(f"Debug output saved to: {save_path}") + + def _create_debug_directory(self, base_dir: str) -> str: + """Create and return a unique debug output directory. + + Parameters + ---------- + base_dir : str + Base directory for debug outputs. + + Returns + ------- + str + Path to the created directory. + """ + os.makedirs(base_dir, exist_ok=True) + run_number = len(os.listdir(base_dir)) + 1 + run_dir = os.path.join(base_dir, f"{run_number}-{self.model_name}") + os.makedirs(run_dir, exist_ok=True) + return run_dir + + @property + def model_name(self) -> str: + return self.correction_model.__class__.__name__ + + @property + def img_grid_patches_ref(self) -> np.ndarray: + return create_patch_tiled_image(self.reference_color_card) + + def set_reference_patches( + self, + image: np.ndarray | None, + debug: bool = False, + ) -> None: + if image is None: + self.reference_patches = reference_color_d50_bgr + self.reference_grid_image = create_patch_tiled_image(self.reference_patches) + else: + ( + self.reference_patches, + self.reference_grid_image, + self.reference_debug_image, + ) = self._extract_color_patches(image=image, debug=debug) + + def set_input_patches(self, image: np.ndarray, debug: bool = False) -> None: + self.input_patches = None + self.input_grid_image = None + self.input_debug_image = None + + ( + self.input_patches, + self.input_grid_image, + self.input_debug_image, + ) = self._extract_color_patches(image=image, debug=debug) + + def fit(self) -> tuple[NDArray, list[ColorPatchType], list[ColorPatchType]]: """Fit color correction model using input and reference images. Parameters @@ -95,40 +286,63 @@ def fit( Tuple[NDArray, List[NDArray], List[NDArray]] Correction weights, input patches, and reference patches. """ - input_patches = self.extract_color_patches(input_image=input_image) - reference_patches = ( - reference_color_d50 - if reference_image is None - else self.extract_color_patches(reference_image) - ) + if self.reference_patches is None: + raise RuntimeError("Reference patches must be set before fitting model") + + if self.input_patches is None: + raise RuntimeError("Input patches must be set before fitting model") - self.correction_weights = self.correction_model.fit( - input_patches=input_patches, - reference_patches=reference_patches, + self.trained_model = self.correction_model.fit( + x_patches=self.input_patches, + y_patches=self.reference_patches, ) - return self.correction_weights, input_patches, reference_patches - def correct_image(self, input_image: ImageType) -> ImageType: + return self.trained_model + + def predict( + self, + input_image: ImageType, + debug: bool = False, + debug_output_dir: str = "output-debug", + ) -> ImageType: """Apply color correction to input image. Parameters ---------- - input_image : NDArray + input_image : ImageType Image to be color corrected. + debug : bool, optional + Whether to save debug visualizations. + debug_output_dir : str, optional + Directory to save debug outputs. Returns ------- - NDArray + ImageType Color corrected image. + + Raises + ------ + RuntimeError + If model has not been fitted. """ - if self.correction_weights is None: + if self.trained_model is None: raise RuntimeError("Model must be fitted before correction") - return self.correction_model.compute_correction( + corrected_image = self.correction_model.compute_correction( input_image=input_image.copy(), ) - def calculate_color_difference( + if debug: + self._save_debug_output( + input_image=input_image, + corrected_image=corrected_image, + output_directory=debug_output_dir, + ) + + return corrected_image + + def calc_color_diff( self, image1: ImageType, image2: ImageType, @@ -162,24 +376,26 @@ def calculate_color_difference( if __name__ == "__main__": - import os + # Step 1: Define the path to the input image + image_path = "asset/images/cc-19.png" - image_path = "color_correction_asdfghjkl/asset/images/cc-1.jpg" - image_path = "color_correction_asdfghjkl/asset/images/cc-19.png" - filename = os.path.basename(image_path) - cc = ColorCorrection(detection_model="yolov8", correction_model="least_squares") + # Step 2: Load the input image input_image = cv2.imread(image_path) - _, ls_input_patches, ls_reference_patches = cc.fit(input_image=input_image) - corrected_image = cc.correct_image(input_image=input_image) - - in_img_patch = generate_image_patches(ls_input_patches) - ref_img_patch = generate_image_patches(ls_reference_patches) - cv2.imwrite(f"input_image_patches-{filename}.jpg", in_img_patch) - cv2.imwrite(f"reference_image_patches-{filename}.jpg", ref_img_patch) - cc.calculate_color_difference(in_img_patch, ref_img_patch) - - ls_correct_patch = cc.extract_color_patches(input_image=corrected_image) - corrected_img_patch = generate_image_patches(ls_correct_patch) - cv2.imwrite(f"corrected_image_patches-{filename}.jpg", corrected_img_patch) - cc.calculate_color_difference(corrected_img_patch, ref_img_patch) - cv2.imwrite(f"corrected_image-{filename}.jpg", corrected_image) + + # Step 3: Initialize the color correction model with specified parameters + color_corrector = ColorCorrection( + detection_model="yolov8", + detection_conf_th=0.25, + correction_model="least_squares", + degree=2, # for polynomial correction model + use_gpu=True, + ) + + # Step 4: Extract color patches from the input image + color_corrector.set_input_patches(image=input_image, debug=True) + color_corrector.fit() + corrected_image = color_corrector.predict( + input_image=input_image, + debug=True, + debug_output_dir="zzz", + ) diff --git a/color_correction_asdfghjkl/utils/correction.py b/color_correction_asdfghjkl/utils/correction.py new file mode 100644 index 0000000..91f6f47 --- /dev/null +++ b/color_correction_asdfghjkl/utils/correction.py @@ -0,0 +1,25 @@ +import numpy as np + + +def preprocessing_compute(input_image: np.ndarray) -> np.ndarray: + if input_image.shape == (24, 3): + # to handle grid image patches only + image = input_image.astype(np.float32) + else: + image = input_image.reshape(-1, 3).astype(np.float32) + return image + + +def postprocessing_compute( + original_shape: tuple, + predict_image: np.ndarray, +) -> np.ndarray: + if len(original_shape) == 2: + # to handle grid image patches only + corrected_image = np.clip(predict_image, 0, 255).astype(np.uint8) + else: + h, w, c = original_shape + corrected_image = ( + np.clip(predict_image, 0, 255).astype(np.uint8).reshape(h, w, c) + ) + return corrected_image diff --git a/color_correction_asdfghjkl/utils/downloader.py b/color_correction_asdfghjkl/utils/downloader.py index 4c13df1..33ff30c 100644 --- a/color_correction_asdfghjkl/utils/downloader.py +++ b/color_correction_asdfghjkl/utils/downloader.py @@ -77,7 +77,7 @@ def downloader_model_yolov8(use_gpu: bool = False) -> str: fullpath = os.path.join(model_folder, filename) if os.path.exists(fullpath): return fullpath - + print("Auto downloading YOLOv8 model...") download_google_drive_file(fileid, fullpath) return fullpath diff --git a/color_correction_asdfghjkl/utils/geometry_processing.py b/color_correction_asdfghjkl/utils/geometry_processing.py index aebf049..812dd92 100644 --- a/color_correction_asdfghjkl/utils/geometry_processing.py +++ b/color_correction_asdfghjkl/utils/geometry_processing.py @@ -50,10 +50,10 @@ def generate_expected_patches(card_box: box_tuple) -> list[box_tuple]: expected_patches = [] for row in range(4): for col in range(6): - x1 = card_x1 + col * patch_width - y1 = card_y1 + row * patch_height - x2 = x1 + patch_width - y2 = y1 + patch_height + x1 = int(card_x1 + col * patch_width) + y1 = int(card_y1 + row * patch_height) + x2 = int(x1 + patch_width) + y2 = int(y1 + patch_height) expected_patches.append((x1, y1, x2, y2)) return expected_patches @@ -62,7 +62,7 @@ def generate_expected_patches(card_box: box_tuple) -> list[box_tuple]: def extract_intersecting_patches( ls_patches: list[box_tuple], ls_grid_card: list[box_tuple], -) -> list[box_tuple]: +) -> list[tuple[box_tuple, tuple[int, int]]]: ls_ordered_patch = [] for _, grid_card in enumerate(ls_grid_card, start=1): # get intesect patch @@ -81,7 +81,8 @@ def extract_intersecting_patches( ) # intersect_box = ls_intersect[max_id] val = box_to_xyxy(intersect_box) - ls_ordered_patch.append(val) + xy = box_centroid_xy(intersect_box) + ls_ordered_patch.append((val, xy)) else: ls_ordered_patch.append(None) return ls_ordered_patch @@ -119,7 +120,6 @@ def calculate_patch_statistics(ls_ordered_patch: list[box_tuple]) -> tuple: mean_w = np.mean(ls_w_grid) mean_h = np.mean(ls_h_grid) - print(ls_dx, mean_dx) return mean_dx, mean_dy, mean_w, mean_h diff --git a/color_correction_asdfghjkl/utils/image_patch.py b/color_correction_asdfghjkl/utils/image_patch.py new file mode 100644 index 0000000..1d8767d --- /dev/null +++ b/color_correction_asdfghjkl/utils/image_patch.py @@ -0,0 +1,118 @@ +import numpy as np + + +def create_patch_tiled_image( + ls_patches: list[tuple[int, int, int, int]], + patch_size: tuple[int, int, int] = (50, 50, 1), +) -> np.ndarray: + """Generate a color patch image from a list of BGR values. + + This function creates a color patch image by tiling BGR values into patches and + arranging them in a 4x6 grid pattern. Each patch is repeated according to the + specified patch size. + + Parameters + ---------- + ls_patches : list of tuple + List containing 24 BGR color tuples, where each tuple has three integers + representing (B, G, R) values. + patch_size : tuple of int, optional + Size of each individual patch in pixels, by default (50, 50, 1). + Format is (height, width, channels). + + Returns + ------- + numpy.ndarray + Generated image as a numpy array with shape determined by patch_size and + arrangement (4 rows x 6 columns). Array type is uint8. + + Notes + ----- + This function is specifically designed to work with 24 color patches arranged + in a 4x6 grid pattern. + + Examples + -------- + >>> patches = [(255, 0, 0), (0, 255, 0), ...] # 24 BGR tuples + >>> patch_size = (50, 50, 1) + >>> image = generate_image_patches(patches, patch_size) + """ + + ls_stack_h = [] + ls_stack_v = [] + + if len(ls_patches) != 24: + raise ValueError("Failed to generate image. The number of patches must be 24.") + + for _idx, patch in enumerate(ls_patches, start=1): + patch_img = np.tile(patch, patch_size) + ls_stack_h.append(patch_img) + if _idx % 6 == 0: + row = np.hstack(ls_stack_h) + ls_stack_v.append(row) + ls_stack_h = [] + image = np.vstack(ls_stack_v).astype(np.uint8) + return image + + +def visualize_patch_comparison( + ls_mean_ref: np.ndarray, + ls_mean_in: np.ndarray, + patch_size: tuple[int, int, int] = (100, 100, 1), +) -> np.ndarray: + """ + Compare two sets of image patches by inserting a resized inner patch into + the center of an outer patch. This visualization grid helps in comparing + the reference and input images in a structured manner. + + Parameters + ---------- + ls_mean_ref : list of np.ndarray + List of outer image patches. Each patch is repeated to form the full + grid background. + ls_mean_in : list of np.ndarray + List of inner image patches meant to be resized and placed into the + center of the outer patches. + patch_size : tuple of int, optional + A tuple specifying the size of the patch in the format (height, width, channels) + by default (100, 100, 1). + + Returns + ------- + np.ndarray + The final composited image with each outer patch modified with the + corresponding resized inner patch, arranged in a grid format. + """ + + ls_stack_h = [] + ls_stack_v = [] + + h = patch_size[0] + w = patch_size[1] + h_2 = h // 2 + w_2 = w // 2 + y1 = h_2 - (h // 4) - 1 + y2 = h_2 + (h // 4) + x1 = w_2 - (w // 4) - 1 + x2 = w_2 + (w // 4) + + for _idx, (patch_ref, patch_in) in enumerate( + zip(ls_mean_ref, ls_mean_in, strict=False), + start=1, + ): + img_patch_ref = np.tile(patch_ref, patch_size) + img_patch_in = np.tile( + patch_in, + (y2 - y1, x2 - x1, patch_size[2]), + ) + + # img_patch_in = cv2.resize(img_patch_in, (y2 - y1, x2 - x1)) + img_patch_ref[y1:y2, x1:x2] = img_patch_in + ls_stack_h.append(img_patch_ref) + + if _idx % 6 == 0: + row = np.hstack(ls_stack_h) + ls_stack_v.append(row) + ls_stack_h = [] + image = np.vstack(ls_stack_v).astype(np.uint8) + return image diff --git a/color_correction_asdfghjkl/utils/image_processing.py b/color_correction_asdfghjkl/utils/image_processing.py index 2bef3fc..a1e3971 100644 --- a/color_correction_asdfghjkl/utils/image_processing.py +++ b/color_correction_asdfghjkl/utils/image_processing.py @@ -1,5 +1,3 @@ -import matplotlib.figure -import matplotlib.pyplot as plt import numpy as np @@ -40,8 +38,8 @@ def crop_region_with_margin( return image[crop_y1:crop_y2, crop_x1:crop_x2] -def calculate_mean_rgb(img: np.ndarray) -> np.ndarray: - """Calculate mean RGB values across spatial dimensions. +def calc_mean_color_patch(img: np.ndarray) -> np.ndarray: + """Calculate mean RGB/BGR values across spatial dimensions. Parameters ---------- @@ -54,90 +52,3 @@ def calculate_mean_rgb(img: np.ndarray) -> np.ndarray: Array of mean RGB values, shape (C,), dtype uint8. """ return np.mean(img, axis=(0, 1)).astype(np.uint8) - - -def generate_image_patches( - ls_patches: list[tuple[int, int, int, int]], - patch_size: tuple[int, int, int] = (50, 50, 1), -) -> np.ndarray: - ls_stack_h = [] - ls_stack_v = [] - - for _idx, patch in enumerate(ls_patches, start=1): - patch_img = np.tile(patch, patch_size) - ls_stack_h.append(patch_img) - if _idx % 6 == 0: - row = np.hstack(ls_stack_h) - ls_stack_v.append(row) - ls_stack_h = [] - image = np.vstack(ls_stack_v).astype(np.uint8) - return image - - -def display_image_grid( - images: list[tuple[str, np.ndarray | matplotlib.figure.Figure]], - grid_size: tuple[int, int] = (2, 3), - figsize: tuple[int, int] = (15, 10), - save_path: str | None = None, - dpi: int = 300, -) -> matplotlib.figure.Figure: - """ - Display images in a grid layout with titles - - Parameters: - ----------- - images : List[Tuple[str, Union[np.ndarray, matplotlib.figure.Figure]]] - List of tuples containing (title, image) - grid_size : Tuple[int, int] - Grid layout in (rows, columns) format - figsize : Tuple[int, int] - Size of the entire figure in inches - save_path : Optional[str] - If provided, save the figure to this path - dpi : int - DPI for saved figure - - Returns: - -------- - matplotlib.figure.Figure - The figure object containing the grid - """ - - rows, cols = grid_size - fig = plt.figure(figsize=figsize) - - for idx, (title, img) in enumerate(images): - if idx >= rows * cols: - print( - f"Warning: Only showing first {rows * cols} images due to " - "grid size limitation", - ) - break - - ax = fig.add_subplot(rows, cols, idx + 1) - - # Handle different image types - if isinstance(img, np.ndarray): - if len(img.shape) == 2: # Grayscale - ax.imshow(img, cmap="gray") - else: # RGB/RGBA - ax.imshow(img) - elif isinstance(img, matplotlib.figure.Figure): - # Convert matplotlib figure to image array - fig.canvas.draw() - img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) - img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - ax.imshow(img_array) - - ax.set_title(title) - ax.axis("off") - - plt.tight_layout() - - # Save figure if path is provided - if save_path: - fig.savefig(save_path, dpi=dpi, bbox_inches="tight") - print(f"Figure saved to: {save_path}") - - plt.close() # Close the figure to free memory - return fig diff --git a/color_correction_asdfghjkl/utils/visualization_utils.py b/color_correction_asdfghjkl/utils/visualization_utils.py new file mode 100644 index 0000000..2f81ccf --- /dev/null +++ b/color_correction_asdfghjkl/utils/visualization_utils.py @@ -0,0 +1,74 @@ +import cv2 +import matplotlib.figure +import matplotlib.pyplot as plt +import numpy as np + + +def create_image_grid_visualization( + images: list[tuple[str, np.ndarray | matplotlib.figure.Figure | None]], + grid_size: tuple[int, int] = (2, 3), + figsize: tuple[int, int] = (15, 10), + save_path: str | None = None, + dpi: int = 300, +) -> matplotlib.figure.Figure: + """ + Display images in a grid layout with titles + + Parameters: + ----------- + images : List[Tuple[str, Union[np.ndarray, matplotlib.figure.Figure, None]]] + List of tuples containing (title, image) + grid_size : Tuple[int, int] + Grid layout in (rows, columns) format + figsize : Tuple[int, int] + Size of the entire figure in inches + save_path : Optional[str] + If provided, save the figure to this path + dpi : int + DPI for saved figure + + Returns: + -------- + matplotlib.figure.Figure + The figure object containing the grid + """ + + rows, cols = grid_size + fig = plt.figure(figsize=figsize) + + for idx, (title, img) in enumerate(images): + if idx >= rows * cols: + print( + f"Warning: Only showing first {rows * cols} images due to " + "grid size limitation", + ) + break + + ax = fig.add_subplot(rows, cols, idx + 1) + + # Handle different image types + if isinstance(img, np.ndarray): + if len(img.shape) == 2: # Grayscale + ax.imshow(img, cmap="gray") + else: # RGB/RGBA + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + ax.imshow(img) + elif isinstance(img, matplotlib.figure.Figure): + # Convert matplotlib figure to image array + fig.canvas.draw() + img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + ax.imshow(img_array) + + ax.set_title(title) + ax.axis("off") + + plt.tight_layout() + + # Save figure if path is provided + if save_path: + fig.savefig(save_path, dpi=dpi, bbox_inches="tight") + print(f"Figure saved to: {save_path}") + + plt.close() # Close the figure to free memory + return fig diff --git a/pyproject.toml b/pyproject.toml index cb29ef9..7c3361e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "color-correction-asdfghjkl" -version = "0.0.1a2" +version = "0.0.1b0" description = "help to do color correction on images based on color checker card classic 24 patch." keywords = ["color correction", "color-correction", "color consistency", "color-consistency"] readme = "README.md" @@ -14,6 +14,7 @@ dependencies = [ "pydantic-settings>=2.7.1", "shapely>=2.0.6", "colour-science>=0.4.6", + "scikit-learn>=1.6.1", ] classifiers = [ 'Programming Language :: Python :: 3 :: Only', @@ -48,6 +49,7 @@ dev = [ "pytest-cov==6.0.0", "pytest>=8.3.4", "ruff>=0.9.2", + "pre-commit>=4.1.0", ] dev-model = [ "ultralytics>=8.3.65", @@ -108,7 +110,8 @@ exclude = [ "node_modules", "site-packages", "venv", - "tmp" + "tmp", + "tests", ] line-length = 88 diff --git a/tests/core/card_detection/test_yolov8_detector.py b/tests/core/card_detection/test_yolov8_detector.py index 4f4589e..40c6939 100644 --- a/tests/core/card_detection/test_yolov8_detector.py +++ b/tests/core/card_detection/test_yolov8_detector.py @@ -1,10 +1,7 @@ import numpy as np import pytest -from color_correction_asdfghjkl.core.card_detection.yolov8_det_onnx import ( - YOLOv8CardDetector, -) - +from color_correction_asdfghjkl.core.card_detection.det_yv8_onnx import YOLOv8CardDetector @pytest.mark.skip(reason="Test is not implemented") def test_detector_init(sample_image: np.ndarray) -> None: diff --git a/tests/utils/test_image_patch.py b/tests/utils/test_image_patch.py new file mode 100644 index 0000000..53e8cb8 --- /dev/null +++ b/tests/utils/test_image_patch.py @@ -0,0 +1,67 @@ +import numpy as np +import pytest +import cv2 + +from color_correction_asdfghjkl.utils.image_patch import ( + create_patch_tiled_image, + visualize_patch_comparison, +) + + +@pytest.fixture +def sample_patches_bgr() -> list[tuple[int, int, int]]: + return np.random.randint(0, 255, size=(24, 3)) + + +@pytest.fixture +def sample_outer_patches() -> list[np.ndarray]: + return np.random.randint(0, 255, size=(24, 3)) + + +@pytest.fixture +def sample_inner_patches() -> list[np.ndarray]: + return np.random.randint(0, 255, size=(24, 3)) + + +def test_create_patch_tiled_image(sample_patches_bgr) -> None: # noqa: ANN001 + patch_size = (50, 50, 1) + image = create_patch_tiled_image(sample_patches_bgr, patch_size) + assert image.shape == (4 * patch_size[0], 6 * patch_size[1], 3 * patch_size[2]) + assert image.dtype == np.uint8 + + +def test_create_patch_tiled_image_custom_size(sample_patches_bgr) -> None: # noqa: ANN001 + patch_size = (30, 30, 1) + image = create_patch_tiled_image(sample_patches_bgr, patch_size) + assert image.shape == (4 * patch_size[0], 6 * patch_size[1], 3 * patch_size[2]) + + +def test_visualize_patch_comparison(sample_outer_patches, sample_inner_patches): + patch_size = (100, 100, 1) + image = visualize_patch_comparison( + sample_outer_patches, + sample_inner_patches, + patch_size, + ) + assert image.shape == (4 * patch_size[0], 6 * patch_size[1], 3 * patch_size[2]) + assert image.dtype == np.uint8 + + +def test_visualize_inner_patch_center(sample_outer_patches, sample_inner_patches): + patch_size = (100, 100, 1) + h, w, _ = patch_size + h_half = h // 2 + w_half = w // 2 + y1 = h_half - (h // 4) + y2 = h_half + (h // 4) + x1 = w_half - (w // 4) + x2 = w_half + (w // 4) + + image = visualize_patch_comparison( + sample_outer_patches, + sample_inner_patches, + patch_size, + ) + inner_patch = sample_inner_patches[0] + inner_patch = np.tile(inner_patch, (y2 - y1, x2 - x1, 1)) + assert np.array_equal(image[y1:y2, x1:x2], inner_patch) diff --git a/tests/utils/test_image_processing.py b/tests/utils/test_image_processing.py new file mode 100644 index 0000000..3293c05 --- /dev/null +++ b/tests/utils/test_image_processing.py @@ -0,0 +1,52 @@ +import numpy as np +import pytest +from color_correction_asdfghjkl.utils.image_processing import crop_region_with_margin, calc_mean_color_patch + +@pytest.fixture +def known_image() -> np.ndarray: + # Create an image with a known pattern using np.arange, + # reshape to (100, 100, 3) and wrap values with modulo 255 + img = np.arange(100 * 100 * 3, dtype=np.uint8) % 255 + return img.reshape((100, 100, 3)) + +@pytest.mark.parametrize( + "coordinates, margin_ratio, expected_slice", + [ + ( + (10, 20, 90, 80), + 0.2, + (slice(32, 68), slice(26, 74)) + ), + ( + (10, 20, 90, 80), + 0.0, + (slice(20, 80), slice(10, 90)) + ), + ], +) +def test_crop_region_with_margin(known_image: np.ndarray, coordinates, margin_ratio, expected_slice) -> None: + result = crop_region_with_margin(known_image, coordinates, margin_ratio) + expected = known_image[expected_slice[0], expected_slice[1]] + np.testing.assert_array_equal(result, expected) + +@pytest.mark.parametrize( + "img, expected_mean", + [ + ( + np.full((10, 10, 3), fill_value=50, dtype=np.uint8), + np.array([50, 50, 50], dtype=np.uint8) + ), + ( + # Create an image where channel 0 is 0, channel 1 is 100, channel 2 is 200. + np.stack([ + np.zeros((20, 20), dtype=np.uint8), + np.full((20, 20), 100, dtype=np.uint8), + np.full((20, 20), 200, dtype=np.uint8) + ], axis=-1), + np.array([0, 100, 200], dtype=np.uint8) + ), + ], +) +def test_calc_mean_color_patch(img: np.ndarray, expected_mean) -> None: + mean_color = calc_mean_color_patch(img) + np.testing.assert_array_equal(mean_color, expected_mean) diff --git a/uv.lock b/uv.lock index e88d6ce..8424ade 100644 --- a/uv.lock +++ b/uv.lock @@ -49,6 +49,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a5/32/8f6669fc4798494966bf446c8c4a162e0b5d893dff088afddf76414f70e1/certifi-2024.12.14-py3-none-any.whl", hash = "sha256:1275f7a45be9464efc1173084eaa30f866fe2e47d389406136d332ed4967ec56", size = 164927 }, ] +[[package]] +name = "cfgv" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/11/74/539e56497d9bd1d484fd863dd69cbbfa653cd2aa27abfe35653494d85e94/cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560", size = 7114 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249 }, +] + [[package]] name = "charset-normalizer" version = "3.4.1" @@ -112,13 +121,14 @@ wheels = [ [[package]] name = "color-correction-asdfghjkl" -version = "0.0.1a1" +version = "0.0.1b0" source = { editable = "." } dependencies = [ { name = "colour-science" }, { name = "httpx" }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "scikit-learn" }, { name = "shapely" }, ] @@ -130,6 +140,7 @@ default = [ { name = "opencv-python-headless" }, ] dev = [ + { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "ruff" }, @@ -151,6 +162,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.28.1" }, { name = "pydantic", specifier = ">=2.10.5" }, { name = "pydantic-settings", specifier = ">=2.7.1" }, + { name = "scikit-learn", specifier = ">=1.6.1" }, { name = "shapely", specifier = ">=2.0.6" }, ] @@ -158,6 +170,7 @@ requires-dist = [ analyze = [{ name = "matplotlib", specifier = ">=3.10.0" }] default = [{ name = "opencv-python-headless", specifier = ">=4.11.0.86" }] dev = [ + { name = "pre-commit", specifier = ">=4.1.0" }, { name = "pytest", specifier = ">=8.3.4" }, { name = "pytest-cov", specifier = "==6.0.0" }, { name = "ruff", specifier = ">=0.9.2" }, @@ -344,6 +357,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30", size = 8321 }, ] +[[package]] +name = "distlib" +version = "0.3.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0d/dd/1bec4c5ddb504ca60fc29472f3d27e8d4da1257a854e1d96742f15c1d02d/distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403", size = 613923 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973 }, +] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -470,6 +492,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f0/0f/310fb31e39e2d734ccaa2c0fb981ee41f7bd5056ce9bc29b2248bd569169/humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477", size = 86794 }, ] +[[package]] +name = "identify" +version = "2.6.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/82/bf/c68c46601bacd4c6fb4dd751a42b6e7087240eaabc6487f2ef7a48e0e8fc/identify-2.6.6.tar.gz", hash = "sha256:7bec12768ed44ea4761efb47806f0a41f86e7c0a5fdf5950d4648c90eca7e251", size = 99217 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/74/a1/68a395c17eeefb04917034bd0a1bfa765e7654fa150cca473d669aa3afb5/identify-2.6.6-py2.py3-none-any.whl", hash = "sha256:cbd1810bce79f8b671ecb20f53ee0ae8e86ae84b557de31d89709dc2a48ba881", size = 99083 }, +] + [[package]] name = "idna" version = "3.10" @@ -514,6 +545,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bd/0f/2ba5fbcd631e3e88689309dbe978c5769e883e4b84ebfe7da30b43275c5a/jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb", size = 134596 }, ] +[[package]] +name = "joblib" +version = "1.4.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/64/33/60135848598c076ce4b231e1b1895170f45fbcaeaa2c9d5e38b04db70c35/joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e", size = 2116621 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 }, +] + [[package]] name = "kiwisolver" version = "1.4.8" @@ -730,6 +770,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263 }, ] +[[package]] +name = "nodeenv" +version = "1.9.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 }, +] + [[package]] name = "numpy" version = "1.26.4" @@ -1151,6 +1200,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/67/936f9814bdd74b2dfd4822f1f7725ab5d8ff4103919a1664eb4874c58b2f/pillow-11.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:4637b88343166249fe8aa94e7c4a62a180c4b3898283bb5d3d2fd5fe10d8e4e0", size = 2626353 }, ] +[[package]] +name = "platformdirs" +version = "4.3.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/13/fc/128cc9cb8f03208bdbf93d3aa862e16d376844a14f9a0ce5cf4507372de4/platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907", size = 21302 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/a6/bc1012356d8ece4d66dd75c4b9fc6c1f6650ddd5991e421177d9f8f671be/platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb", size = 18439 }, +] + [[package]] name = "pluggy" version = "1.5.0" @@ -1160,6 +1218,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, ] +[[package]] +name = "pre-commit" +version = "4.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cfgv" }, + { name = "identify" }, + { name = "nodeenv" }, + { name = "pyyaml" }, + { name = "virtualenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2a/13/b62d075317d8686071eb843f0bb1f195eb332f48869d3c31a4c6f1e063ac/pre_commit-4.1.0.tar.gz", hash = "sha256:ae3f018575a588e30dfddfab9a05448bfbd6b73d78709617b5a2b853549716d4", size = 193330 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/b3/df14c580d82b9627d173ceea305ba898dca135feb360b6d84019d0803d3b/pre_commit-4.1.0-py2.py3-none-any.whl", hash = "sha256:d29e7cb346295bcc1cc75fc3e92e343495e3ea0196c9ec6ba53f49f10ab6ae7b", size = 220560 }, +] + [[package]] name = "protobuf" version = "5.29.3" @@ -1462,6 +1536,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/4e/33df635528292bd2d18404e4daabcd74ca8a9853b2e1df85ed3d32d24362/ruff-0.9.2-py3-none-win_arm64.whl", hash = "sha256:a1b63fa24149918f8b37cef2ee6fff81f24f0d74b6f0bdc37bc3e1f2143e41c6", size = 10001738 }, ] +[[package]] +name = "scikit-learn" +version = "1.6.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "joblib" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'darwin'" }, + { name = "scipy" }, + { name = "threadpoolctl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/a5/4ae3b3a0755f7b35a280ac90b28817d1f380318973cff14075ab41ef50d9/scikit_learn-1.6.1.tar.gz", hash = "sha256:b4fc2525eca2c69a59260f583c56a7557c6ccdf8deafdba6e060f94c1c59738e", size = 7068312 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/3a/f4597eb41049110b21ebcbb0bcb43e4035017545daa5eedcfeb45c08b9c5/scikit_learn-1.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d056391530ccd1e501056160e3c9673b4da4805eb67eb2bdf4e983e1f9c9204e", size = 12067702 }, + { url = "https://files.pythonhosted.org/packages/37/19/0423e5e1fd1c6ec5be2352ba05a537a473c1677f8188b9306097d684b327/scikit_learn-1.6.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0c8d036eb937dbb568c6242fa598d551d88fb4399c0344d95c001980ec1c7d36", size = 11112765 }, + { url = "https://files.pythonhosted.org/packages/70/95/d5cb2297a835b0f5fc9a77042b0a2d029866379091ab8b3f52cc62277808/scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8634c4bd21a2a813e0a7e3900464e6d593162a29dd35d25bdf0103b3fce60ed5", size = 12643991 }, + { url = "https://files.pythonhosted.org/packages/b7/91/ab3c697188f224d658969f678be86b0968ccc52774c8ab4a86a07be13c25/scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:775da975a471c4f6f467725dff0ced5c7ac7bda5e9316b260225b48475279a1b", size = 13497182 }, + { url = "https://files.pythonhosted.org/packages/17/04/d5d556b6c88886c092cc989433b2bab62488e0f0dafe616a1d5c9cb0efb1/scikit_learn-1.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:8a600c31592bd7dab31e1c61b9bbd6dea1b3433e67d264d17ce1017dbdce8002", size = 11125517 }, + { url = "https://files.pythonhosted.org/packages/6c/2a/e291c29670795406a824567d1dfc91db7b699799a002fdaa452bceea8f6e/scikit_learn-1.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:72abc587c75234935e97d09aa4913a82f7b03ee0b74111dcc2881cba3c5a7b33", size = 12102620 }, + { url = "https://files.pythonhosted.org/packages/25/92/ee1d7a00bb6b8c55755d4984fd82608603a3cc59959245068ce32e7fb808/scikit_learn-1.6.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b3b00cdc8f1317b5f33191df1386c0befd16625f49d979fe77a8d44cae82410d", size = 11116234 }, + { url = "https://files.pythonhosted.org/packages/30/cd/ed4399485ef364bb25f388ab438e3724e60dc218c547a407b6e90ccccaef/scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc4765af3386811c3ca21638f63b9cf5ecf66261cc4815c1db3f1e7dc7b79db2", size = 12592155 }, + { url = "https://files.pythonhosted.org/packages/a8/f3/62fc9a5a659bb58a03cdd7e258956a5824bdc9b4bb3c5d932f55880be569/scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25fc636bdaf1cc2f4a124a116312d837148b5e10872147bdaf4887926b8c03d8", size = 13497069 }, + { url = "https://files.pythonhosted.org/packages/a1/a6/c5b78606743a1f28eae8f11973de6613a5ee87366796583fb74c67d54939/scikit_learn-1.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:fa909b1a36e000a03c382aade0bd2063fd5680ff8b8e501660c0f59f021a6415", size = 11139809 }, + { url = "https://files.pythonhosted.org/packages/0a/18/c797c9b8c10380d05616db3bfb48e2a3358c767affd0857d56c2eb501caa/scikit_learn-1.6.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:926f207c804104677af4857b2c609940b743d04c4c35ce0ddc8ff4f053cddc1b", size = 12104516 }, + { url = "https://files.pythonhosted.org/packages/c4/b7/2e35f8e289ab70108f8cbb2e7a2208f0575dc704749721286519dcf35f6f/scikit_learn-1.6.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2c2cae262064e6a9b77eee1c8e768fc46aa0b8338c6a8297b9b6759720ec0ff2", size = 11167837 }, + { url = "https://files.pythonhosted.org/packages/a4/f6/ff7beaeb644bcad72bcfd5a03ff36d32ee4e53a8b29a639f11bcb65d06cd/scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1061b7c028a8663fb9a1a1baf9317b64a257fcb036dae5c8752b2abef31d136f", size = 12253728 }, + { url = "https://files.pythonhosted.org/packages/29/7a/8bce8968883e9465de20be15542f4c7e221952441727c4dad24d534c6d99/scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e69fab4ebfc9c9b580a7a80111b43d214ab06250f8a7ef590a4edf72464dd86", size = 13147700 }, + { url = "https://files.pythonhosted.org/packages/62/27/585859e72e117fe861c2079bcba35591a84f801e21bc1ab85bce6ce60305/scikit_learn-1.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:70b1d7e85b1c96383f872a519b3375f92f14731e279a7b4c6cfd650cf5dffc52", size = 11110613 }, + { url = "https://files.pythonhosted.org/packages/2e/59/8eb1872ca87009bdcdb7f3cdc679ad557b992c12f4b61f9250659e592c63/scikit_learn-1.6.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2ffa1e9e25b3d93990e74a4be2c2fc61ee5af85811562f1288d5d055880c4322", size = 12010001 }, + { url = "https://files.pythonhosted.org/packages/9d/05/f2fc4effc5b32e525408524c982c468c29d22f828834f0625c5ef3d601be/scikit_learn-1.6.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:dc5cf3d68c5a20ad6d571584c0750ec641cc46aeef1c1507be51300e6003a7e1", size = 11096360 }, + { url = "https://files.pythonhosted.org/packages/c8/e4/4195d52cf4f113573fb8ebc44ed5a81bd511a92c0228889125fac2f4c3d1/scikit_learn-1.6.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c06beb2e839ecc641366000ca84f3cf6fa9faa1777e29cf0c04be6e4d096a348", size = 12209004 }, + { url = "https://files.pythonhosted.org/packages/94/be/47e16cdd1e7fcf97d95b3cb08bde1abb13e627861af427a3651fcb80b517/scikit_learn-1.6.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8ca8cb270fee8f1f76fa9bfd5c3507d60c6438bbee5687f81042e2bb98e5a97", size = 13171776 }, + { url = "https://files.pythonhosted.org/packages/34/b0/ca92b90859070a1487827dbc672f998da95ce83edce1270fc23f96f1f61a/scikit_learn-1.6.1-cp313-cp313-win_amd64.whl", hash = "sha256:7a1c43c8ec9fde528d664d947dc4c0789be4077a3647f232869f41d9bf50e0fb", size = 11071865 }, + { url = "https://files.pythonhosted.org/packages/12/ae/993b0fb24a356e71e9a894e42b8a9eec528d4c70217353a1cd7a48bc25d4/scikit_learn-1.6.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:a17c1dea1d56dcda2fac315712f3651a1fea86565b64b48fa1bc090249cbf236", size = 11955804 }, + { url = "https://files.pythonhosted.org/packages/d6/54/32fa2ee591af44507eac86406fa6bba968d1eb22831494470d0a2e4a1eb1/scikit_learn-1.6.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:6a7aa5f9908f0f28f4edaa6963c0a6183f1911e63a69aa03782f0d924c830a35", size = 11100530 }, + { url = "https://files.pythonhosted.org/packages/3f/58/55856da1adec655bdce77b502e94a267bf40a8c0b89f8622837f89503b5a/scikit_learn-1.6.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0650e730afb87402baa88afbf31c07b84c98272622aaba002559b614600ca691", size = 12433852 }, + { url = "https://files.pythonhosted.org/packages/ff/4f/c83853af13901a574f8f13b645467285a48940f185b690936bb700a50863/scikit_learn-1.6.1-cp313-cp313t-win_amd64.whl", hash = "sha256:3f59fe08dc03ea158605170eb52b22a105f238a5d512c4470ddeca71feae8e5f", size = 11337256 }, +] + [[package]] name = "scipy" version = "1.15.1" @@ -1594,6 +1707,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/99/ff/c87e0622b1dadea79d2fb0b25ade9ed98954c9033722eb707053d310d4f3/sympy-1.13.3-py3-none-any.whl", hash = "sha256:54612cf55a62755ee71824ce692986f23c88ffa77207b30c1368eda4a7060f73", size = 6189483 }, ] +[[package]] +name = "threadpoolctl" +version = "3.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bd/55/b5148dcbf72f5cde221f8bfe3b6a540da7aa1842f6b491ad979a6c8b84af/threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107", size = 41936 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/2c/ffbf7a134b9ab11a67b0cf0726453cedd9c5043a4fe7a35d1cefa9a1bcfb/threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467", size = 18414 }, +] + [[package]] name = "tomli" version = "2.2.1" @@ -1789,3 +1911,17 @@ sdist = { url = "https://files.pythonhosted.org/packages/aa/63/e53da845320b757bf wheels = [ { url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 }, ] + +[[package]] +name = "virtualenv" +version = "20.29.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "distlib" }, + { name = "filelock" }, + { name = "platformdirs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a7/ca/f23dcb02e161a9bba141b1c08aa50e8da6ea25e6d780528f1d385a3efe25/virtualenv-20.29.1.tar.gz", hash = "sha256:b8b8970138d32fb606192cb97f6cd4bb644fa486be9308fb9b63f81091b5dc35", size = 7658028 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/9b/599bcfc7064fbe5740919e78c5df18e5dceb0887e676256a1061bb5ae232/virtualenv-20.29.1-py3-none-any.whl", hash = "sha256:4e4cb403c0b0da39e13b46b1b2476e505cb0046b25f242bee80f62bf990b2779", size = 4282379 }, +]