Skip to content

Commit f2461a5

Browse files
committed
Moving closer to the one true function
1 parent fc03d70 commit f2461a5

File tree

5 files changed

+90
-87
lines changed

5 files changed

+90
-87
lines changed

src/convnets/efficientnets/core.jl

+28-24
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,66 @@
1-
function mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}},
2-
stage_idx::Integer; scalings::NTuple{2, Real} = (1, 1),
3-
norm_layer = BatchNorm)
1+
function mbconv_builder(block_configs::AbstractVector{<:Tuple},
2+
inplanes::Integer, stage_idx::Integer;
3+
scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm,
4+
round_fn = planes -> _round_channels(planes, 8))
45
width_mult, depth_mult = scalings
5-
k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx]
6-
inplanes = _round_channels(inplanes * width_mult, 8)
6+
k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx]
7+
inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2]
8+
inplanes = round_fn(inplanes * width_mult)
79
outplanes = _round_channels(outplanes * width_mult, 8)
810
function get_layers(block_idx)
911
inplanes = block_idx == 1 ? inplanes : outplanes
1012
explanes = _round_channels(inplanes * expansion, 8)
1113
stride = block_idx == 1 ? stride : 1
12-
block = mbconv((k, k), inplanes, explanes, outplanes, swish; norm_layer,
13-
stride, reduction = 4)
14+
block = mbconv((k, k), inplanes, explanes, outplanes, activation; norm_layer,
15+
stride, reduction)
1416
return stride == 1 && inplanes == outplanes ? (identity, block) : (block,)
1517
end
1618
return get_layers, ceil(Int, nrepeats * depth_mult)
1719
end
1820

19-
function fused_mbconv_builder(block_configs::AbstractVector{NTuple{6, Int}},
20-
stage_idx::Integer; scalings::NTuple{2, Real} = (1, 1),
21-
norm_layer = BatchNorm)
22-
k, inplanes, outplanes, expansion, stride, nrepeats = block_configs[stage_idx]
21+
function fused_mbconv_builder(block_configs::AbstractVector{<:Tuple},
22+
inplanes::Integer, stage_idx::Integer;
23+
scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm)
24+
k, outplanes, expansion, stride, nrepeats, _, activation = block_configs[stage_idx]
25+
inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][2]
2326
function get_layers(block_idx)
2427
inplanes = block_idx == 1 ? inplanes : outplanes
2528
explanes = _round_channels(inplanes * expansion, 8)
2629
stride = block_idx == 1 ? stride : 1
27-
block = fused_mbconv((k, k), inplanes, explanes, outplanes, swish;
30+
block = fused_mbconv((k, k), inplanes, explanes, outplanes, activation;
2831
norm_layer, stride)
2932
return stride == 1 && inplanes == outplanes ? (identity, block) : (block,)
3033
end
3134
return get_layers, nrepeats
3235
end
3336

34-
function efficientnet_builder(block_configs::AbstractVector{NTuple{6, Int}},
35-
residual_fns::AbstractVector;
36-
scalings::NTuple{2, Real} = (1, 1), norm_layer = BatchNorm)
37-
bxs = [residual_fn(block_configs, stage_idx; scalings, norm_layer)
37+
function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple},
38+
residual_fns::AbstractVector; inplanes::Integer,
39+
scalings::NTuple{2, Real} = (1, 1),
40+
norm_layer = BatchNorm)
41+
bxs = [residual_fn(block_configs, inplanes, stage_idx; scalings, norm_layer)
3842
for (stage_idx, residual_fn) in enumerate(residual_fns)]
3943
return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs)
4044
end
4145

42-
function efficientnet(block_configs::AbstractVector{NTuple{6, Int}},
43-
residual_fns::AbstractVector; scalings::NTuple{2, Real} = (1, 1),
46+
function efficientnet(block_configs::AbstractVector{<:Tuple},
47+
residual_fns::AbstractVector; inplanes::Integer,
48+
scalings::NTuple{2, Real} = (1, 1),
4449
headplanes::Integer = block_configs[end][3] * 4,
4550
norm_layer = BatchNorm, dropout_rate = nothing,
4651
inchannels::Integer = 3, nclasses::Integer = 1000)
4752
layers = []
4853
# stem of the model
4954
append!(layers,
50-
conv_norm((3, 3), inchannels,
51-
_round_channels(block_configs[1][2] * scalings[1], 8), swish;
52-
norm_layer, stride = 2, pad = SamePad()))
55+
conv_norm((3, 3), inchannels, _round_channels(inplanes * scalings[1], 8),
56+
swish; norm_layer, stride = 2, pad = SamePad()))
5357
# building inverted residual blocks
54-
get_layers, block_repeats = efficientnet_builder(block_configs, residual_fns;
55-
scalings, norm_layer)
58+
get_layers, block_repeats = mbconv_stack_builder(block_configs, residual_fns;
59+
inplanes, scalings, norm_layer)
5660
append!(layers, resnet_stages(get_layers, block_repeats, +))
5761
# building last layers
5862
append!(layers,
59-
conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[1], 8),
63+
conv_norm((1, 1), _round_channels(block_configs[end][2] * scalings[1], 8),
6064
headplanes, swish; pad = SamePad()))
6165
return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate))
6266
end

