Skip to content

Commit edde3d0

Browse files
committed
Add rrule for flatview on ArrayOfSimilarArrays
1 parent b1a2614 commit edde3d0

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

src/array_of_similar_arrays.jl

+10-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,16 @@ Returns the array of dimensionality `L = M + N` wrapped by `A`. The shape of
141141
the result may be freely changed without breaking the inner consistency of
142142
`A`.
143143
"""
144-
flatview(A::ArrayOfSimilarArrays) = A.data
144+
flatview(A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N} = A.data
145+
146+
function ChainRulesCore.rrule(::typeof(flatview), A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N}
147+
function flatview_pullback(ΔΩ)
148+
data = unthunk(ΔΩ)
149+
NoTangent(), ArrayOfSimilarArrays{eltype(data),M,N}(data)
150+
end
151+
152+
return flatview(A), flatview_pullback
153+
end
145154

146155

147156
Base.size(A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N} = split_tuple(size(A.data), Val{M}())[2]

0 commit comments

Comments
 (0)