-
Notifications
You must be signed in to change notification settings - Fork 594
reduce linear operator overhead in exact marginal log likelihood computation #2682
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1252854
21c231f
229b6d9
d4018b2
a64b8b8
b807e34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| #!/usr/bin/env python3 | ||
|
|
||
| import torch | ||
|
|
||
| from linear_operator.utils.cholesky import psd_safe_cholesky | ||
| from torch import Tensor | ||
|
|
||
|
|
||
| class TensorInvQuadLogdet(torch.autograd.Function): | ||
| r"""This function computes the inverse quadratic form and the log determinant of a positive semi-definite matrix. | ||
| This is a light weight implementation of `LinearOperator.inv_quad_logdet`. The main motivation is to avoid the | ||
| overhead of linear operators for dense kernel matrices by doing linear algebra operations directly on torch tensors. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def forward( | ||
| ctx, | ||
| matrix: Tensor, | ||
| inv_quad_rhs: Tensor, | ||
| ) -> tuple[Tensor, Tensor]: | ||
| r"""Compute the inverse quadratic form and the log determinant. | ||
|
|
||
| :param matrix: A positive semi-definite matrix of size `(..., N, N)`. | ||
|
||
| :param inv_quad_rhs: The right-hand side vector of size `(..., N, 1)`. | ||
| :return: The inverse quadratic form and the log determinant, both of size `(...)`. | ||
| """ | ||
| chol = psd_safe_cholesky(matrix) | ||
|
|
||
| # The inverse quadratic term | ||
| inv_quad_solves = torch.cholesky_solve(inv_quad_rhs, chol) | ||
| inv_quad_term = (inv_quad_solves * inv_quad_rhs).sum(-2) | ||
| inv_quad_term = inv_quad_term.squeeze(-1) | ||
|
|
||
| # The log determinant term | ||
| logdet_term = 2.0 * chol.diagonal(dim1=-1, dim2=-2).log().sum(-1) | ||
|
|
||
| ctx.save_for_backward(chol, inv_quad_solves) | ||
|
kayween marked this conversation as resolved.
|
||
|
|
||
| return inv_quad_term, logdet_term | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, d_inv_quad: Tensor, d_logdet: Tensor) -> tuple[Tensor, Tensor]: | ||
| r"""Compute the backward pass for the inverse quadratic form and the log determinant. | ||
|
|
||
| :param d_inv_quad: The gradient of the inverse quadratic form of size `(...)`. | ||
| :param d_logdet: The gradient of the log determinant of size `(...)`. | ||
| :return: The gradients with respect to the input matrix and the right-hand side vector. | ||
| """ | ||
| chol, inv_quad_solves = ctx.saved_tensors | ||
|
|
||
| d_matrix_one = ( | ||
| -1.0 * inv_quad_solves @ inv_quad_solves.transpose(-2, -1) * d_inv_quad.unsqueeze(-1).unsqueeze(-1) | ||
| ) | ||
| d_matrix_two = torch.cholesky_inverse(chol) * d_logdet.unsqueeze(-1).unsqueeze(-1) | ||
| d_matrix = d_matrix_one + d_matrix_two | ||
|
|
||
| d_inv_quad_rhs = 2.0 * inv_quad_solves * d_inv_quad.unsqueeze(-1).unsqueeze(-1) | ||
|
|
||
| return d_matrix, d_inv_quad_rhs | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -461,6 +461,17 @@ class use_keops(_feature_flag): | |
| _default = True | ||
|
|
||
|
|
||
| class use_torch_tensors(_feature_flag): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do we think about making this on by default up to some N?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, the first version of this PR turns on this flag up to some N as you suggested. But the benchmark shows speed up even for N=1000 (whereas the default threshold for Cholesky decomposition is N=800). So I decided to turns this on as long as Cholesky decomposition is used for training and inference. I think the design here is intertwined with your comments below---what would happen for larger N. I'll circle back on this once we have benchmark results for larger N. |
||
| """ | ||
| Whether or not to use torch tensors instead of linear operators. If true, then we will use torch tensors as much as | ||
| possible to avoid the overhead of linear operators for dense kernel matrices. | ||
|
|
||
| (Default: False) | ||
| """ | ||
|
|
||
| _default = False | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "_linalg_dtype_symeig", | ||
| "_linalg_dtype_cholesky", | ||
|
|
@@ -502,6 +513,7 @@ class use_keops(_feature_flag): | |
| "tridiagonal_jitter", | ||
| "use_keops", | ||
| "use_toeplitz", | ||
| "use_torch_tensors", | ||
| "variational_cholesky_jitter", | ||
| "verbose_linalg", | ||
| ] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| #!/usr/bin/env python3 | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
|
|
||
| from gpytorch.functions import TensorInvQuadLogdet | ||
| from gpytorch.kernels import RBFKernel | ||
|
|
||
|
|
||
| class TestInvQuadLogdet(unittest.TestCase): | ||
| def test_inv_quad_logdet(self): | ||
| # NOTE: Use small matrices here to avoid flakiness since we are testing in `float32` and `torch.allclose` by | ||
| # default is pretty stringent. | ||
| num_data = 3 | ||
| jitter = 1e-4 | ||
|
|
||
| train_x = torch.linspace(0, 1, num_data).view(num_data, 1) | ||
|
|
||
| # Forward and backward using `InvQuadLogdet` | ||
| covar_module = RBFKernel() | ||
| covar_matrix = covar_module(train_x).evaluate_kernel().add_jitter(jitter).to_dense() | ||
|
|
||
| inv_quad_rhs = torch.linspace(0, 1, num_data).requires_grad_(True) | ||
|
|
||
| inv_quad, logdet = TensorInvQuadLogdet.apply(covar_matrix, inv_quad_rhs.unsqueeze(-1)) | ||
| inv_quad_logdet = inv_quad + logdet | ||
| inv_quad_logdet.backward() | ||
|
|
||
| # Forward and backward using linear operators | ||
| covar_module_linop = RBFKernel() | ||
| covar_matrix_linop = covar_module_linop(train_x).evaluate_kernel().add_jitter(jitter) | ||
|
|
||
| inv_quad_rhs_linop = inv_quad_rhs.detach().clone().requires_grad_(True) | ||
|
|
||
| inv_quad_linop, logdet_linop = covar_matrix_linop.inv_quad_logdet(inv_quad_rhs_linop.unsqueeze(-1), logdet=True) | ||
| inv_quad_logdet_linop = inv_quad_linop + logdet_linop | ||
| inv_quad_logdet_linop.backward() | ||
|
|
||
| self.assertTrue(torch.allclose(inv_quad, inv_quad_linop)) | ||
| self.assertTrue(torch.allclose(logdet, logdet_linop)) | ||
| self.assertTrue(torch.allclose(inv_quad_logdet, inv_quad_logdet_linop)) | ||
| self.assertTrue(torch.allclose(covar_module.raw_lengthscale.grad, covar_module_linop.raw_lengthscale.grad)) | ||
| self.assertTrue(torch.allclose(inv_quad_rhs.grad, inv_quad_rhs_linop.grad)) | ||
|
kayween marked this conversation as resolved.
|
||
|
|
||
| def test_batch_inv_quad_logdet(self): | ||
| num_data = 3 | ||
| jitter = 1e-4 | ||
|
|
||
| train_x = torch.linspace(0, 1, 2 * num_data).view(2, num_data, 1) | ||
|
|
||
| # Forward and backward using `InvQuadLogdet` | ||
| covar_module = RBFKernel(batch_shape=torch.Size([2])) | ||
| covar_matrix = covar_module(train_x).evaluate_kernel().add_jitter(jitter).to_dense() | ||
|
|
||
| inv_quad_rhs = torch.linspace(0, 1, 2 * num_data).view(2, num_data).requires_grad_(True) | ||
|
|
||
| inv_quad, logdet = TensorInvQuadLogdet.apply(covar_matrix, inv_quad_rhs.unsqueeze(-1)) | ||
| inv_quad_logdet = torch.sum(inv_quad + logdet) | ||
| inv_quad_logdet.backward() | ||
|
|
||
| # Forward and backward using linear operators | ||
| covar_module_linop = RBFKernel(batch_shape=torch.Size([2])) | ||
| covar_matrix_linop = covar_module_linop(train_x).evaluate_kernel().add_jitter(jitter) | ||
|
|
||
| inv_quad_rhs_linop = inv_quad_rhs.detach().clone().requires_grad_(True) | ||
|
|
||
| inv_quad_linop, logdet_linop = covar_matrix_linop.inv_quad_logdet(inv_quad_rhs_linop.unsqueeze(-1), logdet=True) | ||
| inv_quad_logdet_linop = torch.sum(inv_quad_linop + logdet_linop) | ||
| inv_quad_logdet_linop.backward() | ||
|
|
||
| self.assertTrue(torch.allclose(inv_quad, inv_quad_linop)) | ||
| self.assertTrue(torch.allclose(logdet, logdet_linop)) | ||
| self.assertTrue(torch.allclose(inv_quad_logdet, inv_quad_logdet_linop)) | ||
| self.assertTrue(torch.allclose(covar_module.raw_lengthscale.grad, covar_module_linop.raw_lengthscale.grad)) | ||
| self.assertTrue(torch.allclose(inv_quad_rhs.grad, inv_quad_rhs_linop.grad)) | ||
|
|
||
|
kayween marked this conversation as resolved.
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
Uh oh!
There was an error while loading. Please reload this page.