src/convnets/efficientnets/efficientnet.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# block configs for EfficientNet
22
const EFFICIENTNET_BLOCK_CONFIGS = [
3-
# k, i, o, e, s, n
4-
(3, 32, 16, 1, 1, 1),
5-
(3, 16, 24, 6, 2, 2),
6-
(5, 24, 40, 6, 2, 2),
7-
(3, 40, 80, 6, 2, 3),
8-
(5, 80, 112, 6, 1, 3),
9-
(5, 112, 192, 6, 2, 4),
10-
(3, 192, 320, 6, 1, 1),
3+
# k, c, e, s, n, r, a
4+
(3, 16, 1, 1, 1, 4, swish),
5+
(3, 24, 6, 2, 2, 4, swish),
6+
(5, 40, 6, 2, 2, 4, swish),
7+
(3, 80, 6, 2, 3, 4, swish),
8+
(5, 112, 6, 1, 3, 4, swish),
9+
(5, 192, 6, 2, 4, 4, swish),
10+
(3, 320, 6, 1, 1, 4, swish),
1111
]
1212
# Data is organised as (r, (w, d))
1313
# r: image resolution
@@ -46,7 +46,7 @@ function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Intege
4646
scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2]
4747
layers = efficientnet(EFFICIENTNET_BLOCK_CONFIGS,
4848
fill(mbconv_builder, length(EFFICIENTNET_BLOCK_CONFIGS));
49-
scalings, inchannels, nclasses)
49+
inplanes = 32, scalings, inchannels, nclasses)
5050
if pretrain
5151
loadpretrain!(layers, string("efficientnet-", config))
5252
end

src/convnets/efficientnets/efficientnetv2.jl

+31-30
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,36 @@
11
# block configs for EfficientNetv2
2-
# data organised as (k, i, o, e, s, n)
2+
# data organised as (k, c, e, s, n, r, a)
33
const EFFNETV2_CONFIGS = Dict(:small => [
4-
(3, 24, 24, 1, 1, 2),
5-
(3, 24, 48, 4, 2, 4),
6-
(3, 48, 64, 4, 2, 4),
7-
(3, 64, 128, 4, 2, 6),
8-
(3, 128, 160, 6, 1, 9),
9-
(3, 160, 256, 6, 2, 15)],
4+
(3, 24, 1, 1, 2, nothing, swish),
5+
(3, 48, 4, 2, 4, nothing, swish),
6+
(3, 64, 4, 2, 4, nothing, swish),
7+
(3, 128, 4, 2, 6, 4, swish),
8+
(3, 160, 6, 1, 9, 4, swish),
9+
(3, 256, 6, 2, 15, 4, swish)],
1010
:medium => [
11-
(3, 24, 24, 1, 1, 3),
12-
(3, 24, 48, 4, 2, 5),
13-
(3, 48, 80, 4, 2, 5),
14-
(3, 80, 160, 4, 2, 7),
15-
(3, 160, 176, 6, 1, 14),
16-
(3, 176, 304, 6, 2, 18),
17-
(3, 304, 512, 6, 1, 5)],
11+
(3, 24, 1, 1, 3, nothing, swish),
12+
(3, 48, 4, 2, 5, nothing, swish),
13+
(3, 80, 4, 2, 5, nothing, swish),
14+
(3, 160, 4, 2, 7, 4, swish),
15+
(3, 176, 6, 1, 14, 4, swish),
16+
(3, 304, 6, 2, 18, 4, swish),
17+
(3, 512, 6, 1, 5, 4, swish)],
1818
:large => [
19-
(3, 32, 32, 1, 1, 4),
20-
(3, 32, 64, 4, 2, 7),
21-
(3, 64, 96, 4, 2, 7),
22-
(3, 96, 192, 4, 2, 10),
23-
(3, 192, 224, 6, 1, 19),
24-
(3, 224, 384, 6, 2, 25),
25-
(3, 384, 640, 6, 1, 7)],
19+
(3, 32, 1, 1, 4, nothing, swish),
20+
(3, 64, 4, 2, 7, nothing, swish),
21+
(3, 96, 4, 2, 7, nothing, swish),
22+
(3, 192, 4, 2, 10, 4, swish),
23+
(3, 224, 6, 1, 19, 4, swish),
24+
(3, 384, 6, 2, 25, 4, swish),
25+
(3, 640, 6, 1, 7, 4, swish)],
2626
:xlarge => [
27-
(3, 32, 32, 1, 1, 4),
28-
(3, 32, 64, 4, 2, 8),
29-
(3, 64, 96, 4, 2, 8),
30-
(3, 96, 192, 4, 2, 16),
31-
(3, 192, 224, 6, 1, 24),
32-
(3, 384, 512, 6, 2, 32),
33-
(3, 512, 768, 6, 1, 8)])
27+
(3, 32, 1, 1, 4, nothing, swish),
28+
(3, 64, 4, 2, 8, nothing, swish),
29+
(3, 96, 4, 2, 8, nothing, swish),
30+
(3, 192, 4, 2, 16, 4, swish),
31+
(3, 384, 6, 1, 24, 4, swish),
32+
(3, 512, 6, 2, 32, 4, swish),
33+
(3, 768, 6, 1, 8, 4, swish)])
3434

