diff --git a/pyaptamer/datasets/dataclasses/_masked.py b/pyaptamer/datasets/dataclasses/_masked.py index 53bf064b..8817d6c6 100644 --- a/pyaptamer/datasets/dataclasses/_masked.py +++ b/pyaptamer/datasets/dataclasses/_masked.py @@ -151,7 +151,7 @@ def __getitem__(self, index: int) -> tuple[Tensor, Tensor, Tensor, Tensor]: y = torch.tensor(self.y[index], dtype=torch.int64) x_masked = x.clone().detach() - y_masked = x.clone().detach() + y_masked = y.clone().detach() # non-padding positions (0 is padding) seq_len = torch.sum(x_masked > 0)