@@ -208,10 +208,8 @@ def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_pr
208
208
tmp_classes = 8631
209
209
elif pretrained == 'casia-webface' :
210
210
tmp_classes = 10575
211
- elif pretrained is None and self .num_classes is None :
212
- raise Exception ('At least one of "pretrained" or "num_classes" must be specified' )
213
- else :
214
- tmp_classes = self .num_classes
211
+ elif pretrained is None and self .classify and self .num_classes is None :
212
+ raise Exception ('If "pretrained" is not specified and "classify" is True, "num_classes" must be specified' )
215
213
216
214
217
215
# Define layers
@@ -255,12 +253,12 @@ def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_pr
255
253
self .dropout = nn .Dropout (dropout_prob )
256
254
self .last_linear = nn .Linear (1792 , 512 , bias = False )
257
255
self .last_bn = nn .BatchNorm1d (512 , eps = 0.001 , momentum = 0.1 , affine = True )
258
- self .logits = nn .Linear (512 , tmp_classes )
259
256
260
257
if pretrained is not None :
258
+ self .logits = nn .Linear (512 , tmp_classes )
261
259
load_weights (self , pretrained )
262
260
263
- if self .num_classes is not None :
261
+ if self .classify and self . num_classes is not None :
264
262
self .logits = nn .Linear (512 , self .num_classes )
265
263
266
264
self .device = torch .device ('cpu' )
0 commit comments