Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Implement PRNet #16

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchvision_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# The code structure is based on an older version of TorchVision
from .resnet import *
from . import segmentation
24 changes: 24 additions & 0 deletions torchvision_models/lane_detection/common_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,30 @@ def forward(self, input):
return output


# SCNN_D (more efficient and used by a lot of people nowadays)
class SCNN_D(nn.Module):
def __init__(self, num_channels=128):
super().__init__()
self.conv_d = nn.Conv2d(num_channels, num_channels, (1, 9), padding=(0, 4))
self._adjust_initializations(num_channels=num_channels)

def _adjust_initializations(self, num_channels=128):
# https://github.com/XingangPan/SCNN/issues/82
bound = math.sqrt(2.0 / (num_channels * 9 * 5))
nn.init.uniform_(self.conv_d.weight, -bound, bound)

def forward(self, input):
output = input

# First one remains unchanged (according to the original paper), why not add a relu afterwards?
# Update and send to next
# Down
for i in range(1, output.shape[2]):
output[:, :, i:i + 1, :].add_(F.relu(self.conv_d(output[:, :, i - 1:i, :])))

return output


# Typical lane existence head originated from the SCNN paper
class SimpleLaneExist(nn.Module):
def __init__(self, num_output, flattened_size=4500):
Expand Down
71 changes: 71 additions & 0 deletions torchvision_models/lane_detection/prnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Implementation of Polynomial Regression Network based on the original paper (PRNet):
# http://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123630698.pdf
import torch.nn as nn
from collections import OrderedDict
from common_models import RESAReducer, SCNN_D
from .. import resnet
from ..segmentation import erfnet_resnet
from .._utils import IntermediateLayerGetter


# One convolution layer for each branch
# The kernel size 3x3 is an educated guess, the 3 branches are implemented separately for future flexibility
class PolynomialBranch(nn.Module):
def __init__(self, in_channels, order=2):
super(PolynomialBranch, self).__init__()
self.conv = nn.Conv2d(in_channels, order + 1, kernel_size=3, stride=1, padding=1, bias=False)

def forward(self, inputs):
return self.conv(inputs)


class InitializationBranch(nn.Module):
def __init__(self, in_channels):
super(InitializationBranch, self).__init__()
self.conv = nn.Conv2d(in_channels, 1, kernel_size=3, stride=1, padding=1, bias=False)

def forward(self, inputs):
return self.conv(inputs)


class HeightBranch(nn.Module):
def __init__(self, in_channels):
super(HeightBranch, self).__init__()
self.conv = nn.Conv2d(in_channels, 1, kernel_size=3, stride=1, padding=1, bias=False)

def forward(self, inputs):
return self.conv(inputs)


# Currently supported backbones: ERFNet, ResNets
class PRNet(nn.Module):
def __init__(self, backbone_name, dropout_1=0.3, dropout_2=0.03, order=2):
super(PRNet, self).__init__()
if backbone_name == 'erfnet':
self.backbone = erfnet_resnet(dropout_1=dropout_1, dropout_2=dropout_2, encoder_only=True)
in_channels = 128
else:
in_channels = 2048 if backbone_name == 'resnet50' or backbone_name == 'resnet101' else 512
backbone = resnet.__dict__[backbone_name](
pretrained=True,
replace_stride_with_dilation=[False, True, True])
return_layers = {'layer4': 'out'}
self.backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

self.channel_reducer = RESAReducer(in_channels=in_channels)
self.spatial_conv = SCNN_D()
self.polynomial_branch = PolynomialBranch(in_channels=128, order=order)
self.initialization_branch = InitializationBranch(in_channels=128)
self.height_branch = HeightBranch(in_channels=128)

def forward(self, inputs):
# Encoder (8x down-sampling) -> channel reduction (128, another educated guess) -> SCNN_D -> 3 branches
out = OrderedDict()
x = self.backbone(inputs)
x = self.channel_reducer(x)
x = self.spatial_conv(x)
out['polynomials'] = self.polynomial_branch(x)
out['initializations'] = self.initialization_branch(x)
out['heights'] = self.height_branch(x)

return out
2 changes: 1 addition & 1 deletion torchvision_models/segmentation/enet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.nn.parameter import Parameter
from ..lane_detection.common_models import EDLaneExist


