44
55import polars as pl
66import numpy as np
7- from random import randrange
7+ from random import randrange , random
88from pathlib import Path
99from pyfaidx import Fasta
1010
@@ -19,6 +19,9 @@ def identity(t):
1919def cast_list (t ):
2020 return t if isinstance (t , list ) else [t ]
2121
22+ def coin_flip ():
23+ return random () > 0.5
24+
2225# genomic function transforms
2326
2427seq_indices_embed = torch .zeros (256 ).long ()
@@ -86,7 +89,8 @@ def __init__(
8689 fasta_file ,
8790 context_length = None ,
8891 return_seq_indices = False ,
89- shift_augs = None
92+ shift_augs = None ,
93+ rc_aug = False
9094 ):
9195 fasta_file = Path (fasta_file )
9296 assert fasta_file .exists (), 'path to fasta file must exist'
@@ -95,6 +99,7 @@ def __init__(
9599 self .return_seq_indices = return_seq_indices
96100 self .context_length = context_length
97101 self .shift_augs = shift_augs
102+ self .rc_aug = rc_aug
98103
99104 def __call__ (self , chr_name , start , end ):
100105 interval_length = end - start
@@ -134,9 +139,16 @@ def __call__(self, chr_name, start, end):
134139 seq = ('.' * left_padding ) + str (chromosome [start :end ]) + ('.' * right_padding )
135140
136141 if self .return_seq_indices :
142+ assert not self .rc_aug , 'reverse complement augmentation not available yet for seq indices'
143+
137144 return str_to_seq_indices (seq )
138145
139- return str_to_one_hot (seq )
146+ one_hot = str_to_one_hot (seq )
147+
148+ if self .rc_aug and coin_flip ():
149+ one_hot = one_hot_reverse_complement (one_hot )
150+
151+ return one_hot
140152
141153
142154class GenomeIntervalDataset (Dataset ):
@@ -148,7 +160,8 @@ def __init__(
148160 chr_bed_to_fasta_map = dict (),
149161 context_length = None ,
150162 return_seq_indices = False ,
151- shift_augs = None
163+ shift_augs = None ,
164+ rc_aug = False
152165 ):
153166 super ().__init__ ()
154167 bed_path = Path (bed_file )
@@ -166,7 +179,8 @@ def __init__(
166179 fasta_file = fasta_file ,
167180 context_length = context_length ,
168181 return_seq_indices = return_seq_indices ,
169- shift_augs = shift_augs
182+ shift_augs = shift_augs ,
183+ rc_aug = rc_aug
170184 )
171185
172186 def __len__ (self ):
0 commit comments