Skip to content

Commit 792076f

Browse files
authored
Merge pull request #151 from theabhirath/conv_bn
Improved time to first gradient
2 parents e88e478 + 9f5295a commit 792076f

14 files changed

+138
-135
lines changed

src/convnets/convmixer.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ Creates a ConvMixer model.
1717
function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9),
1818
patch_size::Dims{2} = (7, 7), activation = gelu, nclasses = 1000)
1919
stem = conv_bn(patch_size, inchannels, planes, activation; preact = true, stride = patch_size[1])
20-
blocks = [Chain(SkipConnection(Chain(conv_bn(kernel_size, planes, planes, activation;
21-
preact = true, groups = planes, pad = SamePad())...), +),
22-
conv_bn((1, 1), planes, planes, activation; preact = true)...) for _ in 1:depth]
20+
blocks = [Chain(SkipConnection(conv_bn(kernel_size, planes, planes, activation;
21+
preact = true, groups = planes, pad = SamePad()), +),
22+
conv_bn((1, 1), planes, planes, activation; preact = true)) for _ in 1:depth]
2323
head = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(planes, nclasses))
24-
return Chain(Chain(stem..., blocks...), head)
24+
return Chain(Chain(stem, Chain(blocks)), head)
2525
end
2626

2727
convmixer_config = Dict(:base => Dict(:planes => 1536, :depth => 20, :kernel_size => (9, 9),

src/convnets/convnext.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Creates a single block of ConvNeXt.
1010
- `λ`: Init value for LayerScale
1111
"""
1212
function convnextblock(planes, drop_path_rate = 0., λ = 1f-6)
13-
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
13+
layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
1414
swapdims((3, 1, 2, 4)),
1515
LayerNorm(planes; ϵ = 1f-6),
1616
mlp_block(planes, 4 * planes),
@@ -61,7 +61,7 @@ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0., λ = 1f-6
6161
LayerNorm(planes[end]),
6262
Dense(planes[end], nclasses))
6363

64-
return Chain(Chain(backbone...), head)
64+
return Chain(Chain(backbone), head)
6565
end
6666

6767
# Configurations for ConvNeXt models

src/convnets/densenet.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ Create a Densenet bottleneck layer
1111
"""
1212
function dense_bottleneck(inplanes, outplanes)
1313
inner_channels = 4 * outplanes
14-
m = Chain(conv_bn((1, 1), inplanes, inner_channels; bias = false, rev = true)...,
15-
conv_bn((3, 3), inner_channels, outplanes; pad = 1, bias = false, rev = true)...)
14+
m = Chain(conv_bn((1, 1), inplanes, inner_channels; bias = false, rev = true),
15+
conv_bn((3, 3), inner_channels, outplanes; pad = 1, bias = false, rev = true))
1616

17-
SkipConnection(m, (mx, x) -> cat(x, mx; dims = 3))
17+
SkipConnection(m, cat_channels)
1818
end
1919

2020
"""
@@ -28,8 +28,7 @@ Create a DenseNet transition sequence
2828
- `outplanes`: number of output feature maps
2929
"""
3030
transition(inplanes, outplanes) =
31-
[conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)...,
32-
MeanPool((2, 2))]
31+
Chain(conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true), MeanPool((2, 2)))
3332

3433
"""
3534
dense_block(inplanes, growth_rates)
@@ -60,20 +59,21 @@ Create a DenseNet model
6059
- `nclasses`: the number of output classes
6160
"""
6261
function densenet(inplanes, growth_rates; reduction = 0.5, nclasses = 1000)
63-
layers = conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3), bias = false)
62+
layers = []
63+
push!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = (3, 3), bias = false))
6464
push!(layers, MaxPool((3, 3), stride = 2, pad = (1, 1)))
6565

6666
outplanes = 0
6767
for (i, rates) in enumerate(growth_rates)
6868
outplanes = inplanes + sum(rates)
6969
append!(layers, dense_block(inplanes, rates))
70-
(i != length(growth_rates)) &&
71-
append!(layers, transition(outplanes, floor(Int, outplanes * reduction)))
70+
(i != length(growth_rates)) &&
71+
push!(layers, transition(outplanes, floor(Int, outplanes * reduction)))
7272
inplanes = floor(Int, outplanes * reduction)
7373
end
7474
push!(layers, BatchNorm(outplanes, relu))
7575

