11import torch
22import torch .nn .functional as F
3+ from torch .utils .data import Dataset
4+
5+ import polars as pl
6+ import numpy as np
7+ from random import randrange
8+ from pathlib import Path
9+ from pyfaidx import Fasta
10+
11+ # helper functions
312
413def exists (val ):
514 return val is not None
@@ -10,12 +19,47 @@ def identity(t):
1019def cast_list (t ):
1120 return t if isinstance (t , list ) else [t ]
1221
13- def str_to_seq_indices (seq_strs , padding = '.' ):
22+ # genomic function transforms
23+
24+ seq_indices_embed = torch .zeros (256 ).long ()
25+ seq_indices_embed [ord ('a' )] = 0
26+ seq_indices_embed [ord ('c' )] = 1
27+ seq_indices_embed [ord ('g' )] = 2
28+ seq_indices_embed [ord ('t' )] = 3
29+ seq_indices_embed [ord ('n' )] = 4
30+ seq_indices_embed [ord ('A' )] = 0
31+ seq_indices_embed [ord ('C' )] = 1
32+ seq_indices_embed [ord ('G' )] = 2
33+ seq_indices_embed [ord ('T' )] = 3
34+ seq_indices_embed [ord ('N' )] = 4
35+ seq_indices_embed [ord ('.' )] = - 1
36+
37+ one_hot_embed = torch .zeros (256 , 4 )
38+ one_hot_embed [ord ('a' )] = torch .Tensor ([1. , 0. , 0. , 0. ])
39+ one_hot_embed [ord ('c' )] = torch .Tensor ([0. , 1. , 0. , 0. ])
40+ one_hot_embed [ord ('g' )] = torch .Tensor ([0. , 0. , 1. , 0. ])
41+ one_hot_embed [ord ('t' )] = torch .Tensor ([0. , 0. , 0. , 1. ])
42+ one_hot_embed [ord ('n' )] = torch .Tensor ([0. , 0. , 0. , 0. ])
43+ one_hot_embed [ord ('A' )] = torch .Tensor ([1. , 0. , 0. , 0. ])
44+ one_hot_embed [ord ('C' )] = torch .Tensor ([0. , 1. , 0. , 0. ])
45+ one_hot_embed [ord ('G' )] = torch .Tensor ([0. , 0. , 1. , 0. ])
46+ one_hot_embed [ord ('T' )] = torch .Tensor ([0. , 0. , 0. , 1. ])
47+ one_hot_embed [ord ('N' )] = torch .Tensor ([0. , 0. , 0. , 0. ])
48+ one_hot_embed [ord ('.' )] = torch .Tensor ([0.25 , 0.25 , 0.25 , 0.25 ])
49+
50+ def torch_fromstring (seq_strs ):
1451 seq_strs = cast_list (seq_strs )
15- char_to_index_map = {'a' : 0 , 'c' : 1 , 'g' : 2 , 't' : 3 , 'n' : 4 , padding : - 1 }
16- seq_strs = map (lambda x : x .lower (), seq_strs )
17- seq_indices = list (map (lambda seq_str : torch .Tensor (list (map (lambda char : char_to_index_map [char ], seq_str ))), seq_strs ))
18- return torch .stack (seq_indices ).long ()
52+ np_seq_chrs = list (map (lambda t : np .fromstring (t , dtype = np .uint8 ), seq_strs ))
53+ seq_chrs = list (map (torch .from_numpy , np_seq_chrs ))
54+ return torch .stack (seq_chrs )
55+
56+ def str_to_seq_indices (seq_strs ):
57+ seq_chrs = torch_fromstring (seq_strs )
58+ return seq_indices_embed [seq_chrs .long ()]
59+
60+ def str_to_one_hot (seq_strs ):
61+ seq_chrs = torch_fromstring (seq_strs )
62+ return one_hot_embed [seq_chrs .long ()]
1963
2064def seq_indices_to_one_hot (t , padding = - 1 ):
2165 is_padding = t == padding
@@ -27,12 +71,6 @@ def seq_indices_to_one_hot(t, padding = -1):
2771
2872# processing bed files
2973
30- import polars as pl
31- from random import randrange
32- from pathlib import Path
33- from pyfaidx import Fasta
34- from torch .utils .data import Dataset
35-
3674class GenomeIntervalDataset (Dataset ):
3775 def __init__ (
3876 self ,
@@ -112,10 +150,8 @@ def __getitem__(self, ind):
112150 end = chromosome_length
113151
114152 seq = ('.' * left_padding ) + str (chromosome [start :end ]) + ('.' * right_padding )
115- seq_indices = str_to_seq_indices (seq )
116153
117154 if self .return_seq_indices :
118- return seq_indices .squeeze (0 )
155+ return str_to_seq_indices ( seq ) .squeeze (0 )
119156
120- seq_onehot = seq_indices_to_one_hot (seq_indices )
121- return seq_onehot .squeeze (0 )
157+ return str_to_one_hot (seq ).squeeze (0 )
0 commit comments