Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 82 additions & 34 deletions Artifacts.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,75 @@
[resnet101]
git-tree-sha1 = "68d563526ab34d3e5aa66b7d96278d2acde212f9"
lazy = true

[[resnet101.download]]
sha256 = "0725f05db5772cfab1024b8d0d6c85ac1fc5a83eb6f0fe02b67b1e689d5a28db"
url = "https://huggingface.co/FluxML/resnet101/resolve/980158099e6917b74ade2b0a9599359f06057d21/resnet101.tar.gz"

[resnet152]
git-tree-sha1 = "85a97464b6cef66e1217ae829d3620651cffab47"
lazy = true

[[resnet152.download]]
sha256 = "a8d30a735ef5649ec40a74a0515ee3d6774499267be06f5f2b372259c5ced8d6"
url = "https://huggingface.co/FluxML/resnet152/resolve/a66a3e1f5056179d167cb2165401950e3890b34d/resnet152.tar.gz"

[resnet18]
git-tree-sha1 = "4ced5a0338c0f0293940f1deb63e1c463125a6ff"
lazy = true

[[resnet18.download]]
sha256 = "9444ef2285f507bd890d2ca852d663749f079110ed19b544e8d91f67f3cc6b83"
url = "https://huggingface.co/FluxML/resnet18/resolve/9b1c6c4f7c5dbe734d80d7d4b5f132ef58bf2467/resnet18.tar.gz"

[resnet34]
git-tree-sha1 = "485519977f375ca1770b3ff3971f61e438823f5a"
lazy = true

[[resnet34.download]]
sha256 = "71ed75be6db0160af7f30be33e2f4a44836310949d1374267510e5803b1fb313"
url = "https://huggingface.co/FluxML/resnet34/resolve/0988ae2d4a86da06eefa6b61edf3e728861e286c/resnet34.tar.gz"

[resnet50]
git-tree-sha1 = "2973be0da60544080105756ecb3951cca2e007da"
lazy = true

[[resnet50.download]]
sha256 = "60ad32eaf160444f3bfdb6f6d81ec1e5c36a3769be7df22aaa75127b16bb1501"
url = "https://huggingface.co/FluxML/resnet50/resolve/1529d6ddca42e3e705cb708c9de6f79188ce8ad5/resnet50.tar.gz"

[resnext101_32x8d]
git-tree-sha1 = "d13a85131b2c0c62ef2af79a09137d4e0760a685"
lazy = true

[[resnext101_32x8d.download]]
sha256 = "aeb48f86f50ee8b0ca7dc01ca0ff5a2d2b2163e43c524203c4a8bd589db9bcc6"
url = "https://huggingface.co/FluxML/resnext101_32x8d/resolve/e060f030c445f644112efa2a00e3c544944046e1/resnext101_32x8d.tar.gz"

[resnext101_64x4d]
git-tree-sha1 = "db50f48614e673a40f98fb80a17688b34f42067a"
lazy = true

[[resnext101_64x4d.download]]
sha256 = "89764dd7dc3b3432f0424cb592cec5d9db2fb802ab1646f0e3c2cca2b2e5386b"
url = "https://huggingface.co/FluxML/resnext101_64x4d/resolve/0d0485da04efe5a53289a560d105c42d3ca5435c/resnext101_64x4d.tar.gz"

[resnext50_32x4d]
git-tree-sha1 = "1e7a08a4acae690b635e8d1caa06e75eeb2dd2fe"
lazy = true

[[resnext50_32x4d.download]]
sha256 = "084ccbc40fde07496c401ee2bc389b9cd1d60b1ac3b7ccbfde05479ea91ca707"
url = "https://huggingface.co/FluxML/resnext50_32x4d/resolve/150c52c9646fe697030d38ab2be767564fb4f28c/resnext50_32x4d.tar.gz"

[squeezenet]
git-tree-sha1 = "e2eeee109fda46470d657b13669cca09d5ef2f8c"
lazy = true

[[squeezenet.download]]
sha256 = "aebfa06f44767e5ff728b7b67b2d01352b4618bd5d305c26e603aabcd5ba593d"
url = "https://huggingface.co/FluxML/squeezenet/resolve/01ef4221df5260bd992c669ab587eed74df0c39f/squeezenet.tar.gz"

