Skip to content

Commit 66b0cb6

Browse files
committed
Cleaner shaped_noise_covar using linear operator
1 parent 40452ba commit 66b0cb6

File tree

1 file changed

+8
-16
lines changed

1 file changed

+8
-16
lines changed

gpytorch/likelihoods/hadamard_gaussian_likelihood.py

+8-16
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import torch
66

7+
from linear_operator.operators import DiagLinearOperator
8+
79
from ..distributions import base_distributions, MultivariateNormal
810
from ..likelihoods import _GaussianLikelihoodBase
911
from .noise_models import MultitaskHomoskedasticNoise
@@ -54,24 +56,14 @@ def raw_noise(self, value: torch.Tensor) -> None:
5456
self.noise_covar.initialize(raw_noise=value)
5557

5658
def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: Any):
57-
# params contains training data
59+
# params contains task indexes
5860
task_idxs = params[0][-1]
5961
noise_base_covar_matrix = self.noise_covar(*params, shape=base_shape, **kwargs)
60-
# initialize masking
61-
mask = torch.zeros(size=noise_base_covar_matrix.shape)
62-
# for each task create a masking
63-
for task_num in range(self.num_tasks):
64-
# create vector of indexes
65-
task_idx_diag = (task_idxs == task_num).int().reshape(-1).diag()
66-
mask[..., task_num, :, :] = task_idx_diag
67-
# multiply covar by masking
68-
# there seems to be problems when base_shape is singleton, so we need to squeeze
69-
if base_shape == torch.Size([1]):
70-
noise_base_covar_matrix = noise_base_covar_matrix.squeeze(-1).mul(mask.squeeze(-1))
71-
noise_covar_matrix = noise_base_covar_matrix.unsqueeze(-1).sum(dim=1)
72-
else:
73-
noise_covar_matrix = noise_base_covar_matrix.mul(mask).sum(dim=1)
74-
return noise_covar_matrix
62+
63+
all_tasks = torch.arange(self.num_tasks)[:, None]
64+
diag = torch.eq(all_tasks, task_idxs.mT)
65+
mask = DiagLinearOperator(diag)
66+
return (noise_base_covar_matrix @ mask).sum(dim=-3)
7567

7668
def forward(
7769
self,

0 commit comments

Comments
 (0)