Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions robosat/scse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Squeeze and Excitation blocks - attention for classification and segmentation

See:
- https://arxiv.org/abs/1709.01507 - Squeeze-and-Excitation Networks
- https://arxiv.org/abs/1803.02579 - Concurrent Spatial and Channel 'Squeeze & Excitation' in Fully Convolutional Networks

"""

import torch
import torch.nn as nn


class SpatialSqChannelEx(nn.Module):
"""Spatial Squeeze and Channel Excitation (cSE) block
See https://arxiv.org/abs/1803.02579 Figure 1 b
"""

def __init__(self, num_in, r):
super().__init__()
self.fc0 = Conv1x1(num_in, num_in // r)
self.fc1 = Conv1x1(num_in // r, num_in)

def forward(self, x):
xx = nn.functional.adaptive_avg_pool2d(x, 1)
xx = self.fc0(xx)
xx = nn.functional.relu(xx, inplace=True)
xx = self.fc1(xx)
xx = torch.sigmoid(xx)
return x * xx


class ChannelSqSpatialEx(nn.Module):
"""Channel Squeeze and Spatial Excitation (sSE) block
See https://arxiv.org/abs/1803.02579 Figure 1 c
"""

def __init__(self, num_in):
super().__init__()
self.conv = Conv1x1(num_in, 1)

def forward(self, x):
xx = self.conv(x)
xx = torch.sigmoid(xx)
return x * xx


class SpatialChannelSqChannelEx(nn.Module):
"""Concurrent Spatial and Channel Squeeze and Channel Excitation (csSE) block
See https://arxiv.org/abs/1803.02579 Figure 1 d
"""

def __init__(self, num_in, r=16):
super().__init__()

self.cse = SpatialSqChannelEx(num_in, r)
self.sse = ChannelSqSpatialEx(num_in)

def forward(self, x):
return self.cse(x) + self.sse(x)


def Conv1x1(num_in, num_out):
return nn.Conv2d(num_in, num_out, kernel_size=1, bias=False)
38 changes: 27 additions & 11 deletions robosat/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from torchvision.models import resnet50

from robosat.scse import SpatialChannelSqChannelEx


class ConvRelu(nn.Module):
"""3x3 convolution followed by ReLU activation building block.
Expand Down Expand Up @@ -91,10 +93,23 @@ def __init__(self, num_classes, num_filters=32, pretrained=True):

# Todo: make input channels configurable, not hard-coded to three channels for RGB

self.resnet = resnet50(pretrained=pretrained)

# Access resnet directly in forward pass; do not store refs here due to
# https://github.com/pytorch/pytorch/issues/8392
self.resnet = resnet50(pretrained=pretrained)

# seSE blocks to append to encoder and decoder as recommended by
# https://arxiv.org/abs/1803.02579
self.scse0 = SpatialChannelSqChannelEx(64)
self.scse1 = SpatialChannelSqChannelEx(256)
self.scse2 = SpatialChannelSqChannelEx(512)
self.scse3 = SpatialChannelSqChannelEx(1024)
self.scse4 = SpatialChannelSqChannelEx(2048)

self.scse5 = SpatialChannelSqChannelEx(num_filters * 8)
self.scse6 = SpatialChannelSqChannelEx(num_filters * 8)
self.scse7 = SpatialChannelSqChannelEx(num_filters * 2)
self.scse8 = SpatialChannelSqChannelEx(num_filters * 2 * 2)
self.scse9 = SpatialChannelSqChannelEx(num_filters)

self.center = DecoderBlock(2048, num_filters * 8)

Expand Down Expand Up @@ -122,20 +137,21 @@ def forward(self, x):
enc0 = self.resnet.conv1(x)
enc0 = self.resnet.bn1(enc0)
enc0 = self.resnet.relu(enc0)
enc0 = self.scse0(enc0)
enc0 = self.resnet.maxpool(enc0)

enc1 = self.resnet.layer1(enc0)
enc2 = self.resnet.layer2(enc1)
enc3 = self.resnet.layer3(enc2)
enc4 = self.resnet.layer4(enc3)
enc1 = self.scse1(self.resnet.layer1(enc0))
enc2 = self.scse2(self.resnet.layer2(enc1))
enc3 = self.scse3(self.resnet.layer3(enc2))
enc4 = self.scse4(self.resnet.layer4(enc3))

center = self.center(nn.functional.max_pool2d(enc4, kernel_size=2, stride=2))

dec0 = self.dec0(torch.cat([enc4, center], dim=1))
dec1 = self.dec1(torch.cat([enc3, dec0], dim=1))
dec2 = self.dec2(torch.cat([enc2, dec1], dim=1))
dec3 = self.dec3(torch.cat([enc1, dec2], dim=1))
dec4 = self.dec4(dec3)
dec0 = self.scse5(self.dec0(torch.cat([enc4, center], dim=1)))
dec1 = self.scse6(self.dec1(torch.cat([enc3, dec0], dim=1)))
dec2 = self.scse7(self.dec2(torch.cat([enc2, dec1], dim=1)))
dec3 = self.scse8(self.dec3(torch.cat([enc1, dec2], dim=1)))
dec4 = self.scse9(self.dec4(dec3))
dec5 = self.dec5(dec4)

return self.final(dec5)