Skip to content

Commit 9756c22

Browse files
committed
add util function for converting genetic sequence string rep to one hot
1 parent c9600d8 commit 9756c22

2 files changed

Lines changed: 13 additions & 8 deletions

File tree

enformer_pytorch/enformer_pytorch.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import math
22
import torch
33
from torch import nn, einsum
4+
import torch.nn.functional as F
45
from torch.utils.checkpoint import checkpoint_sequential
6+
57
from einops import rearrange, reduce
68
from einops.layers.torch import Rearrange
79

8-
import torch.nn.functional as F
9-
1010
# constants
1111

1212
SEQUENCE_LENGTH = 196_608
@@ -33,12 +33,17 @@ def _round(x):
3333
def log(t, eps = 1e-20):
3434
return torch.log(t.clamp(min = eps))
3535

36+
# sequence helpers
37+
38+
def str_to_seq_indices(seq_strs):
39+
char_to_index_map = dict(a = 0, c = 1, g = 2, t = 3, n = 4)
40+
seq_strs = map(lambda x: x.lower(), seq_strs)
41+
seq_indices = list(map(lambda seq_str: torch.Tensor(list(map(lambda char: char_to_index_map[char], seq_str))), seq_strs))
42+
return torch.stack(seq_indices).long()
43+
3644
def seq_indices_to_one_hot(t):
37-
wildcard = t == 4 # the Ns in the sequence
38-
t = t.clamp(max = 3)
39-
one_hot = F.one_hot(t, num_classes = 4)
40-
one_hot = one_hot.masked_fill(wildcard[..., None], 0.)
41-
return one_hot.float()
45+
one_hot = F.one_hot(t, num_classes = 5)
46+
return one_hot[..., :4].float()
4247

4348
# losses and metrics
4449

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.1.25',
7+
version = '0.1.26',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)