11
11
from convs .linears import SimpleLinear , SplitCosineLinear , CosineLinear
12
12
from convs .modified_represnet import resnet18_rep ,resnet34_rep
13
13
from convs .resnet_cbam import resnet18_cbam ,resnet34_cbam ,resnet50_cbam
14
-
15
- # FOR MEMO
16
- from convs .memo_resnet import get_resnet18_imagenet as get_memo_resnet18 #for imagenet
17
- from convs .memo_cifar_resnet import get_resnet32_a2fc as get_memo_resnet32 #for cifar
18
-
19
- # FOR AUC & DER
20
- from convs .conv_cifar import conv2 as conv2_cifar
21
- from convs .cifar_resnet import resnet14 as resnet14_cifar
22
- from convs .cifar_resnet import resnet20 as resnet20_cifar
23
- from convs .cifar_resnet import resnet26 as resnet26_cifar
24
-
25
- from convs .conv_imagenet import conv4 as conv4_imagenet
26
- from convs .resnet import resnet10 as resnet10_imagenet
27
- from convs .resnet import resnet26 as resnet26_imagenet
28
- from convs .resnet import resnet34 as resnet34_imagenet
29
- from convs .resnet import resnet50 as resnet50_imagenet
30
-
31
- # FOR AUC & MEMO
32
- from convs .conv_cifar import get_conv_a2fc as memo_conv2_cifar
33
- from convs .memo_cifar_resnet import get_resnet14_a2fc as memo_resnet14_cifar
34
- from convs .memo_cifar_resnet import get_resnet20_a2fc as memo_resnet20_cifar
35
- from convs .memo_cifar_resnet import get_resnet26_a2fc as memo_resnet26_cifar
36
-
37
- from convs .conv_imagenet import conv_a2fc_imagenet as memo_conv4_imagenet
38
- from convs .memo_resnet import get_resnet10_imagenet as memo_resnet10_imagenet
39
- from convs .memo_resnet import get_resnet26_imagenet as memo_resnet26_imagenet
40
- from convs .memo_resnet import get_resnet34_imagenet as memo_resnet34_imagenet
41
- from convs .memo_resnet import get_resnet50_imagenet as memo_resnet50_imagenet
14
+ from convs .memo_resnet import get_resnet18_imagenet as get_memo_resnet18 #for MEMO imagenet
15
+ from convs .memo_cifar_resnet import get_resnet32_a2fc as get_memo_resnet32 #for MEMO cifar
42
16
43
17
def get_convnet (args , pretrained = False ):
44
18
name = args ["convnet_type" ].lower ()
@@ -75,57 +49,6 @@ def get_convnet(args, pretrained=False):
75
49
_basenet , _adaptive_net = get_memo_resnet32 ()
76
50
return _basenet , _adaptive_net
77
51
78
- # AUC
79
- ## cifar
80
- elif name == 'conv2' :
81
- return conv2_cifar ()
82
- elif name == 'resnet14_cifar' :
83
- return resnet14_cifar ()
84
- elif name == 'resnet20_cifar' :
85
- return resnet20_cifar ()
86
- elif name == 'resnet26_cifar' :
87
- return resnet26_cifar ()
88
-
89
- elif name == 'memo_conv2' :
90
- g_blocks , s_blocks = memo_conv2_cifar () # generalized/specialized
91
- return g_blocks , s_blocks
92
- elif name == 'memo_resnet14_cifar' :
93
- g_blocks , s_blocks = memo_resnet14_cifar () # generalized/specialized
94
- return g_blocks , s_blocks
95
- elif name == 'memo_resnet20_cifar' :
96
- g_blocks , s_blocks = memo_resnet20_cifar () # generalized/specialized
97
- return g_blocks , s_blocks
98
- elif name == 'memo_resnet26_cifar' :
99
- g_blocks , s_blocks = memo_resnet26_cifar () # generalized/specialized
100
- return g_blocks , s_blocks
101
-
102
- ## imagenet
103
- elif name == 'conv4' :
104
- return conv4_imagenet ()
105
- elif name == 'resnet10_imagenet' :
106
- return resnet10_imagenet ()
107
- elif name == 'resnet26_imagenet' :
108
- return resnet26_imagenet ()
109
- elif name == 'resnet34_imagenet' :
110
- return resnet34_imagenet ()
111
- elif name == 'resnet50_imagenet' :
112
- return resnet50_imagenet ()
113
-
114
- elif name == 'memo_conv4' :
115
- g_blcoks , s_blocks = memo_conv4_imagenet ()
116
- return g_blcoks , s_blocks
117
- elif name == 'memo_resnet10_imagenet' :
118
- g_blcoks , s_blocks = memo_resnet10_imagenet ()
119
- return g_blcoks , s_blocks
120
- elif name == 'memo_resnet26_imagenet' :
121
- g_blcoks , s_blocks = memo_resnet26_imagenet ()
122
- return g_blcoks , s_blocks
123
- elif name == 'memo_resnet34_imagenet' :
124
- g_blocks , s_blocks = memo_resnet34_imagenet ()
125
- return g_blocks , s_blocks
126
- elif name == 'memo_resnet50_imagenet' :
127
- g_blcoks , s_blocks = memo_resnet50_imagenet ()
128
- return g_blcoks , s_blocks
129
52
else :
130
53
raise NotImplementedError ("Unknown type {}" .format (name ))
131
54
@@ -894,4 +817,4 @@ def load_checkpoint(self, args):
894
817
self .AdaptiveExtractors [0 ].load_state_dict (adap_state_dict )
895
818
self .fc .load_state_dict (model_infos ['fc' ])
896
819
test_acc = model_infos ['test_acc' ]
897
- return test_acc
820
+ return test_acc
0 commit comments