76-
return Chain(Chain(layers...),
76+
return Chain(Chain(layers),
7777
Chain(AdaptiveMeanPool((1, 1)),
7878
MLUtils.flatten,
7979
Dense(outplanes, nclasses)))

src/convnets/googlenet.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@ Create an inception module for use in GoogLeNet
1515
"""
1616
function _inceptionblock(inplanes, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, pool_proj)
1717
branch1 = Chain(Conv((1, 1), inplanes => out_1x1))
18-
18+
1919
branch2 = Chain(Conv((1, 1), inplanes => red_3x3),
2020
Conv((3, 3), red_3x3 => out_3x3; pad = 1))
21-
21+
2222
branch3 = Chain(Conv((1, 1), inplanes => red_5x5),
2323
Conv((5, 5), red_5x5 => out_5x5; pad = 2))
24-
24+
2525
branch4 = Chain(MaxPool((3, 3), stride=1, pad = 1),
2626
Conv((1, 1), inplanes => pool_proj))
27-
27+
2828
return Parallel(cat_channels,
2929
branch1, branch2, branch3, branch4)
3030
end

src/convnets/inception.jl

+46-46
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@ Create an Inception-v3 style-A module
99
- `pool_proj`: the number of output feature maps for the pooling projection
1010
"""
1111
function inception_a(inplanes, pool_proj)
12-
branch1x1 = Chain(conv_bn((1, 1), inplanes, 64)...)
13-
14-
branch5x5 = Chain(conv_bn((1, 1), inplanes, 48)...,
15-
conv_bn((5, 5), 48, 64; pad = 2)...)
12+
branch1x1 = conv_bn((1, 1), inplanes, 64)
1613

17-
branch3x3 = Chain(conv_bn((1, 1), inplanes, 64)...,
18-
conv_bn((3, 3), 64, 96; pad = 1)...,
19-
conv_bn((3, 3), 96, 96; pad = 1)...)
14+
branch5x5 = Chain(conv_bn((1, 1), inplanes, 48),
15+
conv_bn((5, 5), 48, 64; pad = 2))
16+
17+
branch3x3 = Chain(conv_bn((1, 1), inplanes, 64),
18+
conv_bn((3, 3), 64, 96; pad = 1),
19+
conv_bn((3, 3), 96, 96; pad = 1))
2020

2121
branch_pool = Chain(MeanPool((3, 3), pad = 1, stride = 1),
22-
conv_bn((1, 1), inplanes, pool_proj)...)
22+
conv_bn((1, 1), inplanes, pool_proj))
2323

2424
return Parallel(cat_channels,
2525
branch1x1, branch5x5, branch3x3, branch_pool)
@@ -35,13 +35,13 @@ Create an Inception-v3 style-B module
3535
- `inplanes`: number of input feature maps
3636
"""
3737
function inception_b(inplanes)
38-
branch3x3_1 = Chain(conv_bn((3, 3), inplanes, 384; stride = 2)...)
38+
branch3x3_1 = conv_bn((3, 3), inplanes, 384; stride = 2)
3939

40-
branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 64)...,
41-
conv_bn((3, 3), 64, 96; pad = 1)...,
42-
conv_bn((3, 3), 96, 96; stride = 2)...)
40+
branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 64),
41+
conv_bn((3, 3), 64, 96; pad = 1),
42+
conv_bn((3, 3), 96, 96; stride = 2))
4343

44-
branch_pool = Chain(MaxPool((3, 3), stride = 2))
44+
branch_pool = MaxPool((3, 3), stride = 2)
4545

4646
return Parallel(cat_channels,
4747
branch3x3_1, branch3x3_2, branch_pool)
@@ -59,20 +59,20 @@ Create an Inception-v3 style-C module
5959
- `n`: the "grid size" (kernel size) for the convolution layers
6060
"""
6161
function inception_c(inplanes, inner_planes, n = 7)
62-
branch1x1 = Chain(conv_bn((1, 1), inplanes, 192)...)
62+
branch1x1 = conv_bn((1, 1), inplanes, 192)
6363

64-
branch7x7_1 = Chain(conv_bn((1, 1), inplanes, inner_planes)...,
65-
conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))...,
66-
conv_bn((n, 1), inner_planes, 192; pad = (3, 0))...)
64+
branch7x7_1 = Chain(conv_bn((1, 1), inplanes, inner_planes),
65+
conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3)),
66+
conv_bn((n, 1), inner_planes, 192; pad = (3, 0)))
6767

