|
| 1 | +import numpy as np |
| 2 | +import pandas as pd |
| 3 | +from matplotlib import pyplot as plt |
| 4 | + |
| 5 | +from diffusion_for_multi_scale_molecular_dynamics import TOP_DIR |
| 6 | +from diffusion_for_multi_scale_molecular_dynamics.analysis import ( |
| 7 | + PLEASANT_FIG_SIZE, PLOT_STYLE_PATH) |
| 8 | + |
| 9 | +plt.style.use(PLOT_STYLE_PATH) |
| 10 | + |
| 11 | +experiment_dir = TOP_DIR / "experiments/active_learning/pretraining_flare/" |
| 12 | +output_dir = experiment_dir / "validation_performance" |
| 13 | + |
| 14 | +images_dir = experiment_dir / "images" |
| 15 | +images_dir.mkdir(parents=True, exist_ok=True) |
| 16 | + |
| 17 | + |
| 18 | +if __name__ == "__main__": |
| 19 | + df = pd.read_pickle(output_dir / "validation_set_performance.pkl").sort_values( |
| 20 | + ["sigma", "number_of_structures"] |
| 21 | + ) |
| 22 | + |
| 23 | + fig = plt.figure(figsize=PLEASANT_FIG_SIZE) |
| 24 | + fig.suptitle("Training FLARE from scratch on Si 2x2x2.") |
| 25 | + ax1 = fig.add_subplot(121) |
| 26 | + ax2 = fig.add_subplot(122) |
| 27 | + |
| 28 | + ax1.set_title("Energy Errors") |
| 29 | + ax1.set_xlabel("Number of Training Structures") |
| 30 | + ax1.set_ylabel("Validation Energy RMSE (eV)") |
| 31 | + |
| 32 | + ax2.set_title("Force Errors") |
| 33 | + ax2.set_xlabel("Number of Training Structures") |
| 34 | + ax2.set_ylabel(r"Validation Mean Force RMSE (eV / $\AA$)") |
| 35 | + |
| 36 | + for sigma, group_df in df.groupby(by="sigma"): |
| 37 | + number_of_training_structures = group_df["number_of_structures"].values |
| 38 | + validation_energy_rmse = group_df["flare_energy_rmse"].values |
| 39 | + validation_mean_force_rmse = group_df["flare_mean_force_rmse"] |
| 40 | + |
| 41 | + mapped_validation_energy_rmse = group_df["mapped_flare_energy_rmse"].values |
| 42 | + mapped_validation_mean_force_rmse = group_df[ |
| 43 | + "mapped_flare_mean_force_rmse" |
| 44 | + ].values |
| 45 | + |
| 46 | + (lines,) = ax1.semilogy( |
| 47 | + number_of_training_structures, |
| 48 | + validation_energy_rmse, |
| 49 | + "-", |
| 50 | + label=rf"$\sigma$ = {sigma}", |
| 51 | + ) |
| 52 | + color = lines.get_color() |
| 53 | + |
| 54 | + ax1.semilogy( |
| 55 | + number_of_training_structures, |
| 56 | + mapped_validation_energy_rmse, |
| 57 | + "*", |
| 58 | + ms=10, |
| 59 | + color=color, |
| 60 | + label=rf"$\sigma$ = {sigma} (MAPPED)", |
| 61 | + ) |
| 62 | + |
| 63 | + ax2.semilogy( |
| 64 | + number_of_training_structures, validation_mean_force_rmse, "-", color=color |
| 65 | + ) |
| 66 | + |
| 67 | + ax2.semilogy( |
| 68 | + number_of_training_structures, |
| 69 | + mapped_validation_mean_force_rmse, |
| 70 | + "*", |
| 71 | + ms=10, |
| 72 | + color=color, |
| 73 | + ) |
| 74 | + |
| 75 | + xmin = df["number_of_structures"].min() - 1 |
| 76 | + xmax = df["number_of_structures"].max() + 1 |
| 77 | + ax2.hlines(0.01, xmin, xmax, color="green", label="ARTn Force Threshold") |
| 78 | + |
| 79 | + for ax in [ax1, ax2]: |
| 80 | + ax.set_ylim(ymin=0) |
| 81 | + ax.set_xlim(xmin, xmax) |
| 82 | + |
| 83 | + handles, labels = ax1.get_legend_handles_labels() |
| 84 | + |
| 85 | + legend = fig.legend( |
| 86 | + handles, |
| 87 | + labels, |
| 88 | + loc="lower center", # Place the legend's center below the plots |
| 89 | + bbox_to_anchor=( |
| 90 | + 0.5, |
| 91 | + -0.025, |
| 92 | + ), # x=0.5 (center), y=-0.05 (just below the figure bottom) |
| 93 | + ncol=4, # Number of columns for legend entries |
| 94 | + fancybox=True, |
| 95 | + shadow=True, |
| 96 | + borderaxespad=1.0, |
| 97 | + ) |
| 98 | + |
| 99 | + plt.subplots_adjust(bottom=0.25) |
| 100 | + fig.savefig(images_dir / "errors_vs_number_of_structures.png") |
| 101 | + |
| 102 | + # ================================================================================ |
| 103 | + |
| 104 | + sub_df = df[df["number_of_structures"] == 10] |
| 105 | + |
| 106 | + common_params = dict(linestyle="None", marker="o", markersize=5, mew=0, alpha=0.5) |
| 107 | + |
| 108 | + figsize = (1.5 * PLEASANT_FIG_SIZE[0], PLEASANT_FIG_SIZE[1]) |
| 109 | + fig = plt.figure(figsize=figsize) |
| 110 | + fig.suptitle( |
| 111 | + "Training FLARE from scratch on Si 2x2x2.\n Models trained with 10 structures." |
| 112 | + ) |
| 113 | + ax1 = fig.add_subplot(121) |
| 114 | + ax2 = fig.add_subplot(122) |
| 115 | + |
| 116 | + ax1.set_xlabel("FLARE Uncertainty") |
| 117 | + ax1.set_ylabel("FLARE Force Error") |
| 118 | + ax2.set_xlabel("MAPPED FLARE Uncertainty") |
| 119 | + ax2.set_ylabel("MAPPED FLARE Force Error") |
| 120 | + |
| 121 | + xmax1, ymax1, xmax2, ymax2 = 0, 0, 0, 0 |
| 122 | + |
| 123 | + for sigma, sigma_df in sub_df.groupby(by="sigma"): |
| 124 | + |
| 125 | + # ================================================================================ |
| 126 | + list_flare_force_errors = sigma_df["flare_all_force_errors"].values[0] |
| 127 | + list_flare_uncertainties = sigma_df["flare_all_uncertainties"].values[0] |
| 128 | + |
| 129 | + coeffs = np.polyfit(list_flare_uncertainties, list_flare_force_errors, deg=1) |
| 130 | + x = np.array([0.0, list_flare_uncertainties.max()]) |
| 131 | + y = np.poly1d(coeffs)(x) |
| 132 | + |
| 133 | + (line,) = ax1.plot( |
| 134 | + list_flare_uncertainties, |
| 135 | + list_flare_force_errors, |
| 136 | + **common_params, |
| 137 | + label=rf"$\sigma$ = {sigma}, slope = {coeffs[0]:4.1f}", |
| 138 | + ) |
| 139 | + |
| 140 | + ax1.plot(x, y, "-", c=line.get_color(), lw=4, label="__nolabel__") |
| 141 | + xmax1 = np.max([xmax1, list_flare_uncertainties.max()]) |
| 142 | + ymax1 = np.max([ymax1, list_flare_force_errors.max()]) |
| 143 | + |
| 144 | + # ================================================================================ |
| 145 | + list_mapped_flare_force_errors = sigma_df[ |
| 146 | + "mapped_flare_all_force_errors" |
| 147 | + ].values[0] |
| 148 | + list_mapped_flare_uncertainties = sigma_df[ |
| 149 | + "mapped_flare_all_uncertainties" |
| 150 | + ].values[0] |
| 151 | + |
| 152 | + coeffs = np.polyfit( |
| 153 | + list_mapped_flare_uncertainties, list_mapped_flare_force_errors, deg=1 |
| 154 | + ) |
| 155 | + x = np.array([0.0, list_mapped_flare_uncertainties.max()]) |
| 156 | + y = np.poly1d(coeffs)(x) |
| 157 | + (line,) = ax2.plot( |
| 158 | + list_mapped_flare_uncertainties, |
| 159 | + list_mapped_flare_force_errors, |
| 160 | + **common_params, |
| 161 | + label=rf"$\sigma$ = {sigma}, slope = {coeffs[0]:4.1f}", |
| 162 | + ) |
| 163 | + ax2.plot(x, y, "-", c=line.get_color(), lw=4, label="__nolabel__") |
| 164 | + |
| 165 | + xmax2 = np.max([xmax2, list_mapped_flare_uncertainties.max()]) |
| 166 | + ymax2 = np.max([ymax2, list_mapped_flare_force_errors.max()]) |
| 167 | + |
| 168 | + ax1.legend(loc=0) |
| 169 | + ax1.set_xlim(xmin=0, xmax=xmax1) |
| 170 | + ax1.set_ylim(ymin=0, ymax=ymax1) |
| 171 | + |
| 172 | + ax2.legend(loc=0) |
| 173 | + ax2.set_xlim(xmin=0, xmax=xmax2) |
| 174 | + ax2.set_ylim(ymin=0, ymax=ymax2) |
| 175 | + |
| 176 | + fig.tight_layout() |
| 177 | + fig.savefig(images_dir / "errors_vs_uncertainties.png") |
| 178 | + |
| 179 | + # ================================================================================ |
| 180 | + |
| 181 | + fig = plt.figure(figsize=PLEASANT_FIG_SIZE) |
| 182 | + fig.suptitle("Comparing FLARE and MAPPED FLARE") |
| 183 | + ax = fig.add_subplot(111) |
| 184 | + ax.set_xlabel("FLARE Uncertainty") |
| 185 | + ax.set_ylabel("MAPPED FLARE Uncertainty") |
| 186 | + |
| 187 | + for sigma, sigma_df in sub_df.groupby(by="sigma"): |
| 188 | + |
| 189 | + list_flare_uncertainties = sigma_df["flare_all_uncertainties"].values[0] |
| 190 | + list_mapped_flare_uncertainties = sigma_df[ |
| 191 | + "mapped_flare_all_uncertainties" |
| 192 | + ].values[0] |
| 193 | + coeffs = np.polyfit( |
| 194 | + list_flare_uncertainties, list_mapped_flare_uncertainties, deg=1 |
| 195 | + ) |
| 196 | + x = np.array([0.0, list_flare_uncertainties.max()]) |
| 197 | + y = np.poly1d(coeffs)(x) |
| 198 | + |
| 199 | + (line,) = ax.plot( |
| 200 | + list_flare_uncertainties, |
| 201 | + list_mapped_flare_uncertainties, |
| 202 | + "o", |
| 203 | + label=rf"$\sigma$ = {sigma}, slope = {coeffs[0]:4.3f}", |
| 204 | + ) |
| 205 | + |
| 206 | + ax.plot(x, y, "-", c=line.get_color(), lw=4, label="__nolabel__") |
| 207 | + |
| 208 | + ax.legend(loc=0) |
| 209 | + ax.set_xlim(xmin=0) |
| 210 | + ax.set_ylim(ymin=0) |
| 211 | + |
| 212 | + fig.tight_layout() |
| 213 | + fig.savefig(images_dir / "flare_vs_mapped_flare_uncertainties.png") |
0 commit comments