Skip to content

Needless copies in wrap_chainrules_input and friends #1112

Open
@mcabbott

Description

@mcabbott

The conversion to and from ChainRules types should be essentially free for single structs, but for arrays of structs, it at present involves a copy in each direction. For example:

julia> zyg = [(1,2,nothing), (3,4,nothing)];

julia> Zygote.wrap_chainrules_input(ans)
2-element Vector{ChainRulesCore.Tangent{Any, Tuple{Int64, Int64, ChainRulesCore.ZeroTangent}}}:
 Tangent{Any}(1, 2, ChainRulesCore.ZeroTangent())
 Tangent{Any}(3, 4, ChainRulesCore.ZeroTangent())

julia> Zygote.wrap_chainrules_output(ans)
2-element Vector{Tuple{Int64, Int64, Nothing}}:
 (1, 2, nothing)
 (3, 4, nothing)

The arrays here have exactly the same contents, so ideally this would be done by reinterpreting the data. The code needed looks something like this:

@inline function wrap_chainrules_output(dxs::AbstractArray{<:ChainRules.Tangent{<:Any, B}}) where {B}
  if isbitstype(B)
    # B is the backing type. It still contains NoTangent etc, which need converting to Nothing
    reinterpret(wrap_chainrules_output(B), dxs)
  else
    map(wrap_chainrules_output, dxs)
  end
end
wrap_chainrules_output(::Type{<:AbstractZero}) = Nothing
wrap_chainrules_output(::Type{NamedTuple{L,T}}) where {L,T} = NamedTuple{L,wrap_chainrules_output(T)}
@generated function wrap_chainrules_output(::Type{T}) where T<:Tuple
  inner = map(wrap_chainrules_output, T.parameters)
  :(Tuple{$(inner...)})
end

wrap_chainrules_input(::Type{Nothing}) = NoTangent
wrap_chainrules_input(::Type{NamedTuple{L,T}}) where {L,T} = NamedTuple{L,wrap_chainrules_input(T)}
@generated function wrap_chainrules_input(::Type{T}) where T<:Tuple
  inner = map(wrap_chainrules_input, T.parameters)
  :(Tuple{$(inner...)})
end

function z2d(dx::AbstractArray{S}, primal::AbstractArray{P}) where {S,P}
  if isbitstype(S)
    T = wrap_chainrules_input(S)
    reinterpret(Tangent{P,T}, dx)
  else
    map(z2d, dx, primal)
  end
end

But at present, pasting this in causes some 2nd derivative tests to fail.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions