Skip to content
This repository was archived by the owner on Mar 17, 2021. It is now read-only.

Commit 92a8e3e

Browse files
committed
updates squeeze excitation layers
- added doc strings for squeeze excitation layers - renaming layer/squeeze_excitation_layer to layer/squeeze_excitation
1 parent cfbb9b7 commit 92a8e3e

File tree

3 files changed

+32
-9
lines changed

3 files changed

+32
-9
lines changed

niftynet/layer/squeeze_excitation_layer.py renamed to niftynet/layer/squeeze_excitation.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212

1313

1414
class ChannelSELayer(Layer):
15+
"""
16+
Re-implementation of Squeeze-and-Excitation (SE) block described in::
17+
18+
Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507
19+
"""
1520
def __init__(self,
1621
func='AVG',
1722
reduction_ratio=16,
@@ -62,8 +67,17 @@ def layer_op(self, input_tensor):
6267
output_tensor = tf.multiply(input_tensor, fc_out_2)
6368

6469
return output_tensor
65-
70+
71+
6672
class SpatialSELayer(Layer):
73+
"""
74+
Re-implementation of SE block -- squeezing spatially
75+
and exciting channel-wise described in::
76+
77+
Roy et al., Concurrent Spatial and Channel Squeeze & Excitation
78+
in Fully Convolutional Networks, arXiv:1803.02579
79+
80+
"""
6781
def __init__(self,
6882
name='spatial_squeeze_excitation'):
6983
super(SpatialSELayer, self).__init__(name=name)
@@ -75,15 +89,24 @@ def layer_op(self, input_tensor):
7589
with_bn=False,
7690
acti_func='sigmoid',
7791
name="se_conv")
78-
92+
7993
squeeze_tensor = conv(input_tensor)
80-
94+
8195
# spatial excitation
8296
output_tensor = tf.multiply(input_tensor, squeeze_tensor)
8397

8498
return output_tensor
85-
99+
100+
86101
class ChannelSpatialSELayer(Layer):
102+
"""
103+
Re-implementation of concurrent spatial and channel
104+
squeeze & excitation::
105+
106+
Roy et al., Concurrent Spatial and Channel Squeeze & Excitation
107+
in Fully Convolutional Networks, arXiv:1803.02579
108+
109+
"""
87110
def __init__(self,
88111
func='AVG',
89112
reduction_ratio=16,
@@ -99,7 +122,7 @@ def layer_op(self, input_tensor):
99122
reduction_ratio=self.reduction_ratio,
100123
name='cSE')
101124
sSE = SpatialSELayer(name='sSE')
102-
125+
103126
output_tensor = tf.add(cSE(input_tensor), sSE(input_tensor))
104127

105128
return output_tensor

niftynet/network/se_resnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from niftynet.layer.fully_connected import FCLayer
1111
from niftynet.layer.base_layer import TrainableLayer
1212
from niftynet.layer.convolution import ConvolutionalLayer
13-
from niftynet.layer.squeeze_excitation_layer import ChannelSELayer
13+
from niftynet.layer.squeeze_excitation import ChannelSELayer
1414
from niftynet.network.base_net import BaseNet
1515

1616
SE_ResNetDesc = namedtuple('SE_ResNetDesc', ['bn', 'fc', 'conv1', 'blocks'])

tests/squeeze_excitation_layer_test.py renamed to tests/squeeze_excitation_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
from __future__ import absolute_import, print_function
33

44
import tensorflow as tf
5-
from niftynet.layer.squeeze_excitation_layer import ChannelSELayer
6-
from niftynet.layer.squeeze_excitation_layer import SpatialSELayer
7-
from niftynet.layer.squeeze_excitation_layer import ChannelSpatialSELayer
5+
from niftynet.layer.squeeze_excitation import ChannelSELayer
6+
from niftynet.layer.squeeze_excitation import SpatialSELayer
7+
from niftynet.layer.squeeze_excitation import ChannelSpatialSELayer
88

99
class SETest(tf.test.TestCase):
1010
def test_cSE_3d_shape(self):

0 commit comments

Comments
 (0)