Skip to content

Commit bda46f8

Browse files
brianhou0208rwightman
authored andcommitted
Add num_classes assertion after reset_classifier
1 parent 17eabaa commit bda46f8

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

tests/test_models.py

+2
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ def test_model_default_cfgs(model_name, batch_size):
265265

266266
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
267267
model.reset_classifier(0)
268+
assert model.num_classes == 0, f'Expected num_classes to be 0 after reset_classifier(0), but got {model.num_classes}'
268269
model.to(torch_device)
269270
outputs = model.forward(input_tensor)
270271
assert len(outputs.shape) == 2
@@ -339,6 +340,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
339340

340341
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
341342
model.reset_classifier(0)
343+
assert model.num_classes == 0, f'Expected num_classes to be 0 after reset_classifier(0), but got {model.num_classes}'
342344
model.to(torch_device)
343345
outputs = model.forward(input_tensor)
344346
if isinstance(outputs, (tuple, list)):

0 commit comments

Comments
 (0)