Skip to content

Commit 701135b

Browse files
bors[bot]Evgeny Tankhilevich
and
Evgeny Tankhilevich
authored
Merge #515
515: Fix Flux.flip by providing an adjoint for Base.reverse r=dhairyagandhi96 a=tanhevg The main motivation behind this PR is to address various issues concerning `Flux.flip()` (used mainly for bRNNs), e.g. FluxML/Flux.jl#962, FluxML/Flux.jl#990 and FluxML/model-zoo#179 Co-authored-by: Evgeny Tankhilevich <[email protected]>
2 parents 94441dd + 8f7da4d commit 701135b

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

src/lib/array.jl

+5
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ end
6363
circshift(A, shifts), Δ -> (circshift(Δ, map(-, shifts)), nothing)
6464
end
6565

66+
@adjoint function reverse(x::AbstractArray, args...; kwargs...)
67+
_reverse(t) = reverse(t, args...; kwargs...)
68+
_reverse(x), Δ->(_reverse(Δ), map(_->nothing, args)...)
69+
end
70+
6671
@adjoint permutedims(xs) = permutedims(xs), Δ -> (permutedims(Δ),)
6772

6873
@adjoint permutedims(xs::AbstractVector) = permutedims(xs), Δ -> (vec(permutedims(Δ)),)

test/gradcheck.jl

+5
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ end
129129
@test gradtest(x -> meanpool(x, pdims), x)
130130
end
131131

132+
@test gradtest(x -> reverse(x), rand(17))
133+
@test gradtest(x -> reverse(x, 8), rand(17))
134+
@test gradtest(x -> reverse(x, 8, 13), rand(17))
135+
@test gradtest(x -> reverse(x, dims=2), rand(17, 42))
136+
132137
@test gradtest(x -> permutedims(x), rand(2))
133138
@test gradtest(x -> permutedims(x), rand(2,3))
134139
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))

0 commit comments

Comments
 (0)