11# Internal function, used only for layers defined in this file.
22_isactive(m, x) = isnothing(m. active) ? NNlib. within_gradient(x) : m. active
33
4+ # Internal function, used only in this file.
5+ _tidy_active(mode:: Bool ) = mode
6+ _tidy_active(:: Nothing ) = nothing
7+ _tidy_active(mode) = mode === :auto ? nothing : throw(ArgumentError(" active = $(repr(mode)) is not accepted, must be true/false/nothing or :auto" ))
8+
49"""
5- Dropout(p; [dims, rng])
10+ Dropout(p; [dims, rng, active ])
611
712Layer implementing [dropout](https://arxiv.org/abs/1207.0580) with the given probability.
813This is used as a regularisation, i.e. to reduce overfitting.
@@ -12,7 +17,8 @@ or else scales it by `1 / (1 - p)`, using the [`NNlib.dropout`](@ref) function.
1217While testing, it has no effect.
1318
1419By default the mode will switch automatically, but it can also
15- be controlled manually via [`Flux.testmode!`](@ref).
20+ be controlled manually via [`Flux.testmode!`](@ref),
21+ or by passing keyword `active=true` for training mode.
1622
1723By default every input is treated independently. With the `dims` keyword,
1824instead it takes a random choice only along that dimension.
@@ -36,7 +42,11 @@ julia> m(ones(2, 7)) # test mode, no effect
3642 2.0 2.0 2.0 2.0 2.0 2.0 2.0
3743 2.0 2.0 2.0 2.0 2.0 2.0 2.0
3844
39- julia> Flux.trainmode!(m); # equivalent to use within gradient
45+ julia> Flux.trainmode!(m) # equivalent to use within gradient
46+ Chain(
47+ Dense(2 => 3), # 9 parameters
48+ Dropout(0.4, active=true),
49+ )
4050
4151julia> m(ones(2, 7))
42523×7 Matrix{Float64}:
@@ -63,9 +73,9 @@ mutable struct Dropout{F<:Real,D,R<:AbstractRNG}
6373end
6474Dropout(p:: Real , dims, active) = Dropout(p, dims, active, default_rng_value())
6575
66- function Dropout(p:: Real ; dims= :, rng = default_rng_value())
76+ function Dropout(p:: Real ; dims= :, active :: Union{Bool,Nothing} = nothing , rng = default_rng_value())
6777 0 ≤ p ≤ 1 || throw(ArgumentError(" Dropout expects 0 ≤ p ≤ 1, got p = $p " ))
68- Dropout(p, dims, nothing , rng)
78+ Dropout(p, dims, active , rng)
6979end
7080
7181@functor Dropout
@@ -74,16 +84,17 @@ trainable(a::Dropout) = (;)
7484(a:: Dropout )(x) = dropout(a. rng, x, a. p * _isactive(a, x); dims= a. dims)
7585
7686testmode!(m:: Dropout , mode= true ) =
77- (m. active = ( isnothing(mode) || mode == :auto ) ? nothing : ! mode; m)
87+ (m. active = isnothing(_tidy_active( mode)) ? nothing : ! mode; m)
7888
7989function Base. show(io:: IO , d:: Dropout )
8090 print(io, " Dropout(" , d. p)
81- d. dims != (:) && print(io, " , dims = $(repr(d. dims)) " )
91+ d. dims != (:) && print(io, " , dims=" , d. dims)
92+ d. active == nothing || print(io, " , active=" , d. active)
8293 print(io, " )" )
8394end
8495
8596"""
86- AlphaDropout(p; rng = default_rng_value() )
97+ AlphaDropout(p; [ rng, active] )
8798
8899A dropout layer. Used in
89100[Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
@@ -112,13 +123,13 @@ mutable struct AlphaDropout{F,R<:AbstractRNG}
112123 p:: F
113124 active:: Union{Bool, Nothing}
114125 rng:: R
115- function AlphaDropout(p, active, rng)
116- @assert 0 ≤ p ≤ 1
117- new{typeof(p), typeof(rng)}(p, active, rng)
118- end
119126end
127+
120128AlphaDropout(p, active) = AlphaDropout(p, active, default_rng_value())
121- AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing , rng)
129+ function AlphaDropout(p; rng = default_rng_value(), active:: Union{Bool,Nothing} = nothing )
130+ 0 ≤ p ≤ 1 || throw(ArgumentError(" AlphaDropout expects 0 ≤ p ≤ 1, got p = $p " ))
131+ AlphaDropout(p, active, rng)
132+ end
122133
123134@functor AlphaDropout
124135trainable(a:: AlphaDropout ) = (;)
@@ -138,7 +149,7 @@ function (a::AlphaDropout)(x::AbstractArray{T}) where T
138149end
139150
140151testmode!(m:: AlphaDropout , mode= true ) =
141- (m. active = ( isnothing(mode) || mode == :auto ) ? nothing : ! mode; m)
152+ (m. active = isnothing(_tidy_active( mode)) ? nothing : ! mode; m)
142153
143154"""
144155 LayerNorm(size..., λ=identity; affine=true, ϵ=1fe-5)
@@ -257,7 +268,7 @@ ChainRulesCore.@non_differentiable _track_stats!(::Any...)
257268"""
258269 BatchNorm(channels::Integer, λ=identity;
259270 initβ=zeros32, initγ=ones32,
260- affine = true, track_stats = true,
271+ affine= true, track_stats= true, active=nothing ,
261272 ϵ=1f-5, momentum= 0.1f0)
262273
263274[Batch Normalization](https://arxiv.org/abs/1502.03167) layer.
310321
311322function BatchNorm(chs:: Int , λ= identity;
312323 initβ= zeros32, initγ= ones32,
313- affine= true , track_stats= true ,
324+ affine= true , track_stats= true , active :: Union{Bool,Nothing} = nothing ,
314325 ϵ= 1f-5 , momentum= 0.1f0 )
315326
316327 β = affine ? initβ(chs) : nothing
@@ -321,7 +332,7 @@ function BatchNorm(chs::Int, λ=identity;
321332 return BatchNorm(λ, β, γ,
322333 μ, σ², ϵ, momentum,
323334 affine, track_stats,
324- nothing , chs)
335+ active , chs)
325336end
326337
327338@functor BatchNorm
@@ -335,12 +346,13 @@ function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N}
335346end
336347
337348testmode!(m:: BatchNorm , mode= true ) =
338- (m. active = ( isnothing(mode) || mode == :auto ) ? nothing : ! mode; m)
349+ (m. active = isnothing(_tidy_active( mode)) ? nothing : ! mode; m)
339350
340351function Base. show(io:: IO , l:: BatchNorm )
341352 print(io, " BatchNorm($(l. chs) " )
342353 (l. λ == identity) || print(io, " , $(l. λ) " )
343354 hasaffine(l) || print(io, " , affine=false" )
355+ l. active == nothing || print(io, " , active=" , l. active)
344356 print(io, " )" )
345357end
346358
399411
400412function InstanceNorm(chs:: Int , λ= identity;
401413 initβ= zeros32, initγ= ones32,
402- affine= false , track_stats= false ,
414+ affine= false , track_stats= false , active :: Union{Bool,Nothing} = nothing ,
403415 ϵ= 1f-5 , momentum= 0.1f0 )
404416
405417 β = affine ? initβ(chs) : nothing
@@ -410,7 +422,7 @@ function InstanceNorm(chs::Int, λ=identity;
410422 return InstanceNorm(λ, β, γ,
411423 μ, σ², ϵ, momentum,
412424 affine, track_stats,
413- nothing , chs)
425+ active , chs)
414426end
415427
416428@functor InstanceNorm
@@ -424,12 +436,13 @@ function (l::InstanceNorm)(x::AbstractArray{T,N}) where {T,N}
424436end
425437
426438testmode!(m:: InstanceNorm , mode= true ) =
427- (m. active = ( isnothing(mode) || mode == :auto ) ? nothing : ! mode; m)
439+ (m. active = isnothing(_tidy_active( mode)) ? nothing : ! mode; m)
428440
429441function Base. show(io:: IO , l:: InstanceNorm )
430442 print(io, " InstanceNorm($(l. chs) " )
431443 l. λ == identity || print(io, " , $(l. λ) " )
432444 hasaffine(l) || print(io, " , affine=false" )
445+ l. active == nothing || print(io, " , active=" , l. active)
433446 print(io, " )" )
434447end
435448
@@ -495,7 +508,7 @@ trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;)
495508
496509function GroupNorm(chs:: Int , G:: Int , λ= identity;
497510 initβ= zeros32, initγ= ones32,
498- affine= true , track_stats= false ,
511+ affine= true , track_stats= false , active :: Union{Bool,Nothing} = nothing ,
499512 ϵ= 1f-5 , momentum= 0.1f0 )
500513
501514if track_stats
514527 μ, σ²,
515528 ϵ, momentum,
516529 affine, track_stats,
517- nothing , chs)
530+ active , chs)
518531end
519532
520533function (gn:: GroupNorm )(x:: AbstractArray )
@@ -529,13 +542,14 @@ function (gn::GroupNorm)(x::AbstractArray)
529542end
530543
531544testmode!(m:: GroupNorm , mode = true ) =
532- (m. active = ( isnothing(mode) || mode == :auto ) ? nothing : ! mode; m)
545+ (m. active = isnothing(_tidy_active( mode)) ? nothing : ! mode; m)
533546
534547function Base. show(io:: IO , l:: GroupNorm )
535548 # print(io, "GroupNorm($(join(size(l.β), ", "))", ", ", l.G)
536549 print(io, " GroupNorm($(l. chs) , $(l. G) " )
537550 l. λ == identity || print(io, " , " , l. λ)
538551 hasaffine(l) || print(io, " , affine=false" )
552+ l. active == nothing || print(io, " , active=" , l. active)
539553 print(io, " )" )
540554end
541555
0 commit comments