Open
Description
Motivation and description
There exists a method for batched_mul
that reshapes arrays to allow for an arbitrary number of batch dimensions:
function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
batch_size = size(x)[3:end]
@assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
x2 = reshape(x, size(x, 1), size(x, 2), :)
y2 = reshape(y, size(y, 1), size(y, 2), :)
z = batched_mul(x2, y2)
return reshape(z, size(z, 1), size(z, 2), batch_size...)
end
It would be useful to have support for this with batched_transpose
and batched_adjoint
as well.
Possible Implementation
The existing code is quite sophisticated and "lazy", so something like this wouldn't fly:
batched_transpose(A::AbstractArray{T, N}) where {T <: Real, N} = permutedims(A, (2, 1, 3:N))
I imagine it would be possible to generalize the code beyond three dimensions though. Indexing methods are currently hard-coded. Things like the strides would also need to be generalized:
function Base.strides(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}})
sp = strides(A.parent)
(sp[2], sp[1], sp[3:end]...)
end
Is it better to just use PermutedDimsArray
?
Metadata
Metadata
Assignees
Labels
No labels