Skip to content

Commit 62fcac3

Browse files
committed
Some refactors, some consistency, some features
1 parent 34caab4 commit 62fcac3

File tree

16 files changed

+209
-119
lines changed

16 files changed

+209
-119
lines changed

.github/workflows/CI.yml

+1-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ jobs:
2828
suite:
2929
- '["AlexNet", "VGG"]'
3030
- '["GoogLeNet", "SqueezeNet", "MobileNet"]'
31-
- '"EfficientNet"'
32-
- '"EfficientNetv2"'
31+
- '"EfficientNet"'
3332
- 'r"/*/ResNet*"'
3433
- '[r"ResNeXt", r"SEResNet"]'
3534
- '[r"Res2Net", r"Res2NeXt"]'

src/convnets/convmixer.jl

+10-5
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,21 @@ Creates a ConvMixer model.
1717
- `nclasses`: number of classes in the output
1818
"""
1919
function convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9),
20-
patch_size::Dims{2} = (7, 7), activation = gelu,
20+
patch_size::Dims{2} = (7, 7), activation = gelu, dropout_rate = nothing,
2121
inchannels::Integer = 3, nclasses::Integer = 1000)
22-
stem = conv_norm(patch_size, inchannels, planes, activation; preact = true,
23-
stride = patch_size[1])
24-
blocks = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation;
22+
layers = []
23+
# stem of the model
24+
append!(layers,
25+
conv_norm(patch_size, inchannels, planes, activation; preact = true,
26+
stride = patch_size[1]))
27+
# stages of the model
28+
stages = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation;
2529
preact = true, groups = planes,
2630
pad = SamePad())), +),
2731
conv_norm((1, 1), planes, planes, activation; preact = true)...)
2832
for _ in 1:depth]
29-
return Chain(Chain(stem..., Chain(blocks...)), create_classifier(planes, nclasses))
33+
append!(layers, stages)
34+
return Chain(Chain(layers...), create_classifier(planes, nclasses; dropout_rate))
3035
end
3136

3237
const CONVMIXER_CONFIGS = Dict(:base => ((1536, 20),

src/convnets/densenet.jl

+9-6
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ function dense_block(inplanes::Integer, growth_rates)
5555
end
5656

5757
"""
58-
densenet(inplanes, growth_rates; reduction = 0.5, nclasses::Integer = 1000)
58+
densenet(inplanes, growth_rates; reduction = 0.5, dropout_rate = nothing,
59+
inchannels::Integer = 3, nclasses::Integer = 1000)
5960
6061
Create a DenseNet model
6162
([reference](https://arxiv.org/abs/1608.06993)).
@@ -66,10 +67,11 @@ Create a DenseNet model
6667
- `growth_rates`: the growth rates of output feature maps within each
6768
[`dense_block`](#) (a vector of vectors)
6869
- `reduction`: the factor by which the number of feature maps is scaled across each transition
70+
- `dropout_rate`: the dropout rate for the classifier head. Set to `nothing` to disable dropout.
6971
- `nclasses`: the number of output classes
7072
"""
71-
function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels::Integer = 3,
72-
nclasses::Integer = 1000)
73+
function densenet(inplanes::Integer, growth_rates; reduction = 0.5, dropout_rate = nothing,
74+
inchannels::Integer = 3, nclasses::Integer = 1000)
7375
layers = []
7476
append!(layers,
7577
conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3)))
@@ -83,7 +85,7 @@ function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels::
8385
inplanes = floor(Int, outplanes * reduction)
8486
end
8587
push!(layers, BatchNorm(outplanes, relu))
86-
return Chain(Chain(layers...), create_classifier(outplanes, nclasses))
88+
return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate))
8789
end
8890

8991
"""
@@ -100,9 +102,10 @@ Create a DenseNet model
100102
- `nclasses`: the number of output classes
101103
"""
102104
function densenet(nblocks::AbstractVector{<:Integer}; growth_rate::Integer = 32,
103-
reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)
105+
reduction = 0.5, dropout_rate = nothing, inchannels::Integer = 3,
106+
nclasses::Integer = 1000)
104107
return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks];
105-
reduction, inchannels, nclasses)
108+
reduction, dropout_rate, inchannels, nclasses)
106109
end
107110

108111
const DENSENET_CONFIGS = Dict(121 => [6, 12, 24, 16],

src/convnets/efficientnets/core.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ struct MBConvConfig <: _MBConfig
44
kernel_size::Dims{2}
55
inplanes::Integer
66
outplanes::Integer
7-
expansion::Number
7+
expansion::Real
88
stride::Integer
99
nrepeats::Integer
1010
end
1111
function MBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer,
12-
expansion::Number, stride::Integer, nrepeats::Integer,
13-
width_mult::Number = 1, depth_mult::Number = 1)
12+
expansion::Real, stride::Integer, nrepeats::Integer,
13+
width_mult::Real = 1, depth_mult::Real = 1)
1414
inplanes = _round_channels(inplanes * width_mult, 8)
1515
outplanes = _round_channels(outplanes * width_mult, 8)
1616
nrepeats = ceil(Int, nrepeats * depth_mult)
@@ -35,12 +35,12 @@ struct FusedMBConvConfig <: _MBConfig
3535
kernel_size::Dims{2}
3636
inplanes::Integer
3737
outplanes::Integer
38-
expansion::Number
38+
expansion::Real
3939
stride::Integer
4040
nrepeats::Integer
4141
end
4242
function FusedMBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer,
43-
expansion::Number, stride::Integer, nrepeats::Integer)
43+
expansion::Real, stride::Integer, nrepeats::Integer)
4444
return FusedMBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion,
4545
stride, nrepeats)
4646
end

src/convnets/inceptions/inceptionresnetv2.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,18 @@ function block8(scale = 1.0f0; activation = identity)
6464
end
6565

6666
"""
67-
inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = 0.0, nclasses::Integer = 1000)
67+
inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = nothing, nclasses::Integer = 1000)
6868
6969
Creates an InceptionResNetv2 model.
7070
([reference](https://arxiv.org/abs/1602.07261))
7171
7272
# Arguments
7373
7474
- `inchannels`: number of input channels.
75-
- `dropout_rate`: rate of dropout in classifier head.
75+
- `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout.
7676
- `nclasses`: the number of output classes.
7777
"""
78-
function inceptionresnetv2(; dropout_rate = 0.0, inchannels::Integer = 3,
78+
function inceptionresnetv2(; dropout_rate = nothing, inchannels::Integer = 3,
7979
nclasses::Integer = 1000)
8080
backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)...,
8181
basic_conv_bn((3, 3), 32, 32)...,

src/convnets/inceptions/inceptionv4.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,18 @@ function inceptionv4_c()
8585
end
8686

8787
"""
88-
inceptionv4(; inchannels::Integer = 3, dropout_rate = 0.0, nclasses::Integer = 1000)
88+
inceptionv4(; inchannels::Integer = 3, dropout_rate = nothing, nclasses::Integer = 1000)
8989
9090
Create an Inceptionv4 model.
9191
([reference](https://arxiv.org/abs/1602.07261))
9292
9393
# Arguments
9494
9595
- `inchannels`: number of input channels.
96-
- `dropout_rate`: rate of dropout in classifier head.
96+
- `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout.
9797
- `nclasses`: the number of output classes.
9898
"""
99-
function inceptionv4(; dropout_rate = 0.0, inchannels::Integer = 3,
99+
function inceptionv4(; dropout_rate = nothing, inchannels::Integer = 3,
100100
nclasses::Integer = 1000)
101101
backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)...,
102102
basic_conv_bn((3, 3), 32, 32)...,

src/convnets/inceptions/xception.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int
4343
end
4444

4545
"""
46-
xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000)
46+
xception(; dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000)
4747
4848
Creates an Xception model.
4949
([reference](https://arxiv.org/abs/1610.02357))
5050
5151
# Arguments
5252
53-
- `dropout_rate`: rate of dropout in classifier head.
53+
- `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout.
5454
- `inchannels`: number of input channels.
5555
- `nclasses`: the number of output classes.
5656
"""

src/convnets/mobilenets/mobilenetv1.jl

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""
2-
mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu,
2+
mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple};
3+
activation = relu, dropout_rate = nothing,
34
inchannels::Integer = 3, nclasses::Integer = 1000)
45
56
Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)).
@@ -16,10 +17,12 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)).
1617
+ `s`: The stride of the convolutional kernel
1718
+ `r`: The number of time this configuration block is repeated
1819
- `activate`: The activation function to use throughout the network
20+
- `dropout_rate`: The dropout rate to use in the classifier head. Set to `nothing` to disable.
1921
- `inchannels`: The number of input channels. The default value is 3.
2022
- `nclasses`: The number of output classes
2123
"""
22-
function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activation = relu,
24+
function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple};
25+
activation = relu, dropout_rate = nothing,
2326
inchannels::Integer = 3, nclasses::Integer = 1000)
2427
layers = []
2528
for (dw, outchannels, stride, nrepeats) in config
@@ -33,7 +36,7 @@ function mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; activati
3336
inchannels = outchannels
3437
end
3538
end
36-
return Chain(Chain(layers...), create_classifier(inchannels, nclasses))
39+
return Chain(Chain(layers...), create_classifier(inchannels, nclasses; dropout_rate))
3740
end
3841

3942
# Layer configurations for MobileNetv1

src/convnets/mobilenets/mobilenetv2.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Create a MobileNetv2 model.
2020
(with 1 being the default in the paper)
2121
- `max_width`: The maximum number of feature maps in any layer of the network
2222
- `divisor`: The divisor used to round the number of feature maps in each block
23-
- `dropout_rate`: rate of dropout in the classifier head
23+
- `dropout_rate`: rate of dropout in the classifier head. Set to `nothing` to disable dropout.
2424
- `inchannels`: The number of input channels.
2525
- `nclasses`: The number of output classes
2626
"""
@@ -33,12 +33,13 @@ function mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
3333
append!(layers,
3434
conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2))
3535
# building inverted residual blocks
36-
for (t, c, n, s, a) in configs
36+
for (t, c, n, s, activation) in configs
3737
outplanes = _round_channels(c * width_mult, divisor)
3838
for i in 1:n
39+
stride = i == 1 ? s : 1
3940
push!(layers,
40-
mbconv((3, 3), inplanes, round(Int, inplanes * t), outplanes, a;
41-
stride = i == 1 ? s : 1))
41+
mbconv((3, 3), inplanes, round(Int, inplanes * t), outplanes,
42+
activation; stride))
4243
inplanes = outplanes
4344
end
4445
end

src/convnets/mobilenets/mobilenetv3.jl

+24-13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
3-
max_width::Integer = 1024, inchannels::Integer = 3,
4-
nclasses::Integer = 1000)
3+
max_width::Integer = 1024, dropout_rate = 0.2,
4+
inchannels::Integer = 3, nclasses::Integer = 1000)
55
66
Create a MobileNetv3 model.
77
([reference](https://arxiv.org/abs/1905.02244)).
@@ -19,38 +19,49 @@ Create a MobileNetv3 model.
1919
2020
- `width_mult`: Controls the number of output feature maps in each block
2121
(with 1 being the default in the paper; this is usually a value between 0.1 and 1.4.)
22-
- `inchannels`: The number of input channels.
2322
- `max_width`: The maximum number of feature maps in any layer of the network
23+
- `dropout_rate`: The dropout rate to use in the classifier head. Set to `nothing` to disable.
24+
- `inchannels`: The number of input channels.
2425
- `nclasses`: the number of output classes
2526
"""
2627
function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
27-
max_width::Integer = 1024, dropout_rate = 0.2,
28+
max_width::Integer = 1024, reduced_tail::Bool = false,
29+
tail_dilated::Bool = false, dropout_rate = 0.2,
2830
inchannels::Integer = 3, nclasses::Integer = 1000)
2931
# building first layer
3032
inplanes = _round_channels(16 * width_mult, 8)
3133
layers = []
3234
append!(layers,
3335
conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1))
3436
explanes = 0
37+
nstages = length(configs)
38+
reduced_divider = 1
3539
# building inverted residual blocks
36-
for (k, t, c, reduction, activation, stride) in configs
40+
for (i, (k, t, c, reduction, activation, stride)) in enumerate(configs)
41+
dilation = 1
42+
if nstages - i <= 2
43+
if reduced_tail
44+
reduced_divider = 2
45+
c /= reduced_divider
46+
end
47+
if tail_dilated
48+
dilation = 2
49+
end
50+
end
3751
# inverted residual layers
3852
outplanes = _round_channels(c * width_mult, 8)
3953
explanes = _round_channels(inplanes * t, 8)
4054
push!(layers,
4155
mbconv((k, k), inplanes, explanes, outplanes, activation;
42-
stride, reduction))
56+
stride, reduction, dilation))
4357
inplanes = outplanes
4458
end
4559
# building last layers
46-
headplanes = width_mult > 1.0 ? _round_channels(max_width * width_mult, 8) :
47-
max_width
60+
headplanes = _round_channels(max_width ÷ reduced_divider * width_mult, 8)
4861
append!(layers, conv_norm((1, 1), inplanes, explanes, hardswish))
49-
classifier = Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten,
50-
Dense(explanes, headplanes, hardswish),
51-
Dropout(dropout_rate),
52-
Dense(headplanes, nclasses))
53-
return Chain(Chain(layers...), classifier)
62+
return Chain(Chain(layers...),
63+
create_classifier(explanes, headplanes, nclasses,
64+
(hardswish, identity); dropout_rate))
5465
end
5566

5667
# Layer configurations for small and large models for MobileNetv3

src/convnets/resnets/core.jl

+9-10
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,11 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
2727
drop_block = identity, drop_path = identity,
2828
attn_fn = planes -> identity)
2929
first_planes = planes ÷ reduction_factor
30-
outplanes = planes
3130
conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, revnorm,
3231
stride, pad = 1)
33-
conv_bn2 = conv_norm((3, 3), first_planes => outplanes, identity; norm_layer, revnorm,
32+
conv_bn2 = conv_norm((3, 3), first_planes => planes, identity; norm_layer, revnorm,
3433
pad = 1)
35-
layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(outplanes),
34+
layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(planes),
3635
drop_path]
3736
return Chain(filter!(!=(identity), layers)...)
3837
end
@@ -201,7 +200,7 @@ function basicblock_builder(block_repeats::AbstractVector{<:Integer};
201200
expansion::Integer = 1, norm_layer = BatchNorm,
202201
revnorm::Bool = false, activation = relu,
203202
attn_fn = planes -> identity,
204-
drop_block_rate = 0.0, drop_path_rate = 0.0,
203+
drop_block_rate = nothing, drop_path_rate = nothing,
205204
stride_fn = resnet_stride, planes_fn = resnet_planes,
206205
downsample_tuple = (downsample_conv, downsample_identity))
207206
pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats))
@@ -236,7 +235,7 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer};
236235
expansion::Integer = 4, norm_layer = BatchNorm,
237236
revnorm::Bool = false, activation = relu,
238237
attn_fn = planes -> identity,
239-
drop_block_rate = 0.0, drop_path_rate = 0.0,
238+
drop_block_rate = nothing, drop_path_rate = nothing,
240239
stride_fn = resnet_stride, planes_fn = resnet_planes,
241240
downsample_tuple = (downsample_conv, downsample_identity))
242241
pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats))
@@ -295,8 +294,8 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer},
295294
inchannels::Integer = 3, stem_fn = resnet_stem, connection = addact,
296295
activation = relu, norm_layer = BatchNorm, revnorm::Bool = false,
297296
attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool((1, 1)),
298-
use_conv::Bool = false, drop_block_rate = 0.0, drop_path_rate = 0.0,
299-
dropout_rate = 0.0, nclasses::Integer = 1000, kwargs...)
297+
use_conv::Bool = false, drop_block_rate = nothing, drop_path_rate = nothing,
298+
dropout_rate = nothing, nclasses::Integer = 1000, kwargs...)
300299
# Build stem
301300
stem = stem_fn(; inchannels)
302301
# Block builder
@@ -319,8 +318,8 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer},
319318
downsample_tuple = downsample_opt,
320319
kwargs...)
321320
elseif block_type == bottle2neck
322-
@assert drop_block_rate==0.0 "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to 0.0"
323-
@assert drop_path_rate==0.0 "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to 0.0"
321+
@assert isnothing(drop_block_rate) "DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing"
322+
@assert isnothing(drop_path_rate) "DropPath not supported for `bottle2neck`. Set `drop_path_rate` to nothing"
324323
@assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`. Set `reduction_factor` to 1"
325324
get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width,
326325
activation, norm_layer, revnorm, attn_fn,
@@ -347,7 +346,7 @@ const RESNET_CONFIGS = Dict(18 => (basicblock, [2, 2, 2, 2]),
347346
50 => (bottleneck, [3, 4, 6, 3]),
348347
101 => (bottleneck, [3, 4, 23, 3]),
349348
152 => (bottleneck, [3, 8, 36, 3]))
350-
349+
# larger ResNet-like models
351350
const LRESNET_CONFIGS = Dict(50 => (bottleneck, [3, 4, 6, 3]),
352351
101 => (bottleneck, [3, 4, 23, 3]),
353352
152 => (bottleneck, [3, 8, 36, 3]))

src/layers/Layers.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ include("embeddings.jl")
2828
export PatchEmbedding, ViPosEmbedding, ClassTokens
2929

3030
include("mlp.jl")
31-
export mlp_block, gated_mlp_block, create_fc, create_classifier
31+
export mlp_block, gated_mlp_block
32+
33+
include("classifier.jl")
34+
export create_classifier
3235

3336
include("normalise.jl")
3437
export prenorm, ChannelLayerNorm

0 commit comments

Comments
 (0)