Skip to content

Batched dot(x, A, y) #641

@3f6a

Description

@3f6a

Motivation and description

Say I have the arrays x[i,b], y[j,b] and A[i,j,b]. Is there an efficient way to do the following "batched dot" operation:

[sum(x[i,b] * A[i,j,b] * y[j,b] for i = axes(A,1) for j = axes(A,2)) for b = ...]

where b traverses the batch dimension. As usual, we could have size(x,2) == 1, size(A,3)==1, ..., which would mean the corresponding missing dimension is broadcasted.

Apologies if there is already a way to do this (efficiently) with existing functions in NNlib, I could not figure it out.

Possible Implementation

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions