Skip to content

Commit 7ada950

Browse files
committed
add ability to use the embeddings from the output of the transformer block, and not from after the final pointwise, use layernorm if taking embeddings from post-transformer
1 parent a8d8d2f commit 7ada950

3 files changed

Lines changed: 24 additions & 3 deletions

File tree

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ enformer = Enformer.from_hparams(
180180
181181
model = HeadAdapterWrapper(
182182
enformer = enformer,
183-
num_tracks = 128
183+
num_tracks = 128,
184+
post_transformer_embed = False # by default, embeddings are taken from after the final pointwise block w/ conv -> gelu - but if you'd like the embeddings right after the transformer block with a learned layernorm, set this to True
184185
).cuda()
185186

186187
seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda()

enformer_pytorch/finetune.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from copy import deepcopy
23
from contextlib import contextmanager
34
import torch.nn.functional as F
45
from torch import nn, einsum
@@ -16,6 +17,11 @@ def exists(val):
1617
def null_context():
1718
yield
1819

20+
# better sequential
21+
22+
def Sequential(*modules):
23+
return nn.Sequential(*filter(exists, modules))
24+
1925
# controlling freezing of layers
2026

2127
def set_module_requires_grad_(module, requires_grad):
@@ -88,14 +94,16 @@ def __init__(
8894
*,
8995
enformer,
9096
num_tracks,
97+
post_transformer_embed = False, # whether to take the embeddings from right after the transformer, instead of after the final pointwise convolutional - this would add another layernorm
9198
discrete_key_value_bottleneck = False,
9299
bottleneck_num_memories = 256,
93100
bottleneck_num_codebooks = 4,
94101
bottleneck_decay = 0.9,
102+
transformer_embed_fn: nn.Module = nn.Identity()
95103
):
96104
super().__init__()
97105
assert isinstance(enformer, Enformer)
98-
enformer_hidden_dim = enformer.dim * 2
106+
enformer_hidden_dim = enformer.dim * (2 if not post_transformer_embed else 1)
99107

100108
self.discrete_key_value_bottleneck = discrete_key_value_bottleneck
101109

@@ -109,8 +117,20 @@ def __init__(
109117
decay = bottleneck_decay,
110118
)
111119

120+
self.post_transformer_embed = post_transformer_embed
121+
112122
self.enformer = enformer
113123

124+
if post_transformer_embed:
125+
self.enformer = deepcopy(enformer)
126+
self.enformer._trunk[-1] = nn.Identity()
127+
self.enformer.final_pointwise = nn.Identity()
128+
129+
self.post_embed_transform = Sequential(
130+
transformer_embed_fn,
131+
nn.LayerNorm(enformer_hidden_dim) if post_transformer_embed else None
132+
)
133+
114134
self.to_tracks = nn.Sequential(
115135
nn.Linear(enformer_hidden_dim, num_tracks),
116136
nn.Softplus()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.5.7',
7+
version = '0.6.0',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)