Skip to content

Commit 06f97e2

Browse files
committed
use pyfastkron for kroneckerproduct (t/r)matmul
1 parent 6dad1cb commit 06f97e2

File tree

2 files changed

+11
-16
lines changed

2 files changed

+11
-16
lines changed

linear_operator/operators/kronecker_product_linear_operator.py

+10-16
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from jaxtyping import Float
99
from torch import Tensor
1010

11+
from pyfastkron import fastkrontorch as fktorch
12+
1113
from linear_operator import settings
1214
from linear_operator.operators._linear_operator import IndexType, LinearOperator
1315
from linear_operator.operators.dense_linear_operator import to_linear_operator
@@ -267,14 +269,13 @@ def _matmul(
267269
self: Float[LinearOperator, "*batch M N"],
268270
rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]],
269271
) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]:
270-
is_vec = rhs.ndimension() == 1
271-
if is_vec:
272-
rhs = rhs.unsqueeze(-1)
273-
274-
res = _matmul(self.linear_ops, self.shape, rhs.contiguous())
275-
276-
if is_vec:
277-
res = res.squeeze(-1)
272+
res = fktorch.gekmm([op.to_dense() for op in self.linear_ops], rhs.contiguous())
273+
return res
274+
275+
def rmatmul(self: Float[LinearOperator, "... M N"],
276+
rhs: Union[Float[Tensor, "... P M"], Float[Tensor, "... M"], Float[LinearOperator, "... P M"]],
277+
) -> Union[Float[Tensor, "... P N"], Float[Tensor, "N"], Float[LinearOperator, "... P N"]]:
278+
res = fktorch.gemkm(rhs.contiguous(), [op.to_dense() for op in self.linear_ops])
278279
return res
279280

280281
@cached(name="root_decomposition")
@@ -357,14 +358,7 @@ def _t_matmul(
357358
self: Float[LinearOperator, "*batch M N"],
358359
rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]],
359360
) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]:
360-
is_vec = rhs.ndimension() == 1
361-
if is_vec:
362-
rhs = rhs.unsqueeze(-1)
363-
364-
res = _t_matmul(self.linear_ops, self.shape, rhs.contiguous())
365-
366-
if is_vec:
367-
res = res.squeeze(-1)
361+
res = fktorch.gekmm([op.to_dense().mT for op in self.linear_ops], rhs.contiguous())
368362
return res
369363

370364
def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]:

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"scipy",
4242
"jaxtyping==0.2.19",
4343
"mpmath>=0.19,<=1.3", # avoid incompatibiltiy with torch+sympy with mpmath 1.4
44+
"pyfastkron"
4445
]
4546

4647

0 commit comments

Comments
 (0)