[vgg11]
git-tree-sha1 = "78ffe7d74c475cc28175f9e23a545ce2f17b1520"
lazy = true
Expand Down Expand Up @@ -30,42 +102,18 @@ lazy = true
sha256 = "5fe26391572b9f6ac84eaa0541d27e959f673f82e6515026cdcd3262cbd93ceb"
url = "https://huggingface.co/FluxML/vgg19/resolve/88e9056f60b054eccdc190a2eeb23731d5c693b6/vgg19.tar.gz"

[resnet18]
git-tree-sha1 = "7b555ed2708e551bfdbcb7e71b25001f4b3731c6"
lazy = true

[[resnet18.download]]
sha256 = "d5782fd873a3072df251c7a4b3cf16efca8ee1da1180ff815bc107833f84bb26"
url = "https://huggingface.co/FluxML/resnet18/resolve/ef9c74047fda4a4a503b1f72553ec05acc90929f/resnet18.tar.gz"

[resnet34]
git-tree-sha1 = "e6e79666cd0fc81cd828508314e6c7f66df8d43d"
lazy = true

[[resnet34.download]]
sha256 = "a8dec13609a86f7a2adac6a44b3af912a863bc2d7319120066c5fdaa04c3f395"
url = "https://huggingface.co/FluxML/resnet34/resolve/42061ddb463902885eea4fcc85275462a5445987/resnet34.tar.gz"

[resnet50]
git-tree-sha1 = "5c442ffd6c51a70c3bc36d849fca86beced446d4"
lazy = true

[[resnet50.download]]
sha256 = "5325920ec91c2a4499ad7e659961f9eaac2b1a3a2905ca6410eaa593ecd35503"
url = "https://huggingface.co/FluxML/resnet50/resolve/10e601719e1cd5b0cab87ce7fd1e8f69a07ce042/resnet50.tar.gz"

[resnet101]
git-tree-sha1 = "694a8563ec20fb826334dd663d532b10bb2b3c97"
[wideresnet101]
git-tree-sha1 = "b881a9469fb230faff414ce9f983bc113061ab1c"
lazy = true

[[resnet101.download]]
sha256 = "f4d737ce640957c30f76bfa642fc9da23e6852d81474d58a2338c1148e55bff0"
url = "https://huggingface.co/FluxML/resnet101/resolve/ea37819163cc3f4a41989a6239ce505e483b112d/resnet101.tar.gz"
[[wideresnet101.download]]
sha256 = "defa61fd80a988bb07bb9db00c692d8d0a30d95e6276add1413fb7f1f3aa2607"
url = "https://huggingface.co/FluxML/wideresnet101/resolve/ad4df1016bb5eba4c10e2d37b049ca5d2a455670/wideresnet101.tar.gz"

[resnet152]
git-tree-sha1 = "55eb883248a276d710d75ecaecfbd2427e50cc0a"
[wideresnet50]
git-tree-sha1 = "bbc6bc632e743c992784b5121dcb0f6082c66b1f"
lazy = true

[[resnet152.download]]
sha256 = "57be335e6828d1965c9d11f933d2d41f51e5e534f9bfdbde01c6144fa8862a4d"
url = "https://huggingface.co/FluxML/resnet152/resolve/ba28814d5746643387b5c0e1d2269104e5e9bc8d/resnet152.tar.gz"
[[wideresnet50.download]]
sha256 = "7596a67b7aba762c2bfce8367055da79a8a3c117bd79ce11124c9f1f5a96c4e3"
url = "https://huggingface.co/FluxML/wideresnet50/resolve/5eca9979f5d9438a684b1e6a5b227f9d5611965a/wideresnet50.tar.gz"
9 changes: 4 additions & 5 deletions src/convnets/densenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ Create a DenseNet model
- `nclasses`: the number of output classes
"""
function densenet(nblocks::AbstractVector{<:Integer}; growth_rate::Integer = 32,
reduction = 0.5,
inchannels::Integer = 3, nclasses::Integer = 1000)
reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)
return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks];
reduction, inchannels, nclasses)
end
Expand Down Expand Up @@ -133,11 +132,11 @@ end
function DenseNet(config::Integer; pretrain::Bool = false, growth_rate::Integer = 32,
reduction = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)
_checkconfig(config, keys(DENSENET_CONFIGS))
model = densenet(DENSENET_CONFIGS[config]; growth_rate, reduction, inchannels, nclasses)
layers = densenet(DENSENET_CONFIGS[config]; growth_rate, reduction, inchannels, nclasses)
if pretrain
loadpretrain!(model, string("DenseNet", config))
loadpretrain!(layers, string("densenet", config))
end
return model
return DenseNet(layers)
end

(m::DenseNet)(x) = m.layers(x)
Expand Down
5 changes: 1 addition & 4 deletions src/convnets/inception/inceptionv3.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
## Inceptionv3

"""
inceptionv3_a(inplanes, pool_proj)

