Skip to content

Commit d73f17d

Browse files
committed
maximize performance with embedding lookup for genomic string to one hot / seq indices
1 parent f90b082 commit d73f17d

4 files changed

Lines changed: 57 additions & 22 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ Special thanks goes out to <a href="https://www.eleuther.ai/">EleutherAI</a> for
322322
- [x] allow for fine tuning with only unfrozen layernorms (technique from fine tuning transformers)
323323
- [x] fix handling of 'N' in sequence, figure out representation of N in basenji barnyard
324324
- [x] take care of shift augmentation in `GenomicIntervalDataset`
325-
- [ ] speed up `str_to_seq_indices` using https://github.com/lucidrains/enformer-tensorflow-sonnet-training-script/blob/main/sequence.py#L12-L27
325+
- [x] speed up `str_to_seq_indices`
326326
- [ ] offer some basic training utils, as gradient accumulation will be needed for fine tuning
327327
- [ ] add to EleutherAI huggingface
328328

enformer_pytorch/data.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
import torch
22
import torch.nn.functional as F
3+
from torch.utils.data import Dataset
4+
5+
import polars as pl
6+
import numpy as np
7+
from random import randrange
8+
from pathlib import Path
9+
from pyfaidx import Fasta
10+
11+
# helper functions
312

413
def exists(val):
514
return val is not None
@@ -10,12 +19,47 @@ def identity(t):
1019
def cast_list(t):
1120
return t if isinstance(t, list) else [t]
1221

13-
def str_to_seq_indices(seq_strs, padding = '.'):
22+
# genomic function transforms
23+
24+
seq_indices_embed = torch.zeros(256).long()
25+
seq_indices_embed[ord('a')] = 0
26+
seq_indices_embed[ord('c')] = 1
27+
seq_indices_embed[ord('g')] = 2
28+
seq_indices_embed[ord('t')] = 3
29+
seq_indices_embed[ord('n')] = 4
30+
seq_indices_embed[ord('A')] = 0
31+
seq_indices_embed[ord('C')] = 1
32+
seq_indices_embed[ord('G')] = 2
33+
seq_indices_embed[ord('T')] = 3
34+
seq_indices_embed[ord('N')] = 4
35+
seq_indices_embed[ord('.')] = -1
36+
37+
one_hot_embed = torch.zeros(256, 4)
38+
one_hot_embed[ord('a')] = torch.Tensor([1., 0., 0., 0.])
39+
one_hot_embed[ord('c')] = torch.Tensor([0., 1., 0., 0.])
40+
one_hot_embed[ord('g')] = torch.Tensor([0., 0., 1., 0.])
41+
one_hot_embed[ord('t')] = torch.Tensor([0., 0., 0., 1.])
42+
one_hot_embed[ord('n')] = torch.Tensor([0., 0., 0., 0.])
43+
one_hot_embed[ord('A')] = torch.Tensor([1., 0., 0., 0.])
44+
one_hot_embed[ord('C')] = torch.Tensor([0., 1., 0., 0.])
45+
one_hot_embed[ord('G')] = torch.Tensor([0., 0., 1., 0.])
46+
one_hot_embed[ord('T')] = torch.Tensor([0., 0., 0., 1.])
47+
one_hot_embed[ord('N')] = torch.Tensor([0., 0., 0., 0.])
48+
one_hot_embed[ord('.')] = torch.Tensor([0.25, 0.25, 0.25, 0.25])
49+
50+
def torch_fromstring(seq_strs):
1451
seq_strs = cast_list(seq_strs)
15-
char_to_index_map = {'a': 0, 'c': 1, 'g': 2, 't': 3, 'n': 4, padding: -1}
16-
seq_strs = map(lambda x: x.lower(), seq_strs)
17-
seq_indices = list(map(lambda seq_str: torch.Tensor(list(map(lambda char: char_to_index_map[char], seq_str))), seq_strs))
18-
return torch.stack(seq_indices).long()
52+
np_seq_chrs = list(map(lambda t: np.fromstring(t, dtype = np.uint8), seq_strs))
53+
seq_chrs = list(map(torch.from_numpy, np_seq_chrs))
54+
return torch.stack(seq_chrs)
55+
56+
def str_to_seq_indices(seq_strs):
57+
seq_chrs = torch_fromstring(seq_strs)
58+
return seq_indices_embed[seq_chrs.long()]
59+
60+
def str_to_one_hot(seq_strs):
61+
seq_chrs = torch_fromstring(seq_strs)
62+
return one_hot_embed[seq_chrs.long()]
1963

2064
def seq_indices_to_one_hot(t, padding = -1):
2165
is_padding = t == padding
@@ -27,12 +71,6 @@ def seq_indices_to_one_hot(t, padding = -1):
2771

2872
# processing bed files
2973

30-
import polars as pl
31-
from random import randrange
32-
from pathlib import Path
33-
from pyfaidx import Fasta
34-
from torch.utils.data import Dataset
35-
3674
class GenomeIntervalDataset(Dataset):
3775
def __init__(
3876
self,
@@ -112,10 +150,8 @@ def __getitem__(self, ind):
112150
end = chromosome_length
113151

114152
seq = ('.' * left_padding) + str(chromosome[start:end]) + ('.' * right_padding)
115-
seq_indices = str_to_seq_indices(seq)
116153

117154
if self.return_seq_indices:
118-
return seq_indices.squeeze(0)
155+
return str_to_seq_indices(seq).squeeze(0)
119156

120-
seq_onehot = seq_indices_to_one_hot(seq_indices)
121-
return seq_onehot.squeeze(0)
157+
return str_to_one_hot(seq).squeeze(0)

enformer_pytorch/enformer_pytorch.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from einops import rearrange, reduce
88
from einops.layers.torch import Rearrange
99

10-
from enformer_pytorch.data import str_to_seq_indices, seq_indices_to_one_hot
10+
from enformer_pytorch.data import str_to_one_hot, seq_indices_to_one_hot
1111

1212
# constants
1313

@@ -400,11 +400,9 @@ def forward(
400400
head = None
401401
):
402402
if isinstance(x, list):
403-
x = str_to_seq_indices(x)
403+
x = str_to_one_hot(x)
404404

405-
dtype = x.dtype
406-
407-
if x.dtype == torch.long:
405+
elif x.dtype == torch.long:
408406
x = seq_indices_to_one_hot(x)
409407

410408
no_batch = x.ndim == 2

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.2.7',
7+
version = '0.2.8',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',
@@ -17,6 +17,7 @@
1717
],
1818
install_requires=[
1919
'einops>=0.3',
20+
'numpy',
2021
'torch>=1.6',
2122
'polars',
2223
'pyfaidx',

0 commit comments

Comments
 (0)