Skip to content

Commit 04735c5

Browse files
committed
add prenorm residual to attention aggregation fine tuning adapter
1 parent 55acb63 commit 04735c5

2 files changed

Lines changed: 24 additions & 11 deletions

File tree

enformer_pytorch/finetune.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,20 +122,25 @@ def __init__(
122122
super().__init__()
123123
assert isinstance(enformer, Enformer)
124124
self.enformer = enformer
125+
enformer_hidden_dim = enformer.dim * 2
126+
127+
self.query_norm = nn.LayerNorm(enformer_hidden_dim)
128+
self.key_values_norm = nn.LayerNorm(context_dim)
125129

126130
self.scale = dim_head ** -0.5
127131
self.heads = heads
128132
inner_dim = heads * dim_head
129-
self.to_queries = nn.Linear(enformer.dim * 2, inner_dim)
133+
self.to_queries = nn.Linear(enformer_hidden_dim, inner_dim)
130134

131135
self.null_key = nn.Parameter(torch.randn(inner_dim))
132136
self.null_value = nn.Parameter(torch.randn(inner_dim))
133137

134138
self.to_key_values = nn.Linear(context_dim, inner_dim * 2, bias = False)
139+
self.to_out = nn.Linear(inner_dim, enformer_hidden_dim)
135140

136-
self.to_out = nn.Sequential(
137-
nn.Linear(inner_dim, 1),
138-
Rearrange('c ... 1 -> ... c'),
141+
self.to_pred = nn.Sequential(
142+
nn.Linear(enformer_hidden_dim, 1),
143+
Rearrange('b c ... 1 -> b ... c'),
139144
nn.Softplus()
140145
)
141146

@@ -155,8 +160,8 @@ def forward(
155160
if context.ndim == 2:
156161
context = rearrange(context, 'b d -> b 1 d')
157162

158-
q = self.to_queries(embeddings)
159-
k, v = self.to_key_values(context).chunk(2, dim = -1)
163+
q = self.to_queries(self.query_norm(embeddings))
164+
k, v = self.to_key_values(self.key_values_norm(context)).chunk(2, dim = -1)
160165

161166
null_k, null_v = map(lambda t: repeat(t, 'd -> b 1 d', b = context.shape[0]), (self.null_key, self.null_value))
162167

@@ -174,13 +179,21 @@ def forward(
174179

175180
# aggregate
176181

177-
out = einsum('b c h i j, c h j d -> c h i d', attn, v)
182+
out = einsum('b c h i j, c h j d -> b c h i d', attn, v)
183+
184+
out = rearrange(out, 'b c h n d -> b c n (h d)', h = h)
185+
186+
# combine heads
187+
188+
branch_out = self.to_out(out)
189+
190+
# residual
178191

179-
out = rearrange(out, 'c h n d -> c n (h d)', h = h)
192+
embeddings = embeddings + branch_out
180193

181-
# combine heads and project / softplus
194+
# to prediction
182195

183-
pred = self.to_out(out)
196+
pred = self.to_pred(embeddings)
184197

185198
if not exists(target):
186199
return pred

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.1.14',
7+
version = '0.1.15',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)