11istraining() = false
22
3- @adjoint istraining() = true , _ -> nothing
3+ ChainRulesCore . rrule( :: typeof (istraining)) = true , _ -> (NoTangent(),)
44
55_isactive(m) = isnothing(m. active) ? istraining() : m. active
66
@@ -38,12 +38,6 @@ function dropout(rng, x, p; dims=:, active::Bool=true)
3838end
3939dropout(x, p; kwargs... ) = dropout(rng_from_array(x), x, p; kwargs... )
4040
41- @adjoint function dropout(rng, x, p; dims= :, active:: Bool = true )
42- active || return x, Δ -> (Δ, nothing )
43- y = dropout_mask(rng, x, p, dims= dims)
44- return x .* y, Δ -> (nothing , Δ .* y, nothing )
45- end
46-
4741dropout_mask(rng:: CUDA.RNG , x:: CuArray , p; kwargs... ) = _dropout_mask(rng, x, p; kwargs... )
4842dropout_mask(rng, x:: CuArray , p; kwargs... ) =
4943 throw(ArgumentError(" x isa CuArray, but rng isa $(typeof(rng)) . dropout_mask only support CUDA.RNG for CuArrays." ))
@@ -56,7 +50,7 @@ function _dropout_mask(rng, x, p; dims=:)
5650end
5751
5852# TODO move this to NNlib
59- Zygote . ChainRulesCore. @non_differentiable dropout_mask(rng, x, p )
53+ ChainRulesCore. @non_differentiable dropout_mask(:: Any , :: Any , :: Any )
6054
6155"""
6256 Dropout(p; dims=:, rng = rng_from_array())
@@ -234,7 +228,8 @@ function _track_stats!(
234228 bn. σ² = res_mtm .* bn. σ² .+ mtm .* (m / (m - one(V))) .* σ²new
235229 return nothing
236230end
237- Zygote. @nograd _track_stats!
231+
232+ ChainRulesCore. @non_differentiable _track_stats!(:: Any... )
238233
239234"""
240235 BatchNorm(channels::Integer, λ=identity;
0 commit comments