3
3
# plot accuracy across hyperparameter values
4
4
5
5
# e.g. compare <logdirs-prefix>
6
- # --logdirs_prefix =trained- \
6
+ # --logdirs_filter =trained- \
7
7
# --loss=exclusive \
8
8
# --overlapped_prefix=not_
9
9
@@ -33,8 +33,9 @@ def main():
33
33
for key in sorted (flags .keys ()):
34
34
print ('%s = %s' % (key , flags [key ]))
35
35
36
- logdirs_prefix = FLAGS .logdirs_prefix
37
- basename , dirname = os .path .split (logdirs_prefix )
36
+ logdirs_filter = FLAGS .logdirs_filter
37
+ logdirs_dirname , logdirs_basename = os .path .split (logdirs_filter )
38
+ indepvar , * filters = logdirs_basename .split ('_' )
38
39
39
40
same_time = False
40
41
outlier_criteria = 50
@@ -51,20 +52,24 @@ def main():
51
52
nlayers = {}
52
53
hyperparameters = {}
53
54
54
- logdirs = list (filter (lambda x : x .startswith (dirname + '-' ) and \
55
- os .path .isdir (os .path .join (basename ,x )), os .listdir (basename )))
55
+ def filter_logdirs (logdir ):
56
+ params = logdir .split ('_' )
57
+ return all ([f in params for f in filters ]) and \
58
+ any ([p .startswith (indepvar ) for p in params ])
59
+
60
+ logdirs = list (filter (filter_logdirs , os .listdir (logdirs_dirname )))
56
61
57
62
for logdir in logdirs :
58
63
print (logdir )
59
- hyperparameters [logdir ] = set (logdir .split ('-' )[ - 1 ]. split ( ' _' ))
64
+ hyperparameters [logdir ] = set (logdir .split ('_' ))
60
65
_ , _ , train_time [logdir ], _ , \
61
66
_ , _ , validation_precision [logdir ], validation_recall [logdir ], \
62
67
validation_time [logdir ], validation_step [logdir ], \
63
68
_ , _ , _ , _ , \
64
69
labels_touse [logdir ], _ , \
65
70
nparameters_total [logdir ], nparameters_finallayer [logdir ], \
66
71
batch_size [logdir ], nlayers [logdir ] = \
67
- read_logs (os .path .join (basename ,logdir ))
72
+ read_logs (os .path .join (logdirs_dirname ,logdir ))
68
73
if len (set ([tuple (x ) for x in labels_touse [logdir ].values ()]))> 1 :
69
74
print ('WARNING: not all labels_touse are the same' )
70
75
if len (set (nparameters_total [logdir ].values ()))> 1 :
@@ -89,15 +94,20 @@ def main():
89
94
90
95
commonparameters = reduce (lambda x ,y : x & y , hyperparameters .values ())
91
96
differentparameters = {x :',' .join (natsorted (list (hyperparameters [x ]- commonparameters ))) \
92
- for x in natsorted (logdirs )}
97
+ for x in logdirs }
98
+
93
99
100
+ def sortby_indepvar (logdir ):
101
+ params = logdir .split ('_' )
102
+ iindepvar = next (i for i ,x in enumerate (params ) if x .startswith (indepvar ))
103
+ return str (params [iindepvar ]) + str (params [:iindepvar ]) + str (params [iindepvar + 1 :])
94
104
95
105
fig = plt .figure (figsize = (8 ,10 * 2 / 3 ))
96
106
97
107
ax = fig .add_subplot (2 ,2 ,1 )
98
108
99
109
precisions_mean , recalls_mean = [], []
100
- for (ilogdir ,logdir ) in enumerate (natsorted (logdirs )):
110
+ for (ilogdir ,logdir ) in enumerate (natsorted (logdirs , key = sortby_indepvar )):
101
111
color = cm .viridis (ilogdir / max (1 ,len (validation_recall )- 1 ))
102
112
precisions_all , recalls_all = [], []
103
113
for model in validation_recall [logdir ].keys ():
@@ -111,7 +121,7 @@ def main():
111
121
112
122
ax = fig .add_subplot (2 ,2 ,2 )
113
123
bottom = 100
114
- for (iexpt ,expt ) in enumerate (natsorted (validation_recall .keys ())):
124
+ for (iexpt ,expt ) in enumerate (natsorted (validation_recall .keys (), key = sortby_indepvar )):
115
125
color = cm .viridis (iexpt / max (1 ,len (validation_recall )- 1 ))
116
126
validation_recall_average = np .zeros (len (next (iter (validation_recall [expt ].values ()))))
117
127
for model in validation_time [expt ].keys ():
@@ -127,45 +137,45 @@ def main():
127
137
ax .set_ylim (bottom = bottom - 5 , top = 100 )
128
138
ax .set_xlabel ('Training time (min)' )
129
139
ax .set_ylabel ('Overall validation recall' )
130
- ax .legend (loc = 'lower right' , title = dirname , ncol = 2 if "Annotations" in dirname else 1 )
140
+ ax .legend (loc = 'lower right' , ncol = 2 if "Annotations" in logdirs_basename else 1 )
131
141
132
142
ax = fig .add_subplot (2 ,2 ,3 )
133
- ldata = natsorted (nparameters_total .keys ())
143
+ ldata = natsorted (nparameters_total .keys (), key = sortby_indepvar )
134
144
xdata = range (len (ldata ))
135
145
ydata = [next (iter (nparameters_total [x ].values ())) - \
136
146
next (iter (nparameters_finallayer [x ].values ())) for x in ldata ]
137
147
ydata2 = [next (iter (nparameters_finallayer [x ].values ())) for x in ldata ]
138
148
bar1 = ax .bar (xdata ,ydata ,color = 'k' )
139
149
bar2 = ax .bar (xdata ,ydata2 ,bottom = ydata ,color = 'gray' )
140
150
ax .legend ((bar2 ,bar1 ), ('last' ,'rest' ))
141
- ax .set_xlabel (dirname )
151
+ ax .set_xlabel (logdirs_basename )
142
152
ax .set_ylabel ('Trainable parameters' )
143
153
ax .set_xticks (xdata )
144
154
ax .set_xticklabels ([differentparameters [x ] for x in ldata ], rotation = 40 , ha = 'right' )
145
155
146
156
ax = fig .add_subplot (2 ,2 ,4 )
147
- data = {k :list ([np .median (np .diff (x )) for x in train_time [k ].values ()]) for k in train_time }
157
+ data = {k :list ([np .median (np .diff (x )) for x in train_time [k ].values ()])
158
+ for k in sorted (train_time .keys (), key = sortby_indepvar )}
148
159
ldata = jitter_plot (ax , data )
149
160
ax .set_ylabel ('time / step (ms)' )
150
- ax .set_xlabel (dirname )
161
+ ax .set_xlabel (logdirs_basename )
151
162
ax .set_xticks (range (len (ldata )))
152
163
ax .set_xticklabels ([differentparameters [x ] for x in ldata ], rotation = 40 , ha = 'right' )
153
164
154
- fig .suptitle (',' .join (list (commonparameters )))
155
-
156
- fig .tight_layout (rect = [0 , 0.03 , 1 , 0.95 ])
157
- plt .savefig (logdirs_prefix + '-compare-overall-params-speed.pdf' )
165
+ fig .suptitle (',' .join (list (commonparameters )), fontsize = 'xx-large' )
166
+ fig .tight_layout (rect = [0 , 0.03 , 1 , 0.97 ])
167
+ plt .savefig (logdirs_filter + '-compare-overall-params-speed.pdf' )
158
168
plt .close ()
159
169
160
170
161
171
recall_confusion_matrices = {}
162
172
precision_confusion_matrices = {}
163
173
labels = None
164
174
165
- for ilogdir ,logdir in enumerate (natsorted (logdirs )):
175
+ for ilogdir ,logdir in enumerate (natsorted (logdirs , key = sortby_indepvar )):
166
176
kind = next (iter (validation_time [logdir ].keys ())).split ('_' )[0 ]
167
177
confusion_matrices , theselabels = \
168
- parse_confusion_matrices (os .path .join (basename ,logdir ), kind , \
178
+ parse_confusion_matrices (os .path .join (logdirs_dirname ,logdir ), kind , \
169
179
idx_time = idx_time [logdir ] if same_time else None )
170
180
171
181
recall_confusion_matrices [logdir ]= {}
@@ -215,7 +225,7 @@ def main():
215
225
summed2_confusion_matrix ,
216
226
precision_summed_matrix , recall_summed_matrix ,
217
227
len (labels )< 10 ,
218
- logdir + "\n " ,
228
+ differentparameters [ logdir ] + "\n " ,
219
229
labels if FLAGS .loss == 'exclusive' else
220
230
["song" , FLAGS .overlapped_prefix + "song" ],
221
231
precision_summed , recall_summed )
@@ -229,12 +239,13 @@ def main():
229
239
summed_confusion_matrix [ilabel ], \
230
240
precision_summed_matrix , recall_summed_matrix , \
231
241
len (labels )< 10 ,
232
- logdir + "\n " ,
242
+ differentparameters [ logdir ] + "\n " ,
233
243
[labels [ilabel ], FLAGS .overlapped_prefix + labels [ilabel ]],
234
244
precision_summed , recall_summed )
235
245
236
- fig .tight_layout ()
237
- plt .savefig (logdirs_prefix + '-compare-confusion-matrices.pdf' )
246
+ fig .suptitle (',' .join (list (commonparameters )), fontsize = 'xx-large' )
247
+ fig .tight_layout (rect = [0 , 0.03 , 1 , 0.97 ])
248
+ plt .savefig (logdirs_filter + '-compare-confusion-matrices.pdf' )
238
249
plt .close ()
239
250
240
251
@@ -245,7 +256,7 @@ def main():
245
256
for (ilabel ,label ) in enumerate (labels ):
246
257
ax = fig .add_subplot (nrows , ncols , ilabel + 1 )
247
258
precisions_mean , recalls_mean = [], []
248
- for (ilogdir ,logdir ) in enumerate (natsorted (logdirs )):
259
+ for (ilogdir ,logdir ) in enumerate (natsorted (logdirs , key = sortby_indepvar )):
249
260
color = cm .viridis (ilogdir / max (1 ,len (validation_recall )- 1 ))
250
261
precisions_all , recalls_all = [], []
251
262
for (imodel ,model ) in enumerate (recall_confusion_matrices [logdir ].keys ()):
@@ -261,14 +272,15 @@ def main():
261
272
'o' , markeredgecolor = 'k' , color = color )
262
273
label_precisions_recall (ax , recalls_mean , precisions_mean , label + "\n " )
263
274
264
- fig .tight_layout ()
265
- plt .savefig (logdirs_prefix + '-compare-PR-classes.pdf' )
275
+ fig .suptitle (',' .join (list (commonparameters )), fontsize = 'xx-large' )
276
+ fig .tight_layout (rect = [0 , 0.03 , 1 , 0.97 ])
277
+ plt .savefig (logdirs_filter + '-compare-PR-classes.pdf' )
266
278
plt .close ()
267
279
268
280
if __name__ == "__main__" :
269
281
parser = argparse .ArgumentParser ()
270
282
parser .add_argument (
271
- '--logdirs_prefix ' ,
283
+ '--logdirs_filter ' ,
272
284
type = str ,
273
285
default = '/tmp/speech_commands_train' ,
274
286
help = 'Common prefix of the directories of logs and checkpoints' )
0 commit comments