@@ -50,9 +50,10 @@ function AlphaDropout(p::T) where {T<:Real}
5050 return AlphaDropout(p, α, γ, β)
5151end
5252
53- function (d:: AlphaDropout )(x, _, st:: NamedTuple )
54- y, rng = alpha_dropout(st. rng, x, d. p, st. training, d. alpha, d. scale, d. bias)
55- return y, (; rng, st. training)
53+ function apply(:: Type{<:AlphaDropout} , d, x:: AbstractArray )
54+ y, rng = alpha_dropout(d. rng, x, d. p, d. training, d. alpha, d. scale, d. bias)
55+ d. rng = rng
56+ return y
5657end
5758
5859Base. show(io:: IO , d:: AlphaDropout ) = print(io, " AlphaDropout(" , d. p, " )" )
@@ -106,9 +107,10 @@ function Dropout(p; dims=:)
106107 return Dropout(p, 1 / (1 - p), dims)
107108end
108109
109- function (d:: Dropout )(x, _, st:: NamedTuple )
110- y, _, rng = dropout(st. rng, x, d. p, st. training, d. q, d. dims)
111- return y, (; rng, st. training)
110+ function apply(:: Type{<:Dropout} , d, x:: AbstractArray )
111+ y, _, rng = dropout(d. rng, x, d. p, d. training, d. q, d. dims)
112+ d. rng = rng
113+ return y
112114end
113115
114116function Base. show(io:: IO , d:: Dropout )
@@ -177,6 +179,8 @@ function VariationalHiddenDropout(p; dims=:)
177179 return VariationalHiddenDropout(p, 1 / (1 - p), dims)
178180end
179181
182+ # Note that we don't use `apply` here. While we support non-fixed state types, that
183+ # api is inherently type-unstable.
180184function (d:: VariationalHiddenDropout )(x, _, st:: NamedTuple )
181185 maskₒ = st. mask === nothing ? x : st. mask
182186 y, mask, rng = dropout(st. rng, x, maskₒ, d. p, st. training, st. update_mask, d. q, d. dims)
0 commit comments