Skip to content

Commit 99367ef

Browse files
committed
able to return nucleotide indices from dataset
1 parent fb6bd3f commit 99367ef

3 files changed

Lines changed: 10 additions & 4 deletions

File tree

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ ds = GenomeIntervalDataset(
276276
bed_file = './sequences.bed', # bed file
277277
fasta_file = './hg38.ml.fa', # path to fasta file
278278
filter_df_fn = lambda df: df[df.type == 'train'], # filter dataframe function
279+
return_seq_indices = True, # return nucleotide indices (ACGTN) or one hot encodings
279280
context_length = 196_608,
280281
# this can be longer than the interval designated in the .bed file,
281282
# in which case it will take care of lengthening the interval on either sides
@@ -290,7 +291,8 @@ model = Enformer(
290291
target_length = 896,
291292
)
292293

293-
pred = model(ds[0], head = 'human') # (896, 5313)
294+
seq = ds[0] # (196608,)
295+
pred = model(seq, head = 'human') # (896, 5313)
294296
```
295297

296298
## Appreciation

enformer_pytorch/data.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
bed_file,
3939
fasta_file,
4040
context_length = None,
41-
shift_augmentation_range = None,
41+
return_seq_indices = False,
4242
filter_df_fn = identity
4343
):
4444
super().__init__()
@@ -54,7 +54,7 @@ def __init__(
5454
self.df = df
5555
self.seqs = Fasta(str(fasta_file))
5656
self.context_length = context_length
57-
self.shift_augmentation = shift_augmentation
57+
self.return_seq_indices = return_seq_indices
5858

5959
def __len__(self):
6060
return len(self.df)
@@ -88,5 +88,9 @@ def __getitem__(self, ind):
8888

8989
seq = ('.' * left_padding) + str(chromosome[start:end]) + ('.' * right_padding)
9090
seq_indices = str_to_seq_indices(seq)
91+
92+
if self.return_seq_indices:
93+
return seq_indices.squeeze(0)
94+
9195
seq_onehot = seq_indices_to_one_hot(seq_indices)
9296
return seq_onehot.squeeze(0)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.2.0',
7+
version = '0.2.1',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)