3535
"""
3636
EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1,
@@ -58,9 +58,10 @@ function EfficientNetv2(config::Symbol; pretrain::Bool = false,
5858
layers = efficientnet(EFFNETV2_CONFIGS[config],
5959
vcat(fill(fused_mbconv_builder, 3),
6060
fill(mbconv_builder, length(EFFNETV2_CONFIGS[config]) - 3));
61-
headplanes = 1280, inchannels, nclasses)
61+
inplanes = EFFNETV2_CONFIGS[config][1][2], headplanes = 1280,
62+
inchannels, nclasses)
6263
if pretrain
63-
loadpretrain!(layers, string("efficientnetv2"))
64+
loadpretrain!(layers, string("efficientnetv2-", config))
6465
end
6566
return EfficientNetv2(layers)
6667
end

src/convnets/mobilenets/mobilenetv2.jl

+20-23
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ Create a MobileNetv2 model.
1414
+ `c`: The number of output feature maps
1515
+ `n`: The number of times a block is repeated
1616
+ `s`: The stride of the convolutional kernel
17-
+ `a`: The activation function used in the bottleneck layer
1817
1918
- `width_mult`: Controls the number of output feature maps in each block
2019
(with 1 being the default in the paper)
@@ -24,41 +23,39 @@ Create a MobileNetv2 model.
2423
- `inchannels`: The number of input channels.
2524
- `nclasses`: The number of output classes
2625
"""
27-
function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
28-
max_width::Integer = 1280, divisor::Integer = 8, dropout_rate = 0.2,
26+
function mobilenetv2(block_configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
27+
max_width::Integer = 1280, divisor::Integer = 8,
28+
inplanes::Integer = 32, dropout_rate = 0.2,
2929
inchannels::Integer = 3, nclasses::Integer = 1000)
3030
# building first layer
31-
inplanes = _round_channels(32 * width_mult, divisor)
31+
inplanes = _round_channels(inplanes * width_mult, divisor)
3232
layers = []
3333
append!(layers,
3434
conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2))
3535
# building inverted residual blocks
36-
for (t, c, n, s, activation) in configs
37-
outplanes = _round_channels(c * width_mult, divisor)
38-
for i in 1:n
39-
stride = i == 1 ? s : 1
40-
push!(layers,
41-
mbconv((3, 3), inplanes, round(Int, inplanes * t), outplanes,
42-
activation; stride))
43-
inplanes = outplanes
44-
end
45-
end
36+
get_layers, block_repeats = mbconv_stack_builder(block_configs,
37+
fill(mbconv_builder,
38+
length(block_configs));
39+
inplanes)
40+
append!(layers, resnet_stages(get_layers, block_repeats, +))
4641
# building last layers
4742
outplanes = _round_channels(max_width * max(1, width_mult), divisor)
48-
append!(layers, conv_norm((1, 1), inplanes, outplanes, relu6))
43+
append!(layers,
44+
conv_norm((1, 1), _round_channels(block_configs[end][2], 8),
45+
outplanes, relu6))
4946
return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate))
5047
end
5148

5249
# Layer configurations for MobileNetv2
5350
const MOBILENETV2_CONFIGS = [
54-
# t, c, n, s, a
55-
(1, 16, 1, 1, relu6),
56-
(6, 24, 2, 2, relu6),
57-
(6, 32, 3, 2, relu6),
58-
(6, 64, 4, 2, relu6),
59-
(6, 96, 3, 1, relu6),
60-
(6, 160, 3, 2, relu6),
61-
(6, 320, 1, 1, relu6),
51+
# k, c, e, s, n, r, a
52+
(3, 16, 1, 1, 1, nothing, relu6),
53+
(3, 24, 6, 2, 2, nothing, relu6),
54+
(3, 32, 6, 2, 3, nothing, relu6),
55+
(3, 64, 6, 2, 4, nothing, relu6),
56+
(3, 96, 6, 1, 3, nothing, relu6),
57+
(3, 160, 6, 2, 3, nothing, relu6),
58+
(3, 320, 6, 1, 1, nothing, relu6),
6259
]
6360

6461
"""

src/layers/mbconv.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ function fused_mbconv(kernel_size::Dims{2}, inplanes::Integer,
5959
append!(layers, conv_norm((1, 1), explanes, outplanes, identity; norm_layer))
6060
else
6161
append!(layers,
62-
conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, stride))
62+
conv_norm(kernel_size, inplanes, outplanes, activation; pad = SamePad(),
63+
norm_layer, stride))
6364
end
6465
return Chain(layers...)
6566
end

0 commit comments

Comments
 (0)