@@ -60,20 +60,21 @@ function efficientnet(config::Symbol; norm_layer = BatchNorm, stochastic_depth_p
60
60
end
61
61
62
62
"""
63
- EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3,
63
+ EfficientNet(config::Symbol; pretrain::Union{ Bool,String} = false, inchannels::Integer = 3,
64
64
nclasses::Integer = 1000)
65
65
66
66
Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
67
67
68
68
# Arguments
69
69
70
70
- `config`: size of the model. Can be one of `[:b0, :b1, :b2, :b3, :b4, :b5, :b6, :b7, :b8]`.
71
- - `pretrain`: set to `true` to load the pre-trained weights for ImageNet
71
+ - `pretrain`: set to `true` to load the pre-trained weights for ImageNet, or provide a local path string to load a
72
+ custom weights file.
72
73
- `inchannels`: number of input channels.
73
74
- `nclasses`: number of output classes.
74
75
75
76
!!! warning
76
-
77
+
77
78
EfficientNet does not currently support pretrained weights.
78
79
79
80
See also [`Metalhead.efficientnet`](@ref).
@@ -83,12 +84,16 @@ struct EfficientNet
83
84
end
84
85
@functor EfficientNet
85
86
86
- function EfficientNet (config:: Symbol ; pretrain:: Bool = false , inchannels:: Integer = 3 ,
87
+ function EfficientNet (config:: Symbol ; pretrain:: Union{ Bool,String} = false , inchannels:: Integer = 3 ,
87
88
nclasses:: Integer = 1000 )
88
89
layers = efficientnet (config; inchannels, nclasses)
89
90
model = EfficientNet (layers)
90
- if pretrain
91
+ if pretrain === true
91
92
loadpretrain! (model, string (" efficientnet_" , config))
93
+ elseif pretrain isa String
94
+ isfile (pretrain) || error (" Weights file does not exist at `$pretrain `" )
95
+ m = load_weights_file (pretrain)
96
+ Flux. loadmodel! (model, m)
92
97
end
93
98
return model
94
99
end
0 commit comments