68-
branch7x7_2 = Chain(conv_bn((1, 1), inplanes, inner_planes)...,
69-
conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))...,
70-
conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3))...,
71-
conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0))...,
72-
conv_bn((1, n), inner_planes, 192; pad = (0, 3))...)
68+
branch7x7_2 = Chain(conv_bn((1, 1), inplanes, inner_planes),
69+
conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0)),
70+
conv_bn((1, n), inner_planes, inner_planes; pad = (0, 3)),
71+
conv_bn((n, 1), inner_planes, inner_planes; pad = (3, 0)),
72+
conv_bn((1, n), inner_planes, 192; pad = (0, 3)))
7373

74-
branch_pool = Chain(MeanPool((3, 3), pad = 1, stride=1),
75-
conv_bn((1, 1), inplanes, 192)...)
74+
branch_pool = Chain(MeanPool((3, 3), pad = 1, stride=1),
75+
conv_bn((1, 1), inplanes, 192))
7676

7777
return Parallel(cat_channels,
7878
branch1x1, branch7x7_1, branch7x7_2, branch_pool)
@@ -88,15 +88,15 @@ Create an Inception-v3 style-D module
8888
- `inplanes`: number of input feature maps
8989
"""
9090
function inception_d(inplanes)
91-
branch3x3 = Chain(conv_bn((1, 1), inplanes, 192)...,
92-
conv_bn((3, 3), 192, 320; stride = 2)...)
91+
branch3x3 = Chain(conv_bn((1, 1), inplanes, 192),
92+
conv_bn((3, 3), 192, 320; stride = 2))
9393

94-
branch7x7x3 = Chain(conv_bn((1, 1), inplanes, 192)...,
95-
conv_bn((1, 7), 192, 192; pad = (0, 3))...,
96-
conv_bn((7, 1), 192, 192; pad = (3, 0))...,
97-
conv_bn((3, 3), 192, 192; stride = 2)...)
94+
branch7x7x3 = Chain(conv_bn((1, 1), inplanes, 192),
95+
conv_bn((1, 7), 192, 192; pad = (0, 3)),
96+
conv_bn((7, 1), 192, 192; pad = (3, 0)),
97+
conv_bn((3, 3), 192, 192; stride = 2))
9898

99-
branch_pool = Chain(MaxPool((3, 3), stride=2))
99+
branch_pool = MaxPool((3, 3), stride=2)
100100

101101
return Parallel(cat_channels,
102102
branch3x3, branch7x7x3, branch_pool)
@@ -112,26 +112,26 @@ Create an Inception-v3 style-E module
112112
- `inplanes`: number of input feature maps
113113
"""
114114
function inception_e(inplanes)
115-
branch1x1 = Chain(conv_bn((1, 1), inplanes, 320)...)
115+
branch1x1 = conv_bn((1, 1), inplanes, 320)
116116

117-
branch3x3_1 = Chain(conv_bn((1, 1), inplanes, 384)...)
118-
branch3x3_1a = Chain(conv_bn((1, 3), 384, 384; pad = (0, 1))...)
119-
branch3x3_1b = Chain(conv_bn((3, 1), 384, 384; pad = (1, 0))...)
117+
branch3x3_1 = conv_bn((1, 1), inplanes, 384)
118+
branch3x3_1a = conv_bn((1, 3), 384, 384; pad = (0, 1))
119+
branch3x3_1b = conv_bn((3, 1), 384, 384; pad = (1, 0))
120120

121-
branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 448)...,
122-
conv_bn((3, 3), 448, 384; pad = 1)...)
123-
branch3x3_2a = Chain(conv_bn((1, 3), 384, 384; pad = (0, 1))...)
124-
branch3x3_2b = Chain(conv_bn((3, 1), 384, 384; pad = (1, 0))...)
121+
branch3x3_2 = Chain(conv_bn((1, 1), inplanes, 448),
122+
conv_bn((3, 3), 448, 384; pad = 1))
123+
branch3x3_2a = conv_bn((1, 3), 384, 384; pad = (0, 1))
124+
branch3x3_2b = conv_bn((3, 1), 384, 384; pad = (1, 0))
125125

126126
branch_pool = Chain(MeanPool((3, 3), pad = 1, stride = 1),
127-
conv_bn((1, 1), inplanes, 192)...)
127+
conv_bn((1, 1), inplanes, 192))
128128

