Skip to content

Commit 9440265

Browse files
committed
fix handling of N in genetic sequence
1 parent 4d014cd commit 9440265

4 files changed

Lines changed: 23 additions & 18 deletions

File tree

README.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ model = Enformer(
2424
target_length = 896,
2525
)
2626

27-
seq = torch.randint(0, 4, (1, 196_608)) # for ACGT, in that order
27+
seq = torch.randint(0, 5, (1, 196_608)) # for ACGTN, in that order
2828
output = model(seq)
2929

3030
output['human'] # (1, 896, 5313)
@@ -36,7 +36,7 @@ You can also directly pass in the sequence as one-hot encodings, which must be f
3636
```python
3737
import torch
3838
import torch.nn.functional as F
39-
from enformer_pytorch import Enformer
39+
from enformer_pytorch import Enformer, seq_indices_to_one_hot
4040

4141
model = Enformer(
4242
dim = 1536,
@@ -46,8 +46,8 @@ model = Enformer(
4646
target_length = 896,
4747
)
4848

49-
seq = torch.randint(0, 4, (1, 196_608))
50-
one_hot = F.one_hot(seq, num_classes = 4).float()
49+
seq = torch.randint(0, 5, (1, 196_608))
50+
one_hot = seq_indices_to_one_hot(seq)
5151

5252
output = model(one_hot)
5353

@@ -60,7 +60,7 @@ Finally, one can fetch the embeddings, for fine-tuning and otherwise, by setting
6060
```python
6161
import torch
6262
import torch.nn.functional as F
63-
from enformer_pytorch import Enformer
63+
from enformer_pytorch import Enformer, seq_indices_to_one_hot
6464

6565
model = Enformer(
6666
dim = 1536,
@@ -70,8 +70,8 @@ model = Enformer(
7070
target_length = 896,
7171
)
7272

73-
seq = torch.randint(0, 4, (1, 196_608))
74-
one_hot = F.one_hot(seq, num_classes = 4).float()
73+
seq = torch.randint(0, 5, (1, 196_608))
74+
one_hot = seq_indices_to_one_hot(seq)
7575

7676
output, embeddings = model(one_hot, return_embeddings = True)
7777

@@ -82,7 +82,7 @@ For training, you can directly pass the head and target in to get the poisson lo
8282

8383
```python
8484
import torch
85-
from enformer_pytorch import Enformer
85+
from enformer_pytorch import Enformer, seq_indices_to_one_hot
8686

8787
model = Enformer(
8888
dim = 1536,
@@ -92,7 +92,7 @@ model = Enformer(
9292
target_length = 200,
9393
).cuda()
9494

95-
seq = torch.randint(0, 4, (196_608 // 2,)).cuda()
95+
seq = torch.randint(0, 5, (196_608 // 2,)).cuda()
9696
target = torch.randn(200, 5313).cuda()
9797

9898
loss = model(
@@ -188,7 +188,7 @@ model = HeadAdapterWrapper(
188188
num_tracks = 128
189189
).cuda()
190190

191-
seq = torch.randint(0, 4, (1, 196_608 // 2,)).cuda()
191+
seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda()
192192
target = torch.randn(1, 200, 128).cuda() # 128 tracks
193193

194194
loss = model(seq, target = target)
@@ -214,7 +214,7 @@ model = ContextAdapterWrapper(
214214
context_dim = 1024
215215
).cuda()
216216

217-
seq = torch.randint(0, 4, (1, 196_608 // 2,)).cuda()
217+
seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda()
218218

219219
target = torch.randn(1, 200, 4).cuda() # 4 tracks
220220
context = torch.randn(4, 1024).cuda() # 4 contexts for the different 'tracks'
@@ -249,7 +249,7 @@ model = ContextAttentionAdapterWrapper(
249249
dim_head = 64 # dimension per head
250250
).cuda()
251251

252-
seq = torch.randint(0, 4, (1, 196_608 // 2,)).cuda()
252+
seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda()
253253

254254
target = torch.randn(1, 200, 4).cuda() # 4 tracks
255255
context = torch.randn(4, 16, 1024).cuda() # 4 contexts for the different 'tracks', each with 16 tokens

enformer_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from enformer_pytorch.enformer_pytorch import Enformer, SEQUENCE_LENGTH, AttentionPool
1+
from enformer_pytorch.enformer_pytorch import Enformer, SEQUENCE_LENGTH, AttentionPool, seq_indices_to_one_hot
22
from enformer_pytorch.model_loader import load_pretrained_model

enformer_pytorch/enformer_pytorch.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ def _round(x):
3333
def log(t, eps = 1e-20):
3434
return torch.log(t.clamp(min = eps))
3535

36+
def seq_indices_to_one_hot(t):
37+
wildcard = t == 4 # the Ns in the sequence
38+
t = t.clamp(max = 3)
39+
one_hot = F.one_hot(t, num_classes = 4)
40+
one_hot = one_hot.masked_fill(wildcard[..., None], 0.)
41+
return one_hot.float()
42+
3643
# losses and metrics
3744

3845
def poisson_loss(pred, target):
@@ -259,7 +266,6 @@ def __init__(
259266
heads = 8,
260267
output_heads = dict(human = 5313, mouse= 1643),
261268
target_length = TARGET_LENGTH,
262-
num_alphabet = 4,
263269
attn_dim_key = 64,
264270
dropout_rate = 0.4,
265271
attn_dropout = 0.05,
@@ -268,15 +274,14 @@ def __init__(
268274
):
269275
super().__init__()
270276
self.dim = dim
271-
self.num_alphabet = num_alphabet
272277
half_dim = dim // 2
273278
twice_dim = dim * 2
274279

275280
# create stem
276281

277282
self.stem = nn.Sequential(
278283
Rearrange('b n d -> b d n'),
279-
nn.Conv1d(num_alphabet, half_dim, 15, padding = 7),
284+
nn.Conv1d(4, half_dim, 15, padding = 7),
280285
Residual(ConvBlock(half_dim)),
281286
AttentionPool(half_dim, pool_size = 2)
282287
)
@@ -402,7 +407,7 @@ def forward(
402407
dtype = x.dtype
403408

404409
if x.dtype == torch.long:
405-
x = F.one_hot(x, num_classes = self.num_alphabet).float()
410+
x = seq_indices_to_one_hot(x)
406411

407412
no_batch = x.ndim == 2
408413

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.1.24',
7+
version = '0.1.25',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)