Skip to content

Commit 3f6635c

Browse files
committed
Another plotting script.
1 parent 10f3827 commit 3f6635c

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import numpy as np
2+
from flare.bffs.sgp.sparse_gp import compute_negative_likelihood_grad_stable
3+
from matplotlib import pyplot as plt
4+
5+
from diffusion_for_multi_scale_molecular_dynamics import TOP_DIR
6+
from diffusion_for_multi_scale_molecular_dynamics.active_learning_loop.trainer.flare_trainer import \
7+
FlareTrainer
8+
from diffusion_for_multi_scale_molecular_dynamics.analysis import (
9+
PLEASANT_FIG_SIZE, PLOT_STYLE_PATH)
10+
11+
plt.style.use(PLOT_STYLE_PATH)
12+
13+
experiment_dir = TOP_DIR / "experiments/active_learning/pretraining_flare/"
14+
checkpoint_dir = experiment_dir / "flare_checkpoints" / "sigma_1000.0_n_10"
15+
16+
images_dir = experiment_dir / "images"
17+
images_dir.mkdir(parents=True, exist_ok=True)
18+
19+
if __name__ == "__main__":
20+
21+
checkpoint_path = checkpoint_dir / "flare_model_pretrained.json"
22+
flare_trainer = FlareTrainer.from_checkpoint(checkpoint_path)
23+
24+
flare_trainer.sgp_model.sparse_gp.precompute_KnK()
25+
26+
starting_hyperparameters = 1.0 * flare_trainer.sgp_model.sparse_gp.hyperparameters
27+
initial_sigma, initial_sigma_e, initial_sigma_f, _ = starting_hyperparameters
28+
29+
list_sigma = np.exp(np.linspace(np.log(0.1), np.log(10000), 81))
30+
list_sigma_e = np.exp(np.linspace(np.log(0.0001), np.log(500), 81))
31+
list_sigma_f = np.exp(np.linspace(np.log(0.001), np.log(5), 81))
32+
33+
list_log_likelihoods_vs_sigma = []
34+
list_log_likelihoods_vs_sigma_e = []
35+
list_log_likelihoods_vs_sigma_f = []
36+
37+
for sigma in list_sigma:
38+
hyperparameters = 1.0 * starting_hyperparameters
39+
hyperparameters[0] = sigma
40+
41+
nll, grads = compute_negative_likelihood_grad_stable(hyperparameters,
42+
flare_trainer.sgp_model.sparse_gp,
43+
precomputed=True)
44+
list_log_likelihoods_vs_sigma.append(-nll)
45+
list_log_likelihoods_vs_sigma = np.array(list_log_likelihoods_vs_sigma)
46+
47+
for sigma_e in list_sigma_e:
48+
hyperparameters = 1.0 * starting_hyperparameters
49+
hyperparameters[1] = sigma_e
50+
51+
nll, grads = compute_negative_likelihood_grad_stable(hyperparameters,
52+
flare_trainer.sgp_model.sparse_gp,
53+
precomputed=True)
54+
list_log_likelihoods_vs_sigma_e.append(-nll)
55+
list_log_likelihoods_vs_sigma_e = np.array(list_log_likelihoods_vs_sigma_e)
56+
57+
for sigma_f in list_sigma_f:
58+
hyperparameters = 1.0 * starting_hyperparameters
59+
hyperparameters[2] = sigma_f
60+
61+
nll, grads = compute_negative_likelihood_grad_stable(hyperparameters,
62+
flare_trainer.sgp_model.sparse_gp,
63+
precomputed=True)
64+
list_log_likelihoods_vs_sigma_f.append(-nll)
65+
list_log_likelihoods_vs_sigma_f = np.array(list_log_likelihoods_vs_sigma_f)
66+
67+
figsize = (1.5 * PLEASANT_FIG_SIZE[0], PLEASANT_FIG_SIZE[1])
68+
fig = plt.figure(figsize=figsize)
69+
fig.suptitle("FLARE on Si 2x2x2: Log Likelihood\n FLARE trained on 10 structures")
70+
ax1 = fig.add_subplot(131)
71+
ax2 = fig.add_subplot(132)
72+
ax3 = fig.add_subplot(133)
73+
74+
ax1.set_title(rf"$\sigma_e$ = {initial_sigma_e}, $\sigma_f$ = {initial_sigma_f}")
75+
ax2.set_title(rf"$\sigma$ = {initial_sigma}, $\sigma_f$ = {initial_sigma_f}")
76+
ax3.set_title(rf", $\sigma$ = {initial_sigma}, $\sigma_e$ = {initial_sigma_e}")
77+
78+
ax1.loglog(list_sigma, list_log_likelihoods_vs_sigma, '-', color='k')
79+
ax2.loglog(list_sigma_e, list_log_likelihoods_vs_sigma_e, '-', color='k')
80+
ax3.loglog(list_sigma_f, list_log_likelihoods_vs_sigma_f, '-', color='k')
81+
82+
ymin1, ymax1 = ax1.get_ylim()
83+
ymin2, ymax2 = ax2.get_ylim()
84+
ymin3, ymax3 = ax3.get_ylim()
85+
86+
ax1.vlines(1000.0, ymin1, ymax1, color='red', label=r'$\sigma$ = 1000.')
87+
ax2.vlines(1.0, ymin2, ymax2, color='red', label=r'$\sigma_e$ = 1.0')
88+
ax3.vlines(0.05, ymin3, ymax3, color='red', label=r'$\sigma_f$ = 0.05')
89+
90+
ax1.set_xlabel(r"$\sigma$")
91+
ax2.set_xlabel(r"$\sigma_e$")
92+
ax3.set_xlabel(r"$\sigma_f$")
93+
94+
ax1.set_ylim(ymin1, ymax1)
95+
ax2.set_ylim(ymin2, ymax2)
96+
ax3.set_ylim(ymin3, ymax3)
97+
98+
for ax in [ax1, ax2, ax3]:
99+
ax.set_ylabel("Log Likelihood")
100+
ax.legend(loc=0)
101+
fig.tight_layout()
102+
fig.savefig(images_dir / "log_likelihood.png")

0 commit comments

Comments
 (0)