Skip to content

Add a function to compute trace of matrix product efficiently #3161

Open
@jachymb

Description

@jachymb

I figured something like this could be useful (although not strictly necessary) pre-requisite for #3149

Consider rectangular matrices $A \in \mathbb{R}^{m \times n}, B \in \mathbb{R}^{n \times m}$. We can observe the following:

$\mathrm{Trace}(A \cdot B) = \sum_i A_i \cdot (B^T)_{i} = \sum_i \sum_j A_{i,j} B_{j,i}$

The left-hand-side expression is elegant and concise way to write it. From programmers perspective, it's convenient especially when we already have matrices on input.

The middle expression seems like a good way to compute the value in a vectorized way. It's asymptotically faster than computing the whole product and then taking the trace.

The right-hand-side expression is just expanded form and may appear in the wild. It suggests computing it as sum of elementwise product of $A$ and $B^T$ which is maybe better computationally, but kinda obscures the linear-algebraic structure.

I thus suggest we add a function like real trace_dot(matrix a, matrix b). Note that we already have functions like trace_quad_form that do something related.

Not sure if the property $\mathrm{Trace}(A \cdot B) = \mathrm{Trace}(B \cdot A)$ can be useful here, but it allows us to choose (in the middle formula) to do either more dot products of shorter vectors or fewer dot-products of longer vectors. It should amount to the same number of multiplications but perhaps for hardware acceleration one of those is preferred.

For autodiff, I think there is

$\frac{\mathrm{d}}{\mathrm{d}A} \mathrm{Trace}(A\cdot B) = B^T$

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions