Skip to content

Commit bab07bb

Browse files
authored
Update inc_net.py
1 parent 964c2a8 commit bab07bb

File tree

1 file changed

+3
-80
lines changed

1 file changed

+3
-80
lines changed

Diff for: utils/inc_net.py

+3-80
Original file line numberDiff line numberDiff line change
@@ -11,34 +11,8 @@
1111
from convs.linears import SimpleLinear, SplitCosineLinear, CosineLinear
1212
from convs.modified_represnet import resnet18_rep,resnet34_rep
1313
from convs.resnet_cbam import resnet18_cbam,resnet34_cbam,resnet50_cbam
14-
15-
# FOR MEMO
16-
from convs.memo_resnet import get_resnet18_imagenet as get_memo_resnet18 #for imagenet
17-
from convs.memo_cifar_resnet import get_resnet32_a2fc as get_memo_resnet32 #for cifar
18-
19-
# FOR AUC & DER
20-
from convs.conv_cifar import conv2 as conv2_cifar
21-
from convs.cifar_resnet import resnet14 as resnet14_cifar
22-
from convs.cifar_resnet import resnet20 as resnet20_cifar
23-
from convs.cifar_resnet import resnet26 as resnet26_cifar
24-
25-
from convs.conv_imagenet import conv4 as conv4_imagenet
26-
from convs.resnet import resnet10 as resnet10_imagenet
27-
from convs.resnet import resnet26 as resnet26_imagenet
28-
from convs.resnet import resnet34 as resnet34_imagenet
29-
from convs.resnet import resnet50 as resnet50_imagenet
30-
31-
# FOR AUC & MEMO
32-
from convs.conv_cifar import get_conv_a2fc as memo_conv2_cifar
33-
from convs.memo_cifar_resnet import get_resnet14_a2fc as memo_resnet14_cifar
34-
from convs.memo_cifar_resnet import get_resnet20_a2fc as memo_resnet20_cifar
35-
from convs.memo_cifar_resnet import get_resnet26_a2fc as memo_resnet26_cifar
36-
37-
from convs.conv_imagenet import conv_a2fc_imagenet as memo_conv4_imagenet
38-
from convs.memo_resnet import get_resnet10_imagenet as memo_resnet10_imagenet
39-
from convs.memo_resnet import get_resnet26_imagenet as memo_resnet26_imagenet
40-
from convs.memo_resnet import get_resnet34_imagenet as memo_resnet34_imagenet
41-
from convs.memo_resnet import get_resnet50_imagenet as memo_resnet50_imagenet
14+
from convs.memo_resnet import get_resnet18_imagenet as get_memo_resnet18 #for MEMO imagenet
15+
from convs.memo_cifar_resnet import get_resnet32_a2fc as get_memo_resnet32 #for MEMO cifar
4216

4317
def get_convnet(args, pretrained=False):
4418
name = args["convnet_type"].lower()
@@ -75,57 +49,6 @@ def get_convnet(args, pretrained=False):
7549
_basenet, _adaptive_net = get_memo_resnet32()
7650
return _basenet, _adaptive_net
7751

78-
# AUC
79-
## cifar
80-
elif name == 'conv2':
81-
return conv2_cifar()
82-
elif name == 'resnet14_cifar':
83-
return resnet14_cifar()
84-
elif name == 'resnet20_cifar':
85-
return resnet20_cifar()
86-
elif name == 'resnet26_cifar':
87-
return resnet26_cifar()
88-
89-
elif name == 'memo_conv2':
90-
g_blocks, s_blocks = memo_conv2_cifar() # generalized/specialized
91-
return g_blocks, s_blocks
92-
elif name == 'memo_resnet14_cifar':
93-
g_blocks, s_blocks = memo_resnet14_cifar() # generalized/specialized
94-
return g_blocks, s_blocks
95-
elif name == 'memo_resnet20_cifar':
96-
g_blocks, s_blocks = memo_resnet20_cifar() # generalized/specialized
97-
return g_blocks, s_blocks
98-
elif name == 'memo_resnet26_cifar':
99-
g_blocks, s_blocks = memo_resnet26_cifar() # generalized/specialized
100-
return g_blocks, s_blocks
101-
102-
## imagenet
103-
elif name == 'conv4':
104-
return conv4_imagenet()
105-
elif name == 'resnet10_imagenet':
106-
return resnet10_imagenet()
107-
elif name == 'resnet26_imagenet':
108-
return resnet26_imagenet()
109-
elif name == 'resnet34_imagenet':
110-
return resnet34_imagenet()
111-
elif name == 'resnet50_imagenet':
112-
return resnet50_imagenet()
113-
114-
elif name == 'memo_conv4':
115-
g_blcoks, s_blocks = memo_conv4_imagenet()
116-
return g_blcoks, s_blocks
117-
elif name == 'memo_resnet10_imagenet':
118-
g_blcoks, s_blocks = memo_resnet10_imagenet()
119-
return g_blcoks, s_blocks
120-
elif name == 'memo_resnet26_imagenet':
121-
g_blcoks, s_blocks = memo_resnet26_imagenet()
122-
return g_blcoks, s_blocks
123-
elif name == 'memo_resnet34_imagenet':
124-
g_blocks, s_blocks = memo_resnet34_imagenet()
125-
return g_blocks, s_blocks
126-
elif name == 'memo_resnet50_imagenet':
127-
g_blcoks, s_blocks = memo_resnet50_imagenet()
128-
return g_blcoks, s_blocks
12952
else:
13053
raise NotImplementedError("Unknown type {}".format(name))
13154

@@ -894,4 +817,4 @@ def load_checkpoint(self, args):
894817
self.AdaptiveExtractors[0].load_state_dict(adap_state_dict)
895818
self.fc.load_state_dict(model_infos['fc'])
896819
test_acc = model_infos['test_acc']
897-
return test_acc
820+
return test_acc

0 commit comments

Comments
 (0)