1
+ from pathlib import Path
2
+ from enum import Enum
3
+
4
+ import json
5
+ import pandas as pd
6
+ import matplotlib .pyplot as plt
7
+
8
+ PRODUCTS_GREEDY_SPECULATIVE_BATCH_SIZE = {
9
+ 1 : "results_product_500_greedy_speculative_bs_1_report.txt" ,
10
+ 4 : "results_product_500_greedy_speculative_bs_4_report.txt" ,
11
+ 16 : "results_product_500_greedy_speculative_bs_16_report.txt" ,
12
+ 32 : "results_product_500_greedy_speculative_bs_32_report.txt"
13
+ }
14
+
15
+ PRODUCTS_BEAM_SEARCH_SPECULATIVE_BATCH_SIZE = {
16
+ 1 : "results_product_500_beam_search_speculative_bs_1_report.txt" ,
17
+ 2 : "results_product_500_beam_search_speculative_bs_2_report.txt" ,
18
+ 3 : "results_product_500_beam_search_speculative_bs_3_report.txt" ,
19
+ 4 : "results_product_500_beam_search_speculative_bs_4_report.txt"
20
+ }
21
+
22
+ RETRO_BEAM_SEARCH_SPECULATIVE_BS_1_NBEST = {
23
+ 5 : "results_retro_500_beam_search_speculative_bs_1_nbest_5_report.txt" ,
24
+ 10 : "results_retro_500_beam_search_speculative_bs_1_nbest_10_report.txt" ,
25
+ 15 : "results_retro_500_beam_search_speculative_bs_1_nbest_15_report.txt" ,
26
+ 20 : "results_retro_500_beam_search_speculative_bs_1_nbest_20_report.txt"
27
+ }
28
+
29
+ RETRO_BEAM_SEARCH_SPECULATIVE_NBEST_10_BATCH_SIZE = {
30
+ 1 : "results_retro_500_beam_search_speculative_bs_1_nbest_10_report.txt" ,
31
+ 2 : "results_retro_500_beam_search_speculative_bs_2_nbest_10_report.txt" ,
32
+ 4 : "results_retro_500_beam_search_speculative_bs_4_nbest_10_report.txt" ,
33
+ 8 : "results_retro_500_beam_search_speculative_bs_8_nbest_10_report.txt"
34
+ }
35
+
36
+ class Experiment (Enum ):
37
+ PRODUCTS_GREEDY_SPECULATIVE = 1
38
+ PRODUCTS_BEAM_SEARCH_SPECULATIVE = 2
39
+ RETRO_BEAM_SEARCH_SPECULATIVE_BS_1 = 3
40
+ RETRO_BEAM_SEARCH_SPECULATIVE_NBEST_10 = 4
41
+
42
+ EXPERIMENTS = {
43
+ Experiment .PRODUCTS_GREEDY_SPECULATIVE : PRODUCTS_GREEDY_SPECULATIVE_BATCH_SIZE ,
44
+ Experiment .PRODUCTS_BEAM_SEARCH_SPECULATIVE : PRODUCTS_BEAM_SEARCH_SPECULATIVE_BATCH_SIZE ,
45
+ Experiment .RETRO_BEAM_SEARCH_SPECULATIVE_BS_1 : RETRO_BEAM_SEARCH_SPECULATIVE_BS_1_NBEST ,
46
+ Experiment .RETRO_BEAM_SEARCH_SPECULATIVE_NBEST_10 : RETRO_BEAM_SEARCH_SPECULATIVE_NBEST_10_BATCH_SIZE ,
47
+ }
48
+
49
+
50
+ def load_reports (experiment : Experiment ) -> dict [int , pd .DataFrame ]:
51
+ report = {}
52
+ for k , path in EXPERIMENTS [experiment ].items ():
53
+ with open (Path (path ), "r" ) as file :
54
+ records = []
55
+ for line in file .readlines ():
56
+ records .append (pd .DataFrame .from_dict (json .loads (line ), orient = "index" ).T )
57
+ records = pd .concat (records ).reset_index (drop = True )
58
+ report [k ] = records
59
+ return report
60
+
61
+
62
+ def figure_products_greedy_speculative (
63
+ ax , # Array of axes
64
+ major_text_size : int = 16 ,
65
+ minor_text_size : int = 14 ,
66
+ marker_size : int = 8 ,
67
+ alpha = 1.0 ,
68
+ ):
69
+ # Products greedy speculative
70
+ report = load_reports (Experiment .PRODUCTS_GREEDY_SPECULATIVE )
71
+ batch_sizes = sorted (report .keys ())
72
+ axs = {}
73
+ for i , batch_size in enumerate (batch_sizes ):
74
+ axs [batch_size ] = ax [i ] # Just use the provided axes directly
75
+
76
+ # Add 'A' label to the leftmost subplot
77
+ axs [1 ].text (- 0.25 , 1.03 , 'A' , transform = axs [1 ].transAxes ,
78
+ fontsize = 23 , fontweight = 'bold' , va = 'center' )
79
+
80
+ for batch_size in report .keys ():
81
+ results = report [batch_size ]
82
+ unique_n_drafts = sorted (results ["n_drafts" ].unique ().tolist ())
83
+ for i in unique_n_drafts :
84
+ axs [batch_size ].plot (
85
+ results [results ["n_drafts" ] == i ]["draft_len" ],
86
+ results [results ["n_drafts" ] == i ]["total_seconds" ],
87
+ "-s" ,
88
+ markersize = marker_size ,
89
+ alpha = alpha ,
90
+ label = f"{ i } drafts"
91
+ )
92
+ axs [batch_size ].grid ()
93
+ axs [batch_size ].set_ylim (5 , 60 )
94
+ axs [batch_size ].set_title (f"Batch size { batch_size } " , size = minor_text_size )
95
+ axs [batch_size ].tick_params (axis = 'both' , labelsize = minor_text_size )
96
+ axs [batch_size ].xaxis .label .set_size (minor_text_size )
97
+ axs [batch_size ].yaxis .label .set_size (minor_text_size )
98
+ axs [batch_size ].set_xlabel ("Draft length" )
99
+ if batch_size != 1 : # Remove y-axis labels for all but first subplot
100
+ axs [batch_size ].set_yticklabels ([])
101
+
102
+ axs [1 ].set_ylabel ("Total seconds" )
103
+ axs [32 ].legend (loc = "upper right" , fontsize = minor_text_size )
104
+ return axs
105
+
106
+
107
+ def figure_products_beam_search_speculative (
108
+ ax , # Array of axes
109
+ major_text_size : int = 16 ,
110
+ minor_text_size : int = 14 ,
111
+ marker_size : int = 8 ,
112
+ alpha = 1.0 ,
113
+ ):
114
+ # Products greedy speculative
115
+ report = load_reports (Experiment .PRODUCTS_BEAM_SEARCH_SPECULATIVE )
116
+ batch_sizes = sorted (report .keys ())
117
+ axs = {}
118
+ for i , batch_size in enumerate (batch_sizes ):
119
+ axs [batch_size ] = ax [i ] # Just use the provided axes directly
120
+
121
+ # Add 'B' label to the leftmost subplot
122
+ axs [1 ].text (- 0.25 , 1.05 , 'B' , transform = axs [1 ].transAxes ,
123
+ fontsize = 23 , fontweight = 'bold' , va = 'center' )
124
+
125
+ for batch_size in report .keys ():
126
+ results = report [batch_size ]
127
+ unique_n_drafts = sorted (results ["n_drafts" ].unique ().tolist ())
128
+ for i in unique_n_drafts :
129
+ axs [batch_size ].plot (
130
+ results [results ["n_drafts" ] == i ]["draft_len" ],
131
+ results [results ["n_drafts" ] == i ]["total_seconds" ],
132
+ "-s" ,
133
+ markersize = marker_size ,
134
+ alpha = alpha ,
135
+ label = f"{ i } drafts"
136
+ )
137
+ axs [batch_size ].grid ()
138
+ axs [batch_size ].set_ylim (60 , 150 )
139
+ axs [batch_size ].set_title (f"Batch size { batch_size } " , size = minor_text_size )
140
+ axs [batch_size ].tick_params (axis = 'both' , labelsize = minor_text_size )
141
+ axs [batch_size ].xaxis .label .set_size (minor_text_size )
142
+ axs [batch_size ].yaxis .label .set_size (minor_text_size )
143
+ axs [batch_size ].set_xlabel ("Draft length" )
144
+ if batch_size != 1 : # Remove y-axis labels for all but first subplot
145
+ axs [batch_size ].set_yticklabels ([])
146
+
147
+ axs [1 ].set_ylabel ("Total seconds" )
148
+ axs [4 ].legend (loc = "upper left" , fontsize = minor_text_size )
149
+ return axs
150
+
151
+
152
+ def figure_retro_beam_search_speculative_bs_1 (
153
+ ax , # Array of axes
154
+ major_text_size : int = 16 ,
155
+ minor_text_size : int = 14 ,
156
+ marker_size : int = 8 ,
157
+ alpha = 1.0 ,
158
+ ):
159
+ # Products greedy speculative
160
+ report = load_reports (Experiment .RETRO_BEAM_SEARCH_SPECULATIVE_BS_1 )
161
+ n_best_values = sorted (report .keys ())
162
+ axs = {}
163
+ for i , n_best in enumerate (n_best_values ):
164
+ axs [n_best ] = ax [i ] # Just use the provided axes directly
165
+
166
+ # Add 'C' label to the leftmost subplot
167
+ axs [5 ].text (- 0.25 , 1.035 , 'C' , transform = axs [5 ].transAxes ,
168
+ fontsize = 23 , fontweight = 'bold' , va = 'center' )
169
+
170
+ for n_best in report .keys ():
171
+ results = report [n_best ]
172
+ unique_n_drafts = sorted (results ["n_drafts" ].unique ().tolist ())
173
+ for i in unique_n_drafts :
174
+ axs [n_best ].plot (
175
+ results [results ["n_drafts" ] == i ]["draft_len" ],
176
+ results [results ["n_drafts" ] == i ]["total_seconds" ],
177
+ "-s" ,
178
+ markersize = marker_size ,
179
+ alpha = alpha ,
180
+ label = f"{ i } drafts"
181
+ )
182
+ axs [n_best ].grid ()
183
+ axs [n_best ].set_ylim (150 , 410 )
184
+ axs [n_best ].set_title (f"{ n_best } best sequences" , size = minor_text_size )
185
+ axs [n_best ].tick_params (axis = 'both' , labelsize = minor_text_size )
186
+ axs [n_best ].xaxis .label .set_size (minor_text_size )
187
+ axs [n_best ].yaxis .label .set_size (minor_text_size )
188
+ axs [n_best ].set_xlabel ("Draft length" )
189
+ if n_best != 5 : # Remove y-axis labels for all but first subplot
190
+ axs [n_best ].set_yticklabels ([])
191
+
192
+ axs [5 ].set_ylabel ("Total seconds" )
193
+ axs [5 ].legend (loc = "upper right" , fontsize = minor_text_size )
194
+ return axs
195
+
196
+ def figure_retro_beam_search_speculative_nbest_10 (
197
+ ax , # Array of axes
198
+ major_text_size : int = 16 ,
199
+ minor_text_size : int = 14 ,
200
+ marker_size : int = 8 ,
201
+ alpha = 1.0 ,
202
+ ):
203
+ # Products greedy speculative
204
+ report = load_reports (Experiment .RETRO_BEAM_SEARCH_SPECULATIVE_NBEST_10 )
205
+ batch_sizes = sorted (report .keys ())
206
+ axs = {}
207
+ for i , batch_size in enumerate (batch_sizes ):
208
+ axs [batch_size ] = ax [i ] # Just use the provided axes directly
209
+
210
+ # Add 'D' label to the leftmost subplot
211
+ axs [1 ].text (- 0.25 , 1.03 , 'D' , transform = axs [1 ].transAxes ,
212
+ fontsize = 23 , fontweight = 'bold' , va = 'center' )
213
+
214
+ for batch_size in report .keys ():
215
+ results = report [batch_size ]
216
+ unique_n_drafts = sorted (results ["n_drafts" ].unique ().tolist ())
217
+ for i in unique_n_drafts :
218
+ axs [batch_size ].plot (
219
+ results [results ["n_drafts" ] == i ]["draft_len" ],
220
+ results [results ["n_drafts" ] == i ]["total_seconds" ],
221
+ "-s" ,
222
+ markersize = marker_size ,
223
+ alpha = alpha ,
224
+ label = f"{ i } drafts"
225
+ )
226
+ axs [batch_size ].grid ()
227
+ axs [batch_size ].set_ylim (40 , 330 )
228
+ axs [batch_size ].set_title (f"Batch size { batch_size } " , size = minor_text_size )
229
+ axs [batch_size ].tick_params (axis = 'both' , labelsize = minor_text_size )
230
+ axs [batch_size ].xaxis .label .set_size (minor_text_size )
231
+ axs [batch_size ].yaxis .label .set_size (minor_text_size )
232
+ axs [batch_size ].set_xlabel ("Draft length" )
233
+ if batch_size != 1 : # Remove y-axis labels for all but first subplot
234
+ axs [batch_size ].set_yticklabels ([])
235
+
236
+ axs [1 ].set_ylabel ("Total seconds" )
237
+ axs [1 ].legend (loc = "lower left" , fontsize = minor_text_size - 3 )
238
+ return axs
239
+
240
+
241
+ if __name__ == "__main__" :
242
+ fig = plt .figure (figsize = (15 , 24 ))
243
+
244
+ # Create a 2x4 grid of subplots
245
+ gs = fig .add_gridspec (4 , 4 )
246
+
247
+ # Create two rows of axes
248
+ ax1 = [fig .add_subplot (gs [0 , i ]) for i in range (4 )]
249
+ ax2 = [fig .add_subplot (gs [1 , i ]) for i in range (4 )]
250
+ ax3 = [fig .add_subplot (gs [2 , i ]) for i in range (4 )]
251
+ ax4 = [fig .add_subplot (gs [3 , i ]) for i in range (4 )]
252
+
253
+ # Call the plotting functions with their respective axes
254
+ marker_size = 9
255
+ figure_products_greedy_speculative (ax1 , marker_size = marker_size )
256
+ figure_products_beam_search_speculative (ax2 , marker_size = marker_size )
257
+ figure_retro_beam_search_speculative_bs_1 (ax3 , marker_size = marker_size )
258
+ figure_retro_beam_search_speculative_nbest_10 (ax4 , marker_size = marker_size )
259
+
260
+ # Add overall title
261
+ fig .suptitle (
262
+ """Time it takes for the model to process 500 reactions with different hyperparameters.
263
+ A - product prediction, greedy speculative.
264
+ B - product prediction, speculative beam search.
265
+ C - single-step retrosynthesis, speculative beam search, batch size 1.
266
+ D - single-step retrosynthesis, speculative beam search, 10 best sequences.
267
+ """ ,
268
+ size = 18 )
269
+
270
+ # Adjust layout to prevent overlap
271
+ plt .tight_layout ()
272
+ plt .subplots_adjust (left = 0.1 , right = 0.9 , top = 0.9 , bottom = 0.1 , hspace = 0.3 , wspace = 0.05 )
273
+ plt .savefig ("grid_search_summary.png" , dpi = 300 , bbox_inches = 'tight' )
0 commit comments