File tree 1 file changed +2
-0
lines changed
1 file changed +2
-0
lines changed Original file line number Diff line number Diff line change @@ -265,6 +265,7 @@ def test_model_default_cfgs(model_name, batch_size):
265
265
266
266
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
267
267
model .reset_classifier (0 )
268
+ assert model .num_classes == 0 , f'Expected num_classes to be 0 after reset_classifier(0), but got { model .num_classes } '
268
269
model .to (torch_device )
269
270
outputs = model .forward (input_tensor )
270
271
assert len (outputs .shape ) == 2
@@ -339,6 +340,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
339
340
340
341
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
341
342
model .reset_classifier (0 )
343
+ assert model .num_classes == 0 , f'Expected num_classes to be 0 after reset_classifier(0), but got { model .num_classes } '
342
344
model .to (torch_device )
343
345
outputs = model .forward (input_tensor )
344
346
if isinstance (outputs , (tuple , list )):
You can’t perform that action at this time.
0 commit comments