|
4 | 4 |
|
5 | 5 | import torch
|
6 | 6 |
|
| 7 | +from linear_operator.operators import DiagLinearOperator |
| 8 | + |
7 | 9 | from ..distributions import base_distributions, MultivariateNormal
|
8 | 10 | from ..likelihoods import _GaussianLikelihoodBase
|
9 | 11 | from .noise_models import MultitaskHomoskedasticNoise
|
@@ -54,24 +56,14 @@ def raw_noise(self, value: torch.Tensor) -> None:
|
54 | 56 | self.noise_covar.initialize(raw_noise=value)
|
55 | 57 |
|
56 | 58 | def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: Any):
|
57 |
| - # params contains training data |
| 59 | + # params contains task indexes |
58 | 60 | task_idxs = params[0][-1]
|
59 | 61 | noise_base_covar_matrix = self.noise_covar(*params, shape=base_shape, **kwargs)
|
60 |
| - # initialize masking |
61 |
| - mask = torch.zeros(size=noise_base_covar_matrix.shape) |
62 |
| - # for each task create a masking |
63 |
| - for task_num in range(self.num_tasks): |
64 |
| - # create vector of indexes |
65 |
| - task_idx_diag = (task_idxs == task_num).int().reshape(-1).diag() |
66 |
| - mask[..., task_num, :, :] = task_idx_diag |
67 |
| - # multiply covar by masking |
68 |
| - # there seems to be problems when base_shape is singleton, so we need to squeeze |
69 |
| - if base_shape == torch.Size([1]): |
70 |
| - noise_base_covar_matrix = noise_base_covar_matrix.squeeze(-1).mul(mask.squeeze(-1)) |
71 |
| - noise_covar_matrix = noise_base_covar_matrix.unsqueeze(-1).sum(dim=1) |
72 |
| - else: |
73 |
| - noise_covar_matrix = noise_base_covar_matrix.mul(mask).sum(dim=1) |
74 |
| - return noise_covar_matrix |
| 62 | + |
| 63 | + all_tasks = torch.arange(self.num_tasks)[:, None] |
| 64 | + diag = torch.eq(all_tasks, task_idxs.mT) |
| 65 | + mask = DiagLinearOperator(diag) |
| 66 | + return (noise_base_covar_matrix @ mask).sum(dim=-3) |
75 | 67 |
|
76 | 68 | def forward(
|
77 | 69 | self,
|
|
0 commit comments