Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Needless copies in wrap_chainrules_input and friends #1112

Open
mcabbott opened this issue Oct 27, 2021 · 0 comments
Open

Needless copies in wrap_chainrules_input and friends #1112

mcabbott opened this issue Oct 27, 2021 · 0 comments
Labels
ChainRules adjoint -> rrule, and further integration performance

Comments

@mcabbott
Copy link
Member

The conversion to and from ChainRules types should be essentially free for single structs, but for arrays of structs, it at present involves a copy in each direction. For example:

julia> zyg = [(1,2,nothing), (3,4,nothing)];

julia> Zygote.wrap_chainrules_input(ans)
2-element Vector{ChainRulesCore.Tangent{Any, Tuple{Int64, Int64, ChainRulesCore.ZeroTangent}}}:
 Tangent{Any}(1, 2, ChainRulesCore.ZeroTangent())
 Tangent{Any}(3, 4, ChainRulesCore.ZeroTangent())

julia> Zygote.wrap_chainrules_output(ans)
2-element Vector{Tuple{Int64, Int64, Nothing}}:
 (1, 2, nothing)
 (3, 4, nothing)

The arrays here have exactly the same contents, so ideally this would be done by reinterpreting the data. The code needed looks something like this:

@inline function wrap_chainrules_output(dxs::AbstractArray{<:ChainRules.Tangent{<:Any, B}}) where {B}
  if isbitstype(B)
    # B is the backing type. It still contains NoTangent etc, which need converting to Nothing
    reinterpret(wrap_chainrules_output(B), dxs)
  else
    map(wrap_chainrules_output, dxs)
  end
end
wrap_chainrules_output(::Type{<:AbstractZero}) = Nothing
wrap_chainrules_output(::Type{NamedTuple{L,T}}) where {L,T} = NamedTuple{L,wrap_chainrules_output(T)}
@generated function wrap_chainrules_output(::Type{T}) where T<:Tuple
  inner = map(wrap_chainrules_output, T.parameters)
  :(Tuple{$(inner...)})
end

wrap_chainrules_input(::Type{Nothing}) = NoTangent
wrap_chainrules_input(::Type{NamedTuple{L,T}}) where {L,T} = NamedTuple{L,wrap_chainrules_input(T)}
@generated function wrap_chainrules_input(::Type{T}) where T<:Tuple
  inner = map(wrap_chainrules_input, T.parameters)
  :(Tuple{$(inner...)})
end

function z2d(dx::AbstractArray{S}, primal::AbstractArray{P}) where {S,P}
  if isbitstype(S)
    T = wrap_chainrules_input(S)
    reinterpret(Tangent{P,T}, dx)
  else
    map(z2d, dx, primal)
  end
end

But at present, pasting this in causes some 2nd derivative tests to fail.

@mcabbott mcabbott added ChainRules adjoint -> rrule, and further integration performance labels Jul 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ChainRules adjoint -> rrule, and further integration performance
Projects
None yet
Development

No branches or pull requests

1 participant