Expand Down Expand Up @@ -176,6 +174,7 @@ See also [`inceptionv3`](#).
struct Inceptionv3
layers::Any
end
@functor Inceptionv3

function Inceptionv3(; pretrain::Bool = false, inchannels::Integer = 3,
nclasses::Integer = 1000)
Expand All @@ -186,8 +185,6 @@ function Inceptionv3(; pretrain::Bool = false, inchannels::Integer = 3,
return Inceptionv3(layers)
end

@functor Inceptionv3

(m::Inceptionv3)(x) = m.layers(x)

backbone(m::Inceptionv3) = m.layers[1]
Expand Down
8 changes: 3 additions & 5 deletions src/convnets/resnets/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,13 @@ function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, con
end

function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer},
connection,
classifier_fn)
connection, classifier_fn)
# Build stages of the ResNet
stage_blocks = resnet_stages(get_layers, block_repeats, connection)
backbone = Chain(stem, stage_blocks)
# Build the classifier head
# Add classifier to the backbone
nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3]
classifier = classifier_fn(nfeaturemaps)
return Chain(backbone, classifier)
return Chain(backbone, classifier_fn(nfeaturemaps))
end

function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer};
Expand Down
10 changes: 1 addition & 9 deletions src/convnets/resnets/resnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ Creates a ResNet model with the specified depth.
- `inchannels`: The number of input channels.
- `nclasses`: the number of output classes

!!! warning

`ResNet` does not currently support pretrained weights.

Advanced users who want more configuration options will be better served by using [`resnet`](#).
"""
struct ResNet
Expand All @@ -27,7 +23,7 @@ function ResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3,
_checkconfig(depth, keys(RESNET_CONFIGS))
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses)
if pretrain
loadpretrain!(layers, string("ResNet", depth))
loadpretrain!(layers, string("resnet", depth))
end
return ResNet(layers)
end
Expand All @@ -52,10 +48,6 @@ The number of channels in outer 1x1 convolutions is the same.
- `inchannels`: The number of input channels.
- `nclasses`: the number of output classes

!!! warning

`WideResNet` does not currently support pretrained weights.

Advanced users who want more configuration options will be better served by using [`resnet`](#).
"""
struct WideResNet
Expand Down
20 changes: 10 additions & 10 deletions src/convnets/resnets/resnext.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
"""
ResNeXt(depth::Integer; pretrain::Bool = false, cardinality = 32,
base_width = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
ResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32,
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)

Creates a ResNeXt model with the specified depth, cardinality, and base width.
((reference)[https://arxiv.org/abs/1611.05431])

# Arguments

- `depth`: one of `[18, 34, 50, 101, 152]`. The depth of the ResNet model.
- `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet
- `pretrain`: set to `true` to load the model with pre-trained weights for ImageNet.
Supported configurations are:
- depth 50, cardinality of 32 and base width of 4.
- depth 101, cardinality of 32 and base width of 8.
- depth 101, cardinality of 64 and base width of 4.
- `cardinality`: the number of groups to be used in the 3x3 convolution in each block.
- `base_width`: the number of feature maps in each group.
- `inchannels`: the number of input channels.
- `nclasses`: the number of output classes

!!! warning

`ResNeXt` does not currently support pretrained weights.

Advanced users who want more configuration options will be better served by using [`resnet`](#).
"""
struct ResNeXt
Expand All @@ -27,12 +27,12 @@ end

(m::ResNeXt)(x) = m.layers(x)

function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality = 32,
base_width = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32,
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width)
if pretrain
loadpretrain!(layers, string("ResNeXt", depth, "_", cardinality, "x", base_width))
loadpretrain!(layers, string("resnext", depth, "_", cardinality, "x", base_width, "d"))
end
return ResNeXt(layers)
end
Expand Down
12 changes: 6 additions & 6 deletions src/convnets/resnets/seresnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function SEResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer =
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses,
attn_fn = squeeze_excite)
if pretrain
loadpretrain!(layers, string("SEResNet", depth))
loadpretrain!(layers, string("seresnet", depth))
end
return SEResNet(layers)
end
Expand All @@ -39,8 +39,8 @@ backbone(m::SEResNet) = m.layers[1]
classifier(m::SEResNet) = m.layers[2]

