Skip to content

Commit fdbb114

Browse files
committed
allow for reverse complement augmentation from FastaInterval
1 parent 8df251c commit fdbb114

2 files changed

Lines changed: 20 additions & 6 deletions

File tree

enformer_pytorch/data.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import polars as pl
66
import numpy as np
7-
from random import randrange
7+
from random import randrange, random
88
from pathlib import Path
99
from pyfaidx import Fasta
1010

@@ -19,6 +19,9 @@ def identity(t):
1919
def cast_list(t):
2020
return t if isinstance(t, list) else [t]
2121

22+
def coin_flip():
23+
return random() > 0.5
24+
2225
# genomic function transforms
2326

2427
seq_indices_embed = torch.zeros(256).long()
@@ -86,7 +89,8 @@ def __init__(
8689
fasta_file,
8790
context_length = None,
8891
return_seq_indices = False,
89-
shift_augs = None
92+
shift_augs = None,
93+
rc_aug = False
9094
):
9195
fasta_file = Path(fasta_file)
9296
assert fasta_file.exists(), 'path to fasta file must exist'
@@ -95,6 +99,7 @@ def __init__(
9599
self.return_seq_indices = return_seq_indices
96100
self.context_length = context_length
97101
self.shift_augs = shift_augs
102+
self.rc_aug = rc_aug
98103

99104
def __call__(self, chr_name, start, end):
100105
interval_length = end - start
@@ -134,9 +139,16 @@ def __call__(self, chr_name, start, end):
134139
seq = ('.' * left_padding) + str(chromosome[start:end]) + ('.' * right_padding)
135140

136141
if self.return_seq_indices:
142+
assert not self.rc_aug, 'reverse complement augmentation not available yet for seq indices'
143+
137144
return str_to_seq_indices(seq)
138145

139-
return str_to_one_hot(seq)
146+
one_hot = str_to_one_hot(seq)
147+
148+
if self.rc_aug and coin_flip():
149+
one_hot = one_hot_reverse_complement(one_hot)
150+
151+
return one_hot
140152

141153

142154
class GenomeIntervalDataset(Dataset):
@@ -148,7 +160,8 @@ def __init__(
148160
chr_bed_to_fasta_map = dict(),
149161
context_length = None,
150162
return_seq_indices = False,
151-
shift_augs = None
163+
shift_augs = None,
164+
rc_aug = False
152165
):
153166
super().__init__()
154167
bed_path = Path(bed_file)
@@ -166,7 +179,8 @@ def __init__(
166179
fasta_file = fasta_file,
167180
context_length = context_length,
168181
return_seq_indices = return_seq_indices,
169-
shift_augs = shift_augs
182+
shift_augs = shift_augs,
183+
rc_aug = rc_aug
170184
)
171185

172186
def __len__(self):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.3.0',
7+
version = '0.3.1',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)