Skip to content

Commit ec5b3c6

Browse files
committed
separate fasta interval logic from bed file parsing
1 parent 1c8e733 commit ec5b3c6

3 files changed

Lines changed: 47 additions & 33 deletions

File tree

enformer_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from enformer_pytorch.enformer_pytorch import Enformer, SEQUENCE_LENGTH, AttentionPool
22
from enformer_pytorch.model_loader import load_pretrained_model
3-
from enformer_pytorch.data import seq_indices_to_one_hot, GenomeIntervalDataset
3+
from enformer_pytorch.data import seq_indices_to_one_hot, str_to_one_hot, GenomeIntervalDataset, FastaInterval

enformer_pytorch/data.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -72,51 +72,25 @@ def seq_indices_to_one_hot(t, padding = -1):
7272

7373
# processing bed files
7474

75-
class GenomeIntervalDataset(Dataset):
75+
class FastaInterval():
7676
def __init__(
7777
self,
78-
bed_file,
78+
*,
7979
fasta_file,
8080
context_length = None,
8181
return_seq_indices = False,
82-
filter_df_fn = identity,
83-
shift_augs = None,
84-
chr_bed_to_fasta_map = dict()
82+
shift_augs = None
8583
):
86-
super().__init__()
87-
bed_path = Path(bed_file)
8884
fasta_file = Path(fasta_file)
89-
90-
assert bed_path.exists(), 'path to .bed file must exist'
9185
assert fasta_file.exists(), 'path to fasta file must exist'
9286

93-
df = pl.read_csv(str(bed_path), sep = '\t', has_headers = False)
94-
df = filter_df_fn(df)
95-
96-
self.df = df
9787
self.seqs = Fasta(str(fasta_file))
98-
self.context_length = context_length
9988
self.return_seq_indices = return_seq_indices
100-
101-
if exists(shift_augs):
102-
assert len(shift_augs) == 2, 'shift augs needs to be a tuple of 2, indicating min and max relative shifts inclusive - ex. (-2, 2) for [-2, -1, 0, 1, 2]'
103-
89+
self.context_length = context_length
10490
self.shift_augs = shift_augs
10591

106-
# if the chromosome name in the bed file is different than the keyname in the fasta
107-
# can remap on the fly
108-
self.chr_bed_to_fasta_map = chr_bed_to_fasta_map
109-
110-
def __len__(self):
111-
return len(self.df)
112-
113-
def __getitem__(self, ind):
114-
interval = self.df.row(ind)
115-
chr_name, start, end = (interval[0], interval[1], interval[2])
92+
def __call__(self, chr_name, start, end):
11693
interval_length = end - start
117-
118-
chr_name = self.chr_bed_to_fasta_map.get(chr_name, chr_name)
119-
12094
chromosome = self.seqs[chr_name]
12195
chromosome_length = len(chromosome)
12296

@@ -156,3 +130,43 @@ def __getitem__(self, ind):
156130
return str_to_seq_indices(seq)
157131

158132
return str_to_one_hot(seq)
133+
134+
135+
class GenomeIntervalDataset(Dataset):
136+
def __init__(
137+
self,
138+
bed_file,
139+
fasta_file,
140+
filter_df_fn = identity,
141+
chr_bed_to_fasta_map = dict(),
142+
context_length = None,
143+
return_seq_indices = False,
144+
shift_augs = None
145+
):
146+
super().__init__()
147+
bed_path = Path(bed_file)
148+
assert bed_path.exists(), 'path to .bed file must exist'
149+
150+
df = pl.read_csv(str(bed_path), sep = '\t', has_headers = False)
151+
df = filter_df_fn(df)
152+
self.df = df
153+
154+
# if the chromosome name in the bed file is different than the keyname in the fasta
155+
# can remap on the fly
156+
self.chr_bed_to_fasta_map = chr_bed_to_fasta_map
157+
158+
self.fasta = FastaInterval(
159+
fasta_file = fasta_file,
160+
context_length = context_length,
161+
return_seq_indices = return_seq_indices,
162+
shift_augs = shift_augs
163+
)
164+
165+
def __len__(self):
166+
return len(self.df)
167+
168+
def __getitem__(self, ind):
169+
interval = self.df.row(ind)
170+
chr_name, start, end = (interval[0], interval[1], interval[2])
171+
chr_name = self.chr_bed_to_fasta_map.get(chr_name, chr_name)
172+
return self.fasta(chr_name, start, end)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.2.10',
7+
version = '0.2.11',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)