Skip to content

Commit bac14c0

Browse files
allow loading custom weights files for EfficientNet
1 parent 43e0e9d commit bac14c0

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

src/convnets/efficientnets/efficientnet.jl

+10-5
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,21 @@ function efficientnet(config::Symbol; norm_layer = BatchNorm, stochastic_depth_p
6060
end
6161

6262
"""
63-
EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3,
63+
EfficientNet(config::Symbol; pretrain::Union{Bool,String} = false, inchannels::Integer = 3,
6464
nclasses::Integer = 1000)
6565
6666
Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
6767
6868
# Arguments
6969
7070
- `config`: size of the model. Can be one of `[:b0, :b1, :b2, :b3, :b4, :b5, :b6, :b7, :b8]`.
71-
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
71+
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet, or provide a local path string to load a
72+
custom weights file.
7273
- `inchannels`: number of input channels.
7374
- `nclasses`: number of output classes.
7475
7576
!!! warning
76-
77+
7778
EfficientNet does not currently support pretrained weights.
7879
7980
See also [`Metalhead.efficientnet`](@ref).
@@ -83,12 +84,16 @@ struct EfficientNet
8384
end
8485
@functor EfficientNet
8586

86-
function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3,
87+
function EfficientNet(config::Symbol; pretrain::Union{Bool,String} = false, inchannels::Integer = 3,
8788
nclasses::Integer = 1000)
8889
layers = efficientnet(config; inchannels, nclasses)
8990
model = EfficientNet(layers)
90-
if pretrain
91+
if pretrain === true
9192
loadpretrain!(model, string("efficientnet_", config))
93+
elseif pretrain isa String
94+
isfile(pretrain) || error("Weights file does not exist at `$pretrain`")
95+
m = load_weights_file(pretrain)
96+
Flux.loadmodel!(model, m)
9297
end
9398
return model
9499
end

src/pretrain.jl

+10-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Load the pre-trained weights for `model` using the stored artifacts.
55
"""
66
function loadweights(artifact_name)
77
artifact_dir = try
8-
@artifact_str(artifact_name)
8+
@artifact_str(artifact_name)
99
catch e
1010
throw(ArgumentError("No pre-trained weights available for $artifact_name."))
1111
end
@@ -23,15 +23,19 @@ function loadweights(artifact_name)
2323
end
2424

2525
file_path = joinpath(artifact_dir, file_name)
26-
27-
if endswith(file_name, ".bson")
26+
27+
return load_weights_file(file_path)
28+
end
29+
30+
function load_weights_file(file_path::String)
31+
if endswith(file_path, ".bson")
2832
artifact = BSON.load(file_path, @__MODULE__)
2933
if haskey(artifact, :model_state)
3034
return artifact[:model_state]
3135
elseif haskey(artifact, :model)
3236
return artifact[:model]
3337
else
34-
throw(ErrorException("Found weight artifact for $artifact_name but the weights are not saved under the key :model_state or :model."))
38+
throw(ErrorException("Weights in the file `$file_path` are not saved under the key :model_state or :model."))
3539
end
3640
elseif endswith(file_path, ".jld2")
3741
artifact = JLD2.load(file_path)
@@ -40,10 +44,10 @@ function loadweights(artifact_name)
4044
elseif haskey(artifact, "model")
4145
return artifact["model"]
4246
else
43-
throw(ErrorException("Found weight artifact for $artifact_name but the weights are not saved under the key \"model_state\" or \"model\"."))
47+
throw(ErrorException("Weights in the file `$file_path` are not saved under the key \"model_state\" or \"model\"."))
4448
end
4549
else
46-
throw(ErrorException("Found weight artifact for $artifact_name but only jld2 and bson serialization format are supported."))
50+
throw(ErrorException("Only jld2 and bson serialization format are supported for weights files."))
4751
end
4852
end
4953

0 commit comments

Comments
 (0)