Skip to content

Commit 055a6d9

Browse files
committed
feat: rework dropout layers
1 parent 0262582 commit 055a6d9

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

src/layers/dropout.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,10 @@ function AlphaDropout(p::T) where {T<:Real}
5050
return AlphaDropout(p, α, γ, β)
5151
end
5252

53-
function (d::AlphaDropout)(x, _, st::NamedTuple)
54-
y, rng = alpha_dropout(st.rng, x, d.p, st.training, d.alpha, d.scale, d.bias)
55-
return y, (; rng, st.training)
53+
function apply(::Type{<:AlphaDropout}, d, x::AbstractArray)
54+
y, rng = alpha_dropout(d.rng, x, d.p, d.training, d.alpha, d.scale, d.bias)
55+
d.rng = rng
56+
return y
5657
end
5758

5859
Base.show(io::IO, d::AlphaDropout) = print(io, "AlphaDropout(", d.p, ")")
@@ -106,9 +107,10 @@ function Dropout(p; dims=:)
106107
return Dropout(p, 1 / (1 - p), dims)
107108
end
108109

109-
function (d::Dropout)(x, _, st::NamedTuple)
110-
y, _, rng = dropout(st.rng, x, d.p, st.training, d.q, d.dims)
111-
return y, (; rng, st.training)
110+
function apply(::Type{<:Dropout}, d, x::AbstractArray)
111+
y, _, rng = dropout(d.rng, x, d.p, d.training, d.q, d.dims)
112+
d.rng = rng
113+
return y
112114
end
113115

114116
function Base.show(io::IO, d::Dropout)
@@ -177,6 +179,8 @@ function VariationalHiddenDropout(p; dims=:)
177179
return VariationalHiddenDropout(p, 1 / (1 - p), dims)
178180
end
179181

182+
# Note that we don't use `apply` here. While we support non-fixed state types, that
183+
# api is inherently type-unstable.
180184
function (d::VariationalHiddenDropout)(x, _, st::NamedTuple)
181185
maskₒ = st.mask === nothing ? x : st.mask
182186
y, mask, rng = dropout(st.rng, x, maskₒ, d.p, st.training, st.update_mask, d.q, d.dims)

0 commit comments

Comments
 (0)