Skip to content

Commit 8a1f919

Browse files
committed
allow for fine tuning on new static head
1 parent c49c2bb commit 8a1f919

4 files changed

Lines changed: 73 additions & 5 deletions

File tree

README.md

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,35 @@ model = load_pretrained_model('preview')
135135

136136
## Fine-tuning (wip)
137137

138-
This repository will also allow for easy fine-tuning of Enformer. For starters, the following example shows a single step for finetuning on contextual data (cell type, transcription factor, etc)
138+
This repository will also allow for easy fine-tuning of Enformer.
139+
140+
Fine-tuning on new tracks
141+
142+
```python
143+
import torch
144+
from enformer_pytorch import Enformer
145+
from enformer_pytorch.finetune import HeadAdapterWrapper
146+
147+
enformer = Enformer(
148+
dim = 1536,
149+
depth = 1,
150+
heads = 8,
151+
target_length = 200,
152+
)
153+
154+
model = HeadAdapterWrapper(
155+
enformer = enformer,
156+
num_tracks = 128
157+
).cuda()
158+
159+
seq = torch.randint(0, 4, (1, 196_608 // 2,)).cuda()
160+
target = torch.randn(1, 200, 128).cuda() # 128 tracks
161+
162+
loss = model(seq, target = target)
163+
loss.backward()
164+
```
165+
166+
Finetuning on contextual data (cell type, transcription factor, etc)
139167

140168
```python
141169
import torch
@@ -151,7 +179,6 @@ enformer = Enformer(
151179

152180
model = ContextAdapterWrapper(
153181
enformer = enformer,
154-
enformer_dim = 1536,
155182
context_dim = 1024
156183
).cuda()
157184

enformer_pytorch/enformer_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ def __init__(
265265
pos_dropout = 0.01
266266
):
267267
super().__init__()
268+
self.dim = dim
268269
self.num_alphabet = num_alphabet
269270
half_dim = dim // 2
270271
twice_dim = dim * 2

enformer_pytorch/finetune.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from contextlib import contextmanager
3+
import torch.nn.functional as F
34
from torch import nn, einsum
45
from einops import rearrange
56
from enformer_pytorch.enformer_pytorch import Enformer, poisson_loss
@@ -26,19 +27,56 @@ def freeze_batchnorm_context(model):
2627
for p, requires_grad in zip(bn.parameters(), state['requires_grad']):
2728
p.requires_grad = requires_grad
2829

30+
class HeadAdapterWrapper(nn.Module):
31+
def __init__(
32+
self,
33+
*,
34+
enformer,
35+
num_tracks
36+
):
37+
super().__init__()
38+
assert isinstance(enformer, Enformer)
39+
self.enformer = enformer
40+
41+
self.to_tracks = nn.Sequential(
42+
nn.Linear(enformer.dim * 2, num_tracks),
43+
nn.Softplus()
44+
)
45+
46+
def forward(
47+
self,
48+
seq,
49+
*,
50+
target = None,
51+
freeze_enformer = False
52+
):
53+
enformer_context = freeze_batchnorm_context(self.enformer) if not freeze_enformer else torch.no_grad()
54+
55+
with enformer_context:
56+
_, embeddings = self.enformer(seq, return_embeddings = True)
57+
58+
if freeze_enformer:
59+
embeddings.detach_()
60+
61+
preds = self.to_tracks(embeddings)
62+
63+
if not exists(target):
64+
return preds
65+
66+
return poisson_loss(preds, target)
67+
2968
class ContextAdapterWrapper(nn.Module):
3069
def __init__(
3170
self,
3271
*,
3372
enformer,
34-
enformer_dim,
3573
context_dim
3674
):
3775
super().__init__()
3876
assert isinstance(enformer, Enformer)
3977
self.enformer = enformer
4078

41-
self.to_context_weights = nn.Parameter(torch.randn(context_dim, enformer_dim * 2))
79+
self.to_context_weights = nn.Parameter(torch.randn(context_dim, enformer.dim * 2))
4280
self.to_context_bias = nn.Parameter(torch.randn(context_dim))
4381

4482
def forward(
@@ -62,6 +100,8 @@ def forward(
62100

63101
pred = einsum('b n d, t d -> b n t', embeddings, weights) + bias
64102

103+
pred = F.softplus(pred)
104+
65105
if not exists(target):
66106
return pred
67107

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.4',
6+
version = '0.1.5',
77
license='MIT',
88
description = 'Enformer - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)