Skip to content

Commit 98e3602

Browse files
committed
prep for accepting sequence in string form, with period "." as padding
1 parent 9756c22 commit 98e3602

2 files changed

Lines changed: 9 additions & 5 deletions

File tree

enformer_pytorch/enformer_pytorch.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,19 @@ def log(t, eps = 1e-20):
3535

3636
# sequence helpers
3737

38-
def str_to_seq_indices(seq_strs):
39-
char_to_index_map = dict(a = 0, c = 1, g = 2, t = 3, n = 4)
38+
def str_to_seq_indices(seq_strs, padding = '.'):
39+
char_to_index_map = {'a': 0, 'c': 1, 'g': 2, 't': 3, 'n': 4, padding: -1}
4040
seq_strs = map(lambda x: x.lower(), seq_strs)
4141
seq_indices = list(map(lambda seq_str: torch.Tensor(list(map(lambda char: char_to_index_map[char], seq_str))), seq_strs))
4242
return torch.stack(seq_indices).long()
4343

44-
def seq_indices_to_one_hot(t):
44+
def seq_indices_to_one_hot(t, padding = -1):
45+
is_padding = t == padding
46+
t = t.clamp(min = 0)
4547
one_hot = F.one_hot(t, num_classes = 5)
46-
return one_hot[..., :4].float()
48+
out = one_hot[..., :4].float()
49+
out = out.masked_fill(is_padding[..., None], 0.25)
50+
return out
4751

4852
# losses and metrics
4953

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.1.26',
7+
version = '0.1.28',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)