Skip to content

Commit 2750d0f

Browse files
committed
nearly there!
1 parent fd426f7 commit 2750d0f

File tree

4 files changed

+619
-80
lines changed

4 files changed

+619
-80
lines changed

analysis/plot_helper.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
# flake8: noqa: E501
2+
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
3+
import numpy as np
4+
from pal.problem.constrained_problem import ConstrainedProblem
5+
from pal.logic.lra_torch import lra_to_torch
6+
from scipy import ndimage
7+
import torch
8+
9+
10+
blue_red6 = LinearSegmentedColormap.from_list(
11+
"my_gradient",
12+
(
13+
# Edit this gradient at https://eltos.github.io/gradient/#0:FFC94C-20:75EB6E-40.1:2B72DE-44:2900B3-50:FFFFFF-56:C70160-75.2:DE5422-100:471000
14+
(0.000, (1.000, 0.788, 0.298)),
15+
(0.200, (0.459, 0.922, 0.431)),
16+
(0.401, (0.169, 0.447, 0.871)),
17+
(0.430, (0.161, 0.000, 0.702)),
18+
(0.500, (1.000, 1.000, 1.000)),
19+
(0.570, (0.780, 0.004, 0.376)),
20+
(0.752, (0.871, 0.329, 0.133)),
21+
(1.000, (0.278, 0.063, 0.000)),
22+
),
23+
N=2048,
24+
)
25+
26+
blue_red5 = LinearSegmentedColormap.from_list(
27+
"my_gradient",
28+
(
29+
# Edit this gradient at https://eltos.github.io/gradient/#0:4CFFC8-35:2B72DE-47:2900B3-50:000000-53:C70160-65:DE5422-100:F9FC4A
30+
(0.000, (1.000, 0.788, 0.298)),
31+
(0.200, (0.459, 0.922, 0.431)),
32+
(0.401, (0.169, 0.447, 0.871)),
33+
(0.470, (0.161, 0.000, 0.702)),
34+
(0.500, (1.000, 1.000, 1.000)),
35+
(0.530, (0.780, 0.004, 0.376)),
36+
(0.752, (0.871, 0.329, 0.133)),
37+
(1.000, (0.278, 0.063, 0.000)),
38+
),
39+
N=2048,
40+
)
41+
42+
blue_red5_transparent = LinearSegmentedColormap.from_list(
43+
"my_gradient",
44+
(
45+
# Edit this gradient at https://eltos.github.io/gradient/#0:4CFFC8-35:2B72DE-47:2900B3-50:000000-53:C70160-65:DE5422-100:F9FC4A
46+
(0.000, (1.000, 0.788, 0.298, 1.0)),
47+
(0.200, (0.459, 0.922, 0.431, 1.0)),
48+
(0.401, (0.169, 0.447, 0.871, 1.0)),
49+
(0.470, (0.161, 0.000, 0.702, 1.0)),
50+
(0.499, (0.161, 0.000, 0.702, 0.0)),
51+
(0.500, (1.000, 1.000, 1.000, 0.0)),
52+
(0.501, (0.780, 0.004, 0.376, 0.0)),
53+
(0.530, (0.780, 0.004, 0.376, 1.0)),
54+
(0.752, (0.871, 0.329, 0.133, 1.0)),
55+
(1.000, (0.278, 0.063, 0.000, 1.0)),
56+
),
57+
N=2048,
58+
)
59+
60+
blue_red6 = LinearSegmentedColormap.from_list(
61+
"my_gradient",
62+
(
63+
# Edit this gradient at https://eltos.github.io/gradient/#0:4CFFC8-35:2B72DE-47:2900B3-50:000000-53:C70160-65:DE5422-100:F9FC4A
64+
(0.000, (1.000, 0.788, 0.298)),
65+
(0.200, (0.459, 0.922, 0.431)),
66+
(0.401, (0.169, 0.447, 0.871)),
67+
(0.480, (0.161, 0.000, 0.702)),
68+
(0.500, (1.000, 1.000, 1.000)),
69+
(0.520, (0.780, 0.004, 0.376)),
70+
(0.752, (0.871, 0.329, 0.133)),
71+
(1.000, (0.278, 0.063, 0.000)),
72+
),
73+
N=2048,
74+
)
75+
76+
blue_red6_transapernt = LinearSegmentedColormap.from_list(
77+
"my_gradient",
78+
(
79+
# Edit this gradient at https://eltos.github.io/gradient/#0:4CFFC8-35:2B72DE-47:2900B3-50:000000-53:C70160-65:DE5422-100:F9FC4A
80+
(0.000, (1.000, 0.788, 0.298, 1.0)),
81+
(0.200, (0.459, 0.922, 0.431, 1.0)),
82+
(0.401, (0.169, 0.447, 0.871, 1.0)),
83+
(0.480, (0.161, 0.000, 0.702, 1.0)),
84+
(0.499, (0.161, 0.000, 0.702, 0.0)),
85+
(0.500, (1.000, 1.000, 1.000, 0.0)),
86+
(0.501, (0.780, 0.004, 0.376, 0.0)),
87+
(0.520, (0.780, 0.004, 0.376, 1.0)),
88+
(0.752, (0.871, 0.329, 0.133, 1.0)),
89+
(1.000, (0.278, 0.063, 0.000, 1.0)),
90+
),
91+
N=2048,
92+
)
93+
94+
blue_red8 = LinearSegmentedColormap.from_list(
95+
"my_gradient",
96+
(
97+
# Edit this gradient at https://eltos.github.io/gradient/#0:FFC94C-20:75EB6E-35:2B72DE-41:2900B3-50:FFFFFF-59:C70160-75.2:DE5422-100:471000
98+
(0.000, (1.000, 0.788, 0.298)),
99+
(0.200, (0.459, 0.922, 0.431)),
100+
(0.350, (0.169, 0.447, 0.871)),
101+
(0.400, (0.161, 0.000, 0.702)),
102+
(0.500, (1.000, 1.000, 1.000)),
103+
(0.600, (0.780, 0.004, 0.376)),
104+
(0.752, (0.871, 0.329, 0.133)),
105+
(1.000, (0.278, 0.063, 0.000)),
106+
),
107+
N=2048,
108+
)
109+
110+
blue_red8_transparent = LinearSegmentedColormap.from_list(
111+
"my_gradient",
112+
(
113+
# Edit this gradient at https://eltos.github.io/gradient/#0:FFC94C-20:75EB6E-35:2B72DE-41:2900B3-50:FFFFFF-59:C70160-75.2:DE5422-100:471000
114+
(0.000, (1.000, 0.788, 0.298, 1.0)),
115+
(0.200, (0.459, 0.922, 0.431, 1.0)),
116+
(0.350, (0.169, 0.447, 0.871, 1.0)),
117+
(0.400, (0.161, 0.000, 0.702, 1.0)),
118+
(0.499, (0.161, 0.000, 0.702, 0.0)),
119+
(0.500, (1.000, 1.000, 1.000, 0.0)),
120+
(0.501, (1.000, 1.000, 1.000, 0.0)),
121+
(0.600, (0.780, 0.004, 0.376, 1.0)),
122+
(0.752, (0.871, 0.329, 0.133, 1.0)),
123+
(1.000, (0.278, 0.063, 0.000, 1.0)),
124+
),
125+
N=2048,
126+
)
127+
128+
the_div_colormap = blue_red8
129+
130+
131+
def prepare_meshgrid(problem: ConstrainedProblem, resolution):
132+
var_dict = problem.get_y_vars()
133+
constraints = problem.create_constraints()
134+
135+
y_pos_dict = {i: name for name, i in var_dict.items()}
136+
# plot the fitted polynomial
137+
limits = constraints.get_global_limits()
138+
linspaces = [
139+
torch.linspace(limits[y_pos_dict[i]][0], limits[y_pos_dict[i]][1], resolution)
140+
for i in range(len(var_dict))
141+
]
142+
143+
img_extent = [
144+
limits[y_pos_dict[0]][0],
145+
limits[y_pos_dict[0]][1],
146+
limits[y_pos_dict[1]][0],
147+
limits[y_pos_dict[1]][1],
148+
]
149+
150+
mesh_grid = torch.meshgrid(*linspaces)
151+
mesh = torch.stack(mesh_grid, dim=-1).reshape(-1, len(var_dict))
152+
153+
assert constraints.expression is not None
154+
pytorch_constraints = lra_to_torch(
155+
constraints.expression, var_dict=problem.get_y_vars()
156+
)
157+
158+
valid = pytorch_constraints(mesh).detach().reshape(resolution, resolution)
159+
160+
return mesh, mesh_grid, img_extent, valid
161+
162+
163+
def prepare_border(problem: ConstrainedProblem, meshgrid):
164+
constraints_generic = problem.create_constraints()
165+
assert constraints_generic.expression is not None
166+
pytorch_constraints = lra_to_torch(
167+
constraints_generic.expression, var_dict=problem.get_y_vars()
168+
)
169+
170+
var_dict = problem.get_y_vars()
171+
172+
# check if meshgrid element is torch tensor
173+
if not isinstance(meshgrid[0], torch.Tensor):
174+
meshgrid = [torch.tensor(m) for m in meshgrid]
175+
ys = torch.stack(meshgrid, dim=-1).reshape(-1, len(var_dict))
176+
valid_ys: torch.Tensor = pytorch_constraints(ys).detach()
177+
178+
shape_image = valid_ys.reshape(meshgrid[0].shape[0], meshgrid[1].shape[0])
179+
180+
# Convert the shape image to grayscale if it's a color image
181+
if shape_image.ndim == 3:
182+
shape_image = np.mean(shape_image, axis=2)
183+
184+
# Apply a binary threshold
185+
binary = shape_image > 0.5
186+
187+
# Find the edges using binary erosion
188+
eroded_image = ndimage.binary_erosion(~binary)
189+
border = binary & ~eroded_image
190+
191+
return border
192+
193+
194+
def plot_in_paper_diverging_theme(
195+
the_ax,
196+
data,
197+
valid,
198+
border,
199+
img_extent,
200+
norm=None,
201+
plot_max_value=None,
202+
the_cmap=the_div_colormap,
203+
with_contour=False,
204+
with_valid=False,
205+
contour_lw=0.5,
206+
spine_lw=1.5,
207+
hatched=False,
208+
plot_borders_below=False,
209+
):
210+
# plt.figure()
211+
# f = plt.gcf()
212+
# the_ax = plt.gca()
213+
if plot_max_value is None:
214+
the_max = data.max()
215+
else:
216+
the_max = plot_max_value
217+
218+
if norm is not None:
219+
the_max = norm.vmax
220+
221+
if with_contour and plot_borders_below:
222+
if hatched:
223+
contour = the_ax.contourf(
224+
valid.T.to(float),
225+
levels=[-0.5, 0.5],
226+
colors="None",
227+
hatches=["//", None],
228+
ls=None,
229+
extent=img_extent,
230+
origin="upper",
231+
zorder=0,
232+
)
233+
for collection in contour.collections:
234+
collection.set_edgecolor("pink")
235+
collection.set_linewidth(0) # Remove contour lines
236+
the_ax.contour(
237+
border,
238+
colors="pink",
239+
linewidths=contour_lw,
240+
extent=img_extent,
241+
origin="upper",
242+
zorder=0,
243+
)
244+
the_data = data.copy()
245+
246+
# negative values are valid, positve values are invalid
247+
# I know this is stupid, but it is what it is
248+
the_data[valid.T] = (-1) * the_data[valid.T]
249+
# the_data[~valid] = (-1) * the_data[~valid]
250+
251+
the_ax.imshow(
252+
the_data,
253+
extent=img_extent,
254+
origin="upper",
255+
cmap=the_cmap,
256+
vmin=-the_max,
257+
vmax=the_max,
258+
interpolation="nearest",
259+
)
260+
261+
if with_valid:
262+
# plot valid tensor with true being pink and false transparent
263+
white_cmap = ListedColormap(["white", "none"])
264+
265+
the_ax.imshow(valid.T, extent=img_extent, origin="upper", cmap=white_cmap)
266+
267+
if with_contour and not plot_borders_below:
268+
if hatched:
269+
contour = the_ax.contourf(
270+
valid.T.to(float),
271+
levels=[-0.5, 0.5],
272+
colors="None",
273+
hatches=["//", None],
274+
ls=None,
275+
extent=img_extent,
276+
origin="upper",
277+
)
278+
for collection in contour.collections:
279+
collection.set_edgecolor("pink")
280+
collection.set_linewidth(0) # Remove contour lines
281+
the_ax.contour(
282+
border,
283+
colors="pink",
284+
linewidths=contour_lw,
285+
extent=img_extent,
286+
origin="upper",
287+
)
288+
289+
# black frame, nothing else
290+
the_ax.set_xticks([])
291+
the_ax.set_yticks([])
292+
the_ax.set_xticklabels([])
293+
the_ax.set_yticklabels([])
294+
295+
# frame a bit thicker
296+
for spine in the_ax.spines.values():
297+
spine.set_linewidth(spine_lw)
298+
299+
300+
plot_in_paper_theme = plot_in_paper_diverging_theme

0 commit comments

Comments
 (0)