129129
return Parallel(cat_channels,
130130
branch1x1,
131131
Chain(branch3x3_1,
132132
Parallel(cat_channels,
133133
branch3x3_1a, branch3x3_1b)),
134-
134+
135135
Chain(branch3x3_2,
136136
Parallel(cat_channels,
137137
branch3x3_2a, branch3x3_2b)),
@@ -150,12 +150,12 @@ Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)).
150150
`inception3` does not currently support pretrained weights.
151151
"""
152152
function inception3(; nclasses = 1000)
153-
layer = Chain(Chain(conv_bn((3, 3), 3, 32; stride = 2)...,
154-
conv_bn((3, 3), 32, 32)...,
155-
conv_bn((3, 3), 32, 64; pad = 1)...,
153+
layer = Chain(Chain(conv_bn((3, 3), 3, 32; stride = 2),
154+
conv_bn((3, 3), 32, 32),
155+
conv_bn((3, 3), 32, 64; pad = 1),
156156
MaxPool((3, 3), stride = 2),
157-
conv_bn((1, 1), 64, 80)...,
158-
conv_bn((3, 3), 80, 192)...,
157+
conv_bn((1, 1), 64, 80),
158+
conv_bn((3, 3), 80, 192),
159159
MaxPool((3, 3), stride = 2),
160160
inception_a(192, 32),
161161
inception_a(256, 64),

src/convnets/mobilenet.jl

+15-19
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,15 @@ function mobilenetv1(width_mult, config;
3131
for (dw, outch, stride, repeats) in config
3232
outch = Int(outch * width_mult)
3333
for _ in 1:repeats
34-
layer = if dw
35-
depthwise_sep_conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1)
36-
else
37-
conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1)
38-
end
39-
append!(layers, layer)
34+
layer = dw ? depthwise_sep_conv_bn((3, 3), inchannels, outch, activation;
35+
stride = stride, pad = 1) :
36+
conv_bn((3, 3), inchannels, outch, activation; stride = stride, pad = 1)
37+
push!(layers, layer)
4038
inchannels = outch
4139
end
4240
end
4341

44-
return Chain(Chain(layers...),
42+
return Chain(Chain(layers),
4543
Chain(GlobalMeanPool(),
4644
MLUtils.flatten,
4745
Dense(inchannels, fcsize, activation),
@@ -120,7 +118,7 @@ function mobilenetv2(width_mult, configs; max_width = 1280, nclasses = 1000)
120118
# building first layer
121119
inplanes = _round_channels(32 * width_mult, width_mult == 0.1 ? 4 : 8)
122120
layers = []
123-
append!(layers, conv_bn((3, 3), 3, inplanes, stride = 2))
121+
push!(layers, conv_bn((3, 3), 3, inplanes, stride = 2))
124122

125123
# building inverted residual blocks
126124
for (t, c, n, s, a) in configs
@@ -136,8 +134,7 @@ function mobilenetv2(width_mult, configs; max_width = 1280, nclasses = 1000)
136134
outplanes = (width_mult > 1) ? _round_channels(max_width * width_mult, width_mult == 0.1 ? 4 : 8) :
137135
max_width
138136

139-
return Chain(Chain(layers...,
140-
conv_bn((1, 1), inplanes, outplanes, relu6, bias = false)...),
137+
return Chain(Chain(Chain(layers), conv_bn((1, 1), inplanes, outplanes, relu6, bias = false)),
141138
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(outplanes, nclasses)))
142139
end
143140

@@ -186,7 +183,7 @@ end
186183
(m::MobileNetv2)(x) = m.layers(x)
187184

188185
backbone(m::MobileNetv2) = m.layers[1]
189-
classifier(m::MobileNetv2) = m.layers[2:end]
186+
classifier(m::MobileNetv2) = m.layers[2]
190187

191188
# MobileNetv3
192189

@@ -214,7 +211,7 @@ function mobilenetv3(width_mult, configs; max_width = 1024, nclasses = 1000)
214211
# building first layer
215212
inplanes = _round_channels(16 * width_mult, 8)
216213
layers = []
217-
append!(layers, conv_bn((3, 3), 3, inplanes, hardswish; stride = 2))
214+
push!(layers, conv_bn((3, 3), 3, inplanes, hardswish; stride = 2))
218215
explanes = 0
219216
# building inverted residual blocks
220217
for (k, t, c, r, a, s) in configs
@@ -229,13 +226,12 @@ function mobilenetv3(width_mult, configs; max_width = 1024, nclasses = 1000)
229226
# building last several layers
230227
output_channel = max_width
231228
output_channel = width_mult > 1.0 ? _round_channels(output_channel * width_mult, 8) : output_channel
232-
classifier = (Dense(explanes, output_channel, hardswish),
233-
Dropout(0.2),
234-
Dense(output_channel, nclasses))
229+
classifier = Chain(Dense(explanes, output_channel, hardswish),
230+
Dropout(0.2),
231+
Dense(output_channel, nclasses))
235232

236-
return Chain(Chain(layers...,
237-
conv_bn((1, 1), inplanes, explanes, hardswish, bias = false)...),
238-
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier...))
233+
return Chain(Chain(Chain(layers), conv_bn((1, 1), inplanes, explanes, hardswish, bias = false)),
234+
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier))
239235
end
240236

241237
# Configurations for small and large mode for MobileNetv3
@@ -310,4 +306,4 @@ end
310306
(m::MobileNetv3)(x) = m.layers(x)
311307

312308
backbone(m::MobileNetv3) = m.layers[1]
313-
classifier(m::MobileNetv3) = m.layers[2:end]
309+
classifier(m::MobileNetv3) = m.layers[2]

src/convnets/resnet.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ Create a basic residual block
1212
"""
1313
function basicblock(inplanes, outplanes, downsample = false)
1414
stride = downsample ? 2 : 1
15-
Chain(conv_bn((3, 3), inplanes, outplanes[1]; stride = stride, pad = 1, bias = false)...,
16-
conv_bn((3, 3), outplanes[1], outplanes[2], identity; stride = 1, pad = 1, bias = false)...)
15+
Chain(conv_bn((3, 3), inplanes, outplanes[1]; stride = stride, pad = 1, bias = false),
16+
conv_bn((3, 3), outplanes[1], outplanes[2], identity; stride = 1, pad = 1, bias = false))
1717
end
1818

