Skip to content

batched_transpose with multiple batch dimensions #588

Open
@AntonOresten

Description

@AntonOresten

Motivation and description

There exists a method for batched_mul that reshapes arrays to allow for an arbitrary number of batch dimensions:

function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
    batch_size = size(x)[3:end]
    @assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
    x2 = reshape(x, size(x, 1), size(x, 2), :)
    y2 = reshape(y, size(y, 1), size(y, 2), :)
    z = batched_mul(x2, y2)
    return reshape(z, size(z, 1), size(z, 2), batch_size...)
  end

It would be useful to have support for this with batched_transpose and batched_adjoint as well.

Possible Implementation

The existing code is quite sophisticated and "lazy", so something like this wouldn't fly:

batched_transpose(A::AbstractArray{T, N}) where {T <: Real, N} = permutedims(A, (2, 1, 3:N))

I imagine it would be possible to generalize the code beyond three dimensions though. Indexing methods are currently hard-coded. Things like the strides would also need to be generalized:

function Base.strides(A::Union{BatchedTranspose, BatchedAdjoint{<:Real}})
    sp = strides(A.parent)
    (sp[2], sp[1], sp[3:end]...)
end

Is it better to just use PermutedDimsArray?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions