-
Notifications
You must be signed in to change notification settings - Fork 486
Description
I new a fake input to test each layer's output shape, but I got some error. Could anyone give me some suggestion or the right code?
import torch
import torch.nn as nn
from structure.builder import Builder
from structure.model import SegDetectorModel
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_args = {
'backbone': 'deformable_resnet18', # 选择的backbone
'decoder': 'SegDetector', # 选择的decoder
'decoder_args': {
'adaptive': True,
'in_channels': [64, 128, 256, 512],
'k': 50
},
'loss_class': 'L1BalanceCELoss', # 选择的loss 类别
}
model = SegDetectorModel(model_args, device)
fake_input = torch.randn(size=(1, 3, 640, 640), dtype=torch.float32).to(device)
for names, layers in model.model.module.named_children():
for name, layer in layers.named_children():
fake_input = layer(fake_input)
print(f'{name} Shape: {fake_input.shape}')
the command output:
conv1 Shape: torch.Size([1, 64, 320, 320])
bn1 Shape: torch.Size([1, 64, 320, 320])
relu Shape: torch.Size([1, 64, 320, 320])
maxpool Shape: torch.Size([1, 64, 160, 160])
layer1 Shape: torch.Size([1, 64, 160, 160])
layer2 Shape: torch.Size([1, 128, 80, 80])
layer3 Shape: torch.Size([1, 256, 40, 40])
layer4 Shape: torch.Size([1, 512, 20, 20])
avgpool Shape: torch.Size([1, 512, 14, 14])
Traceback (most recent call last):
File "E:\Anaconda\envs\yolov5\lib\site-packages\IPython\core\interactiveshell.py", line 3460, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "", line 3, in
fake_input = layer(fake_input)
File "E:\Anaconda\envs\yolov5\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "E:\Anaconda\envs\yolov5\lib\site-packages\torch\nn\modules\linear.py", line 96, in forward
return F.linear(input, self.weight, self.bias)
File "E:\Anaconda\envs\yolov5\lib\site-packages\torch\nn\functional.py", line 1847, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (7168x14 and 512x1000)