Skip to content

Commit 2490c79

Browse files
simsuracedevmotion
andauthored
Fix reverse failure (#1396)
* Add internal function `_reverse` and overloads * Add unit tests * Correct issue number * Label testset * Add missing wrappers * Avoid `collect` in `_reverse` for `Hermitian` and `Symmetric` Co-authored-by: David Widmann <[email protected]> * Use `_reverse` instead of `reverse` Co-authored-by: David Widmann <[email protected]> * Fix wrong names :) Co-authored-by: David Widmann <[email protected]> * Add end user test case * Add `using Zygote: _reverse` --------- Co-authored-by: David Widmann <[email protected]>
1 parent 756dd37 commit 2490c79

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

src/lib/array.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ using Random, FillArrays, AbstractFFTs
22
using FillArrays: AbstractFill, getindex_value
33
using Base.Broadcast: broadcasted, broadcast_shape
44
using Distributed: pmap, AbstractWorkerPool
5+
using LinearAlgebra: Diagonal, Hermitian, LowerTriangular, UpperTriangular
6+
using LinearAlgebra: UnitLowerTriangular, UnitUpperTriangular
57

68
@adjoint Array(xs::AbstractArray) = Array(xs), ȳ -> (ȳ,)
79
@adjoint Array(xs::Array) = Array(xs), ȳ -> (ȳ,)
@@ -165,10 +167,21 @@ end
165167
# This is also used by comprehensions, which do guarantee iteration order.
166168
# Not done for pmap, presumably because all is lost if you are relying on its order.
167169
_tryreverse(m, backs, Δ) = backs, Δ
168-
_tryreverse(m::typeof(map), backs, Δ) = reverse(backs), reverse(Δ)
170+
_tryreverse(m::typeof(map), backs, Δ) = _reverse(backs), _reverse(Δ)
169171

170172
_tryreverse(m, x) = x
171-
_tryreverse(m::typeof(map), x) = reverse(x)
173+
_tryreverse(m::typeof(map), x) = _reverse(x)
174+
175+
# Fallback
176+
_reverse(x) = reverse(x)
177+
178+
# Known cases in the standard library on which `reverse` errors (issue #1393)
179+
_reverse(x::LowerTriangular) = UpperTriangular(_reverse(parent(x)))
180+
_reverse(x::UpperTriangular) = LowerTriangular(_reverse(parent(x)))
181+
_reverse(x::UnitLowerTriangular) = UnitUpperTriangular(_reverse(parent(x)))
182+
_reverse(x::UnitUpperTriangular) = UnitLowerTriangular(_reverse(parent(x)))
183+
_reverse(x::Hermitian) = Hermitian(_reverse(x.data), x.uplo == 'U' ? :L : :U)
184+
_reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U)
172185

173186
# With mismatched lengths, map stops early. With mismatched shapes, it makes a vector.
174187
# So we keep axes(x) to restore gradient dx to its full length & correct shape.

test/lib/array.jl

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using ChainRulesTestUtils
2-
using Zygote: ZygoteRuleConfig, _pullback
2+
using LinearAlgebra: Diagonal, Hermitian, LowerTriangular, UpperTriangular
3+
using LinearAlgebra: UnitLowerTriangular, UnitUpperTriangular
4+
using Zygote: ZygoteRuleConfig, _pullback, _reverse
35

46
# issue 897
57

@@ -65,3 +67,32 @@ end
6567
end
6668
@test gradient(f_comprehension, w)[1] == ones(5)
6769
end
70+
71+
@testset "_reverse" begin
72+
m = [1 2 3; 4 5 6; 7 8 9]
73+
@testset "$wrapper" for wrapper in [
74+
Hermitian, Symmetric, LowerTriangular, UpperTriangular,
75+
UnitLowerTriangular, UnitUpperTriangular,
76+
]
77+
M = wrapper(m)
78+
@test collect(_reverse(M)) == _reverse(collect(M))
79+
end
80+
end
81+
82+
@testset "rrule for `map`" begin
83+
@testset "MWE from #1393" begin
84+
# https://github.com/FluxML/Zygote.jl/issues/1393#issuecomment-1468496804
85+
struct Foo1393 x::Float64 end
86+
(f::Foo1393)(x) = f.x * x
87+
x = randn(5, 5)
88+
out, pb = Zygote.pullback(x -> map(Foo1393(5.0), x), x)
89+
@testset "$wrapper" for wrapper in [
90+
Hermitian, Symmetric, LowerTriangular, UpperTriangular,
91+
UnitLowerTriangular, UnitUpperTriangular,
92+
]
93+
m = wrapper(rand(5, 5))
94+
res = only(pb(m))
95+
@test res == 5m
96+
end
97+
end
98+
end

0 commit comments

Comments
 (0)