-
Notifications
You must be signed in to change notification settings - Fork 594
Expand file tree
/
Copy pathtest_inv_quad_logdet.py
More file actions
80 lines (56 loc) · 3.4 KB
/
test_inv_quad_logdet.py
File metadata and controls
80 lines (56 loc) · 3.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python3
import unittest
import torch
from gpytorch.functions import TensorInvQuadLogdet
from gpytorch.kernels import RBFKernel
class TestInvQuadLogdet(unittest.TestCase):
def test_inv_quad_logdet(self):
# NOTE: Use small matrices here to avoid flakiness since we are testing in `float32` and `torch.allclose` by
# default is pretty stringent.
num_data = 3
jitter = 1e-4
train_x = torch.linspace(0, 1, num_data).view(num_data, 1)
# Forward and backward using `InvQuadLogdet`
covar_module = RBFKernel()
covar_matrix = covar_module(train_x).evaluate_kernel().add_jitter(jitter).to_dense()
inv_quad_rhs = torch.linspace(0, 1, num_data).requires_grad_(True)
inv_quad, logdet = TensorInvQuadLogdet.apply(covar_matrix, inv_quad_rhs.unsqueeze(-1))
inv_quad_logdet = inv_quad + logdet
inv_quad_logdet.backward()
# Forward and backward using linear operators
covar_module_linop = RBFKernel()
covar_matrix_linop = covar_module_linop(train_x).evaluate_kernel().add_jitter(jitter)
inv_quad_rhs_linop = inv_quad_rhs.detach().clone().requires_grad_(True)
inv_quad_linop, logdet_linop = covar_matrix_linop.inv_quad_logdet(inv_quad_rhs_linop.unsqueeze(-1), logdet=True)
inv_quad_logdet_linop = inv_quad_linop + logdet_linop
inv_quad_logdet_linop.backward()
self.assertTrue(torch.allclose(inv_quad, inv_quad_linop))
self.assertTrue(torch.allclose(logdet, logdet_linop))
self.assertTrue(torch.allclose(inv_quad_logdet, inv_quad_logdet_linop))
self.assertTrue(torch.allclose(covar_module.raw_lengthscale.grad, covar_module_linop.raw_lengthscale.grad))
self.assertTrue(torch.allclose(inv_quad_rhs.grad, inv_quad_rhs_linop.grad))
def test_batch_inv_quad_logdet(self):
num_data = 3
jitter = 1e-4
train_x = torch.linspace(0, 1, 2 * num_data).view(2, num_data, 1)
# Forward and backward using `InvQuadLogdet`
covar_module = RBFKernel(batch_shape=torch.Size([2]))
covar_matrix = covar_module(train_x).evaluate_kernel().add_jitter(jitter).to_dense()
inv_quad_rhs = torch.linspace(0, 1, 2 * num_data).view(2, num_data).requires_grad_(True)
inv_quad, logdet = TensorInvQuadLogdet.apply(covar_matrix, inv_quad_rhs.unsqueeze(-1))
inv_quad_logdet = torch.sum(inv_quad + logdet)
inv_quad_logdet.backward()
# Forward and backward using linear operators
covar_module_linop = RBFKernel(batch_shape=torch.Size([2]))
covar_matrix_linop = covar_module_linop(train_x).evaluate_kernel().add_jitter(jitter)
inv_quad_rhs_linop = inv_quad_rhs.detach().clone().requires_grad_(True)
inv_quad_linop, logdet_linop = covar_matrix_linop.inv_quad_logdet(inv_quad_rhs_linop.unsqueeze(-1), logdet=True)
inv_quad_logdet_linop = torch.sum(inv_quad_linop + logdet_linop)
inv_quad_logdet_linop.backward()
self.assertTrue(torch.allclose(inv_quad, inv_quad_linop))
self.assertTrue(torch.allclose(logdet, logdet_linop))
self.assertTrue(torch.allclose(inv_quad_logdet, inv_quad_logdet_linop))
self.assertTrue(torch.allclose(covar_module.raw_lengthscale.grad, covar_module_linop.raw_lengthscale.grad))
self.assertTrue(torch.allclose(inv_quad_rhs.grad, inv_quad_rhs_linop.grad))
if __name__ == "__main__":
unittest.main()