Skip to content

Inplace version of batched adjoint/transpose #502

Open
@chengchingwen

Description

@chengchingwen

We are missing an the inplace version of batched adjoint/transpose. They are required to avoid gpu scalar indexing with Base.copy like copy(batched_adjoint(CUDA.randn(3,5,2))). They can be implemented as:

# Inplace                                                                                                                 
function batched_transpose_f!(f, B::AbstractArray{T, 3}, A::AbstractArray{T, 3}) where T
    axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) && axes(A,3) == axes(B,3) || throw(DimensionMismatch(string(f)))
    Threads.@threads for i in axes(A,3)
        Bi = @view B[:, :, i]
        Ai = @view A[:, :, i]
        LinearAlgebra.transpose_f!(f, Bi, Ai)
    end
    return B
end

using GPUArrays
function batched_transpose_f!(f, B::AnyGPUArray{T, 3}, A::AnyGPUArray{T, 3}) where T
    axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) && axes(A,3) == axes(B,3) || throw(DimensionMismatch(string(f)))
    GPUArrays.gpu_call(B, A) do ctx, B, A
        idx = GPUArrays.@cartesianidx A
        @inbounds B[idx[2], idx[1], idx[3]] = f(A[idx[1], idx[2], idx[3]])
        return
    end
    return B
end

batched_adjoint!(B, A) = batched_transpose_f!(adjoint, B, A)
batched_transpose!(B, A) = batched_transpose_f!(transpose, B, A)

# copy                                                                                                                    
function Base.copy(x::BatchedAdjoint)
    p = parent(x)
    a1, a2, a3 = axes(p)
    return batched_adjoint!(similar(p, (a2, a1, a3)), p)
end
function Base.copy(x::BatchedTranspose)
    p = parent(x)
    a1, a2, a3 = axes(p)
    return batched_transpose!(similar(p, (a2, a1, a3)), p)
end

which require an extra dependency of GPUArrays. I have no idea where should we put these code under the ext.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions