Skip to content

Commit 5dcce36

Browse files
authored
Merge pull request #87 from gdahia/num-classes-not-required
Allow None "num_classes" if not "classify"
2 parents e5a30d7 + 5fd1d05 commit 5dcce36

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

models/inception_resnet_v1.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,8 @@ def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_pr
208208
tmp_classes = 8631
209209
elif pretrained == 'casia-webface':
210210
tmp_classes = 10575
211-
elif pretrained is None and self.num_classes is None:
212-
raise Exception('At least one of "pretrained" or "num_classes" must be specified')
213-
else:
214-
tmp_classes = self.num_classes
211+
elif pretrained is None and self.classify and self.num_classes is None:
212+
raise Exception('If "pretrained" is not specified and "classify" is True, "num_classes" must be specified')
215213

216214

217215
# Define layers
@@ -255,12 +253,12 @@ def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_pr
255253
self.dropout = nn.Dropout(dropout_prob)
256254
self.last_linear = nn.Linear(1792, 512, bias=False)
257255
self.last_bn = nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True)
258-
self.logits = nn.Linear(512, tmp_classes)
259256

260257
if pretrained is not None:
258+
self.logits = nn.Linear(512, tmp_classes)
261259
load_weights(self, pretrained)
262260

263-
if self.num_classes is not None:
261+
if self.classify and self.num_classes is not None:
264262
self.logits = nn.Linear(512, self.num_classes)
265263

266264
self.device = torch.device('cpu')

0 commit comments

Comments
 (0)