11import math
22import torch
33from torch import nn , einsum
4+ import torch .nn .functional as F
45from torch .utils .checkpoint import checkpoint_sequential
6+
57from einops import rearrange , reduce
68from einops .layers .torch import Rearrange
79
8- import torch .nn .functional as F
9-
1010# constants
1111
1212SEQUENCE_LENGTH = 196_608
@@ -33,12 +33,17 @@ def _round(x):
3333def log (t , eps = 1e-20 ):
3434 return torch .log (t .clamp (min = eps ))
3535
36+ # sequence helpers
37+
38+ def str_to_seq_indices (seq_strs ):
39+ char_to_index_map = dict (a = 0 , c = 1 , g = 2 , t = 3 , n = 4 )
40+ seq_strs = map (lambda x : x .lower (), seq_strs )
41+ seq_indices = list (map (lambda seq_str : torch .Tensor (list (map (lambda char : char_to_index_map [char ], seq_str ))), seq_strs ))
42+ return torch .stack (seq_indices ).long ()
43+
3644def seq_indices_to_one_hot (t ):
37- wildcard = t == 4 # the Ns in the sequence
38- t = t .clamp (max = 3 )
39- one_hot = F .one_hot (t , num_classes = 4 )
40- one_hot = one_hot .masked_fill (wildcard [..., None ], 0. )
41- return one_hot .float ()
45+ one_hot = F .one_hot (t , num_classes = 5 )
46+ return one_hot [..., :4 ].float ()
4247
4348# losses and metrics
4449
0 commit comments