Skip to content

test_rrule for gather!, scatter! #429

Open
@vpuri3

Description

@vpuri3

NNlib.jl/src/scatter.jl

Lines 206 to 212 in 023cd3d

function rrule(::typeof(scatter!), op, dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
dst_old = copy(dst)
scatter!(op, dst, src, idx)
scatter!_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter!_dst(op, unthunk(Δ), dst_old, dst), ∇scatter!_src(op, unthunk(Δ), dst, src, idx), NoTangent())
dst, scatter!_pullback
end

NNlib.jl/src/gather.jl

Lines 82 to 87 in 023cd3d

function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
y = gather!(dst, src, idx)
src_size = size(src)
gather!_pullback(Δ) = (NoTangent(), NoTangent(), ∇gather_src(unthunk(Δ), src_size, idx), NoTangent())
y, gather!_pullback
end

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions