Skip to content

Commit 10f3827

Browse files
committed
Another plotting script.
1 parent 882dbf2 commit 10f3827

File tree

1 file changed

+213
-0
lines changed

1 file changed

+213
-0
lines changed
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import numpy as np
2+
import pandas as pd
3+
from matplotlib import pyplot as plt
4+
5+
from diffusion_for_multi_scale_molecular_dynamics import TOP_DIR
6+
from diffusion_for_multi_scale_molecular_dynamics.analysis import (
7+
PLEASANT_FIG_SIZE, PLOT_STYLE_PATH)
8+
9+
plt.style.use(PLOT_STYLE_PATH)
10+
11+
experiment_dir = TOP_DIR / "experiments/active_learning/pretraining_flare/"
12+
output_dir = experiment_dir / "validation_performance"
13+
14+
images_dir = experiment_dir / "images"
15+
images_dir.mkdir(parents=True, exist_ok=True)
16+
17+
18+
if __name__ == "__main__":
19+
df = pd.read_pickle(output_dir / "validation_set_performance.pkl").sort_values(
20+
["sigma", "number_of_structures"]
21+
)
22+
23+
fig = plt.figure(figsize=PLEASANT_FIG_SIZE)
24+
fig.suptitle("Training FLARE from scratch on Si 2x2x2.")
25+
ax1 = fig.add_subplot(121)
26+
ax2 = fig.add_subplot(122)
27+
28+
ax1.set_title("Energy Errors")
29+
ax1.set_xlabel("Number of Training Structures")
30+
ax1.set_ylabel("Validation Energy RMSE (eV)")
31+
32+
ax2.set_title("Force Errors")
33+
ax2.set_xlabel("Number of Training Structures")
34+
ax2.set_ylabel(r"Validation Mean Force RMSE (eV / $\AA$)")
35+
36+
for sigma, group_df in df.groupby(by="sigma"):
37+
number_of_training_structures = group_df["number_of_structures"].values
38+
validation_energy_rmse = group_df["flare_energy_rmse"].values
39+
validation_mean_force_rmse = group_df["flare_mean_force_rmse"]
40+
41+
mapped_validation_energy_rmse = group_df["mapped_flare_energy_rmse"].values
42+
mapped_validation_mean_force_rmse = group_df[
43+
"mapped_flare_mean_force_rmse"
44+
].values
45+
46+
(lines,) = ax1.semilogy(
47+
number_of_training_structures,
48+
validation_energy_rmse,
49+
"-",
50+
label=rf"$\sigma$ = {sigma}",
51+
)
52+
color = lines.get_color()
53+
54+
ax1.semilogy(
55+
number_of_training_structures,
56+
mapped_validation_energy_rmse,
57+
"*",
58+
ms=10,
59+
color=color,
60+
label=rf"$\sigma$ = {sigma} (MAPPED)",
61+
)
62+
63+
ax2.semilogy(
64+
number_of_training_structures, validation_mean_force_rmse, "-", color=color
65+
)
66+
67+
ax2.semilogy(
68+
number_of_training_structures,
69+
mapped_validation_mean_force_rmse,
70+
"*",
71+
ms=10,
72+
color=color,
73+
)
74+
75+
xmin = df["number_of_structures"].min() - 1
76+
xmax = df["number_of_structures"].max() + 1
77+
ax2.hlines(0.01, xmin, xmax, color="green", label="ARTn Force Threshold")
78+
79+
for ax in [ax1, ax2]:
80+
ax.set_ylim(ymin=0)
81+
ax.set_xlim(xmin, xmax)
82+
83+
handles, labels = ax1.get_legend_handles_labels()
84+
85+
legend = fig.legend(
86+
handles,
87+
labels,
88+
loc="lower center", # Place the legend's center below the plots
89+
bbox_to_anchor=(
90+
0.5,
91+
-0.025,
92+
), # x=0.5 (center), y=-0.05 (just below the figure bottom)
93+
ncol=4, # Number of columns for legend entries
94+
fancybox=True,
95+
shadow=True,
96+
borderaxespad=1.0,
97+
)
98+
99+
plt.subplots_adjust(bottom=0.25)
100+
fig.savefig(images_dir / "errors_vs_number_of_structures.png")
101+
102+
# ================================================================================
103+
104+
sub_df = df[df["number_of_structures"] == 10]
105+
106+
common_params = dict(linestyle="None", marker="o", markersize=5, mew=0, alpha=0.5)
107+
108+
figsize = (1.5 * PLEASANT_FIG_SIZE[0], PLEASANT_FIG_SIZE[1])
109+
fig = plt.figure(figsize=figsize)
110+
fig.suptitle(
111+
"Training FLARE from scratch on Si 2x2x2.\n Models trained with 10 structures."
112+
)
113+
ax1 = fig.add_subplot(121)
114+
ax2 = fig.add_subplot(122)
115+
116+
ax1.set_xlabel("FLARE Uncertainty")
117+
ax1.set_ylabel("FLARE Force Error")
118+
ax2.set_xlabel("MAPPED FLARE Uncertainty")
119+
ax2.set_ylabel("MAPPED FLARE Force Error")
120+
121+
xmax1, ymax1, xmax2, ymax2 = 0, 0, 0, 0
122+
123+
for sigma, sigma_df in sub_df.groupby(by="sigma"):
124+
125+
# ================================================================================
126+
list_flare_force_errors = sigma_df["flare_all_force_errors"].values[0]
127+
list_flare_uncertainties = sigma_df["flare_all_uncertainties"].values[0]
128+
129+
coeffs = np.polyfit(list_flare_uncertainties, list_flare_force_errors, deg=1)
130+
x = np.array([0.0, list_flare_uncertainties.max()])
131+
y = np.poly1d(coeffs)(x)
132+
133+
(line,) = ax1.plot(
134+
list_flare_uncertainties,
135+
list_flare_force_errors,
136+
**common_params,
137+
label=rf"$\sigma$ = {sigma}, slope = {coeffs[0]:4.1f}",
138+
)
139+
140+
ax1.plot(x, y, "-", c=line.get_color(), lw=4, label="__nolabel__")
141+
xmax1 = np.max([xmax1, list_flare_uncertainties.max()])
142+
ymax1 = np.max([ymax1, list_flare_force_errors.max()])
143+
144+
# ================================================================================
145+
list_mapped_flare_force_errors = sigma_df[
146+
"mapped_flare_all_force_errors"
147+
].values[0]
148+
list_mapped_flare_uncertainties = sigma_df[
149+
"mapped_flare_all_uncertainties"
150+
].values[0]
151+
152+
coeffs = np.polyfit(
153+
list_mapped_flare_uncertainties, list_mapped_flare_force_errors, deg=1
154+
)
155+
x = np.array([0.0, list_mapped_flare_uncertainties.max()])
156+
y = np.poly1d(coeffs)(x)
157+
(line,) = ax2.plot(
158+
list_mapped_flare_uncertainties,
159+
list_mapped_flare_force_errors,
160+
**common_params,
161+
label=rf"$\sigma$ = {sigma}, slope = {coeffs[0]:4.1f}",
162+
)
163+
ax2.plot(x, y, "-", c=line.get_color(), lw=4, label="__nolabel__")
164+
165+
xmax2 = np.max([xmax2, list_mapped_flare_uncertainties.max()])
166+
ymax2 = np.max([ymax2, list_mapped_flare_force_errors.max()])
167+
168+
ax1.legend(loc=0)
169+
ax1.set_xlim(xmin=0, xmax=xmax1)
170+
ax1.set_ylim(ymin=0, ymax=ymax1)
171+
172+
ax2.legend(loc=0)
173+
ax2.set_xlim(xmin=0, xmax=xmax2)
174+
ax2.set_ylim(ymin=0, ymax=ymax2)
175+
176+
fig.tight_layout()
177+
fig.savefig(images_dir / "errors_vs_uncertainties.png")
178+
179+
# ================================================================================
180+
181+
fig = plt.figure(figsize=PLEASANT_FIG_SIZE)
182+
fig.suptitle("Comparing FLARE and MAPPED FLARE")
183+
ax = fig.add_subplot(111)
184+
ax.set_xlabel("FLARE Uncertainty")
185+
ax.set_ylabel("MAPPED FLARE Uncertainty")
186+
187+
for sigma, sigma_df in sub_df.groupby(by="sigma"):
188+
189+
list_flare_uncertainties = sigma_df["flare_all_uncertainties"].values[0]
190+
list_mapped_flare_uncertainties = sigma_df[
191+
"mapped_flare_all_uncertainties"
192+
].values[0]
193+
coeffs = np.polyfit(
194+
list_flare_uncertainties, list_mapped_flare_uncertainties, deg=1
195+
)
196+
x = np.array([0.0, list_flare_uncertainties.max()])
197+
y = np.poly1d(coeffs)(x)
198+
199+
(line,) = ax.plot(
200+
list_flare_uncertainties,
201+
list_mapped_flare_uncertainties,
202+
"o",
203+
label=rf"$\sigma$ = {sigma}, slope = {coeffs[0]:4.3f}",
204+
)
205+
206+
ax.plot(x, y, "-", c=line.get_color(), lw=4, label="__nolabel__")
207+
208+
ax.legend(loc=0)
209+
ax.set_xlim(xmin=0)
210+
ax.set_ylim(ymin=0)
211+
212+
fig.tight_layout()
213+
fig.savefig(images_dir / "flare_vs_mapped_flare_uncertainties.png")

0 commit comments

Comments
 (0)