|
| 1 | +from __future__ import annotations |
| 2 | +from typing import Callable, TypeAlias, Any |
| 3 | +from time import perf_counter |
| 4 | +from pathlib import Path |
| 5 | +from collections.abc import Sequence |
| 6 | +from datetime import timedelta |
| 7 | + |
| 8 | +import json5 |
| 9 | +import numpy as np |
| 10 | +from scipy.optimize import minimize |
| 11 | + |
| 12 | +from vsl_ial import FArray |
| 13 | +from vsl_ial.cs.pcs23 import PCS23UCS, CS |
| 14 | +from vsl_ial.datasets.sensitivities import load as load_sensitivity |
| 15 | +from vsl_ial.cs.xyz import XYZ |
| 16 | +from vsl_ial.cs.ciexyy import CIExyY |
| 17 | + |
| 18 | +from ..eval._base import StrictModel |
| 19 | +from ..eval.dataset import DatasetConfig, WeightedDataset |
| 20 | +from ..eval.metrics import Metrics |
| 21 | + |
| 22 | + |
| 23 | +class Config(StrictModel): |
| 24 | + datasets: list[DatasetConfig] |
| 25 | + loss: Metrics |
| 26 | + |
| 27 | + |
| 28 | +Float = np.floating[Any] |
| 29 | +F32Array = np.ndarray[Any, np.dtype(np.float32)] |
| 30 | + |
| 31 | +LossFunction: TypeAlias = Callable[[Sequence[FArray], Sequence[FArray]], Float] |
| 32 | + |
| 33 | + |
| 34 | +def _point_inside(x: float, y: float, poly: list[tuple[float, float]]): |
| 35 | + """ |
| 36 | + ToDo: this is a very slow function. |
| 37 | + Can be optimized 100x after this module is working |
| 38 | + """ |
| 39 | + n = len(poly) |
| 40 | + inside = False |
| 41 | + p2x = 0.0 |
| 42 | + p2y = 0.0 |
| 43 | + xints = 0.0 |
| 44 | + p1x, p1y = poly[0] |
| 45 | + for i in range(n + 1): |
| 46 | + p2x, p2y = poly[i % n] |
| 47 | + if y > min(p1y, p2y): |
| 48 | + if y <= max(p1y, p2y): |
| 49 | + if x <= max(p1x, p2x): |
| 50 | + if p1y != p2y: |
| 51 | + xints = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x |
| 52 | + if p1x == p2x or x <= xints: |
| 53 | + inside = not inside |
| 54 | + p1x, p1y = p2x, p2y |
| 55 | + return inside |
| 56 | + |
| 57 | + |
| 58 | +class MonotonicityLoss: |
| 59 | + def __init__(self) -> None: |
| 60 | + full_xy_grid = self._create_grid( |
| 61 | + np.linspace(0, 1, 30, dtype=np.float32), |
| 62 | + np.linspace(0, 1, 30, dtype=np.float32), |
| 63 | + ) |
| 64 | + self.Y = Y = np.linspace(0.02, 1.0, 49) |
| 65 | + self._sensitivity_xyz = sensitivity_xyz = np.asarray( |
| 66 | + load_sensitivity("cie-1931-2")["xyz"], dtype=np.float64 |
| 67 | + ).T |
| 68 | + spectral_locus_xy = ( |
| 69 | + CIExyY().from_XYZ(XYZ(), sensitivity_xyz)[:, :2].tolist() |
| 70 | + ) |
| 71 | + xy_points = np.asarray( |
| 72 | + [ |
| 73 | + pt |
| 74 | + for pt in full_xy_grid |
| 75 | + if _point_inside(*pt.tolist(), spectral_locus_xy) |
| 76 | + ], |
| 77 | + dtype=np.float32, |
| 78 | + ) |
| 79 | + self.xy_points = xy_points |
| 80 | + xy_points_repeated = xy_points[None].repeat(Y.size, 0) |
| 81 | + n_points = len(xy_points) |
| 82 | + xyY = np.dstack( |
| 83 | + (xy_points_repeated, Y.repeat(n_points).reshape(-1, n_points, 1)) |
| 84 | + ) |
| 85 | + self.XYZ_diff = CIExyY().to_XYZ(XYZ(), xyY) |
| 86 | + self.xyY_diff = xyY |
| 87 | + |
| 88 | + def __call__(self, cs: CS) -> float: |
| 89 | + """ |
| 90 | + Formula 10 of PCS23-UCS |
| 91 | + """ |
| 92 | + res = cs.from_XYZ(XYZ(), self.XYZ_diff) |
| 93 | + minimum_diff = 0.02 |
| 94 | + L_plus = res[..., 0] |
| 95 | + θ = float( |
| 96 | + np.mean( |
| 97 | + np.maximum( |
| 98 | + 0, |
| 99 | + -np.diff(L_plus, axis=0) / np.diff(self.Y)[..., None] |
| 100 | + + minimum_diff, |
| 101 | + ) |
| 102 | + ) |
| 103 | + ) |
| 104 | + return θ |
| 105 | + |
| 106 | + @staticmethod |
| 107 | + def _create_grid(x: F32Array, y: F32Array) -> F32Array: |
| 108 | + return np.dstack(np.meshgrid(x, y, indexing="ij")).reshape(-1, 2) |
| 109 | + |
| 110 | + |
| 111 | +def train( |
| 112 | + model_cls: type[PCS23UCS], |
| 113 | + loaded_datasets: list[WeightedDataset], |
| 114 | + loss_function: LossFunction, |
| 115 | +) -> None: |
| 116 | + monotonicity_loss = MonotonicityLoss() |
| 117 | + |
| 118 | + def evaluate(x: list[float]) -> float: |
| 119 | + |
| 120 | + opt_model = model_cls( |
| 121 | + F_LA_or_D=0.0, illuminant_xyz=None, V=x[:39], H=x[39:] |
| 122 | + ) |
| 123 | + |
| 124 | + stress = 0.0 |
| 125 | + for loaded_dataset in loaded_datasets: |
| 126 | + ref: list[FArray] = [] |
| 127 | + exp: list[FArray] = [] |
| 128 | + for dataset in loaded_dataset.datasets: |
| 129 | + assert dataset.F is not None, dataset |
| 130 | + model = model_cls( |
| 131 | + F_LA_or_D=(dataset.F, dataset.L_A), |
| 132 | + illuminant_xyz=dataset.illuminant, |
| 133 | + V=x[:39], |
| 134 | + H=x[39:], |
| 135 | + ) |
| 136 | + |
| 137 | + model_coordinates = model.from_XYZ(XYZ(), dataset.xyz) |
| 138 | + a_colors = np.empty((len(dataset.pairs), 3), dtype=np.float64) |
| 139 | + b_colors = np.empty_like(a_colors) |
| 140 | + for idx, (a_idx, b_idx) in enumerate(dataset.pairs): |
| 141 | + a_colors[idx] = model_coordinates[a_idx] |
| 142 | + b_colors[idx] = model_coordinates[b_idx] |
| 143 | + exp_distance = np.linalg.norm( |
| 144 | + a_colors - b_colors, axis=1, ord=2 |
| 145 | + ) |
| 146 | + exp.append(exp_distance) |
| 147 | + ref.append(dataset.dv) |
| 148 | + stress += loss_function(ref, exp) * loaded_dataset.weight |
| 149 | + monotonicity = monotonicity_loss(opt_model) |
| 150 | + loss = stress + 0.1 * monotonicity |
| 151 | + print(f"{stress=}, {monotonicity=}, {loss=}") |
| 152 | + return loss |
| 153 | + |
| 154 | + x0 = np.random.rand(39 + 8) |
| 155 | + start = perf_counter() |
| 156 | + for i in range(60): |
| 157 | + res = minimize( |
| 158 | + fun=evaluate, |
| 159 | + method="Nelder-Mead", |
| 160 | + x0=x0, |
| 161 | + tol=1e-2, |
| 162 | + options={"maxiter": 150, "maxfev": 150}, |
| 163 | + ) |
| 164 | + print( |
| 165 | + f"step {i+1}: {res.x.tolist()} t={timedelta(seconds=perf_counter() - start)}" |
| 166 | + ) |
| 167 | + x0 = res.x |
| 168 | + print("minimization result", res) |
| 169 | + |
| 170 | + |
| 171 | +def main(): |
| 172 | + from argparse import ArgumentParser |
| 173 | + |
| 174 | + parser = ArgumentParser() |
| 175 | + parser.add_argument("config", type=Path) |
| 176 | + parser.add_argument("--update-schema", action="store_true") |
| 177 | + args = parser.parse_args() |
| 178 | + if args.update_schema: |
| 179 | + schema_path = Path(__file__).with_name("schema.json") |
| 180 | + schema_path.write_text( |
| 181 | + json5.dumps( |
| 182 | + Config.model_json_schema(), ensure_ascii=False, quote_keys=True |
| 183 | + ) |
| 184 | + ) |
| 185 | + return |
| 186 | + config = Config(**json5.loads(args.config.read_text())) |
| 187 | + loaded_datasets = [dataset.load() for dataset in config.datasets] |
| 188 | + loss = config.loss.load() |
| 189 | + model = PCS23UCS |
| 190 | + for loaded_dataset in loaded_datasets: |
| 191 | + print(f"* {loaded_dataset.name}") |
| 192 | + if len(loaded_dataset.datasets) != 1: |
| 193 | + for subset in loaded_dataset.datasets: |
| 194 | + print(f" - {subset.name}") |
| 195 | + |
| 196 | + print(f"model = {model}") |
| 197 | + print(f"loss = {loss}") |
| 198 | + train(model, loaded_datasets, loss) |
0 commit comments