@@ -85,7 +85,7 @@ def __init__(
8585 self .box = np .array (list (range (max_len )))
8686 self .len = len (self .x )
8787
88- def _mask_rna (self , x_masked : Tensor , mask_positions : list [int ]) -> Tensor :
88+ def _mask_rna (self , x_masked : Tensor , mask_positions : list [int ]) -> tuple [ Tensor , list [ int ]] :
8989 """Mask adjacent nucleotides for RNA sequences.
9090
9191 Parameters
@@ -97,20 +97,20 @@ def _mask_rna(self, x_masked: Tensor, mask_positions: list[int]) -> Tensor:
9797
9898 Returns
9999 -------
100- Tensor
101- The tensor with adjacent nucleotides masked.
100+ tuple[ Tensor, list[int]]
101+ The tensor with adjacent nucleotides masked and the list of adjacent positions .
102102 """
103103 adjacent_positions = []
104104 for pos in mask_positions :
105105 # mask position + 1 (if within bounds)
106- if pos < self .max_len - 1 :
106+ if pos < self .max_len - 1 and x_masked [ pos + 1 ] > 0 :
107107 adjacent_positions .append (pos + 1 )
108108 # mask position - 1 (if within bounds)
109- if pos > 0 :
109+ if pos > 0 and x_masked [ pos - 1 ] > 0 :
110110 adjacent_positions .append (pos - 1 )
111111 x_masked [adjacent_positions ] = self .mask_idx
112112
113- return x_masked
113+ return x_masked , adjacent_positions
114114
115115 def __len__ (self ) -> int :
116116 """
@@ -151,7 +151,7 @@ def __getitem__(self, index: int) -> tuple[Tensor, Tensor, Tensor, Tensor]:
151151 y = torch .tensor (self .y [index ], dtype = torch .int64 )
152152
153153 x_masked = x .clone ().detach ()
154- y_masked = x .clone ().detach ()
154+ y_masked = y .clone ().detach ()
155155
156156 # non-padding positions (0 is padding)
157157 seq_len = torch .sum (x_masked > 0 )
@@ -173,7 +173,8 @@ def __getitem__(self, index: int) -> tuple[Tensor, Tensor, Tensor, Tensor]:
173173
174174 # for RNA, also mask adjacent nucleotides for base pairing
175175 if self .is_rna :
176- x_masked = self ._mask_rna (x_masked , actual_mask_positions )
176+ x_masked , adjacent_positions = self ._mask_rna (x_masked , actual_mask_positions )
177+ no_mask_positions = [pos for pos in no_mask_positions if pos not in adjacent_positions ]
177178
178179 # zero out non-masked positions in target
179180 y_masked [no_mask_positions ] = 0
0 commit comments