@@ -24,7 +24,7 @@ model = Enformer(
2424 target_length = 896 ,
2525)
2626
27- seq = torch.randint(0 , 4 , (1 , 196_608 )) # for ACGT , in that order
27+ seq = torch.randint(0 , 5 , (1 , 196_608 )) # for ACGTN , in that order
2828output = model(seq)
2929
3030output[' human' ] # (1, 896, 5313)
@@ -36,7 +36,7 @@ You can also directly pass in the sequence as one-hot encodings, which must be f
3636``` python
3737import torch
3838import torch.nn.functional as F
39- from enformer_pytorch import Enformer
39+ from enformer_pytorch import Enformer, seq_indices_to_one_hot
4040
4141model = Enformer(
4242 dim = 1536 ,
@@ -46,8 +46,8 @@ model = Enformer(
4646 target_length = 896 ,
4747)
4848
49- seq = torch.randint(0 , 4 , (1 , 196_608 ))
50- one_hot = F.one_hot (seq, num_classes = 4 ).float( )
49+ seq = torch.randint(0 , 5 , (1 , 196_608 ))
50+ one_hot = seq_indices_to_one_hot (seq)
5151
5252output = model(one_hot)
5353
@@ -60,7 +60,7 @@ Finally, one can fetch the embeddings, for fine-tuning and otherwise, by setting
6060``` python
6161import torch
6262import torch.nn.functional as F
63- from enformer_pytorch import Enformer
63+ from enformer_pytorch import Enformer, seq_indices_to_one_hot
6464
6565model = Enformer(
6666 dim = 1536 ,
@@ -70,8 +70,8 @@ model = Enformer(
7070 target_length = 896 ,
7171)
7272
73- seq = torch.randint(0 , 4 , (1 , 196_608 ))
74- one_hot = F.one_hot (seq, num_classes = 4 ).float( )
73+ seq = torch.randint(0 , 5 , (1 , 196_608 ))
74+ one_hot = seq_indices_to_one_hot (seq)
7575
7676output, embeddings = model(one_hot, return_embeddings = True )
7777
@@ -82,7 +82,7 @@ For training, you can directly pass the head and target in to get the poisson lo
8282
8383``` python
8484import torch
85- from enformer_pytorch import Enformer
85+ from enformer_pytorch import Enformer, seq_indices_to_one_hot
8686
8787model = Enformer(
8888 dim = 1536 ,
@@ -92,7 +92,7 @@ model = Enformer(
9292 target_length = 200 ,
9393).cuda()
9494
95- seq = torch.randint(0 , 4 , (196_608 // 2 ,)).cuda()
95+ seq = torch.randint(0 , 5 , (196_608 // 2 ,)).cuda()
9696target = torch.randn(200 , 5313 ).cuda()
9797
9898loss = model(
@@ -188,7 +188,7 @@ model = HeadAdapterWrapper(
188188 num_tracks = 128
189189).cuda()
190190
191- seq = torch.randint(0 , 4 , (1 , 196_608 // 2 ,)).cuda()
191+ seq = torch.randint(0 , 5 , (1 , 196_608 // 2 ,)).cuda()
192192target = torch.randn(1 , 200 , 128 ).cuda() # 128 tracks
193193
194194loss = model(seq, target = target)
@@ -214,7 +214,7 @@ model = ContextAdapterWrapper(
214214 context_dim = 1024
215215).cuda()
216216
217- seq = torch.randint(0 , 4 , (1 , 196_608 // 2 ,)).cuda()
217+ seq = torch.randint(0 , 5 , (1 , 196_608 // 2 ,)).cuda()
218218
219219target = torch.randn(1 , 200 , 4 ).cuda() # 4 tracks
220220context = torch.randn(4 , 1024 ).cuda() # 4 contexts for the different 'tracks'
@@ -249,7 +249,7 @@ model = ContextAttentionAdapterWrapper(
249249 dim_head = 64 # dimension per head
250250).cuda()
251251
252- seq = torch.randint(0 , 4 , (1 , 196_608 // 2 ,)).cuda()
252+ seq = torch.randint(0 , 5 , (1 , 196_608 // 2 ,)).cuda()
253253
254254target = torch.randn(1 , 200 , 4 ).cuda() # 4 tracks
255255context = torch.randn(4 , 16 , 1024 ).cuda() # 4 contexts for the different 'tracks', each with 16 tokens
@@ -279,7 +279,7 @@ Special thanks goes out to <a href="https://www.eleuther.ai/">EleutherAI</a> for
279279- [x] build context manager for fine-tuning with unfrozen enformer but with frozen batchnorm
280280- [x] allow for plain fine-tune with fixed static context
281281- [x] allow for fine tuning with only unfrozen layernorms (technique from fine tuning transformers)
282- - [ ] fix handling of 'N' in sequence, figure out representation of N in basenji barnyard
282+ - [x ] fix handling of 'N' in sequence, figure out representation of N in basenji barnyard
283283- [ ] add to EleutherAI huggingface
284284
285285## Citations
0 commit comments