From dca3803fc119a2fcf9f79e4948dfad419590d58a Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Tue, 13 Sep 2022 15:15:51 -0400 Subject: [PATCH] removing evaluate_kernel --- linear_operator/operators/_linear_operator.py | 18 +++--------------- .../operators/added_diag_linear_operator.py | 4 ---- 2 files changed, 3 insertions(+), 19 deletions(-) diff --git a/linear_operator/operators/_linear_operator.py b/linear_operator/operators/_linear_operator.py index 177269ae..aceaaa5f 100644 --- a/linear_operator/operators/_linear_operator.py +++ b/linear_operator/operators/_linear_operator.py @@ -469,12 +469,10 @@ def _cholesky(self, upper: bool = False) -> "TriangularLinearOperator": # noqa from .keops_linear_operator import KeOpsLinearOperator from .triangular_linear_operator import TriangularLinearOperator - evaluated_kern_mat = self.evaluate_kernel() - - if any(isinstance(sub_mat, KeOpsLinearOperator) for sub_mat in evaluated_kern_mat._args): + if any(isinstance(sub_mat, KeOpsLinearOperator) for sub_mat in self._args): raise RuntimeError("Cannot run Cholesky with KeOps: it will either be really slow or not work.") - evaluated_mat = evaluated_kern_mat.to_dense() + evaluated_mat = self.to_dense() # if the tensor is a scalar, we can just take the square root if evaluated_mat.size(-1) == 1: @@ -554,8 +552,6 @@ def _mul_matrix(self, other: Union[torch.Tensor, "LinearOperator"]) -> LinearOpe from .dense_linear_operator import DenseLinearOperator from .mul_linear_operator import MulLinearOperator - self = self.evaluate_kernel() - other = other.evaluate_kernel() if isinstance(self, DenseLinearOperator) or isinstance(other, DenseLinearOperator): return DenseLinearOperator(self.to_dense() * other.to_dense()) else: @@ -1445,14 +1441,6 @@ def eigvalsh(self) -> torch.Tensor: pass return self._symeig(eigenvectors=False)[0] - # TODO: remove - def evaluate_kernel(self): - """ - Return a new LinearOperator representing the same one as this one, but with - all lazily evaluated kernels actually evaluated. - """ - return self.representation_tree()(*self.representation()) - @_implements(torch.exp) def exp(self) -> "LinearOperator": # Only implemented by some LinearOperator subclasses @@ -2522,7 +2510,7 @@ def zero_mean_mvn_samples(self, num_samples: int) -> torch.Tensor: base_samples = base_samples.permute(-1, *range(self.dim() - 1)).contiguous() base_samples = base_samples.unsqueeze(-1) solves, weights, _, _ = contour_integral_quad( - self.evaluate_kernel(), + self, base_samples, inverse=False, num_contour_quadrature=settings.num_contour_quadrature.value(), diff --git a/linear_operator/operators/added_diag_linear_operator.py b/linear_operator/operators/added_diag_linear_operator.py index 4cd5f61f..325c8c08 100644 --- a/linear_operator/operators/added_diag_linear_operator.py +++ b/linear_operator/operators/added_diag_linear_operator.py @@ -186,7 +186,3 @@ def _symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional[LinearOp evals = evals_ + self._diag_tensor._diagonal() return evals, evecs return super()._symeig(eigenvectors=eigenvectors) - - def evaluate_kernel(self) -> LinearOperator: - added_diag_linear_op = self.representation_tree()(*self.representation()) - return added_diag_linear_op._linear_op + added_diag_linear_op._diag_tensor