|
| 1 | +# flake8: noqa: E501 |
| 2 | +from matplotlib.colors import LinearSegmentedColormap, ListedColormap |
| 3 | +import numpy as np |
| 4 | +from pal.problem.constrained_problem import ConstrainedProblem |
| 5 | +from pal.logic.lra_torch import lra_to_torch |
| 6 | +from scipy import ndimage |
| 7 | +import torch |
| 8 | + |
| 9 | + |
| 10 | +blue_red6 = LinearSegmentedColormap.from_list( |
| 11 | + "my_gradient", |
| 12 | + ( |
| 13 | + # Edit this gradient at https://eltos.github.io/gradient/#0:FFC94C-20:75EB6E-40.1:2B72DE-44:2900B3-50:FFFFFF-56:C70160-75.2:DE5422-100:471000 |
| 14 | + (0.000, (1.000, 0.788, 0.298)), |
| 15 | + (0.200, (0.459, 0.922, 0.431)), |
| 16 | + (0.401, (0.169, 0.447, 0.871)), |
| 17 | + (0.430, (0.161, 0.000, 0.702)), |
| 18 | + (0.500, (1.000, 1.000, 1.000)), |
| 19 | + (0.570, (0.780, 0.004, 0.376)), |
| 20 | + (0.752, (0.871, 0.329, 0.133)), |
| 21 | + (1.000, (0.278, 0.063, 0.000)), |
| 22 | + ), |
| 23 | + N=2048, |
| 24 | +) |
| 25 | + |
| 26 | +blue_red5 = LinearSegmentedColormap.from_list( |
| 27 | + "my_gradient", |
| 28 | + ( |
| 29 | + # Edit this gradient at https://eltos.github.io/gradient/#0:4CFFC8-35:2B72DE-47:2900B3-50:000000-53:C70160-65:DE5422-100:F9FC4A |
| 30 | + (0.000, (1.000, 0.788, 0.298)), |
| 31 | + (0.200, (0.459, 0.922, 0.431)), |
| 32 | + (0.401, (0.169, 0.447, 0.871)), |
| 33 | + (0.470, (0.161, 0.000, 0.702)), |
| 34 | + (0.500, (1.000, 1.000, 1.000)), |
| 35 | + (0.530, (0.780, 0.004, 0.376)), |
| 36 | + (0.752, (0.871, 0.329, 0.133)), |
| 37 | + (1.000, (0.278, 0.063, 0.000)), |
| 38 | + ), |
| 39 | + N=2048, |
| 40 | +) |
| 41 | + |
| 42 | +blue_red5_transparent = LinearSegmentedColormap.from_list( |
| 43 | + "my_gradient", |
| 44 | + ( |
| 45 | + # Edit this gradient at https://eltos.github.io/gradient/#0:4CFFC8-35:2B72DE-47:2900B3-50:000000-53:C70160-65:DE5422-100:F9FC4A |
| 46 | + (0.000, (1.000, 0.788, 0.298, 1.0)), |
| 47 | + (0.200, (0.459, 0.922, 0.431, 1.0)), |
| 48 | + (0.401, (0.169, 0.447, 0.871, 1.0)), |
| 49 | + (0.470, (0.161, 0.000, 0.702, 1.0)), |
| 50 | + (0.499, (0.161, 0.000, 0.702, 0.0)), |
| 51 | + (0.500, (1.000, 1.000, 1.000, 0.0)), |
| 52 | + (0.501, (0.780, 0.004, 0.376, 0.0)), |
| 53 | + (0.530, (0.780, 0.004, 0.376, 1.0)), |
| 54 | + (0.752, (0.871, 0.329, 0.133, 1.0)), |
| 55 | + (1.000, (0.278, 0.063, 0.000, 1.0)), |
| 56 | + ), |
| 57 | + N=2048, |
| 58 | +) |
| 59 | + |
| 60 | +blue_red6 = LinearSegmentedColormap.from_list( |
| 61 | + "my_gradient", |
| 62 | + ( |
| 63 | + # Edit this gradient at https://eltos.github.io/gradient/#0:4CFFC8-35:2B72DE-47:2900B3-50:000000-53:C70160-65:DE5422-100:F9FC4A |
| 64 | + (0.000, (1.000, 0.788, 0.298)), |
| 65 | + (0.200, (0.459, 0.922, 0.431)), |
| 66 | + (0.401, (0.169, 0.447, 0.871)), |
| 67 | + (0.480, (0.161, 0.000, 0.702)), |
| 68 | + (0.500, (1.000, 1.000, 1.000)), |
| 69 | + (0.520, (0.780, 0.004, 0.376)), |
| 70 | + (0.752, (0.871, 0.329, 0.133)), |
| 71 | + (1.000, (0.278, 0.063, 0.000)), |
| 72 | + ), |
| 73 | + N=2048, |
| 74 | +) |
| 75 | + |
| 76 | +blue_red6_transapernt = LinearSegmentedColormap.from_list( |
| 77 | + "my_gradient", |
| 78 | + ( |
| 79 | + # Edit this gradient at https://eltos.github.io/gradient/#0:4CFFC8-35:2B72DE-47:2900B3-50:000000-53:C70160-65:DE5422-100:F9FC4A |
| 80 | + (0.000, (1.000, 0.788, 0.298, 1.0)), |
| 81 | + (0.200, (0.459, 0.922, 0.431, 1.0)), |
| 82 | + (0.401, (0.169, 0.447, 0.871, 1.0)), |
| 83 | + (0.480, (0.161, 0.000, 0.702, 1.0)), |
| 84 | + (0.499, (0.161, 0.000, 0.702, 0.0)), |
| 85 | + (0.500, (1.000, 1.000, 1.000, 0.0)), |
| 86 | + (0.501, (0.780, 0.004, 0.376, 0.0)), |
| 87 | + (0.520, (0.780, 0.004, 0.376, 1.0)), |
| 88 | + (0.752, (0.871, 0.329, 0.133, 1.0)), |
| 89 | + (1.000, (0.278, 0.063, 0.000, 1.0)), |
| 90 | + ), |
| 91 | + N=2048, |
| 92 | +) |
| 93 | + |
| 94 | +blue_red8 = LinearSegmentedColormap.from_list( |
| 95 | + "my_gradient", |
| 96 | + ( |
| 97 | + # Edit this gradient at https://eltos.github.io/gradient/#0:FFC94C-20:75EB6E-35:2B72DE-41:2900B3-50:FFFFFF-59:C70160-75.2:DE5422-100:471000 |
| 98 | + (0.000, (1.000, 0.788, 0.298)), |
| 99 | + (0.200, (0.459, 0.922, 0.431)), |
| 100 | + (0.350, (0.169, 0.447, 0.871)), |
| 101 | + (0.400, (0.161, 0.000, 0.702)), |
| 102 | + (0.500, (1.000, 1.000, 1.000)), |
| 103 | + (0.600, (0.780, 0.004, 0.376)), |
| 104 | + (0.752, (0.871, 0.329, 0.133)), |
| 105 | + (1.000, (0.278, 0.063, 0.000)), |
| 106 | + ), |
| 107 | + N=2048, |
| 108 | +) |
| 109 | + |
| 110 | +blue_red8_transparent = LinearSegmentedColormap.from_list( |
| 111 | + "my_gradient", |
| 112 | + ( |
| 113 | + # Edit this gradient at https://eltos.github.io/gradient/#0:FFC94C-20:75EB6E-35:2B72DE-41:2900B3-50:FFFFFF-59:C70160-75.2:DE5422-100:471000 |
| 114 | + (0.000, (1.000, 0.788, 0.298, 1.0)), |
| 115 | + (0.200, (0.459, 0.922, 0.431, 1.0)), |
| 116 | + (0.350, (0.169, 0.447, 0.871, 1.0)), |
| 117 | + (0.400, (0.161, 0.000, 0.702, 1.0)), |
| 118 | + (0.499, (0.161, 0.000, 0.702, 0.0)), |
| 119 | + (0.500, (1.000, 1.000, 1.000, 0.0)), |
| 120 | + (0.501, (1.000, 1.000, 1.000, 0.0)), |
| 121 | + (0.600, (0.780, 0.004, 0.376, 1.0)), |
| 122 | + (0.752, (0.871, 0.329, 0.133, 1.0)), |
| 123 | + (1.000, (0.278, 0.063, 0.000, 1.0)), |
| 124 | + ), |
| 125 | + N=2048, |
| 126 | +) |
| 127 | + |
| 128 | +the_div_colormap = blue_red8 |
| 129 | + |
| 130 | + |
| 131 | +def prepare_meshgrid(problem: ConstrainedProblem, resolution): |
| 132 | + var_dict = problem.get_y_vars() |
| 133 | + constraints = problem.create_constraints() |
| 134 | + |
| 135 | + y_pos_dict = {i: name for name, i in var_dict.items()} |
| 136 | + # plot the fitted polynomial |
| 137 | + limits = constraints.get_global_limits() |
| 138 | + linspaces = [ |
| 139 | + torch.linspace(limits[y_pos_dict[i]][0], limits[y_pos_dict[i]][1], resolution) |
| 140 | + for i in range(len(var_dict)) |
| 141 | + ] |
| 142 | + |
| 143 | + img_extent = [ |
| 144 | + limits[y_pos_dict[0]][0], |
| 145 | + limits[y_pos_dict[0]][1], |
| 146 | + limits[y_pos_dict[1]][0], |
| 147 | + limits[y_pos_dict[1]][1], |
| 148 | + ] |
| 149 | + |
| 150 | + mesh_grid = torch.meshgrid(*linspaces) |
| 151 | + mesh = torch.stack(mesh_grid, dim=-1).reshape(-1, len(var_dict)) |
| 152 | + |
| 153 | + assert constraints.expression is not None |
| 154 | + pytorch_constraints = lra_to_torch( |
| 155 | + constraints.expression, var_dict=problem.get_y_vars() |
| 156 | + ) |
| 157 | + |
| 158 | + valid = pytorch_constraints(mesh).detach().reshape(resolution, resolution) |
| 159 | + |
| 160 | + return mesh, mesh_grid, img_extent, valid |
| 161 | + |
| 162 | + |
| 163 | +def prepare_border(problem: ConstrainedProblem, meshgrid): |
| 164 | + constraints_generic = problem.create_constraints() |
| 165 | + assert constraints_generic.expression is not None |
| 166 | + pytorch_constraints = lra_to_torch( |
| 167 | + constraints_generic.expression, var_dict=problem.get_y_vars() |
| 168 | + ) |
| 169 | + |
| 170 | + var_dict = problem.get_y_vars() |
| 171 | + |
| 172 | + # check if meshgrid element is torch tensor |
| 173 | + if not isinstance(meshgrid[0], torch.Tensor): |
| 174 | + meshgrid = [torch.tensor(m) for m in meshgrid] |
| 175 | + ys = torch.stack(meshgrid, dim=-1).reshape(-1, len(var_dict)) |
| 176 | + valid_ys: torch.Tensor = pytorch_constraints(ys).detach() |
| 177 | + |
| 178 | + shape_image = valid_ys.reshape(meshgrid[0].shape[0], meshgrid[1].shape[0]) |
| 179 | + |
| 180 | + # Convert the shape image to grayscale if it's a color image |
| 181 | + if shape_image.ndim == 3: |
| 182 | + shape_image = np.mean(shape_image, axis=2) |
| 183 | + |
| 184 | + # Apply a binary threshold |
| 185 | + binary = shape_image > 0.5 |
| 186 | + |
| 187 | + # Find the edges using binary erosion |
| 188 | + eroded_image = ndimage.binary_erosion(~binary) |
| 189 | + border = binary & ~eroded_image |
| 190 | + |
| 191 | + return border |
| 192 | + |
| 193 | + |
| 194 | +def plot_in_paper_diverging_theme( |
| 195 | + the_ax, |
| 196 | + data, |
| 197 | + valid, |
| 198 | + border, |
| 199 | + img_extent, |
| 200 | + norm=None, |
| 201 | + plot_max_value=None, |
| 202 | + the_cmap=the_div_colormap, |
| 203 | + with_contour=False, |
| 204 | + with_valid=False, |
| 205 | + contour_lw=0.5, |
| 206 | + spine_lw=1.5, |
| 207 | + hatched=False, |
| 208 | + plot_borders_below=False, |
| 209 | +): |
| 210 | + # plt.figure() |
| 211 | + # f = plt.gcf() |
| 212 | + # the_ax = plt.gca() |
| 213 | + if plot_max_value is None: |
| 214 | + the_max = data.max() |
| 215 | + else: |
| 216 | + the_max = plot_max_value |
| 217 | + |
| 218 | + if norm is not None: |
| 219 | + the_max = norm.vmax |
| 220 | + |
| 221 | + if with_contour and plot_borders_below: |
| 222 | + if hatched: |
| 223 | + contour = the_ax.contourf( |
| 224 | + valid.T.to(float), |
| 225 | + levels=[-0.5, 0.5], |
| 226 | + colors="None", |
| 227 | + hatches=["//", None], |
| 228 | + ls=None, |
| 229 | + extent=img_extent, |
| 230 | + origin="upper", |
| 231 | + zorder=0, |
| 232 | + ) |
| 233 | + for collection in contour.collections: |
| 234 | + collection.set_edgecolor("pink") |
| 235 | + collection.set_linewidth(0) # Remove contour lines |
| 236 | + the_ax.contour( |
| 237 | + border, |
| 238 | + colors="pink", |
| 239 | + linewidths=contour_lw, |
| 240 | + extent=img_extent, |
| 241 | + origin="upper", |
| 242 | + zorder=0, |
| 243 | + ) |
| 244 | + the_data = data.copy() |
| 245 | + |
| 246 | + # negative values are valid, positve values are invalid |
| 247 | + # I know this is stupid, but it is what it is |
| 248 | + the_data[valid.T] = (-1) * the_data[valid.T] |
| 249 | + # the_data[~valid] = (-1) * the_data[~valid] |
| 250 | + |
| 251 | + the_ax.imshow( |
| 252 | + the_data, |
| 253 | + extent=img_extent, |
| 254 | + origin="upper", |
| 255 | + cmap=the_cmap, |
| 256 | + vmin=-the_max, |
| 257 | + vmax=the_max, |
| 258 | + interpolation="nearest", |
| 259 | + ) |
| 260 | + |
| 261 | + if with_valid: |
| 262 | + # plot valid tensor with true being pink and false transparent |
| 263 | + white_cmap = ListedColormap(["white", "none"]) |
| 264 | + |
| 265 | + the_ax.imshow(valid.T, extent=img_extent, origin="upper", cmap=white_cmap) |
| 266 | + |
| 267 | + if with_contour and not plot_borders_below: |
| 268 | + if hatched: |
| 269 | + contour = the_ax.contourf( |
| 270 | + valid.T.to(float), |
| 271 | + levels=[-0.5, 0.5], |
| 272 | + colors="None", |
| 273 | + hatches=["//", None], |
| 274 | + ls=None, |
| 275 | + extent=img_extent, |
| 276 | + origin="upper", |
| 277 | + ) |
| 278 | + for collection in contour.collections: |
| 279 | + collection.set_edgecolor("pink") |
| 280 | + collection.set_linewidth(0) # Remove contour lines |
| 281 | + the_ax.contour( |
| 282 | + border, |
| 283 | + colors="pink", |
| 284 | + linewidths=contour_lw, |
| 285 | + extent=img_extent, |
| 286 | + origin="upper", |
| 287 | + ) |
| 288 | + |
| 289 | + # black frame, nothing else |
| 290 | + the_ax.set_xticks([]) |
| 291 | + the_ax.set_yticks([]) |
| 292 | + the_ax.set_xticklabels([]) |
| 293 | + the_ax.set_yticklabels([]) |
| 294 | + |
| 295 | + # frame a bit thicker |
| 296 | + for spine in the_ax.spines.values(): |
| 297 | + spine.set_linewidth(spine_lw) |
| 298 | + |
| 299 | + |
| 300 | +plot_in_paper_theme = plot_in_paper_diverging_theme |
0 commit comments