diff --git a/src/lib/array.jl b/src/lib/array.jl index 1c6e09916..b7c6d45ee 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -2,6 +2,8 @@ using Random, FillArrays, AbstractFFTs using FillArrays: AbstractFill, getindex_value using Base.Broadcast: broadcasted, broadcast_shape using Distributed: pmap, AbstractWorkerPool +using LinearAlgebra: Diagonal, Hermitian, LowerTriangular, UpperTriangular +using LinearAlgebra: UnitLowerTriangular, UnitUpperTriangular @adjoint Array(xs::AbstractArray) = Array(xs), ȳ -> (ȳ,) @adjoint Array(xs::Array) = Array(xs), ȳ -> (ȳ,) @@ -165,10 +167,21 @@ end # This is also used by comprehensions, which do guarantee iteration order. # Not done for pmap, presumably because all is lost if you are relying on its order. _tryreverse(m, backs, Δ) = backs, Δ -_tryreverse(m::typeof(map), backs, Δ) = reverse(backs), reverse(Δ) +_tryreverse(m::typeof(map), backs, Δ) = _reverse(backs), _reverse(Δ) _tryreverse(m, x) = x -_tryreverse(m::typeof(map), x) = reverse(x) +_tryreverse(m::typeof(map), x) = _reverse(x) + +# Fallback +_reverse(x) = reverse(x) + +# Known cases in the standard library on which `reverse` errors (issue #1393) +_reverse(x::LowerTriangular) = UpperTriangular(_reverse(parent(x))) +_reverse(x::UpperTriangular) = LowerTriangular(_reverse(parent(x))) +_reverse(x::UnitLowerTriangular) = UnitUpperTriangular(_reverse(parent(x))) +_reverse(x::UnitUpperTriangular) = UnitLowerTriangular(_reverse(parent(x))) +_reverse(x::Hermitian) = Hermitian(_reverse(x.data), x.uplo == 'U' ? :L : :U) +_reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U) # With mismatched lengths, map stops early. With mismatched shapes, it makes a vector. # So we keep axes(x) to restore gradient dx to its full length & correct shape. diff --git a/test/lib/array.jl b/test/lib/array.jl index d02e9f9d3..889301c1e 100644 --- a/test/lib/array.jl +++ b/test/lib/array.jl @@ -1,5 +1,7 @@ using ChainRulesTestUtils -using Zygote: ZygoteRuleConfig, _pullback +using LinearAlgebra: Diagonal, Hermitian, LowerTriangular, UpperTriangular +using LinearAlgebra: UnitLowerTriangular, UnitUpperTriangular +using Zygote: ZygoteRuleConfig, _pullback, _reverse # issue 897 @@ -65,3 +67,32 @@ end end @test gradient(f_comprehension, w)[1] == ones(5) end + +@testset "_reverse" begin + m = [1 2 3; 4 5 6; 7 8 9] + @testset "$wrapper" for wrapper in [ + Hermitian, Symmetric, LowerTriangular, UpperTriangular, + UnitLowerTriangular, UnitUpperTriangular, + ] + M = wrapper(m) + @test collect(_reverse(M)) == _reverse(collect(M)) + end +end + +@testset "rrule for `map`" begin + @testset "MWE from #1393" begin + # https://github.com/FluxML/Zygote.jl/issues/1393#issuecomment-1468496804 + struct Foo1393 x::Float64 end + (f::Foo1393)(x) = f.x * x + x = randn(5, 5) + out, pb = Zygote.pullback(x -> map(Foo1393(5.0), x), x) + @testset "$wrapper" for wrapper in [ + Hermitian, Symmetric, LowerTriangular, UpperTriangular, + UnitLowerTriangular, UnitUpperTriangular, + ] + m = wrapper(rand(5, 5)) + res = only(pb(m)) + @test res == 5m + end + end +end