-
Notifications
You must be signed in to change notification settings - Fork 95
Open
Description
Our current rrule for sparse matrix vector products is very inefficient, and causes out-of-memory with large sparse CPU or GPU arrays. Our current rrule(*, sparse(A), x) is implemented like this
function rrule(
::typeof(*),
A::AbstractVecOrMat{<:CommutativeMulNumber},
B::AbstractVecOrMat{<:CommutativeMulNumber},
)
project_A = ProjectTo(A)
...
dA = @thunk(project_A(Ȳ * B'))
...
endSo we first compute a non-sparse Ȳ * B' (may easily exceed memory if A was very large but very sparse) and then project back to a sparse tangent.
The best way to fix this (at least if Ȳ' and 'B' are vectors) might be adding a specific "vector-outer-product" array type for read-only vector * adjoint-vector products (might be useful in general) that computes getindex on the fly. Or maybe we already have that somewhere?
Metadata
Metadata
Assignees
Labels
No labels