diff --git a/linear_operator/operators/kronecker_product_linear_operator.py b/linear_operator/operators/kronecker_product_linear_operator.py index 1cece8b0..8308cbc9 100644 --- a/linear_operator/operators/kronecker_product_linear_operator.py +++ b/linear_operator/operators/kronecker_product_linear_operator.py @@ -6,6 +6,8 @@ import torch from jaxtyping import Float + +from pyfastkron import fastkrontorch as fktorch from torch import Tensor from linear_operator import settings @@ -267,14 +269,14 @@ def _matmul( self: Float[LinearOperator, "*batch M N"], rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: - is_vec = rhs.ndimension() == 1 - if is_vec: - rhs = rhs.unsqueeze(-1) - - res = _matmul(self.linear_ops, self.shape, rhs.contiguous()) + res = fktorch.gekmm([op.to_dense() for op in self.linear_ops], rhs.contiguous()) + return res - if is_vec: - res = res.squeeze(-1) + def rmatmul( + self: Float[LinearOperator, "... M N"], + other: Union[Float[Tensor, "... P M"], Float[Tensor, "... M"], Float[LinearOperator, "... P M"]], + ) -> Union[Float[Tensor, "... P N"], Float[Tensor, "N"], Float[LinearOperator, "... P N"]]: + res = fktorch.gemkm(other.contiguous(), [op.to_dense() for op in self.linear_ops]) return res @cached(name="root_decomposition") @@ -357,14 +359,7 @@ def _t_matmul( self: Float[LinearOperator, "*batch M N"], rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]], ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]: - is_vec = rhs.ndimension() == 1 - if is_vec: - rhs = rhs.unsqueeze(-1) - - res = _t_matmul(self.linear_ops, self.shape, rhs.contiguous()) - - if is_vec: - res = res.squeeze(-1) + res = fktorch.gekmm([op.to_dense().mT for op in self.linear_ops], rhs.contiguous()) return res def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: diff --git a/setup.py b/setup.py index 1c318a2f..6260a4ed 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ "scipy", "jaxtyping", "mpmath>=0.19,<=1.3", # avoid incompatibiltiy with torch+sympy with mpmath 1.4 + "pyfastkron" ]