You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The conversion to and from ChainRules types should be essentially free for single structs, but for arrays of structs, it at present involves a copy in each direction. For example:
The arrays here have exactly the same contents, so ideally this would be done by reinterpreting the data. The code needed looks something like this:
@inlinefunctionwrap_chainrules_output(dxs::AbstractArray{<:ChainRules.Tangent{<:Any, B}}) where {B}
ifisbitstype(B)
# B is the backing type. It still contains NoTangent etc, which need converting to Nothingreinterpret(wrap_chainrules_output(B), dxs)
elsemap(wrap_chainrules_output, dxs)
endendwrap_chainrules_output(::Type{<:AbstractZero}) = Nothing
wrap_chainrules_output(::Type{NamedTuple{L,T}}) where {L,T} = NamedTuple{L,wrap_chainrules_output(T)}
@generatedfunctionwrap_chainrules_output(::Type{T}) where T<:Tuple
inner =map(wrap_chainrules_output, T.parameters)
:(Tuple{$(inner...)})
endwrap_chainrules_input(::Type{Nothing}) = NoTangent
wrap_chainrules_input(::Type{NamedTuple{L,T}}) where {L,T} = NamedTuple{L,wrap_chainrules_input(T)}
@generatedfunctionwrap_chainrules_input(::Type{T}) where T<:Tuple
inner =map(wrap_chainrules_input, T.parameters)
:(Tuple{$(inner...)})
endfunctionz2d(dx::AbstractArray{S}, primal::AbstractArray{P}) where {S,P}
ifisbitstype(S)
T =wrap_chainrules_input(S)
reinterpret(Tangent{P,T}, dx)
elsemap(z2d, dx, primal)
endend
But at present, pasting this in causes some 2nd derivative tests to fail.
The text was updated successfully, but these errors were encountered:
The conversion to and from ChainRules types should be essentially free for single structs, but for arrays of structs, it at present involves a copy in each direction. For example:
The arrays here have exactly the same contents, so ideally this would be done by
reinterpret
ing the data. The code needed looks something like this:But at present, pasting this in causes some 2nd derivative tests to fail.
The text was updated successfully, but these errors were encountered: