Skip to content

Commit 51f70df

Browse files
author
McCrearyD
committed
don't extract BN layers for list of weight matrices & wrote test
1 parent 1a8bbbb commit 51f70df

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

rigl_torch/util.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import torchvision
33

44

5+
EXCLUDED_TYPES = (torch.nn.BatchNorm2d, )
6+
7+
58
def get_weighted_layers(model, i=0, layers=None, linear_layers_mask=None):
69
if layers is None:
710
layers = []
@@ -16,18 +19,9 @@ def get_weighted_layers(model, i=0, layers=None, linear_layers_mask=None):
1619
if isinstance(p, torch.nn.Linear):
1720
layers.append([p])
1821
linear_layers_mask.append(1)
19-
elif hasattr(p, 'weight'):
22+
elif hasattr(p, 'weight') and type(p) not in EXCLUDED_TYPES:
2023
layers.append([p])
2124
linear_layers_mask.append(0)
22-
elif isinstance(p, torch.nn.BatchNorm2d):
23-
layers[-1].append(p)
24-
elif isinstance(p, torch.nn.ReLU):
25-
layers[-1].append(p)
26-
elif isinstance(p, torch.nn.MaxPool2d) or isinstance(p, torch.nn.AdaptiveAvgPool2d):
27-
layers[-1].append(p)
28-
elif layer_name == 'downsample':
29-
layers.append(p)
30-
linear_layers_mask.append(0)
3125
elif isinstance(p, torchvision.models.resnet.Bottleneck) or isinstance(p, torchvision.models.resnet.BasicBlock):
3226
_, linear_layers_mask, i = get_weighted_layers(p, i=i + 1, layers=layers, linear_layers_mask=linear_layers_mask)
3327
else:

tests/test_rigl.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@
2424
criterion = torch.nn.functional.cross_entropy
2525

2626

27+
def test_lengths_of_W():
28+
resnet18 = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=False)
29+
resnet18_W = get_W(resnet18)
30+
assert len(resnet18_W) == 21, 'resnet18 should have 21 "weight" matrices'
31+
32+
resnet50 = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=False)
33+
resnet50_W = get_W(resnet50)
34+
assert len(resnet50_W) == 54, 'resnet50 should have 54 "weight" matrices'
35+
36+
2737
def get_dummy_dataloader():
2838
X = torch.rand((max_iters, *image_dimensionality))
2939
T = (torch.rand(max_iters) * num_classes).long()

0 commit comments

Comments
 (0)