Open
Description
We are missing an the inplace version of batched adjoint/transpose. They are required to avoid gpu scalar indexing with Base.copy
like copy(batched_adjoint(CUDA.randn(3,5,2)))
. They can be implemented as:
# Inplace
function batched_transpose_f!(f, B::AbstractArray{T, 3}, A::AbstractArray{T, 3}) where T
axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) && axes(A,3) == axes(B,3) || throw(DimensionMismatch(string(f)))
Threads.@threads for i in axes(A,3)
Bi = @view B[:, :, i]
Ai = @view A[:, :, i]
LinearAlgebra.transpose_f!(f, Bi, Ai)
end
return B
end
using GPUArrays
function batched_transpose_f!(f, B::AnyGPUArray{T, 3}, A::AnyGPUArray{T, 3}) where T
axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) && axes(A,3) == axes(B,3) || throw(DimensionMismatch(string(f)))
GPUArrays.gpu_call(B, A) do ctx, B, A
idx = GPUArrays.@cartesianidx A
@inbounds B[idx[2], idx[1], idx[3]] = f(A[idx[1], idx[2], idx[3]])
return
end
return B
end
batched_adjoint!(B, A) = batched_transpose_f!(adjoint, B, A)
batched_transpose!(B, A) = batched_transpose_f!(transpose, B, A)
# copy
function Base.copy(x::BatchedAdjoint)
p = parent(x)
a1, a2, a3 = axes(p)
return batched_adjoint!(similar(p, (a2, a1, a3)), p)
end
function Base.copy(x::BatchedTranspose)
p = parent(x)
a1, a2, a3 = axes(p)
return batched_transpose!(similar(p, (a2, a1, a3)), p)
end
which require an extra dependency of GPUArrays
. I have no idea where should we put these code under the ext
.