-
-
Couldn't load subscription status.
- Fork 129
Open
Description
Motivation and description
Say I have the arrays x[i,b], y[j,b] and A[i,j,b]. Is there an efficient way to do the following "batched dot" operation:
[sum(x[i,b] * A[i,j,b] * y[j,b] for i = axes(A,1) for j = axes(A,2)) for b = ...]
where b traverses the batch dimension. As usual, we could have size(x,2) == 1, size(A,3)==1, ..., which would mean the corresponding missing dimension is broadcasted.
Apologies if there is already a way to do this (efficiently) with existing functions in NNlib, I could not figure it out.
Possible Implementation
No response
Metadata
Metadata
Assignees
Labels
No labels