Skip to content

Removing evaluate_kernel #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 3 additions & 15 deletions linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,12 +469,10 @@ def _cholesky(self, upper: bool = False) -> "TriangularLinearOperator": # noqa
from .keops_linear_operator import KeOpsLinearOperator
from .triangular_linear_operator import TriangularLinearOperator

evaluated_kern_mat = self.evaluate_kernel()

if any(isinstance(sub_mat, KeOpsLinearOperator) for sub_mat in evaluated_kern_mat._args):
if any(isinstance(sub_mat, KeOpsLinearOperator) for sub_mat in self._args):
raise RuntimeError("Cannot run Cholesky with KeOps: it will either be really slow or not work.")

evaluated_mat = evaluated_kern_mat.to_dense()
evaluated_mat = self.to_dense()

# if the tensor is a scalar, we can just take the square root
if evaluated_mat.size(-1) == 1:
Expand Down Expand Up @@ -554,8 +552,6 @@ def _mul_matrix(self, other: Union[torch.Tensor, "LinearOperator"]) -> LinearOpe
from .dense_linear_operator import DenseLinearOperator
from .mul_linear_operator import MulLinearOperator

self = self.evaluate_kernel()
other = other.evaluate_kernel()
if isinstance(self, DenseLinearOperator) or isinstance(other, DenseLinearOperator):
return DenseLinearOperator(self.to_dense() * other.to_dense())
else:
Expand Down Expand Up @@ -1445,14 +1441,6 @@ def eigvalsh(self) -> torch.Tensor:
pass
return self._symeig(eigenvectors=False)[0]

# TODO: remove
def evaluate_kernel(self):
"""
Return a new LinearOperator representing the same one as this one, but with
all lazily evaluated kernels actually evaluated.
"""
return self.representation_tree()(*self.representation())

@_implements(torch.exp)
def exp(self) -> "LinearOperator":
# Only implemented by some LinearOperator subclasses
Expand Down Expand Up @@ -2522,7 +2510,7 @@ def zero_mean_mvn_samples(self, num_samples: int) -> torch.Tensor:
base_samples = base_samples.permute(-1, *range(self.dim() - 1)).contiguous()
base_samples = base_samples.unsqueeze(-1)
solves, weights, _, _ = contour_integral_quad(
self.evaluate_kernel(),
self,
base_samples,
inverse=False,
num_contour_quadrature=settings.num_contour_quadrature.value(),
Expand Down
4 changes: 0 additions & 4 deletions linear_operator/operators/added_diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,3 @@ def _symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional[LinearOp
evals = evals_ + self._diag_tensor._diagonal()
return evals, evecs
return super()._symeig(eigenvectors=eigenvectors)

def evaluate_kernel(self) -> LinearOperator:
added_diag_linear_op = self.representation_tree()(*self.representation())
return added_diag_linear_op._linear_op + added_diag_linear_op._diag_tensor