Skip to content

Commit adb23bf

Browse files
authored
Skip sequences that are less than minimum sequence length (#44)
1 parent 8270e52 commit adb23bf

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/gtnet/sequence.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def __init__(self, window, step, vocab=None, padval=None, min_seq_len=100, devic
3434
self.device = device
3535

3636
def encode(self, seq):
37+
if len(seq) < self.min_seq_len:
38+
raise ValueError(f"Minimum sequence length is {self.min_seq_len} - got {len(seq)}")
3739
if seq.dtype == np.dtype('S1'):
3840
seq = seq.view(np.uint8)
3941
elif seq.dtype == np.dtype('U1'):
@@ -143,9 +145,14 @@ def readfiles(cls, encoder, fastas):
143145
for fa in fastas:
144146
logging.debug(f'loading {fa}')
145147
for seqid, values in cls.readfile(fa):
146-
batches = encoder.encode(values)
147-
val = (fa, seqid, len(values), batches)
148-
yield val
148+
if len(values) < encoder.min_seq_len:
149+
logging.warning((f"Skipping {seqid} from {fa} - length less than "
150+
"minimum sequence length {encoder.min_seq}"))
151+
yield (fa, seqid, len(values), torch.zeros((0, 0, 0), dtype=torch.uint8))
152+
else:
153+
batches = encoder.encode(values)
154+
val = (fa, seqid, len(values), batches)
155+
yield val
149156

150157

151158
class SerialLoader(Loader):

0 commit comments

Comments
 (0)