Skip to content

Commit 384f30e

Browse files
authored
Add deeplabv3+b model for semantic segmentation (#924)
* new seg model for citys * pylint * add model store * remove inplace add * remove copy * add init * fix pylint * fix pylint * fix shape * add unit test * add init * model zoo
1 parent 9ea81b4 commit 384f30e

9 files changed

Lines changed: 634 additions & 7 deletions

File tree

docs/model_zoo/segmentation.rst

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,23 @@ Pascal VOC Dataset
7878
+-----------------------+-----------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
7979
| deeplab_resnet152_voc | DeepLabV3 [4]_ | N/A | 86.7_ | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_resnet152_voc.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_resnet152_voc.log>`_ |
8080
+-----------------------+-----------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
81-
| psp_resnet101_citys | PSP [3]_ | N/A | 77.1 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet101_city.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet101_city.log>`_ |
82-
+-----------------------+-----------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------+
8381

8482
.. _83.6: http://host.robots.ox.ac.uk:8080/anonymous/YB1AN5.html
8583
.. _85.1: http://host.robots.ox.ac.uk:8080/anonymous/9RTTZC.html
8684
.. _86.2: http://host.robots.ox.ac.uk:8080/anonymous/ZPN6II.html
8785
.. _86.7: http://host.robots.ox.ac.uk:8080/anonymous/XZEXL2.html
8886

87+
Cityscapes Dataset
88+
------------------
89+
90+
+-------------------------------------+-----------------+-----------+-----------+---------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------+
91+
| Name | Method | pixAcc | mIoU | Command | log |
92+
+=====================================+=================+===========+===========+=============================================================================================================================================+====================================================================================================================================+
93+
| psp_resnet101_citys | PSP [3]_ | N/A | 77.1 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet101_city.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/psp_resnet101_city.log>`_ |
94+
+-------------------------------------+-----------------+-----------+-----------+---------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------+
95+
| deeplab_v3b_plus_wideresnet_citys | VPLR [5]_ | N/A | 83.5 | `shell script <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_v3b_plus_wideresnet_citys.sh>`_ | `log <https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/logs/segmentation/deeplab_v3b_plus_wideresnet_citys.log>`_ |
96+
+-------------------------------------+-----------------+-----------+-----------+---------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------+
97+
8998

9099
Instance Segmentation
91100
~~~~~~~~~~~~~~~~~~~~~
@@ -125,12 +134,14 @@ MS COCO
125134
+------------------------------------+---------------------------+--------------------------+------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------+
126135

127136
.. [1] He, Kaming, Georgia Gkioxari, Piotr Dollár and Ross Girshick. \
128-
"Mask R-CNN." \
129-
In IEEE International Conference on Computer Vision (ICCV), 2017.
137+
"Mask R-CNN." \
138+
In IEEE International Conference on Computer Vision (ICCV), 2017.
130139
.. [2] Long, Jonathan, Evan Shelhamer, and Trevor Darrell. \
131140
"Fully convolutional networks for semantic segmentation." \
132141
Proceedings of the IEEE conference on computer vision and pattern recognition. 2015.
133142
.. [3] Zhao, Hengshuang, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, and Jiaya Jia. \
134-
"Pyramid scene parsing network." *CVPR*, 2017
143+
"Pyramid scene parsing network." *CVPR*, 2017.
135144
.. [4] Chen, Liang-Chieh, et al. "Rethinking atrous convolution for semantic image segmentation." \
136145
arXiv preprint arXiv:1706.05587 (2017).
146+
.. [5] Zhu, Yi, et al. "Improving Semantic Segmentation via Video Propagation and Label Relaxation." \
147+
CVPR 2019.

gluoncv/model_zoo/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
from .pspnet import *
1313
from .deeplabv3 import *
1414
from .deeplabv3_plus import *
15+
from .deeplabv3b_plus import *
1516
from . import segbase
1617
from .resnetv1b import *
1718
from .se_resnet import *
1819
from .nasnet import *
1920
from .simple_pose.simple_pose_resnet import *
2021
from .action_recognition import *
22+
from .wideresnet import *
2123

2224
from .alexnet import *
2325
from .densenet import *
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
"""DeepLabV3+ with wideresnet backbone for semantic segmentation"""
2+
# pylint: disable=missing-docstring,arguments-differ,unused-argument
3+
from mxnet.gluon import nn
4+
from mxnet.context import cpu
5+
from mxnet.gluon.nn import HybridBlock
6+
from .wideresnet import wider_resnet38_a2
7+
8+
__all__ = ['DeepLabWV3Plus', 'get_deeplabv3b_plus', 'get_deeplab_v3b_plus_wideresnet_citys']
9+
10+
class DeepLabWV3Plus(HybridBlock):
11+
r"""DeepLabWV3Plus
12+
13+
Parameters
14+
----------
15+
nclass : int
16+
Number of categories for the training dataset.
17+
backbone : string
18+
Pre-trained dilated backbone network type (default:'wideresnet').
19+
norm_layer : object
20+
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
21+
for Synchronized Cross-GPU BachNormalization).
22+
aux : bool
23+
Auxiliary loss.
24+
25+
Reference:
26+
27+
Chen, Liang-Chieh, et al. "Encoder-Decoder with Atrous Separable Convolution for Semantic
28+
Image Segmentation.", https://arxiv.org/abs/1802.02611, ECCV 2018
29+
"""
30+
def __init__(self, nclass, backbone='wideresnet', aux=False, ctx=cpu(), pretrained_base=True,
31+
height=None, width=None, base_size=520, crop_size=480, dilated=True, **kwargs):
32+
super(DeepLabWV3Plus, self).__init__()
33+
34+
height = height if height is not None else crop_size
35+
width = width if width is not None else crop_size
36+
self._up_kwargs = {'height': height, 'width': width}
37+
self.base_size = base_size
38+
self.crop_size = crop_size
39+
print('self.crop_size', self.crop_size)
40+
41+
with self.name_scope():
42+
pretrained = wider_resnet38_a2(classes=1000, dilation=True)
43+
pretrained.initialize(ctx=ctx)
44+
self.mod1 = pretrained.mod1
45+
self.mod2 = pretrained.mod2
46+
self.mod3 = pretrained.mod3
47+
self.mod4 = pretrained.mod4
48+
self.mod5 = pretrained.mod5
49+
self.mod6 = pretrained.mod6
50+
self.mod7 = pretrained.mod7
51+
self.pool2 = pretrained.pool2
52+
self.pool3 = pretrained.pool3
53+
del pretrained
54+
self.head = _DeepLabHead(nclass, height=height//2, width=width//2, **kwargs)
55+
self.head.initialize(ctx=ctx)
56+
57+
def hybrid_forward(self, F, x):
58+
outputs = []
59+
x = self.mod1(x)
60+
m2 = self.mod2(self.pool2(x))
61+
x = self.mod3(self.pool3(m2))
62+
x = self.mod4(x)
63+
x = self.mod5(x)
64+
x = self.mod6(x)
65+
x = self.mod7(x)
66+
x = self.head(x, m2)
67+
x = F.contrib.BilinearResize2D(x, **self._up_kwargs)
68+
outputs.append(x)
69+
return tuple(outputs)
70+
71+
def demo(self, x):
72+
return self.predict(x)
73+
74+
def predict(self, x):
75+
h, w = x.shape[2:]
76+
self._up_kwargs['height'] = h
77+
self._up_kwargs['width'] = w
78+
x = self.mod1(x)
79+
m2 = self.mod2(self.pool2(x))
80+
x = self.mod3(self.pool3(m2))
81+
x = self.mod4(x)
82+
x = self.mod5(x)
83+
x = self.mod6(x)
84+
x = self.mod7(x)
85+
x = self.head.demo(x, m2)
86+
import mxnet.ndarray as F
87+
x = F.contrib.BilinearResize2D(x, **self._up_kwargs)
88+
return x
89+
90+
class _DeepLabHead(HybridBlock):
91+
def __init__(self, nclass, c1_channels=128, norm_layer=nn.BatchNorm, norm_kwargs=None,
92+
height=240, width=240, **kwargs):
93+
super(_DeepLabHead, self).__init__()
94+
self._up_kwargs = {'height': height, 'width': width}
95+
with self.name_scope():
96+
self.aspp = _ASPP(in_channels=4096, atrous_rates=[12, 24, 36], norm_layer=norm_layer,
97+
norm_kwargs=norm_kwargs, height=height//4, width=width//4, **kwargs)
98+
99+
self.c1_block = nn.HybridSequential(prefix='bot_fine_')
100+
self.c1_block.add(nn.Conv2D(in_channels=c1_channels, channels=48,
101+
kernel_size=1, use_bias=False))
102+
103+
self.block = nn.HybridSequential(prefix='final_')
104+
self.block.add(nn.Conv2D(in_channels=304, channels=256,
105+
kernel_size=3, padding=1, use_bias=False))
106+
self.block.add(norm_layer(in_channels=256,
107+
**({} if norm_kwargs is None else norm_kwargs)))
108+
self.block.add(nn.Activation('relu'))
109+
self.block.add(nn.Conv2D(in_channels=256, channels=256,
110+
kernel_size=3, padding=1, use_bias=False))
111+
self.block.add(norm_layer(in_channels=256,
112+
**({} if norm_kwargs is None else norm_kwargs)))
113+
self.block.add(nn.Activation('relu'))
114+
self.block.add(nn.Conv2D(in_channels=256, channels=nclass,
115+
kernel_size=1, use_bias=False))
116+
117+
def hybrid_forward(self, F, x, c1):
118+
c1 = self.c1_block(c1)
119+
x = self.aspp(x)
120+
x = F.contrib.BilinearResize2D(x, **self._up_kwargs)
121+
return self.block(F.concat(c1, x, dim=1))
122+
123+
def demo(self, x, c1):
124+
h, w = c1.shape[2:]
125+
self._up_kwargs['height'] = h
126+
self._up_kwargs['width'] = w
127+
c1 = self.c1_block(c1)
128+
x = self.aspp.demo(x)
129+
import mxnet.ndarray as F
130+
x = F.contrib.BilinearResize2D(x, **self._up_kwargs)
131+
return self.block(F.concat(c1, x, dim=1))
132+
133+
def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer, norm_kwargs):
134+
block = nn.HybridSequential()
135+
with block.name_scope():
136+
block.add(nn.Conv2D(in_channels=in_channels, channels=out_channels,
137+
kernel_size=3, padding=atrous_rate,
138+
dilation=atrous_rate, use_bias=False))
139+
block.add(norm_layer(in_channels=out_channels,
140+
**({} if norm_kwargs is None else norm_kwargs)))
141+
block.add(nn.Activation('relu'))
142+
return block
143+
144+
class _AsppPooling(nn.HybridBlock):
145+
def __init__(self, in_channels, out_channels, norm_layer, norm_kwargs,
146+
height=60, width=60, **kwargs):
147+
super(_AsppPooling, self).__init__()
148+
self.gap = nn.HybridSequential()
149+
self._up_kwargs = {'height': height, 'width': width}
150+
with self.gap.name_scope():
151+
self.gap.add(nn.GlobalAvgPool2D())
152+
self.gap.add(nn.Conv2D(in_channels=in_channels, channels=out_channels,
153+
kernel_size=1, use_bias=False))
154+
self.gap.add(norm_layer(in_channels=out_channels,
155+
**({} if norm_kwargs is None else norm_kwargs)))
156+
self.gap.add(nn.Activation("relu"))
157+
158+
def hybrid_forward(self, F, x):
159+
pool = self.gap(x)
160+
return F.contrib.BilinearResize2D(pool, **self._up_kwargs)
161+
162+
def demo(self, x):
163+
h, w = x.shape[2:]
164+
self._up_kwargs['height'] = h
165+
self._up_kwargs['width'] = w
166+
pool = self.gap(x)
167+
import mxnet.ndarray as F
168+
return F.contrib.BilinearResize2D(pool, **self._up_kwargs)
169+
170+
class _ASPP(nn.HybridBlock):
171+
def __init__(self, in_channels, atrous_rates, norm_layer, norm_kwargs,
172+
height=60, width=60):
173+
super(_ASPP, self).__init__()
174+
out_channels = 256
175+
self.b0 = nn.HybridSequential()
176+
self.b0.add(nn.Conv2D(in_channels=in_channels, channels=out_channels,
177+
kernel_size=1, use_bias=False))
178+
self.b0.add(norm_layer(in_channels=out_channels,
179+
**({} if norm_kwargs is None else norm_kwargs)))
180+
self.b0.add(nn.Activation("relu"))
181+
182+
rate1, rate2, rate3 = tuple(atrous_rates)
183+
self.b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer, norm_kwargs)
184+
self.b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer, norm_kwargs)
185+
self.b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer, norm_kwargs)
186+
self.b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer,
187+
norm_kwargs=norm_kwargs, height=height, width=width)
188+
189+
self.project = nn.HybridSequential(prefix='bot_aspp_')
190+
self.project.add(nn.Conv2D(in_channels=5*out_channels, channels=out_channels,
191+
kernel_size=1, use_bias=False))
192+
193+
def hybrid_forward(self, F, x):
194+
feat1 = self.b0(x)
195+
feat2 = self.b1(x)
196+
feat3 = self.b2(x)
197+
feat4 = self.b3(x)
198+
x = self.b4(x)
199+
x = F.concat(x, feat1, feat2, feat3, feat4, dim=1)
200+
return self.project(x)
201+
202+
def demo(self, x):
203+
feat1 = self.b0(x)
204+
feat2 = self.b1(x)
205+
feat3 = self.b2(x)
206+
feat4 = self.b3(x)
207+
x = self.b4.demo(x)
208+
import mxnet.ndarray as F
209+
x = F.concat(x, feat1, feat2, feat3, feat4, dim=1)
210+
return self.project(x)
211+
212+
def get_deeplabv3b_plus(dataset='citys', backbone='wideresnet', pretrained=False,
213+
root='~/.mxnet/models', ctx=cpu(0), **kwargs):
214+
r"""DeepLabWV3Plus
215+
Parameters
216+
----------
217+
dataset : str, default pascal_voc
218+
The dataset that model pretrained on. (pascal_voc, ade20k, citys)
219+
pretrained : bool or str
220+
Boolean value controls whether to load the default pretrained weights for model.
221+
String value represents the hashtag for a certain version of pretrained weights.
222+
ctx : Context, default CPU
223+
The context in which to load the pretrained weights.
224+
root : str, default '~/.mxnet/models'
225+
Location for keeping the model parameters.
226+
227+
Examples
228+
--------
229+
>>> model = get_deeplabv3b_plus(dataset='citys', backbone='wideresnet', pretrained=False)
230+
>>> print(model)
231+
"""
232+
acronyms = {
233+
'pascal_voc': 'voc',
234+
'pascal_aug': 'voc',
235+
'ade20k': 'ade',
236+
'coco': 'coco',
237+
'citys': 'citys',
238+
}
239+
from ..data import datasets
240+
# infer number of classes
241+
model = DeepLabWV3Plus(datasets[dataset].NUM_CLASS, backbone=backbone, ctx=ctx, **kwargs)
242+
model.classes = datasets[dataset].classes
243+
if pretrained:
244+
from .model_store import get_model_file
245+
model.load_parameters(get_model_file('deeplab_v3b_plus_%s_%s'%(backbone, acronyms[dataset]),
246+
tag=pretrained, root=root), ctx=ctx)
247+
return model
248+
249+
def get_deeplab_v3b_plus_wideresnet_citys(**kwargs):
250+
r"""DeepLabWV3Plus
251+
Parameters
252+
----------
253+
pretrained : bool or str
254+
Boolean value controls whether to load the default pretrained weights for model.
255+
String value represents the hashtag for a certain version of pretrained weights.
256+
ctx : Context, default CPU
257+
The context in which to load the pretrained weights.
258+
root : str, default '~/.mxnet/models'
259+
Location for keeping the model parameters.
260+
261+
Examples
262+
--------
263+
>>> model = get_deeplab_v3b_plus_wideresnet_citys(pretrained=True)
264+
>>> print(model)
265+
"""
266+
return get_deeplabv3b_plus('citys', 'wideresnet', **kwargs)

gluoncv/model_zoo/model_store.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
('3f220f537400dfa607c3d041ed3b172db39b0b01', 'psp_resnet50_ade'),
108108
('240a4758b506447faf7c55cd7a7837d66f5039a6', 'psp_resnet101_ade'),
109109
('0f49fb59180c4d91305b858380a4fd6eaf068b6c', 'psp_resnet101_citys'),
110+
('ef2bb40ad8f8f59f451969b2fabe4e548394e80a', 'deeplab_v3b_plus_wideresnet_citys'),
110111
('f5ece5ce1422eeca3ce2908004e469ffdf91fd41', 'yolo3_darknet53_voc'),
111112
('3b47835ac3dd80f29576633949aa58aee3094353', 'yolo3_mobilenet1.0_voc'),
112113
('66dbbae67be8f1e3cd3c995ce626a2bdc89769c6', 'yolo3_mobilenet1.0_coco'),

gluoncv/model_zoo/model_zoo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .cifarresnext import *
88
from .cifarwideresnet import *
99
from .deeplabv3 import *
10+
from .deeplabv3b_plus import *
1011
from .densenet import *
1112
from .faster_rcnn import *
1213
from .fcn import *
@@ -146,6 +147,7 @@
146147
'deeplab_resnet152_voc': get_deeplab_resnet152_voc,
147148
'deeplab_resnet50_ade': get_deeplab_resnet50_ade,
148149
'deeplab_resnet101_ade': get_deeplab_resnet101_ade,
150+
'deeplab_v3b_plus_wideresnet_citys': get_deeplab_v3b_plus_wideresnet_citys,
149151
'resnet18_v1b': resnet18_v1b,
150152
'resnet34_v1b': resnet34_v1b,
151153
'resnet50_v1b': resnet50_v1b,

gluoncv/model_zoo/pspnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def hybrid_forward(self, F, x):
6262
return tuple(outputs)
6363

6464
def demo(self, x):
65-
self.predict(x)
65+
return self.predict(x)
6666

6767
def predict(self, x):
6868
h, w = x.shape[2:]

gluoncv/model_zoo/segbase.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@ def get_segmentation_model(model, **kwargs):
1717
from .pspnet import get_psp
1818
from .deeplabv3 import get_deeplab
1919
from .deeplabv3_plus import get_deeplab_plus
20+
from .deeplabv3b_plus import get_deeplabv3b_plus
2021
models = {
2122
'fcn': get_fcn,
2223
'psp': get_psp,
2324
'deeplab': get_deeplab,
2425
'deeplabplus': get_deeplab_plus,
26+
'deeplabplusv3b': get_deeplabv3b_plus,
2527
}
2628
return models[model](**kwargs)
2729

0 commit comments

Comments
 (0)