1919
"""
@@ -36,9 +36,9 @@ The original paper uses `stride == [2, 1, 1]` when `downsample == true` instead.
3636
"""
3737
function bottleneck(inplanes, outplanes, downsample = false;
3838
stride = [1, (downsample ? 2 : 1), 1])
39-
Chain(conv_bn((1, 1), inplanes, outplanes[1]; stride = stride[1], bias = false)...,
40-
conv_bn((3, 3), outplanes[1], outplanes[2]; stride = stride[2], pad = 1, bias = false)...,
41-
conv_bn((1, 1), outplanes[2], outplanes[3], identity; stride = stride[3], bias = false)...)
39+
Chain(conv_bn((1, 1), inplanes, outplanes[1]; stride = stride[1], bias = false),
40+
conv_bn((3, 3), outplanes[1], outplanes[2]; stride = stride[2], pad = 1, bias = false),
41+
conv_bn((1, 1), outplanes[2], outplanes[3], identity; stride = stride[3], bias = false))
4242
end
4343

4444

@@ -82,7 +82,7 @@ function resnet(block, residuals::AbstractVector{<:NTuple{2, Any}}, connection =
8282
inplanes = 64
8383
baseplanes = 64
8484
layers = []
85-
append!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = 3, bias = false))
85+
push!(layers, conv_bn((7, 7), 3, inplanes; stride = 2, pad = 3, bias = false))
8686
push!(layers, MaxPool((3, 3), stride = (2, 2), pad = (1, 1)))
8787
for (i, nrepeats) in enumerate(block_config)
8888
# output planes within a block
@@ -102,7 +102,7 @@ function resnet(block, residuals::AbstractVector{<:NTuple{2, Any}}, connection =
102102
baseplanes *= 2
103103
end
104104

105-
return Chain(Chain(layers...),
105+
return Chain(Chain(layers),
106106
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, Dense(inplanes, nclasses)))
107107
end
108108

@@ -246,7 +246,7 @@ function ResNet(depth::Int = 50; pretrain = false, nclasses = 1000)
246246
model
247247
end
248248

249-
# Compat with Methalhead 0.6; remove in 0.7
249+
# Compat with Metalhead 0.6; remove in 0.7
250250
@deprecate ResNet18(; kw...) ResNet(18; kw...)
251251
@deprecate ResNet34(; kw...) ResNet(34; kw...)
252252
@deprecate ResNet50(; kw...) ResNet(50; kw...)

0 commit comments

Comments
 (0)