Skip to content

Commit 6734e7d

Browse files
committed
feat: optimization script for PCS23UCS
rough around the edges, but a working version. please create issues if you want to control more parameters
1 parent beb481a commit 6734e7d

5 files changed

Lines changed: 235 additions & 0 deletions

File tree

.vscode/settings.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
"vsl_ial/eval/*.json"
66
],
77
"url": "./vsl_ial/eval/schema.json"
8+
},
9+
{
10+
"fileMatch": [
11+
"vsl_ial/optimization/*.json"
12+
],
13+
"url": "./vsl_ial/optimization/schema.json"
814
}
915
],
1016
"cSpell.words": [

vsl_ial/optimization/__init__.py

Whitespace-only changes.

vsl_ial/optimization/__main__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .optimization import main
2+
3+
main()
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
from __future__ import annotations
2+
from typing import Callable, TypeAlias, Any
3+
from time import perf_counter
4+
from pathlib import Path
5+
from collections.abc import Sequence
6+
from datetime import timedelta
7+
8+
import json5
9+
import numpy as np
10+
from scipy.optimize import minimize
11+
12+
from vsl_ial import FArray
13+
from vsl_ial.cs.pcs23 import PCS23UCS, CS
14+
from vsl_ial.datasets.sensitivities import load as load_sensitivity
15+
from vsl_ial.cs.xyz import XYZ
16+
from vsl_ial.cs.ciexyy import CIExyY
17+
18+
from ..eval._base import StrictModel
19+
from ..eval.dataset import DatasetConfig, WeightedDataset
20+
from ..eval.metrics import Metrics
21+
22+
23+
class Config(StrictModel):
24+
datasets: list[DatasetConfig]
25+
loss: Metrics
26+
27+
28+
Float = np.floating[Any]
29+
F32Array = np.ndarray[Any, np.dtype(np.float32)]
30+
31+
LossFunction: TypeAlias = Callable[[Sequence[FArray], Sequence[FArray]], Float]
32+
33+
34+
def _point_inside(x: float, y: float, poly: list[tuple[float, float]]):
35+
"""
36+
ToDo: this is a very slow function.
37+
Can be optimized 100x after this module is working
38+
"""
39+
n = len(poly)
40+
inside = False
41+
p2x = 0.0
42+
p2y = 0.0
43+
xints = 0.0
44+
p1x, p1y = poly[0]
45+
for i in range(n + 1):
46+
p2x, p2y = poly[i % n]
47+
if y > min(p1y, p2y):
48+
if y <= max(p1y, p2y):
49+
if x <= max(p1x, p2x):
50+
if p1y != p2y:
51+
xints = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
52+
if p1x == p2x or x <= xints:
53+
inside = not inside
54+
p1x, p1y = p2x, p2y
55+
return inside
56+
57+
58+
class MonotonicityLoss:
59+
def __init__(self) -> None:
60+
full_xy_grid = self._create_grid(
61+
np.linspace(0, 1, 30, dtype=np.float32),
62+
np.linspace(0, 1, 30, dtype=np.float32),
63+
)
64+
self.Y = Y = np.linspace(0.02, 1.0, 49)
65+
self._sensitivity_xyz = sensitivity_xyz = np.asarray(
66+
load_sensitivity("cie-1931-2")["xyz"], dtype=np.float64
67+
).T
68+
spectral_locus_xy = (
69+
CIExyY().from_XYZ(XYZ(), sensitivity_xyz)[:, :2].tolist()
70+
)
71+
xy_points = np.asarray(
72+
[
73+
pt
74+
for pt in full_xy_grid
75+
if _point_inside(*pt.tolist(), spectral_locus_xy)
76+
],
77+
dtype=np.float32,
78+
)
79+
self.xy_points = xy_points
80+
xy_points_repeated = xy_points[None].repeat(Y.size, 0)
81+
n_points = len(xy_points)
82+
xyY = np.dstack(
83+
(xy_points_repeated, Y.repeat(n_points).reshape(-1, n_points, 1))
84+
)
85+
self.XYZ_diff = CIExyY().to_XYZ(XYZ(), xyY)
86+
self.xyY_diff = xyY
87+
88+
def __call__(self, cs: CS) -> float:
89+
"""
90+
Formula 10 of PCS23-UCS
91+
"""
92+
res = cs.from_XYZ(XYZ(), self.XYZ_diff)
93+
minimum_diff = 0.02
94+
L_plus = res[..., 0]
95+
θ = float(
96+
np.mean(
97+
np.maximum(
98+
0,
99+
-np.diff(L_plus, axis=0) / np.diff(self.Y)[..., None]
100+
+ minimum_diff,
101+
)
102+
)
103+
)
104+
return θ
105+
106+
@staticmethod
107+
def _create_grid(x: F32Array, y: F32Array) -> F32Array:
108+
return np.dstack(np.meshgrid(x, y, indexing="ij")).reshape(-1, 2)
109+
110+
111+
def train(
112+
model_cls: type[PCS23UCS],
113+
loaded_datasets: list[WeightedDataset],
114+
loss_function: LossFunction,
115+
) -> None:
116+
monotonicity_loss = MonotonicityLoss()
117+
118+
def evaluate(x: list[float]) -> float:
119+
120+
opt_model = model_cls(
121+
F_LA_or_D=0.0, illuminant_xyz=None, V=x[:39], H=x[39:]
122+
)
123+
124+
stress = 0.0
125+
for loaded_dataset in loaded_datasets:
126+
ref: list[FArray] = []
127+
exp: list[FArray] = []
128+
for dataset in loaded_dataset.datasets:
129+
assert dataset.F is not None, dataset
130+
model = model_cls(
131+
F_LA_or_D=(dataset.F, dataset.L_A),
132+
illuminant_xyz=dataset.illuminant,
133+
V=x[:39],
134+
H=x[39:],
135+
)
136+
137+
model_coordinates = model.from_XYZ(XYZ(), dataset.xyz)
138+
a_colors = np.empty((len(dataset.pairs), 3), dtype=np.float64)
139+
b_colors = np.empty_like(a_colors)
140+
for idx, (a_idx, b_idx) in enumerate(dataset.pairs):
141+
a_colors[idx] = model_coordinates[a_idx]
142+
b_colors[idx] = model_coordinates[b_idx]
143+
exp_distance = np.linalg.norm(
144+
a_colors - b_colors, axis=1, ord=2
145+
)
146+
exp.append(exp_distance)
147+
ref.append(dataset.dv)
148+
stress += loss_function(ref, exp) * loaded_dataset.weight
149+
monotonicity = monotonicity_loss(opt_model)
150+
loss = stress + 0.1 * monotonicity
151+
print(f"{stress=}, {monotonicity=}, {loss=}")
152+
return loss
153+
154+
x0 = np.random.rand(39 + 8)
155+
start = perf_counter()
156+
for i in range(60):
157+
res = minimize(
158+
fun=evaluate,
159+
method="Nelder-Mead",
160+
x0=x0,
161+
tol=1e-2,
162+
options={"maxiter": 150, "maxfev": 150},
163+
)
164+
print(
165+
f"step {i+1}: {res.x.tolist()} t={timedelta(seconds=perf_counter() - start)}"
166+
)
167+
x0 = res.x
168+
print("minimization result", res)
169+
170+
171+
def main():
172+
from argparse import ArgumentParser
173+
174+
parser = ArgumentParser()
175+
parser.add_argument("config", type=Path)
176+
parser.add_argument("--update-schema", action="store_true")
177+
args = parser.parse_args()
178+
if args.update_schema:
179+
schema_path = Path(__file__).with_name("schema.json")
180+
schema_path.write_text(
181+
json5.dumps(
182+
Config.model_json_schema(), ensure_ascii=False, quote_keys=True
183+
)
184+
)
185+
return
186+
config = Config(**json5.loads(args.config.read_text()))
187+
loaded_datasets = [dataset.load() for dataset in config.datasets]
188+
loss = config.loss.load()
189+
model = PCS23UCS
190+
for loaded_dataset in loaded_datasets:
191+
print(f"* {loaded_dataset.name}")
192+
if len(loaded_dataset.datasets) != 1:
193+
for subset in loaded_dataset.datasets:
194+
print(f" - {subset.name}")
195+
196+
print(f"model = {model}")
197+
print(f"loss = {loss}")
198+
train(model, loaded_datasets, loss)

vsl_ial/optimization/pcs23ucs.json

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
{
2+
"datasets": [
3+
{
4+
"name": "combvd",
5+
"weight": 1.0
6+
},
7+
{
8+
"name": "munsell",
9+
"weight": 0.8,
10+
"version": "3.1.0",
11+
"where": [
12+
{
13+
"group": "HC",
14+
"match": {
15+
"chroma": {
16+
"min": 18,
17+
"max": 32
18+
}
19+
}
20+
}
21+
],
22+
"min_subset_size": 3
23+
}
24+
],
25+
"loss": {
26+
"name": "group_stress"
27+
}
28+
}

0 commit comments

Comments
 (0)