|
| 1 | +import os |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import pandas as pd |
| 5 | + |
| 6 | +from color_correction.constant.methods import ( |
| 7 | + LiteralModelCorrection, |
| 8 | + LiteralModelDetection, |
| 9 | +) |
| 10 | +from color_correction.services.color_correction import ColorCorrection |
| 11 | +from color_correction.utils.image_patch import ( |
| 12 | + visualize_patch_comparison, |
| 13 | +) |
| 14 | +from color_correction.utils.image_processing import calc_color_diff |
| 15 | +from color_correction.utils.report_generator import ReportGenerator |
| 16 | + |
| 17 | + |
| 18 | +class ColorCorrectionAnalyzer: |
| 19 | + def __init__( |
| 20 | + self, |
| 21 | + list_correction_methods: list[tuple[LiteralModelCorrection, dict]], |
| 22 | + list_detection_methods: list[tuple[LiteralModelDetection, dict]], |
| 23 | + use_gpu: bool = True, |
| 24 | + ) -> None: |
| 25 | + self.list_correction_methods = list_correction_methods |
| 26 | + self.list_detection_methods = list_detection_methods |
| 27 | + self.use_gpu = use_gpu |
| 28 | + self.rg = ReportGenerator() |
| 29 | + |
| 30 | + def _run_single_exp( |
| 31 | + self, |
| 32 | + idx: int, |
| 33 | + input_image: np.ndarray, |
| 34 | + det_method: LiteralModelDetection, |
| 35 | + det_params: dict, |
| 36 | + cc_method: LiteralModelCorrection, |
| 37 | + cc_params: dict, |
| 38 | + reference_image: np.ndarray | None = None, |
| 39 | + ) -> dict: |
| 40 | + cc = ColorCorrection( |
| 41 | + correction_model=cc_method, |
| 42 | + detection_model=det_method, |
| 43 | + detection_conf_th=det_params.get("detection_conf_th", 0.25), |
| 44 | + use_gpu=self.use_gpu, |
| 45 | + **cc_params, |
| 46 | + ) |
| 47 | + |
| 48 | + if reference_image is not None: |
| 49 | + cc.set_reference_image(reference_image) |
| 50 | + cc.set_input_patches(input_image, debug=True) |
| 51 | + cc.fit() |
| 52 | + corrected_image = cc.predict(input_image=input_image) |
| 53 | + eval_results = cc.calc_color_diff_patches() |
| 54 | + |
| 55 | + before_comparison = visualize_patch_comparison( |
| 56 | + ls_mean_in=cc.input_patches, |
| 57 | + ls_mean_ref=cc.reference_patches, |
| 58 | + ) |
| 59 | + after_comparison = visualize_patch_comparison( |
| 60 | + ls_mean_in=cc.corrected_patches, |
| 61 | + ls_mean_ref=cc.reference_patches, |
| 62 | + ) |
| 63 | + |
| 64 | + dE_image = calc_color_diff( # noqa: N806 |
| 65 | + image1=input_image, |
| 66 | + image2=corrected_image, |
| 67 | + ) |
| 68 | + |
| 69 | + one_row = { |
| 70 | + "Index": idx, |
| 71 | + "Detection Method": det_method, |
| 72 | + "Detection Parameters": det_params, |
| 73 | + "Drawed Preprocessing Input": cc.input_debug_image, |
| 74 | + "Drawed Preprocessing Reference": cc.reference_debug_image, |
| 75 | + "Correction Method": cc_method, |
| 76 | + "Correction Parameters": cc_params, |
| 77 | + "Color Patches - Before": before_comparison, |
| 78 | + "Color Patches - After": after_comparison, |
| 79 | + "Input Image": input_image, |
| 80 | + "Corrected Image": corrected_image, |
| 81 | + "Patch ΔE (Before) - Min": eval_results["initial"]["min"], |
| 82 | + "Patch ΔE (Before) - Max": eval_results["initial"]["max"], |
| 83 | + "Patch ΔE (Before) - Mean": eval_results["initial"]["mean"], |
| 84 | + "Patch ΔE (Before) - Std": eval_results["initial"]["std"], |
| 85 | + "Patch ΔE (After) - Min": eval_results["corrected"]["min"], |
| 86 | + "Patch ΔE (After) - Max": eval_results["corrected"]["max"], |
| 87 | + "Patch ΔE (After) - Mean": eval_results["corrected"]["mean"], |
| 88 | + "Patch ΔE (After) - Std": eval_results["corrected"]["std"], |
| 89 | + "Image ΔE - Min": dE_image["min"], |
| 90 | + "Image ΔE - Max": dE_image["max"], |
| 91 | + "Image ΔE - Mean": dE_image["mean"], |
| 92 | + "Image ΔE - Std": dE_image["std"], |
| 93 | + } |
| 94 | + return one_row |
| 95 | + |
| 96 | + def run( |
| 97 | + self, |
| 98 | + input_image: np.ndarray, |
| 99 | + output_dir: str = "benchmark_debug", |
| 100 | + reference_image: np.ndarray | None = None, |
| 101 | + ) -> pd.DataFrame: |
| 102 | + """ |
| 103 | + Fungsi ini menjalankan benchmark untuk model color correction. |
| 104 | + """ |
| 105 | + ls_data = [] |
| 106 | + idx = 1 |
| 107 | + for det_method, det_params in self.list_detection_methods: |
| 108 | + for cc_method, cc_params in self.list_correction_methods: |
| 109 | + print( |
| 110 | + f"Running benchmark for {cc_method} method with {cc_params}", |
| 111 | + ) |
| 112 | + data = self._run_single_exp( |
| 113 | + idx=idx, |
| 114 | + input_image=input_image, |
| 115 | + det_method=det_method, |
| 116 | + det_params=det_params, |
| 117 | + cc_method=cc_method, |
| 118 | + cc_params=cc_params, |
| 119 | + reference_image=reference_image, |
| 120 | + ) |
| 121 | + idx += 1 |
| 122 | + ls_data.append(data) |
| 123 | + df_results = pd.DataFrame(ls_data) |
| 124 | + |
| 125 | + # Generate HTML report path |
| 126 | + os.makedirs(output_dir, exist_ok=True) |
| 127 | + html_report_path = os.path.join(output_dir, "report.html") |
| 128 | + pickel_report_path = os.path.join(output_dir, "report.pkl") |
| 129 | + |
| 130 | + # Report Generator ----------------------------------------------------- |
| 131 | + self.rg.generate_html_report(df=df_results, path_html=html_report_path) |
| 132 | + self.rg.save_dataframe(df=df_results, filepath=pickel_report_path) |
| 133 | + |
| 134 | + # Save CSV report, but without image data |
| 135 | + df_results.drop( |
| 136 | + columns=[ |
| 137 | + "Drawed Preprocessing Input", |
| 138 | + "Drawed Preprocessing Reference", |
| 139 | + "Color Patches - Before", |
| 140 | + "Color Patches - After", |
| 141 | + "Corrected Image", |
| 142 | + "Input Image", |
| 143 | + ], |
| 144 | + ).to_csv(os.path.join(output_dir, "report_no_image.csv"), index=False) |
| 145 | + |
| 146 | + print("DataFrame shape:", df_results.shape) |
| 147 | + print("\nDataFrame columns:", df_results.columns.tolist()) |
| 148 | + |
| 149 | + |
| 150 | +if __name__ == "__main__": |
| 151 | + # Pastikan path image sesuai dengan lokasi image Anda |
| 152 | + input_image_path = "asset/images/cc-19.png" |
| 153 | + |
| 154 | + benchmark = ColorCorrectionAnalyzer( |
| 155 | + list_correction_methods=[ |
| 156 | + ("least_squares", {}), |
| 157 | + ("linear_reg", {}), |
| 158 | + ("affine_reg", {}), |
| 159 | + ("polynomial", {"degree": 2}), |
| 160 | + ("polynomial", {"degree": 3}), |
| 161 | + ("polynomial", {"degree": 4}), |
| 162 | + ], |
| 163 | + list_detection_methods=[ |
| 164 | + ("yolov8", {"detection_conf_th": 0.25}), |
| 165 | + ], |
| 166 | + ) |
| 167 | + |
| 168 | + benchmark.run( |
| 169 | + input_image_path, |
| 170 | + reference_image=None, |
| 171 | + output_dir="benchmark_debug", |
| 172 | + ) |
0 commit comments