class InitialBlock(nn.Module):
"""The initial block is composed of two branches:
1. a main branch which performs a regular convolution with stride 2;
Expand Down Expand Up @@ -696,7 +697,6 @@ def forward(self, x):
stage2_input_size, input_size)
out['out'] = x


return out

# net = ENet(num_classes=19,encoder_only=True)
Expand Down
20 changes: 14 additions & 6 deletions torchvision_models/segmentation/erfnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,15 @@ def forward(self, input):
# ERFNet
class ERFNet(nn.Module):
def __init__(self, num_classes, encoder=None, num_lanes=0, dropout_1=0.03, dropout_2=0.3, flattened_size=3965,
scnn=False):
scnn=False, encoder_only=False):
super().__init__()
if encoder is None:
self.encoder = Encoder(num_classes=num_classes, dropout_1=dropout_1, dropout_2=dropout_2)
else:
self.encoder = encoder

self.decoder = Decoder(num_classes)
# Only encoder (to be used as backbone)
self.decoder = None if encoder_only else Decoder(num_classes)

if scnn:
self.spatial_conv = SpatialConv()
Expand All @@ -187,16 +188,23 @@ def __init__(self, num_classes, encoder=None, num_lanes=0, dropout_1=0.03, dropo
else:
self.lane_classifier = None

def forward(self, input, only_encode=False):
def forward(self, inputs, only_encode=False):
# only_encode=True is for the pre-training step of 2-step segmentation training,
# in order to match with the original implementation.
# If encoder is used as feature extractor, set encoder_only=True in class init, but do not change this variable
out = OrderedDict()
if only_encode:
return self.encoder.forward(input, predict=True)
return self.encoder.forward(inputs, predict=True)
else:
output = self.encoder(input) # predict=False by default
output = self.encoder(inputs) # predict=False by default

if self.spatial_conv is not None:
output = self.spatial_conv(output)
out['out'] = self.decoder.forward(output)

if self.decoder is not None:
out['out'] = self.decoder.forward(output)

if self.lane_classifier is not None:
out['lane'] = self.lane_classifier(output)

return out
6 changes: 4 additions & 2 deletions torchvision_models/segmentation/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,16 @@ def deeplabv3_resnet101(pretrained=False, progress=True,


def erfnet_resnet(pretrained_weights='erfnet_encoder_pretrained.pth.tar', num_classes=19, num_lanes=0,
dropout_1=0.03, dropout_2=0.3, flattened_size=3965, scnn=False):
dropout_1=0.03, dropout_2=0.3, flattened_size=3965, scnn=False, encoder_only=False):
"""Constructs a ERFNet model with ResNet-style backbone.

Args:
pretrained_weights (str): If not None, load ImageNet pre-trained weights from this filename
encoder_only (bool): If True, only encoder is returned as a feature extractor, ImageNet weights loading
will not be affected
"""
net = ERFNet(num_classes=num_classes, encoder=None, num_lanes=num_lanes, dropout_1=dropout_1, dropout_2=dropout_2,
flattened_size=flattened_size, scnn=scnn)
flattened_size=flattened_size, scnn=scnn, encoder_only=encoder_only)
if pretrained_weights is not None: # Load ImageNet pre-trained weights
saved_weights = load(pretrained_weights)['state_dict']
original_weights = net.state_dict()
Expand Down
46 changes: 46 additions & 0 deletions utils/losses/pr_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Loss for PRNet
import torch
from torch.nn import functional as F
from ._utils import WeightedLoss


def polynomial_curve_without_projection(coefficients, y):
# Arbitrary polynomial curve function
# Return x coordinates
# coefficients: [d1, d2, ... , m]
# m: number of coefficients, order increasing
# y: [d1, d2, ... , N]
y = y.permute(-1, *[i for i in range(len(y.shape) - 1)])
x = coefficients[..., 0]
for i in range(1, coefficients.shape[-1]):
x += coefficients[..., i] * y ** i

return x.permute(*[i + 1 for i in range(len(x.shape) - 1)], 0) # [d1, d2, ... , N]


class PRLoss(WeightedLoss):
__constants__ = ['reduction']
ignore_index: int

def __init__(self, polynomial_weight=1, initialization_weight=1, height_weight=0.1, beta=0.005, m=20,
weight=None, size_average=None, reduce=None, reduction='mean'):
super(PRLoss, self).__init__(weight, size_average, reduce, reduction)
self.polynomial_weight = polynomial_weight
self.initialization_weight = initialization_weight
self.height_weight = height_weight
self.beta = beta # Beta for smoothed L1 loss
self.m = m # Number of sample points to calculate polynomial regression loss

def forward(self, inputs, targets, masks, net):
# masks: True for polynomial points (which have height & polynomial regression losses)
outputs = net(inputs)

pass

@staticmethod
def beta_smoothed_l1_loss(inputs, targets, beta=0.005):
# Smoothed L1 loss with a hyper-parameter (as in PRNet paper)
# The original torch F.smooth_l1_loss() is equivalent to beta=1
t = torch.abs(inputs - targets)

return torch.where(t < beta, 0.5 * t ** 2 / beta, t - 0.5 * beta)