Skip to content

Commit 73244fc

Browse files
committed
add GenomicIntervalDataset, for easy data fetching from .bed and .fasta files
1 parent 342915c commit 73244fc

5 files changed

Lines changed: 125 additions & 20 deletions

File tree

README.md

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ You can also directly pass in the sequence as one-hot encodings, which must be f
3535

3636
```python
3737
import torch
38-
import torch.nn.functional as F
3938
from enformer_pytorch import Enformer, seq_indices_to_one_hot
4039

4140
model = Enformer(
@@ -59,7 +58,6 @@ Finally, one can fetch the embeddings, for fine-tuning and otherwise, by setting
5958

6059
```python
6160
import torch
62-
import torch.nn.functional as F
6361
from enformer_pytorch import Enformer, seq_indices_to_one_hot
6462

6563
model = Enformer(
@@ -266,6 +264,34 @@ loss = model(
266264
loss.backward()
267265
```
268266

267+
## Data
268+
269+
You can use the `GenomicIntervalDataset` to easily fetch sequences of any length from a `.bed` file, with greater context length dynamically computed if specified
270+
271+
```python
272+
import torch
273+
from enformer_pytorch import Enformer, GenomeIntervalDataset
274+
275+
ds = GenomeIntervalDataset(
276+
bed_file = './sequences.bed', # bed file
277+
fasta_file = './hg38.fa', # path to fasta file
278+
context_length = 196_608
279+
# this can be longer than the interval designated in the .bed file,
280+
# in which case it will take care of lengthening the interval on either sides
281+
# as well as proper padding if at the end of the chromosomes
282+
)
283+
284+
model = Enformer(
285+
dim = 1536,
286+
depth = 11,
287+
heads = 8,
288+
output_heads = dict(human = 5313, mouse = 1643),
289+
target_length = 896,
290+
)
291+
292+
pred = model(ds[0], head = 'human') # (896, 5313)
293+
```
294+
269295
## Appreciation
270296

271297
Special thanks goes out to <a href="https://www.eleuther.ai/">EleutherAI</a> for providing the resources to retrain the model in an acceptable amount of time

enformer_pytorch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
from enformer_pytorch.enformer_pytorch import Enformer, SEQUENCE_LENGTH, AttentionPool, seq_indices_to_one_hot
1+
from enformer_pytorch.enformer_pytorch import Enformer, SEQUENCE_LENGTH, AttentionPool
22
from enformer_pytorch.model_loader import load_pretrained_model
3+
from enformer_pytorch.data import seq_indices_to_one_hot, GenomeIntervalDataset

enformer_pytorch/data.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
def exists(val):
5+
return val is not None
6+
7+
def identity(t):
8+
return t
9+
10+
def cast_list(t):
11+
return t if isinstance(t, list) else [t]
12+
13+
def str_to_seq_indices(seq_strs, padding = '.'):
14+
seq_strs = cast_list(seq_strs)
15+
char_to_index_map = {'a': 0, 'c': 1, 'g': 2, 't': 3, 'n': 4, padding: -1}
16+
seq_strs = map(lambda x: x.lower(), seq_strs)
17+
seq_indices = list(map(lambda seq_str: torch.Tensor(list(map(lambda char: char_to_index_map[char], seq_str))), seq_strs))
18+
return torch.stack(seq_indices).long()
19+
20+
def seq_indices_to_one_hot(t, padding = -1):
21+
is_padding = t == padding
22+
t = t.clamp(min = 0)
23+
one_hot = F.one_hot(t, num_classes = 5)
24+
out = one_hot[..., :4].float()
25+
out = out.masked_fill(is_padding[..., None], 0.25)
26+
return out
27+
28+
# processing bed files
29+
30+
import pandas as pd
31+
from pathlib import Path
32+
from pyfaidx import Fasta
33+
from torch.utils.data import Dataset
34+
35+
class GenomeIntervalDataset(Dataset):
36+
def __init__(
37+
self,
38+
bed_file,
39+
fasta_file,
40+
context_length = None,
41+
filter_df_fn = identity
42+
):
43+
super().__init__()
44+
bed_path = Path(bed_file)
45+
fasta_file = Path(fasta_file)
46+
47+
assert bed_path.exists(), 'path to .bed file must exist'
48+
assert fasta_file.exists(), 'path to fasta file must exist'
49+
50+
df = pd.read_csv(str(bed_path), sep = '\t', header = None, names = ['chr', 'start', 'end', 'type'])
51+
df = filter_df_fn(df)
52+
53+
self.df = df
54+
self.seqs = Fasta(str(fasta_file))
55+
self.context_length = context_length
56+
57+
def __len__(self):
58+
return len(self.df)
59+
60+
def __getitem__(self, ind):
61+
interval = self.df.iloc[ind]
62+
chr_name, start, end = (interval.chr, interval.start, interval.end)
63+
interval_length = end - start
64+
65+
chromosome = self.seqs[chr_name]
66+
chromosome_length = len(chromosome)
67+
68+
left_padding = right_padding = 0
69+
70+
if exists(self.context_length) and interval_length < self.context_length:
71+
extra_seq = self.context_length - interval_length
72+
73+
extra_left_seq = extra_seq // 2
74+
extra_right_seq = extra_seq - extra_left_seq
75+
76+
start -= extra_left_seq
77+
end += extra_right_seq
78+
79+
if start < 0:
80+
left_padding = -start
81+
start = 0
82+
83+
if end > chromosome_length:
84+
right_padding = end - chromosome_length
85+
end = chromosome_length
86+
87+
seq = ('.' * left_padding) + str(chromosome[start:end]) + ('.' * right_padding)
88+
seq_indices = str_to_seq_indices(seq)
89+
seq_onehot = seq_indices_to_one_hot(seq_indices)
90+
return seq_onehot.squeeze(0)

enformer_pytorch/enformer_pytorch.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from einops import rearrange, reduce
88
from einops.layers.torch import Rearrange
99

10+
from enformer_pytorch.data import str_to_seq_indices, seq_indices_to_one_hot
11+
1012
# constants
1113

1214
SEQUENCE_LENGTH = 196_608
@@ -33,22 +35,6 @@ def _round(x):
3335
def log(t, eps = 1e-20):
3436
return torch.log(t.clamp(min = eps))
3537

36-
# sequence helpers
37-
38-
def str_to_seq_indices(seq_strs, padding = '.'):
39-
char_to_index_map = {'a': 0, 'c': 1, 'g': 2, 't': 3, 'n': 4, padding: -1}
40-
seq_strs = map(lambda x: x.lower(), seq_strs)
41-
seq_indices = list(map(lambda seq_str: torch.Tensor(list(map(lambda char: char_to_index_map[char], seq_str))), seq_strs))
42-
return torch.stack(seq_indices).long()
43-
44-
def seq_indices_to_one_hot(t, padding = -1):
45-
is_padding = t == padding
46-
t = t.clamp(min = 0)
47-
one_hot = F.one_hot(t, num_classes = 5)
48-
out = one_hot[..., :4].float()
49-
out = out.masked_fill(is_padding[..., None], 0.25)
50-
return out
51-
5238
# losses and metrics
5339

5440
def poisson_loss(pred, target):

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.1.29',
7+
version = '0.2.0',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',
@@ -18,6 +18,8 @@
1818
install_requires=[
1919
'einops>=0.3',
2020
'torch>=1.6',
21+
'pandas',
22+
'pyfaidx',
2123
'pyyaml'
2224
],
2325
classifiers=[

0 commit comments

Comments
 (0)