|
8 | 8 | from jaxtyping import Float
|
9 | 9 | from torch import Tensor
|
10 | 10 |
|
| 11 | +from pyfastkron import fastkrontorch as fktorch |
| 12 | + |
11 | 13 | from linear_operator import settings
|
12 | 14 | from linear_operator.operators._linear_operator import IndexType, LinearOperator
|
13 | 15 | from linear_operator.operators.dense_linear_operator import to_linear_operator
|
@@ -267,14 +269,13 @@ def _matmul(
|
267 | 269 | self: Float[LinearOperator, "*batch M N"],
|
268 | 270 | rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]],
|
269 | 271 | ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]:
|
270 |
| - is_vec = rhs.ndimension() == 1 |
271 |
| - if is_vec: |
272 |
| - rhs = rhs.unsqueeze(-1) |
273 |
| - |
274 |
| - res = _matmul(self.linear_ops, self.shape, rhs.contiguous()) |
275 |
| - |
276 |
| - if is_vec: |
277 |
| - res = res.squeeze(-1) |
| 272 | + res = fktorch.gekmm([op.to_dense() for op in self.linear_ops], rhs.contiguous()) |
| 273 | + return res |
| 274 | + |
| 275 | + def rmatmul(self: Float[LinearOperator, "... M N"], |
| 276 | + rhs: Union[Float[Tensor, "... P M"], Float[Tensor, "... M"], Float[LinearOperator, "... P M"]], |
| 277 | + ) -> Union[Float[Tensor, "... P N"], Float[Tensor, "N"], Float[LinearOperator, "... P N"]]: |
| 278 | + res = fktorch.gemkm(rhs.contiguous(), [op.to_dense() for op in self.linear_ops]) |
278 | 279 | return res
|
279 | 280 |
|
280 | 281 | @cached(name="root_decomposition")
|
@@ -357,14 +358,7 @@ def _t_matmul(
|
357 | 358 | self: Float[LinearOperator, "*batch M N"],
|
358 | 359 | rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]],
|
359 | 360 | ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]:
|
360 |
| - is_vec = rhs.ndimension() == 1 |
361 |
| - if is_vec: |
362 |
| - rhs = rhs.unsqueeze(-1) |
363 |
| - |
364 |
| - res = _t_matmul(self.linear_ops, self.shape, rhs.contiguous()) |
365 |
| - |
366 |
| - if is_vec: |
367 |
| - res = res.squeeze(-1) |
| 361 | + res = fktorch.gekmm([op.to_dense().mT for op in self.linear_ops], rhs.contiguous()) |
368 | 362 | return res
|
369 | 363 |
|
370 | 364 | def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]:
|
|
0 commit comments