Skip to content

Commit 79de4a3

Browse files
committed
handle single genomic string passed in gracefully
1 parent d73f17d commit 79de4a3

2 files changed

Lines changed: 5 additions & 4 deletions

File tree

enformer_pytorch/data.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@ def cast_list(t):
4848
one_hot_embed[ord('.')] = torch.Tensor([0.25, 0.25, 0.25, 0.25])
4949

5050
def torch_fromstring(seq_strs):
51+
batched = not isinstance(seq_strs, str)
5152
seq_strs = cast_list(seq_strs)
5253
np_seq_chrs = list(map(lambda t: np.fromstring(t, dtype = np.uint8), seq_strs))
5354
seq_chrs = list(map(torch.from_numpy, np_seq_chrs))
54-
return torch.stack(seq_chrs)
55+
return torch.stack(seq_chrs) if batched else seq_chrs[0]
5556

5657
def str_to_seq_indices(seq_strs):
5758
seq_chrs = torch_fromstring(seq_strs)
@@ -152,6 +153,6 @@ def __getitem__(self, ind):
152153
seq = ('.' * left_padding) + str(chromosome[start:end]) + ('.' * right_padding)
153154

154155
if self.return_seq_indices:
155-
return str_to_seq_indices(seq).squeeze(0)
156+
return str_to_seq_indices(seq)
156157

157-
return str_to_one_hot(seq).squeeze(0)
158+
return str_to_one_hot(seq)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.2.8',
7+
version = '0.2.9',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)