-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel_types.py
276 lines (268 loc) · 12.7 KB
/
model_types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
from enum import Enum
class ModelTypes(Enum):
STANDARD = ('Standard training', 'tab:blue')
LP_ADV = ('Lp adversarially robust', 'tab:olive')
ROBUST_INTV = ('Other robustness intervention', 'tab:brown')
MORE_DATA = ('Trained with more data', 'tab:green')
RANDOM_FEATURES = ('Random Features', 'tab:yellow')
LINEAR_PIXELS = ('Linear Classifier on Pixels', 'tab:orange')
RANDOM_FORESTS = ('Random Forests', 'tab:pink')
NEAREST_NEIGHBORS = ('Nearest Neighbors', 'tab:purple')
LOW_ACCURACY_CNN = ('Low Accuracy CNN', 'tab:cyan')
model_types_map = {
'BiT-M-R50x1-ILSVRC2012': ModelTypes.MORE_DATA,
'BiT-M-R50x3-ILSVRC2012': ModelTypes.MORE_DATA,
'BiT-M-R101x1-ILSVRC2012': ModelTypes.MORE_DATA,
'BiT-M-R101x3-ILSVRC2012': ModelTypes.MORE_DATA,
'BiT-M-R152x4-ILSVRC2012': ModelTypes.MORE_DATA,
'BiT-M-R50x1-nonfinetuned': ModelTypes.MORE_DATA,
'BiT-M-R50x3-nonfinetuned': ModelTypes.MORE_DATA,
'BiT-M-R101x1-nonfinetuned': ModelTypes.MORE_DATA,
'BiT-M-R101x3-nonfinetuned': ModelTypes.MORE_DATA,
'BiT-M-R152x4-nonfinetuned': ModelTypes.MORE_DATA,
'FixPNASNet': ModelTypes.STANDARD,
'FixResNeXt101_32x48d': ModelTypes.MORE_DATA,
'FixResNeXt101_32x48d_v2': ModelTypes.MORE_DATA,
'FixResNet50': ModelTypes.STANDARD,
'FixResNet50CutMix': ModelTypes.ROBUST_INTV,
'FixResNet50CutMix_v2': ModelTypes.ROBUST_INTV,
'FixResNet50_no_adaptation': ModelTypes.STANDARD,
'FixResNet50_v2': ModelTypes.STANDARD,
'alexnet': ModelTypes.STANDARD,
'alexnet_lpf2': ModelTypes.ROBUST_INTV,
'alexnet_lpf3': ModelTypes.ROBUST_INTV,
'alexnet_lpf5': ModelTypes.ROBUST_INTV,
'bninception': ModelTypes.STANDARD,
'bninception-imagenet21k': ModelTypes.MORE_DATA,
'cafferesnet101': ModelTypes.STANDARD,
'densenet121': ModelTypes.STANDARD,
'densenet121_lpf2': ModelTypes.ROBUST_INTV,
'densenet121_lpf3': ModelTypes.ROBUST_INTV,
'densenet121_lpf5': ModelTypes.ROBUST_INTV,
'densenet161': ModelTypes.STANDARD,
'densenet169': ModelTypes.STANDARD,
'densenet201': ModelTypes.STANDARD,
'dpn107': ModelTypes.MORE_DATA,
'dpn131': ModelTypes.STANDARD,
'dpn68': ModelTypes.STANDARD,
'dpn68b': ModelTypes.MORE_DATA,
'dpn92': ModelTypes.MORE_DATA,
'dpn98': ModelTypes.STANDARD,
'efficientnet-b0': ModelTypes.STANDARD,
'efficientnet-b0-advprop-autoaug': ModelTypes.ROBUST_INTV,
'efficientnet-b0-autoaug': ModelTypes.STANDARD,
'efficientnet-b1': ModelTypes.STANDARD,
'efficientnet-b1-advprop-autoaug': ModelTypes.ROBUST_INTV,
'efficientnet-b1-autoaug': ModelTypes.STANDARD,
'efficientnet-b2': ModelTypes.STANDARD,
'efficientnet-b2-advprop-autoaug': ModelTypes.ROBUST_INTV,
'efficientnet-b2-autoaug': ModelTypes.STANDARD,
'efficientnet-b3': ModelTypes.STANDARD,
'efficientnet-b3-advprop-autoaug': ModelTypes.ROBUST_INTV,
'efficientnet-b3-autoaug': ModelTypes.STANDARD,
'efficientnet-b4': ModelTypes.STANDARD,
'efficientnet-b4-advprop-autoaug': ModelTypes.ROBUST_INTV,
'efficientnet-b4-autoaug': ModelTypes.STANDARD,
'efficientnet-b5': ModelTypes.STANDARD,
'efficientnet-b5-advprop-autoaug': ModelTypes.ROBUST_INTV,
'efficientnet-b5-autoaug': ModelTypes.STANDARD,
'efficientnet-b5-randaug': ModelTypes.STANDARD,
'efficientnet-b6-advprop-autoaug': ModelTypes.ROBUST_INTV,
'efficientnet-b6-autoaug': ModelTypes.STANDARD,
'efficientnet-b7-advprop-autoaug': ModelTypes.ROBUST_INTV,
'efficientnet-b7-autoaug': ModelTypes.STANDARD,
'efficientnet-b7-randaug': ModelTypes.STANDARD,
'efficientnet-b8-advprop-autoaug': ModelTypes.ROBUST_INTV,
'efficientnet-l2-noisystudent': ModelTypes.MORE_DATA,
'facebook_adv_trained_resnet152_baseline': ModelTypes.LP_ADV,
'facebook_adv_trained_resnet152_denoise': ModelTypes.LP_ADV,
'facebook_adv_trained_resnext101_denoiseAll': ModelTypes.LP_ADV,
'fbresnet152': ModelTypes.STANDARD,
'google_resnet101_jft-300M': ModelTypes.MORE_DATA,
'googlenet/inceptionv1': ModelTypes.STANDARD,
'inceptionresnetv2': ModelTypes.STANDARD,
'inceptionv3': ModelTypes.STANDARD,
'inceptionv4': ModelTypes.STANDARD,
'instagram-resnext101_32x16d': ModelTypes.MORE_DATA,
'instagram-resnext101_32x32d': ModelTypes.MORE_DATA,
'instagram-resnext101_32x48d': ModelTypes.MORE_DATA,
'instagram-resnext101_32x8d': ModelTypes.MORE_DATA,
'mnasnet0_5': ModelTypes.STANDARD,
'mnasnet1_0': ModelTypes.STANDARD,
'mobilenet_v2': ModelTypes.STANDARD,
'mobilenet_v2_lpf2': ModelTypes.ROBUST_INTV,
'mobilenet_v2_lpf3': ModelTypes.ROBUST_INTV,
'mobilenet_v2_lpf5': ModelTypes.ROBUST_INTV,
'nasnetalarge': ModelTypes.STANDARD,
'nasnetamobile': ModelTypes.STANDARD,
'pnasnet5large': ModelTypes.STANDARD,
'polynet': ModelTypes.STANDARD,
'resnet101': ModelTypes.STANDARD,
'resnet101-tencent-ml-images': ModelTypes.MORE_DATA,
'resnet101_cutmix': ModelTypes.ROBUST_INTV,
'resnet101_lpf2': ModelTypes.ROBUST_INTV,
'resnet101_lpf3': ModelTypes.ROBUST_INTV,
'resnet101_lpf5': ModelTypes.ROBUST_INTV,
'resnet152': ModelTypes.STANDARD,
'resnet152-imagenet11k': ModelTypes.MORE_DATA,
'resnet152_3x_simclrv2_linear_probe_tf_port': ModelTypes.STANDARD,
'resnet152_3x_simclrv2_finetuned_100pct_tf_port': ModelTypes.STANDARD,
'resnet18': ModelTypes.STANDARD,
'resnet18-rotation-nocrop_40': ModelTypes.ROBUST_INTV,
'resnet18-rotation-random_30': ModelTypes.ROBUST_INTV,
'resnet18-rotation-random_40': ModelTypes.ROBUST_INTV,
'resnet18-rotation-standard_40': ModelTypes.ROBUST_INTV,
'resnet18-rotation-worst10_30': ModelTypes.ROBUST_INTV,
'resnet18-rotation-worst10_40': ModelTypes.ROBUST_INTV,
'resnet18_lpf2': ModelTypes.ROBUST_INTV,
'resnet18_lpf3': ModelTypes.ROBUST_INTV,
'resnet18_lpf5': ModelTypes.ROBUST_INTV,
'resnet18_ssl': ModelTypes.MORE_DATA,
'resnet18_swsl': ModelTypes.MORE_DATA,
'resnet34': ModelTypes.STANDARD,
'resnet34_lpf2': ModelTypes.ROBUST_INTV,
'resnet34_lpf3': ModelTypes.ROBUST_INTV,
'resnet34_lpf5': ModelTypes.ROBUST_INTV,
'resnet50': ModelTypes.STANDARD,
'resnet50-randomized_smoothing_noise_0.00': ModelTypes.STANDARD,
'resnet50-randomized_smoothing_noise_0.25': ModelTypes.LP_ADV,
'resnet50-randomized_smoothing_noise_0.50': ModelTypes.LP_ADV,
'resnet50-randomized_smoothing_noise_1.00': ModelTypes.LP_ADV,
'resnet50-smoothing_adversarial_DNN_2steps_eps_512_noise_0.25': ModelTypes.LP_ADV,
'resnet50-smoothing_adversarial_DNN_2steps_eps_512_noise_0.50': ModelTypes.LP_ADV,
'resnet50-smoothing_adversarial_DNN_2steps_eps_512_noise_1.00': ModelTypes.LP_ADV,
'resnet50-smoothing_adversarial_PGD_1step_eps_512_noise_0.25': ModelTypes.LP_ADV,
'resnet50-smoothing_adversarial_PGD_1step_eps_512_noise_0.50': ModelTypes.LP_ADV,
'resnet50-smoothing_adversarial_PGD_1step_eps_512_noise_1.00': ModelTypes.LP_ADV,
'resnet50-vtab': ModelTypes.STANDARD,
'resnet50-vtab-exemplar': ModelTypes.STANDARD,
'resnet50-vtab-rotation': ModelTypes.STANDARD,
'resnet50-vtab-semi-exemplar': ModelTypes.STANDARD,
'resnet50-vtab-semi-rotation': ModelTypes.STANDARD,
'resnet50_adv-train-free': ModelTypes.LP_ADV,
'resnet50_augmix': ModelTypes.ROBUST_INTV,
'resnet50_aws_baseline': ModelTypes.STANDARD,
'resnet50_clip_zeroshot': ModelTypes.MORE_DATA,
'resnet50_cutmix': ModelTypes.ROBUST_INTV,
'resnet50_cutout': ModelTypes.ROBUST_INTV,
'resnet50_deepaugment': ModelTypes.ROBUST_INTV,
'resnet50_deepaugment_augmix': ModelTypes.ROBUST_INTV,
'resnet50_feature_cutmix': ModelTypes.ROBUST_INTV,
'resnet50_imagenet_100percent_batch64_original_images': ModelTypes.STANDARD,
'resnet50_imagenet_subsample_125_classes_batch64_original_images': ModelTypes.LOW_ACCURACY_CNN,
'resnet50_imagenet_subsample_1_of_16_batch64_original_images': ModelTypes.LOW_ACCURACY_CNN,
'resnet50_imagenet_subsample_1_of_2_batch64_original_images': ModelTypes.LOW_ACCURACY_CNN,
'resnet50_imagenet_subsample_1_of_32_batch64_original_images': ModelTypes.LOW_ACCURACY_CNN,
'resnet50_imagenet_subsample_1_of_4_batch64_original_images': ModelTypes.LOW_ACCURACY_CNN,
'resnet50_imagenet_subsample_1_of_8_batch64_original_images': ModelTypes.LOW_ACCURACY_CNN,
'resnet50_imagenet_subsample_250_classes_batch64_original_images': ModelTypes.LOW_ACCURACY_CNN,
'resnet50_imagenet_subsample_500_classes_batch64_original_images': ModelTypes.LOW_ACCURACY_CNN,
'resnet50_l2_eps3_robust': ModelTypes.LP_ADV,
'resnet50_linf_eps4_robust': ModelTypes.LP_ADV,
'resnet50_linf_eps8_robust': ModelTypes.LP_ADV,
'resnet50_lpf2': ModelTypes.ROBUST_INTV,
'resnet50_lpf3': ModelTypes.ROBUST_INTV,
'resnet50_lpf5': ModelTypes.ROBUST_INTV,
'resnet50_mixup': ModelTypes.ROBUST_INTV,
'resnet50_simclrv2_linear_probe_tf_port': ModelTypes.STANDARD,
'resnet50_simclrv2_finetuned_100pct_tf_port': ModelTypes.STANDARD,
'resnet50_simsiam': ModelTypes.STANDARD,
'resnet50_ssl': ModelTypes.MORE_DATA,
'resnet50_swav': ModelTypes.STANDARD,
'resnet50_swsl': ModelTypes.MORE_DATA,
'resnet50_trained_on_SIN': ModelTypes.ROBUST_INTV,
'resnet50_trained_on_SIN_and_IN': ModelTypes.ROBUST_INTV,
'resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN': ModelTypes.ROBUST_INTV,
'resnet50_with_brightness_aws': ModelTypes.ROBUST_INTV,
'resnet50_with_contrast_aws': ModelTypes.ROBUST_INTV,
'resnet50_with_defocus_blur_aws': ModelTypes.ROBUST_INTV,
'resnet50_with_fog_aws': ModelTypes.ROBUST_INTV,
'resnet50_with_frost_aws': ModelTypes.ROBUST_INTV,
'resnet50_with_gaussian_noise_aws': ModelTypes.ROBUST_INTV,
'resnet50_with_gaussian_noise_contrast_motion_blur_jpeg_compression_aws': ModelTypes.ROBUST_INTV,
'resnet50_with_greyscale_aws': ModelTypes.ROBUST_INTV,
'resnet50_with_jpeg_compression_aws': ModelTypes.ROBUST_INTV,
'resnet50_with_motion_blur_aws': ModelTypes.ROBUST_INTV,
'resnet50_with_pixelate_aws': ModelTypes.ROBUST_INTV,
'resnet50_with_saturate_aws': ModelTypes.ROBUST_INTV,
'resnet50_with_spatter_aws': ModelTypes.ROBUST_INTV,
'resnet50_with_zoom_blur_aws': ModelTypes.ROBUST_INTV,
'resnext101_32x16d_ssl': ModelTypes.MORE_DATA,
'resnext101_32x4d': ModelTypes.STANDARD,
'resnext101_32x4d_ssl': ModelTypes.MORE_DATA,
'resnext101_32x4d_swsl': ModelTypes.MORE_DATA,
'resnext101_32x8d': ModelTypes.STANDARD,
'resnext101_32x8d_deepaugment_augmix': ModelTypes.ROBUST_INTV,
'resnext101_32x8d_ssl': ModelTypes.MORE_DATA,
'resnext101_32x8d_swsl': ModelTypes.MORE_DATA,
'resnext101_64x4d': ModelTypes.STANDARD,
'resnext50_32x4d': ModelTypes.STANDARD,
'resnext50_32x4d_ssl': ModelTypes.MORE_DATA,
'resnext50_32x4d_swsl': ModelTypes.MORE_DATA,
'se_resnet101': ModelTypes.STANDARD,
'se_resnet152': ModelTypes.STANDARD,
'se_resnet50': ModelTypes.STANDARD,
'se_resnext101_32x4d': ModelTypes.STANDARD,
'se_resnext50_32x4d': ModelTypes.STANDARD,
'senet154': ModelTypes.STANDARD,
'shufflenet_v2_x0_5': ModelTypes.STANDARD,
'shufflenet_v2_x1_0': ModelTypes.STANDARD,
'squeezenet1_0': ModelTypes.STANDARD,
'squeezenet1_1': ModelTypes.STANDARD,
'vit_b_32_clip_zeroshot': ModelTypes.MORE_DATA,
'vgg11': ModelTypes.STANDARD,
'vgg11_bn': ModelTypes.STANDARD,
'vgg13': ModelTypes.STANDARD,
'vgg13_bn': ModelTypes.STANDARD,
'vgg16': ModelTypes.STANDARD,
'vgg16_bn': ModelTypes.STANDARD,
'vgg16_bn_lpf2': ModelTypes.ROBUST_INTV,
'vgg16_bn_lpf3': ModelTypes.ROBUST_INTV,
'vgg16_bn_lpf5': ModelTypes.ROBUST_INTV,
'vgg16_lpf2': ModelTypes.ROBUST_INTV,
'vgg16_lpf3': ModelTypes.ROBUST_INTV,
'vgg16_lpf5': ModelTypes.ROBUST_INTV,
'vgg19': ModelTypes.STANDARD,
'vgg19_bn': ModelTypes.STANDARD,
'vit_small_patch16_224': ModelTypes.STANDARD,
'vit_base_patch16_224': ModelTypes.MORE_DATA,
'vit_base_patch16_384': ModelTypes.MORE_DATA,
'vit_base_patch32_384': ModelTypes.MORE_DATA,
'vit_large_patch16_224': ModelTypes.MORE_DATA,
'vit_large_patch16_384': ModelTypes.MORE_DATA,
'vit_large_patch32_384': ModelTypes.MORE_DATA,
'wide_resnet101_2': ModelTypes.STANDARD,
'wide_resnet50_2': ModelTypes.STANDARD,
'xception': ModelTypes.STANDARD,
'resnet50_lstsq': ModelTypes.RANDOM_FEATURES,
'identity32_lstsq': ModelTypes.LINEAR_PIXELS,
'identity32_random_forests': ModelTypes.RANDOM_FORESTS,
'identity32_one_nn': ModelTypes.NEAREST_NEIGHBORS
}
for i in range(100):
model_types_map[f"resnet18_50k_{i}_epochs"] = ModelTypes.LOW_ACCURACY_CNN
for i in range(50):
model_types_map[f"resnet18_100k_{i}_epochs"] = ModelTypes.LOW_ACCURACY_CNN
for i in range(10):
model_types_map[f"resnet101_{i}_epochs"] = ModelTypes.LOW_ACCURACY_CNN
class NatModelTypes(Enum):
STANDARD = ('Standard training', 'tab:blue')
ROBUST_INTV = ('Robustness intervention', 'tab:brown')
MORE_DATA = ('Trained with more data', 'tab:green')
RANDOM_FEATURES = ('Random Features', 'tab:olive')
LINEAR_PIXELS = ('Linear Classifier on Pixels', 'tab:orange')
RANDOM_FORESTS = ('Random Forests', 'tab:pink')
NEAREST_NEIGHBORS = ('Nearest Neighbors', 'tab:purple')
LOW_ACCURACY_CNN = ('Low Accuracy CNN', 'tab:cyan')
mapper = {
ModelTypes.STANDARD: NatModelTypes.STANDARD,
ModelTypes.LP_ADV: NatModelTypes.ROBUST_INTV,
ModelTypes.ROBUST_INTV: NatModelTypes.ROBUST_INTV,
ModelTypes.MORE_DATA: NatModelTypes.MORE_DATA,
ModelTypes.RANDOM_FEATURES: NatModelTypes.RANDOM_FEATURES,
ModelTypes.LINEAR_PIXELS: NatModelTypes.LINEAR_PIXELS,
ModelTypes.RANDOM_FORESTS: NatModelTypes.RANDOM_FORESTS,
ModelTypes.NEAREST_NEIGHBORS : NatModelTypes.NEAREST_NEIGHBORS,
ModelTypes.LOW_ACCURACY_CNN : NatModelTypes.LOW_ACCURACY_CNN
}
nat_model_types_map = {k: mapper[v] for k, v in model_types_map.items()}