"""
SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality = 32, base_width = 4,
inchannels::Integer = 3, nclasses::Integer = 1000)
SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32,
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)

Creates a SEResNeXt model with the specified depth, cardinality, and base width.
((reference)[https://arxiv.org/pdf/1709.01507.pdf])
Expand All @@ -67,13 +67,13 @@ end

(m::SEResNeXt)(x) = m.layers(x)

function SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality = 32, base_width = 4,
inchannels::Integer = 3, nclasses::Integer = 1000)
function SEResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32,
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width,
attn_fn = squeeze_excite)
if pretrain
loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width))
loadpretrain!(layers, string("seresnext", depth, "_", cardinality, "x", base_width))
end
return SEResNeXt(layers)
end
Expand Down
6 changes: 1 addition & 5 deletions src/convnets/squeezenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ Create a SqueezeNet
- `inchannels`: number of input channels.
- `nclasses`: the number of output classes.

!!! warning

`SqueezeNet` does not currently support pretrained weights.

See also [`squeezenet`](#).
"""
struct SqueezeNet
Expand All @@ -77,7 +73,7 @@ function SqueezeNet(; pretrain::Bool = false, inchannels::Integer = 3,
nclasses::Integer = 1000)
layers = squeezenet(; inchannels, nclasses)
if pretrain
loadpretrain!(layers, "SqueezeNet")
loadpretrain!(layers, "squeezenet")
end
return SqueezeNet(layers)
end
Expand Down
15 changes: 9 additions & 6 deletions src/layers/mlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,19 @@ Creates a classifier head to be used for models.
function create_classifier(inplanes::Integer, nclasses::Integer, activation = identity;
use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)),
dropout_rate = nothing)
# Pooling
# Decide whether to flatten the input or not
flatten_in_pool = !use_conv && pool_layer !== identity
if use_conv
@assert pool_layer === identity
"`pool_layer` must be identity if `use_conv` is true"
end
global_pool = flatten_in_pool ? [pool_layer, MLUtils.flatten] : [pool_layer]
classifier = []
flatten_in_pool ? push!(classifier, pool_layer, MLUtils.flatten) :
push!(classifier, pool_layer)
# Dropout is applied after the pooling layer
isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate))
# Fully-connected layer
fc = use_conv ? Conv((1, 1), inplanes => nclasses, activation) :
Dense(inplanes => nclasses, activation)
drop = isnothing(dropout_rate) ? [] : [Dropout(dropout_rate)]
return Chain(global_pool..., drop..., fc)
use_conv ? push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) :
push!(classifier, Dense(inplanes => nclasses, activation))
return Chain(classifier...)
end
3 changes: 1 addition & 2 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,5 @@ end

# Utility function for depth and configuration checks in models
function _checkconfig(config, configs)
@assert config in configs
return "Invalid configuration. Must be one of $(sort(collect(configs)))."
@assert config in configs "Invalid configuration. Must be one of $(sort(collect(configs)))."
end
15 changes: 7 additions & 8 deletions test/convnets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,14 @@ end

@testset "ResNet" begin
# Tests for pretrained ResNets
## TODO: find a way to port pretrained models to the new ResNet API
@testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152]
m = ResNet(sz)
@test size(m(x_224)) == (1000, 1)
# if (ResNet, sz) in PRETRAINED_MODELS
# @test acctest(ResNet(sz, pretrain = true))
# else
# @test_throws ArgumentError ResNet(sz, pretrain = true)
# end
if (ResNet, sz) in PRETRAINED_MODELS
@test acctest(ResNet(sz, pretrain = true))
else
@test_throws ArgumentError ResNet(sz, pretrain = true)
end
end

@testset "resnet" begin
Expand Down Expand Up @@ -79,9 +78,9 @@ end
m = ResNeXt(depth; cardinality, base_width)
@test size(m(x_224)) == (1000, 1)
if (ResNeXt, depth, cardinality, base_width) in PRETRAINED_MODELS
@test acctest(ResNeXt(depth, pretrain = true))
@test acctest(ResNeXt(depth; cardinality, base_width, pretrain = true))
else
@test_throws ArgumentError ResNeXt(depth, pretrain = true)
@test_throws ArgumentError ResNeXt(depth; cardinality, base_width, pretrain = true)
end
@test gradtest(m, x_224)
_gc()
Expand Down
Loading