@@ -6,13 +6,14 @@ gate(x::AbstractMatrix, h, n) = view(x, gate(h,n), :)
66# AD-friendly helper for dividing monolithic RNN params into equally sized gates
77multigate(x:: AbstractArray , h, :: Val{N} ) where N = ntuple(n -> gate(x,h,n), N)
88
9- @adjoint function multigate( x:: AbstractArray , h, c)
9+ function ChainRulesCore . rrule( :: typeof (multigate), x:: AbstractArray , h, c)
1010 function multigate_pullback(dy)
11- dx = Zygote. _zero(x, eltype(x))
12- map(multigate(dx, h, c), dy) do dxᵢ, dyᵢ
13- dyᵢ != = nothing && (dxᵢ.= Zygote. accum.(dxᵢ, dyᵢ));
11+ dx = map!(zero, similar(x, float(eltype(x)), axes(x)), x)
12+ foreach(multigate(dx, h, c), dy) do dxᵢ, dyᵢ
13+ dyᵢ isa AbstractZero && return
14+ @. dxᵢ += dyᵢ
1415 end
15- return (dx, nothing , nothing )
16+ return (NoTangent(), dx, NoTangent(), NoTangent() )
1617 end
1718 return multigate(x, h, c), multigate_pullback
1819end
@@ -380,7 +381,7 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
380381GRUv3(a... ; ka... ) = Recur(GRUv3Cell(a... ; ka... ))
381382Recur(m:: GRUv3Cell ) = Recur(m, m. state0)
382383
383-
384+ # TODO move to ChainRulesCore?
384385@adjoint function Broadcast. broadcasted(f:: Recur , args... )
385386 Zygote.∇map(__context__, f, args... )
386387end
0 commit comments