Skip to content

Commit aa8eec4

Browse files
committed
Added the figure to demonstrate the grid search for the optimal combination of draft length and draft number
1 parent 0a73352 commit aa8eec4

File tree

2 files changed

+273
-0
lines changed

2 files changed

+273
-0
lines changed

results_grid_search/figure_summary.py

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
from pathlib import Path
2+
from enum import Enum
3+
4+
import json
5+
import pandas as pd
6+
import matplotlib.pyplot as plt
7+
8+
PRODUCTS_GREEDY_SPECULATIVE_BATCH_SIZE = {
9+
1: "results_product_500_greedy_speculative_bs_1_report.txt",
10+
4: "results_product_500_greedy_speculative_bs_4_report.txt",
11+
16: "results_product_500_greedy_speculative_bs_16_report.txt",
12+
32: "results_product_500_greedy_speculative_bs_32_report.txt"
13+
}
14+
15+
PRODUCTS_BEAM_SEARCH_SPECULATIVE_BATCH_SIZE = {
16+
1: "results_product_500_beam_search_speculative_bs_1_report.txt",
17+
2: "results_product_500_beam_search_speculative_bs_2_report.txt",
18+
3: "results_product_500_beam_search_speculative_bs_3_report.txt",
19+
4: "results_product_500_beam_search_speculative_bs_4_report.txt"
20+
}
21+
22+
RETRO_BEAM_SEARCH_SPECULATIVE_BS_1_NBEST = {
23+
5: "results_retro_500_beam_search_speculative_bs_1_nbest_5_report.txt",
24+
10: "results_retro_500_beam_search_speculative_bs_1_nbest_10_report.txt",
25+
15: "results_retro_500_beam_search_speculative_bs_1_nbest_15_report.txt",
26+
20: "results_retro_500_beam_search_speculative_bs_1_nbest_20_report.txt"
27+
}
28+
29+
RETRO_BEAM_SEARCH_SPECULATIVE_NBEST_10_BATCH_SIZE = {
30+
1: "results_retro_500_beam_search_speculative_bs_1_nbest_10_report.txt",
31+
2: "results_retro_500_beam_search_speculative_bs_2_nbest_10_report.txt",
32+
4: "results_retro_500_beam_search_speculative_bs_4_nbest_10_report.txt",
33+
8: "results_retro_500_beam_search_speculative_bs_8_nbest_10_report.txt"
34+
}
35+
36+
class Experiment(Enum):
37+
PRODUCTS_GREEDY_SPECULATIVE = 1
38+
PRODUCTS_BEAM_SEARCH_SPECULATIVE = 2
39+
RETRO_BEAM_SEARCH_SPECULATIVE_BS_1 = 3
40+
RETRO_BEAM_SEARCH_SPECULATIVE_NBEST_10 = 4
41+
42+
EXPERIMENTS = {
43+
Experiment.PRODUCTS_GREEDY_SPECULATIVE: PRODUCTS_GREEDY_SPECULATIVE_BATCH_SIZE,
44+
Experiment.PRODUCTS_BEAM_SEARCH_SPECULATIVE: PRODUCTS_BEAM_SEARCH_SPECULATIVE_BATCH_SIZE,
45+
Experiment.RETRO_BEAM_SEARCH_SPECULATIVE_BS_1: RETRO_BEAM_SEARCH_SPECULATIVE_BS_1_NBEST,
46+
Experiment.RETRO_BEAM_SEARCH_SPECULATIVE_NBEST_10: RETRO_BEAM_SEARCH_SPECULATIVE_NBEST_10_BATCH_SIZE,
47+
}
48+
49+
50+
def load_reports(experiment: Experiment) -> dict[int, pd.DataFrame]:
51+
report = {}
52+
for k, path in EXPERIMENTS[experiment].items():
53+
with open(Path(path), "r") as file:
54+
records = []
55+
for line in file.readlines():
56+
records.append(pd.DataFrame.from_dict(json.loads(line), orient="index").T)
57+
records = pd.concat(records).reset_index(drop=True)
58+
report[k] = records
59+
return report
60+
61+
62+
def figure_products_greedy_speculative(
63+
ax, # Array of axes
64+
major_text_size: int = 16,
65+
minor_text_size: int = 14,
66+
marker_size: int = 8,
67+
alpha=1.0,
68+
):
69+
# Products greedy speculative
70+
report = load_reports(Experiment.PRODUCTS_GREEDY_SPECULATIVE)
71+
batch_sizes = sorted(report.keys())
72+
axs = {}
73+
for i, batch_size in enumerate(batch_sizes):
74+
axs[batch_size] = ax[i] # Just use the provided axes directly
75+
76+
# Add 'A' label to the leftmost subplot
77+
axs[1].text(-0.25, 1.03, 'A', transform=axs[1].transAxes,
78+
fontsize=23, fontweight='bold', va='center')
79+
80+
for batch_size in report.keys():
81+
results = report[batch_size]
82+
unique_n_drafts = sorted(results["n_drafts"].unique().tolist())
83+
for i in unique_n_drafts:
84+
axs[batch_size].plot(
85+
results[results["n_drafts"] == i]["draft_len"],
86+
results[results["n_drafts"] == i]["total_seconds"],
87+
"-s",
88+
markersize=marker_size,
89+
alpha=alpha,
90+
label=f"{i} drafts"
91+
)
92+
axs[batch_size].grid()
93+
axs[batch_size].set_ylim(5, 60)
94+
axs[batch_size].set_title(f"Batch size {batch_size}", size=minor_text_size)
95+
axs[batch_size].tick_params(axis='both', labelsize=minor_text_size)
96+
axs[batch_size].xaxis.label.set_size(minor_text_size)
97+
axs[batch_size].yaxis.label.set_size(minor_text_size)
98+
axs[batch_size].set_xlabel("Draft length")
99+
if batch_size != 1: # Remove y-axis labels for all but first subplot
100+
axs[batch_size].set_yticklabels([])
101+
102+
axs[1].set_ylabel("Total seconds")
103+
axs[32].legend(loc="upper right", fontsize=minor_text_size)
104+
return axs
105+
106+
107+
def figure_products_beam_search_speculative(
108+
ax, # Array of axes
109+
major_text_size: int = 16,
110+
minor_text_size: int = 14,
111+
marker_size: int = 8,
112+
alpha=1.0,
113+
):
114+
# Products greedy speculative
115+
report = load_reports(Experiment.PRODUCTS_BEAM_SEARCH_SPECULATIVE)
116+
batch_sizes = sorted(report.keys())
117+
axs = {}
118+
for i, batch_size in enumerate(batch_sizes):
119+
axs[batch_size] = ax[i] # Just use the provided axes directly
120+
121+
# Add 'B' label to the leftmost subplot
122+
axs[1].text(-0.25, 1.05, 'B', transform=axs[1].transAxes,
123+
fontsize=23, fontweight='bold', va='center')
124+
125+
for batch_size in report.keys():
126+
results = report[batch_size]
127+
unique_n_drafts = sorted(results["n_drafts"].unique().tolist())
128+
for i in unique_n_drafts:
129+
axs[batch_size].plot(
130+
results[results["n_drafts"] == i]["draft_len"],
131+
results[results["n_drafts"] == i]["total_seconds"],
132+
"-s",
133+
markersize=marker_size,
134+
alpha=alpha,
135+
label=f"{i} drafts"
136+
)
137+
axs[batch_size].grid()
138+
axs[batch_size].set_ylim(60, 150)
139+
axs[batch_size].set_title(f"Batch size {batch_size}", size=minor_text_size)
140+
axs[batch_size].tick_params(axis='both', labelsize=minor_text_size)
141+
axs[batch_size].xaxis.label.set_size(minor_text_size)
142+
axs[batch_size].yaxis.label.set_size(minor_text_size)
143+
axs[batch_size].set_xlabel("Draft length")
144+
if batch_size != 1: # Remove y-axis labels for all but first subplot
145+
axs[batch_size].set_yticklabels([])
146+
147+
axs[1].set_ylabel("Total seconds")
148+
axs[4].legend(loc="upper left", fontsize=minor_text_size)
149+
return axs
150+
151+
152+
def figure_retro_beam_search_speculative_bs_1(
153+
ax, # Array of axes
154+
major_text_size: int = 16,
155+
minor_text_size: int = 14,
156+
marker_size: int = 8,
157+
alpha=1.0,
158+
):
159+
# Products greedy speculative
160+
report = load_reports(Experiment.RETRO_BEAM_SEARCH_SPECULATIVE_BS_1)
161+
n_best_values = sorted(report.keys())
162+
axs = {}
163+
for i, n_best in enumerate(n_best_values):
164+
axs[n_best] = ax[i] # Just use the provided axes directly
165+
166+
# Add 'C' label to the leftmost subplot
167+
axs[5].text(-0.25, 1.035, 'C', transform=axs[5].transAxes,
168+
fontsize=23, fontweight='bold', va='center')
169+
170+
for n_best in report.keys():
171+
results = report[n_best]
172+
unique_n_drafts = sorted(results["n_drafts"].unique().tolist())
173+
for i in unique_n_drafts:
174+
axs[n_best].plot(
175+
results[results["n_drafts"] == i]["draft_len"],
176+
results[results["n_drafts"] == i]["total_seconds"],
177+
"-s",
178+
markersize=marker_size,
179+
alpha=alpha,
180+
label=f"{i} drafts"
181+
)
182+
axs[n_best].grid()
183+
axs[n_best].set_ylim(150, 410)
184+
axs[n_best].set_title(f"{n_best} best sequences", size=minor_text_size)
185+
axs[n_best].tick_params(axis='both', labelsize=minor_text_size)
186+
axs[n_best].xaxis.label.set_size(minor_text_size)
187+
axs[n_best].yaxis.label.set_size(minor_text_size)
188+
axs[n_best].set_xlabel("Draft length")
189+
if n_best != 5: # Remove y-axis labels for all but first subplot
190+
axs[n_best].set_yticklabels([])
191+
192+
axs[5].set_ylabel("Total seconds")
193+
axs[5].legend(loc="upper right", fontsize=minor_text_size)
194+
return axs
195+
196+
def figure_retro_beam_search_speculative_nbest_10(
197+
ax, # Array of axes
198+
major_text_size: int = 16,
199+
minor_text_size: int = 14,
200+
marker_size: int = 8,
201+
alpha=1.0,
202+
):
203+
# Products greedy speculative
204+
report = load_reports(Experiment.RETRO_BEAM_SEARCH_SPECULATIVE_NBEST_10)
205+
batch_sizes = sorted(report.keys())
206+
axs = {}
207+
for i, batch_size in enumerate(batch_sizes):
208+
axs[batch_size] = ax[i] # Just use the provided axes directly
209+
210+
# Add 'D' label to the leftmost subplot
211+
axs[1].text(-0.25, 1.03, 'D', transform=axs[1].transAxes,
212+
fontsize=23, fontweight='bold', va='center')
213+
214+
for batch_size in report.keys():
215+
results = report[batch_size]
216+
unique_n_drafts = sorted(results["n_drafts"].unique().tolist())
217+
for i in unique_n_drafts:
218+
axs[batch_size].plot(
219+
results[results["n_drafts"] == i]["draft_len"],
220+
results[results["n_drafts"] == i]["total_seconds"],
221+
"-s",
222+
markersize=marker_size,
223+
alpha=alpha,
224+
label=f"{i} drafts"
225+
)
226+
axs[batch_size].grid()
227+
axs[batch_size].set_ylim(40, 330)
228+
axs[batch_size].set_title(f"Batch size {batch_size}", size=minor_text_size)
229+
axs[batch_size].tick_params(axis='both', labelsize=minor_text_size)
230+
axs[batch_size].xaxis.label.set_size(minor_text_size)
231+
axs[batch_size].yaxis.label.set_size(minor_text_size)
232+
axs[batch_size].set_xlabel("Draft length")
233+
if batch_size != 1: # Remove y-axis labels for all but first subplot
234+
axs[batch_size].set_yticklabels([])
235+
236+
axs[1].set_ylabel("Total seconds")
237+
axs[1].legend(loc="lower left", fontsize=minor_text_size - 3)
238+
return axs
239+
240+
241+
if __name__ == "__main__":
242+
fig = plt.figure(figsize=(15, 24))
243+
244+
# Create a 2x4 grid of subplots
245+
gs = fig.add_gridspec(4, 4)
246+
247+
# Create two rows of axes
248+
ax1 = [fig.add_subplot(gs[0, i]) for i in range(4)]
249+
ax2 = [fig.add_subplot(gs[1, i]) for i in range(4)]
250+
ax3 = [fig.add_subplot(gs[2, i]) for i in range(4)]
251+
ax4 = [fig.add_subplot(gs[3, i]) for i in range(4)]
252+
253+
# Call the plotting functions with their respective axes
254+
marker_size = 9
255+
figure_products_greedy_speculative(ax1, marker_size=marker_size)
256+
figure_products_beam_search_speculative(ax2, marker_size=marker_size)
257+
figure_retro_beam_search_speculative_bs_1(ax3, marker_size=marker_size)
258+
figure_retro_beam_search_speculative_nbest_10(ax4, marker_size=marker_size)
259+
260+
# Add overall title
261+
fig.suptitle(
262+
"""Time it takes for the model to process 500 reactions with different hyperparameters.
263+
A - product prediction, greedy speculative.
264+
B - product prediction, speculative beam search.
265+
C - single-step retrosynthesis, speculative beam search, batch size 1.
266+
D - single-step retrosynthesis, speculative beam search, 10 best sequences.
267+
""",
268+
size=18)
269+
270+
# Adjust layout to prevent overlap
271+
plt.tight_layout()
272+
plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1, hspace=0.3, wspace=0.05)
273+
plt.savefig("grid_search_summary.png", dpi=300, bbox_inches='tight')
1.4 MB
Loading

0 commit comments

Comments
 (0)