From 12528545ecf7b2551743f73d6c393aeaa71cc699 Mon Sep 17 00:00:00 2001 From: Kaiwen Wu Date: Tue, 11 Nov 2025 17:27:44 -0500 Subject: [PATCH 1/5] `MVN.log_prob` ditches linear operators if `use_torch_tensor=True` --- gpytorch/distributions/multivariate_normal.py | 18 ++++++++++++++++-- gpytorch/settings.py | 12 ++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/gpytorch/distributions/multivariate_normal.py b/gpytorch/distributions/multivariate_normal.py index a752e9632..3a384e748 100644 --- a/gpytorch/distributions/multivariate_normal.py +++ b/gpytorch/distributions/multivariate_normal.py @@ -245,9 +245,23 @@ def log_prob(self, value: Tensor) -> Tensor: 1, ) - # Get log determininant and first part of quadratic form covar = covar.evaluate_kernel() - inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True) + + if ( + settings.fast_computations.log_prob.off() or covar.size(-1) <= settings.max_cholesky_size.value() + ) and settings.use_torch_tensors.on(): + # If we are to use Cholesky decomposition for inference, and we are allowed to use torch tensors as opposed + # to linear operators, then do so. + covar = covar.to_dense() + chol = torch.linalg.cholesky(covar) + + chol_inv_diff = torch.linalg.solve_triangular(chol, diff.unsqueeze(-1), upper=False) + chol_inv_diff = chol_inv_diff.squeeze(-1) + inv_quad = chol_inv_diff.square().sum(-1) + + logdet = chol.diagonal(dim1=-1, dim2=-2).log().mul(2).sum(-1) + else: + inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True) res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)]) return res diff --git a/gpytorch/settings.py b/gpytorch/settings.py index 99528c419..f22e81bfa 100644 --- a/gpytorch/settings.py +++ b/gpytorch/settings.py @@ -461,6 +461,17 @@ class use_keops(_feature_flag): _default = True +class use_torch_tensors(_feature_flag): + """ + Whether or not to use torch tensors instead of linear operators. If true, then we will use torch tensors as much as + possible to avoid the overhead of linear operators for dense kernel matrices. + + (Default: False) + """ + + _default = False + + __all__ = [ "_linalg_dtype_symeig", "_linalg_dtype_cholesky", @@ -502,6 +513,7 @@ class use_keops(_feature_flag): "tridiagonal_jitter", "use_keops", "use_toeplitz", + "use_torch_tensors", "variational_cholesky_jitter", "verbose_linalg", ] From 21c231f7d3cff44ed0f03ab10cfe5ce0220016a7 Mon Sep 17 00:00:00 2001 From: Kaiwen Wu Date: Thu, 13 Nov 2025 18:03:34 -0500 Subject: [PATCH 2/5] `MVN.log_prob` calls `inv_quad_logdet` on tensors --- gpytorch/distributions/multivariate_normal.py | 10 +-- gpytorch/functions/__init__.py | 2 + gpytorch/functions/inv_quad_logdet.py | 58 ++++++++++++++ test/functions/test_inv_quad_logdet.py | 79 +++++++++++++++++++ 4 files changed, 141 insertions(+), 8 deletions(-) create mode 100644 gpytorch/functions/inv_quad_logdet.py create mode 100644 test/functions/test_inv_quad_logdet.py diff --git a/gpytorch/distributions/multivariate_normal.py b/gpytorch/distributions/multivariate_normal.py index 3a384e748..345a3fc5c 100644 --- a/gpytorch/distributions/multivariate_normal.py +++ b/gpytorch/distributions/multivariate_normal.py @@ -15,6 +15,7 @@ from torch.distributions.kl import register_kl from torch.distributions.utils import _standard_normal, lazy_property +from gpytorch.functions import InvQuadLogdet from .. import settings from ..utils.warnings import NumericalWarning from .distribution import Distribution @@ -252,14 +253,7 @@ def log_prob(self, value: Tensor) -> Tensor: ) and settings.use_torch_tensors.on(): # If we are to use Cholesky decomposition for inference, and we are allowed to use torch tensors as opposed # to linear operators, then do so. - covar = covar.to_dense() - chol = torch.linalg.cholesky(covar) - - chol_inv_diff = torch.linalg.solve_triangular(chol, diff.unsqueeze(-1), upper=False) - chol_inv_diff = chol_inv_diff.squeeze(-1) - inv_quad = chol_inv_diff.square().sum(-1) - - logdet = chol.diagonal(dim1=-1, dim2=-2).log().mul(2).sum(-1) + inv_quad, logdet = InvQuadLogdet.apply(covar.to_dense(), diff.unsqueeze(-1)) else: inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True) diff --git a/gpytorch/functions/__init__.py b/gpytorch/functions/__init__.py index d3294a974..63a4dade2 100644 --- a/gpytorch/functions/__init__.py +++ b/gpytorch/functions/__init__.py @@ -9,6 +9,7 @@ import torch from ._log_normal_cdf import LogNormalCDF +from .inv_quad_logdet import InvQuadLogdet from .matern_covariance import MaternCovariance from .rbf_covariance import RBFCovariance @@ -39,6 +40,7 @@ def inv_matmul(mat, right_tensor, left_tensor=None): __all__ = [ + "InvQuadLogdet", "MaternCovariance", "RBFCovariance", "inv_matmul", diff --git a/gpytorch/functions/inv_quad_logdet.py b/gpytorch/functions/inv_quad_logdet.py new file mode 100644 index 000000000..27e8e9459 --- /dev/null +++ b/gpytorch/functions/inv_quad_logdet.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 + +import torch + +from torch import Tensor + + +class InvQuadLogdet(torch.autograd.Function): + r"""This function computes the inverse quadratic form and the log determinant of a positive semi-definite matrix. + This is a light weight implementation of `LinearOperator.inv_quad_logdet`. The main motivation is to avoid the + overhead of linear operators for dense kernel matrices by doing linear algebra operations directly on torch tensors. + """ + + @staticmethod + def forward( + ctx, + matrix: Tensor, + inv_quad_rhs: Tensor, + ) -> tuple[Tensor, Tensor]: + r"""Compute the inverse quadratic form and the log determinant. + + :param matrix: A positive semi-definite matrix of size `(..., N, N)`. + :param inv_quad_rhs: The right-hand side vector of size `(..., N, 1)`. + :return: The inverse quadratic form and the log determinant, both of size `(...)`. + """ + chol = torch.linalg.cholesky(matrix) + + # The inverse quadratic term + inv_quad_solves = torch.cholesky_solve(inv_quad_rhs, chol) + inv_quad_term = (inv_quad_solves * inv_quad_rhs).sum(-2) + inv_quad_term = inv_quad_term.squeeze(-1) + + # The log determinant term + logdet_term = 2.0 * chol.diagonal(dim1=-1, dim2=-2).log().sum(-1) + + ctx.save_for_backward(chol, inv_quad_solves) + + return inv_quad_term, logdet_term + + @staticmethod + def backward(ctx, d_inv_quad: Tensor, d_logdet: Tensor) -> tuple[Tensor, Tensor]: + r"""Compute the backward pass for the inverse quadratic form and the log determinant. + + :param d_inv_quad: The gradient of the inverse quadratic form of size `(...)`. + :param d_logdet: The gradient of the log determinant of size `(...)`. + :return: The gradients with respect to the input matrix and the right-hand side vector. + """ + chol, inv_quad_solves = ctx.saved_tensors + + d_matrix_one = ( + -1.0 * inv_quad_solves @ inv_quad_solves.transpose(-2, -1) * d_inv_quad.unsqueeze(-1).unsqueeze(-1) + ) + d_matrix_two = torch.cholesky_inverse(chol) * d_logdet.unsqueeze(-1).unsqueeze(-1) + d_matrix = d_matrix_one + d_matrix_two + + d_inv_quad_rhs = 2.0 * inv_quad_solves * d_inv_quad.unsqueeze(-1).unsqueeze(-1) + + return d_matrix, d_inv_quad_rhs diff --git a/test/functions/test_inv_quad_logdet.py b/test/functions/test_inv_quad_logdet.py new file mode 100644 index 000000000..5cca3149b --- /dev/null +++ b/test/functions/test_inv_quad_logdet.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 + +import unittest + +import torch + +from gpytorch.functions.inv_quad_logdet import InvQuadLogdet +from gpytorch.kernels.rbf_kernel import RBFKernel + + +class TestInvQuadLogdet(unittest.TestCase): + def test_inv_quad_logdet(self): + # NOTE: Use small matrics here to avoid flakiness since we are testing in `float32`. + num_data = 3 + jitter = 1e-4 + + train_x = torch.linspace(0, 1, num_data).view(num_data, 1) + + # Foward and backward using `InvQuadLogdet` + covar_module = RBFKernel() + covar_matrix = covar_module(train_x).evaluate_kernel().add_jitter(jitter).to_dense() + + inv_quad_rhs = torch.linspace(0, 1, num_data).requires_grad_(True) + + inv_quad, logdet = InvQuadLogdet.apply(covar_matrix, inv_quad_rhs.unsqueeze(-1)) + inv_quad_logdet = inv_quad + logdet + inv_quad_logdet.backward() + + # Forward and backward using linear operators + covar_module_linop = RBFKernel() + covar_matrix_linop = covar_module_linop(train_x).evaluate_kernel().add_jitter(jitter) + + inv_quad_rhs_linop = inv_quad_rhs.detach().clone().requires_grad_(True) + + inv_quad_linop, logdet_linop = covar_matrix_linop.inv_quad_logdet(inv_quad_rhs_linop.unsqueeze(-1), logdet=True) + inv_quad_logdet_linop = inv_quad_linop + logdet_linop + inv_quad_logdet_linop.backward() + + self.assertTrue(torch.allclose(inv_quad, inv_quad_linop)) + self.assertTrue(torch.allclose(logdet, logdet_linop)) + self.assertTrue(torch.allclose(inv_quad_logdet, inv_quad_logdet_linop)) + self.assertTrue(torch.allclose(covar_module.raw_lengthscale.grad, covar_module_linop.raw_lengthscale.grad)) + self.assertTrue(torch.allclose(inv_quad_rhs.grad, inv_quad_rhs_linop.grad)) + + def test_batch_inv_quad_logdet(self): + num_data = 3 + jitter = 1e-4 + + train_x = torch.linspace(0, 1, 2 * num_data).view(2, num_data, 1) + + # Foward and backward using `InvQuadLogdet` + covar_module = RBFKernel(batch_shape=torch.Size([2])) + covar_matrix = covar_module(train_x).evaluate_kernel().add_jitter(jitter).to_dense() + + inv_quad_rhs = torch.linspace(0, 1, 2 * num_data).view(2, num_data).requires_grad_(True) + + inv_quad, logdet = InvQuadLogdet.apply(covar_matrix, inv_quad_rhs.unsqueeze(-1)) + inv_quad_logdet = torch.sum(inv_quad + logdet) + inv_quad_logdet.backward() + + # Forward and backward using linear operators + covar_module_linop = RBFKernel(batch_shape=torch.Size([2])) + covar_matrix_linop = covar_module_linop(train_x).evaluate_kernel().add_jitter(jitter) + + inv_quad_rhs_linop = inv_quad_rhs.detach().clone().requires_grad_(True) + + inv_quad_linop, logdet_linop = covar_matrix_linop.inv_quad_logdet(inv_quad_rhs_linop.unsqueeze(-1), logdet=True) + inv_quad_logdet_linop = torch.sum(inv_quad_linop + logdet_linop) + inv_quad_logdet_linop.backward() + + self.assertTrue(torch.allclose(inv_quad, inv_quad_linop)) + self.assertTrue(torch.allclose(logdet, logdet_linop)) + self.assertTrue(torch.allclose(inv_quad_logdet, inv_quad_logdet_linop)) + self.assertTrue(torch.allclose(covar_module.raw_lengthscale.grad, covar_module_linop.raw_lengthscale.grad)) + self.assertTrue(torch.allclose(inv_quad_rhs.grad, inv_quad_rhs_linop.grad)) + + +if __name__ == "__main__": + unittest.main() From 229b6d9d18a5423c8c7b005da416a6b738e4f5cf Mon Sep 17 00:00:00 2001 From: Kaiwen Wu Date: Tue, 25 Nov 2025 17:16:22 -0500 Subject: [PATCH 3/5] `InvQuadLogdet` -> `TensorInvQuadLogdet` --- gpytorch/distributions/multivariate_normal.py | 4 ++-- gpytorch/functions/__init__.py | 4 ++-- gpytorch/functions/inv_quad_logdet.py | 2 +- test/functions/test_inv_quad_logdet.py | 8 ++++---- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/gpytorch/distributions/multivariate_normal.py b/gpytorch/distributions/multivariate_normal.py index 345a3fc5c..b6d47b24a 100644 --- a/gpytorch/distributions/multivariate_normal.py +++ b/gpytorch/distributions/multivariate_normal.py @@ -15,7 +15,7 @@ from torch.distributions.kl import register_kl from torch.distributions.utils import _standard_normal, lazy_property -from gpytorch.functions import InvQuadLogdet +from gpytorch.functions import TensorInvQuadLogdet from .. import settings from ..utils.warnings import NumericalWarning from .distribution import Distribution @@ -253,7 +253,7 @@ def log_prob(self, value: Tensor) -> Tensor: ) and settings.use_torch_tensors.on(): # If we are to use Cholesky decomposition for inference, and we are allowed to use torch tensors as opposed # to linear operators, then do so. - inv_quad, logdet = InvQuadLogdet.apply(covar.to_dense(), diff.unsqueeze(-1)) + inv_quad, logdet = TensorInvQuadLogdet.apply(covar.to_dense(), diff.unsqueeze(-1)) else: inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True) diff --git a/gpytorch/functions/__init__.py b/gpytorch/functions/__init__.py index 63a4dade2..fba396de1 100644 --- a/gpytorch/functions/__init__.py +++ b/gpytorch/functions/__init__.py @@ -9,7 +9,7 @@ import torch from ._log_normal_cdf import LogNormalCDF -from .inv_quad_logdet import InvQuadLogdet +from .inv_quad_logdet import TensorInvQuadLogdet from .matern_covariance import MaternCovariance from .rbf_covariance import RBFCovariance @@ -40,7 +40,7 @@ def inv_matmul(mat, right_tensor, left_tensor=None): __all__ = [ - "InvQuadLogdet", + "TensorInvQuadLogdet", "MaternCovariance", "RBFCovariance", "inv_matmul", diff --git a/gpytorch/functions/inv_quad_logdet.py b/gpytorch/functions/inv_quad_logdet.py index 27e8e9459..5c3b3a798 100644 --- a/gpytorch/functions/inv_quad_logdet.py +++ b/gpytorch/functions/inv_quad_logdet.py @@ -5,7 +5,7 @@ from torch import Tensor -class InvQuadLogdet(torch.autograd.Function): +class TensorInvQuadLogdet(torch.autograd.Function): r"""This function computes the inverse quadratic form and the log determinant of a positive semi-definite matrix. This is a light weight implementation of `LinearOperator.inv_quad_logdet`. The main motivation is to avoid the overhead of linear operators for dense kernel matrices by doing linear algebra operations directly on torch tensors. diff --git a/test/functions/test_inv_quad_logdet.py b/test/functions/test_inv_quad_logdet.py index 5cca3149b..b9ebbc1b8 100644 --- a/test/functions/test_inv_quad_logdet.py +++ b/test/functions/test_inv_quad_logdet.py @@ -4,8 +4,8 @@ import torch -from gpytorch.functions.inv_quad_logdet import InvQuadLogdet -from gpytorch.kernels.rbf_kernel import RBFKernel +from gpytorch.functions import TensorInvQuadLogdet +from gpytorch.kernels import RBFKernel class TestInvQuadLogdet(unittest.TestCase): @@ -22,7 +22,7 @@ def test_inv_quad_logdet(self): inv_quad_rhs = torch.linspace(0, 1, num_data).requires_grad_(True) - inv_quad, logdet = InvQuadLogdet.apply(covar_matrix, inv_quad_rhs.unsqueeze(-1)) + inv_quad, logdet = TensorInvQuadLogdet.apply(covar_matrix, inv_quad_rhs.unsqueeze(-1)) inv_quad_logdet = inv_quad + logdet inv_quad_logdet.backward() @@ -54,7 +54,7 @@ def test_batch_inv_quad_logdet(self): inv_quad_rhs = torch.linspace(0, 1, 2 * num_data).view(2, num_data).requires_grad_(True) - inv_quad, logdet = InvQuadLogdet.apply(covar_matrix, inv_quad_rhs.unsqueeze(-1)) + inv_quad, logdet = TensorInvQuadLogdet.apply(covar_matrix, inv_quad_rhs.unsqueeze(-1)) inv_quad_logdet = torch.sum(inv_quad + logdet) inv_quad_logdet.backward() From a64b8b8b8103cc92a2815db52f143586a87fedf1 Mon Sep 17 00:00:00 2001 From: Kaiwen Wu Date: Mon, 29 Dec 2025 17:57:15 -0500 Subject: [PATCH 4/5] typo fixes and remove a redundant if condition --- gpytorch/distributions/multivariate_normal.py | 4 +--- test/functions/test_inv_quad_logdet.py | 7 ++++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/gpytorch/distributions/multivariate_normal.py b/gpytorch/distributions/multivariate_normal.py index b6d47b24a..66b9bbd1b 100644 --- a/gpytorch/distributions/multivariate_normal.py +++ b/gpytorch/distributions/multivariate_normal.py @@ -248,9 +248,7 @@ def log_prob(self, value: Tensor) -> Tensor: covar = covar.evaluate_kernel() - if ( - settings.fast_computations.log_prob.off() or covar.size(-1) <= settings.max_cholesky_size.value() - ) and settings.use_torch_tensors.on(): + if covar.size(-1) <= settings.max_cholesky_size.value() and settings.use_torch_tensors.on(): # If we are to use Cholesky decomposition for inference, and we are allowed to use torch tensors as opposed # to linear operators, then do so. inv_quad, logdet = TensorInvQuadLogdet.apply(covar.to_dense(), diff.unsqueeze(-1)) diff --git a/test/functions/test_inv_quad_logdet.py b/test/functions/test_inv_quad_logdet.py index b9ebbc1b8..90acc494a 100644 --- a/test/functions/test_inv_quad_logdet.py +++ b/test/functions/test_inv_quad_logdet.py @@ -10,13 +10,14 @@ class TestInvQuadLogdet(unittest.TestCase): def test_inv_quad_logdet(self): - # NOTE: Use small matrics here to avoid flakiness since we are testing in `float32`. + # NOTE: Use small matrices here to avoid flakiness since we are testing in `float32` and `torch.allclose` by + # default is pretty stringent. num_data = 3 jitter = 1e-4 train_x = torch.linspace(0, 1, num_data).view(num_data, 1) - # Foward and backward using `InvQuadLogdet` + # Forward and backward using `InvQuadLogdet` covar_module = RBFKernel() covar_matrix = covar_module(train_x).evaluate_kernel().add_jitter(jitter).to_dense() @@ -48,7 +49,7 @@ def test_batch_inv_quad_logdet(self): train_x = torch.linspace(0, 1, 2 * num_data).view(2, num_data, 1) - # Foward and backward using `InvQuadLogdet` + # Forward and backward using `InvQuadLogdet` covar_module = RBFKernel(batch_shape=torch.Size([2])) covar_matrix = covar_module(train_x).evaluate_kernel().add_jitter(jitter).to_dense() From b807e34918c04905552728785cbd8d7c1a03870d Mon Sep 17 00:00:00 2001 From: Kaiwen Wu Date: Mon, 29 Dec 2025 18:14:40 -0500 Subject: [PATCH 5/5] use `psd_safe_cholesky` --- gpytorch/functions/inv_quad_logdet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gpytorch/functions/inv_quad_logdet.py b/gpytorch/functions/inv_quad_logdet.py index 5c3b3a798..e4f970b75 100644 --- a/gpytorch/functions/inv_quad_logdet.py +++ b/gpytorch/functions/inv_quad_logdet.py @@ -2,6 +2,7 @@ import torch +from linear_operator.utils.cholesky import psd_safe_cholesky from torch import Tensor @@ -23,7 +24,7 @@ def forward( :param inv_quad_rhs: The right-hand side vector of size `(..., N, 1)`. :return: The inverse quadratic form and the log determinant, both of size `(...)`. """ - chol = torch.linalg.cholesky(matrix) + chol = psd_safe_cholesky(matrix) # The inverse quadratic term inv_quad_solves = torch.cholesky_solve(inv_quad_rhs, chol)