11import os
22from typing import Literal
33
4- import colour as cl
54import cv2
65import numpy as np
76from numpy .typing import NDArray
1615 create_patch_tiled_image ,
1716 visualize_patch_comparison ,
1817)
18+ from color_correction_asdfghjkl .utils .image_processing import calc_color_diff
1919from color_correction_asdfghjkl .utils .visualization_utils import (
2020 create_image_grid_visualization ,
2121)
@@ -84,6 +84,10 @@ def __init__(
8484 self .input_grid_image = None
8585 self .input_debug_image = None
8686
87+ # Initialize correction output attributes
88+ self .corrected_patches = None
89+ self .corrected_grid_image = None
90+
8791 # Initialize model attributes
8892 self .trained_model = None
8993 self .correction_model = CorrectionModelFactory .create (
@@ -178,17 +182,12 @@ def _save_debug_output(
178182 output_directory : str
179183 Directory to save debug outputs.
180184 """
181- predicted_patches = self .correction_model .compute_correction (
182- input_image = np .array (self .input_patches ),
183- )
184- predicted_grid = create_patch_tiled_image (predicted_patches )
185-
186185 before_comparison = visualize_patch_comparison (
187186 ls_mean_in = self .input_patches ,
188187 ls_mean_ref = self .reference_patches ,
189188 )
190189 after_comparison = visualize_patch_comparison (
191- ls_mean_in = predicted_patches ,
190+ ls_mean_in = self . corrected_patches ,
192191 ls_mean_ref = self .reference_patches ,
193192 )
194193
@@ -204,7 +203,7 @@ def _save_debug_output(
204203 ("Reference vs Corrected" , after_comparison ),
205204 ("[Free Space]" , None ),
206205 ("Patch Input" , self .input_grid_image ),
207- ("Patch Corrected" , predicted_grid ),
206+ ("Patch Corrected" , self . corrected_grid_image ),
208207 ("Patch Reference" , self .reference_grid_image ),
209208 ]
210209
@@ -239,11 +238,17 @@ def _create_debug_directory(self, base_dir: str) -> str:
239238
240239 @property
241240 def model_name (self ) -> str :
241+ "Return the name of the correction model."
242242 return self .correction_model .__class__ .__name__
243243
244244 @property
245- def img_grid_patches_ref (self ) -> np .ndarray :
246- return create_patch_tiled_image (self .reference_color_card )
245+ def ref_patches (self ) -> np .ndarray :
246+ """Return grid image of reference color patches."""
247+ return (
248+ self .reference_patches ,
249+ self .reference_grid_image ,
250+ self .reference_debug_image ,
251+ )
247252
248253 def set_reference_patches (
249254 self ,
@@ -270,6 +275,7 @@ def set_input_patches(self, image: np.ndarray, debug: bool = False) -> None:
270275 self .input_grid_image ,
271276 self .input_debug_image ,
272277 ) = self ._extract_color_patches (image = image , debug = debug )
278+ return self .input_patches , self .input_grid_image , self .input_debug_image
273279
274280 def fit (self ) -> tuple [NDArray , list [ColorPatchType ], list [ColorPatchType ]]:
275281 """Fit color correction model using input and reference images.
@@ -297,6 +303,12 @@ def fit(self) -> tuple[NDArray, list[ColorPatchType], list[ColorPatchType]]:
297303 y_patches = self .reference_patches ,
298304 )
299305
306+ # Compute corrected patches
307+ self .corrected_patches = self .correction_model .compute_correction (
308+ input_image = np .array (self .input_patches ),
309+ )
310+ self .corrected_grid_image = create_patch_tiled_image (self .corrected_patches )
311+
300312 return self .trained_model
301313
302314 def predict (
@@ -342,42 +354,37 @@ def predict(
342354
343355 return corrected_image
344356
345- def calc_color_diff (
346- self ,
347- image1 : ImageType ,
348- image2 : ImageType ,
349- ) -> tuple [float , float , float , float ]:
350- """Calculate color difference metrics between two images.
351-
352- Parameters
353- ----------
354- image1, image2 : NDArray
355- Images to compare in BGR format.
357+ def calc_color_diff_patches (self ) -> dict :
358+ initial_color_diff = calc_color_diff (
359+ image1 = self .input_grid_image ,
360+ image2 = self .reference_grid_image ,
361+ )
356362
357- Returns
358- -------
359- Tuple[float, float, float, float]
360- Minimum, maximum, mean, and standard deviation of delta E values.
361- """
362- rgb1 = cv2 .cvtColor (image1 , cv2 .COLOR_BGR2RGB )
363- rgb2 = cv2 .cvtColor (image2 , cv2 .COLOR_BGR2RGB )
363+ corrected_color_diff = calc_color_diff (
364+ image1 = self .corrected_grid_image ,
365+ image2 = self .reference_grid_image ,
366+ )
364367
365- lab1 = cl .XYZ_to_Lab (cl .sRGB_to_XYZ (rgb1 / 255 ))
366- lab2 = cl .XYZ_to_Lab (cl .sRGB_to_XYZ (rgb2 / 255 ))
368+ delta_color_diff = {
369+ "min" : initial_color_diff ["min" ] - corrected_color_diff ["min" ],
370+ "max" : initial_color_diff ["max" ] - corrected_color_diff ["max" ],
371+ "mean" : initial_color_diff ["mean" ] - corrected_color_diff ["mean" ],
372+ "std" : initial_color_diff ["std" ] - corrected_color_diff ["std" ],
373+ }
367374
368- delta_e = cl .difference .delta_E (lab1 , lab2 , method = "CIE 2000" )
375+ info = {
376+ "initial" : initial_color_diff ,
377+ "corrected" : corrected_color_diff ,
378+ "delta" : delta_color_diff ,
379+ }
369380
370- return (
371- float (np .min (delta_e )),
372- float (np .max (delta_e )),
373- float (np .mean (delta_e )),
374- float (np .std (delta_e )),
375- )
381+ return info
376382
377383
378384if __name__ == "__main__" :
379385 # Step 1: Define the path to the input image
380386 image_path = "asset/images/cc-19.png"
387+ image_path = "asset/images/cc-1.jpg"
381388
382389 # Step 2: Load the input image
383390 input_image = cv2 .imread (image_path )
@@ -386,16 +393,25 @@ def calc_color_diff(
386393 color_corrector = ColorCorrection (
387394 detection_model = "yolov8" ,
388395 detection_conf_th = 0.25 ,
389- correction_model = "least_squares" ,
390- degree = 2 , # for polynomial correction model
396+ correction_model = "polynomial" ,
397+ # correction_model="least_squares",
398+ # correction_model="affine_reg",
399+ # correction_model="linear_reg",
400+ degree = 3 , # for polynomial correction model
391401 use_gpu = True ,
392402 )
393403
394404 # Step 4: Extract color patches from the input image
405+ # you can set reference patches from another image (image has color checker card)
406+ # or use the default D50
407+ # color_corrector.set_reference_patches(image=None, debug=True)
395408 color_corrector .set_input_patches (image = input_image , debug = True )
396409 color_corrector .fit ()
397410 corrected_image = color_corrector .predict (
398411 input_image = input_image ,
399412 debug = True ,
400413 debug_output_dir = "zzz" ,
401414 )
415+
416+ eval_result = color_corrector .calc_color_diff_patches ()
417+ print (eval_result )
0 commit comments