Skip to content

Commit eb4e933

Browse files
committed
allow for activation in finetuning head to be customizable, addressing #23
1 parent 4e70710 commit eb4e933

2 files changed

Lines changed: 18 additions & 8 deletions

File tree

enformer_pytorch/finetune.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import torch
2+
from typing import Optional
3+
24
from copy import deepcopy
35
from contextlib import contextmanager
46
import torch.nn.functional as F
@@ -13,6 +15,9 @@
1315
def exists(val):
1416
return val is not None
1517

18+
def default(val, d):
19+
return val if exists(val) else d
20+
1621
@contextmanager
1722
def null_context():
1823
yield
@@ -101,6 +106,7 @@ def __init__(
101106
bottleneck_num_codebooks = 4,
102107
bottleneck_decay = 0.9,
103108
transformer_embed_fn: nn.Module = nn.Identity(),
109+
output_activation: Optional[nn.Module] = nn.Softplus(),
104110
auto_set_target_length = True
105111
):
106112
super().__init__()
@@ -135,9 +141,9 @@ def __init__(
135141
nn.LayerNorm(enformer_hidden_dim) if post_transformer_embed else None
136142
)
137143

138-
self.to_tracks = nn.Sequential(
144+
self.to_tracks = Sequential(
139145
nn.Linear(enformer_hidden_dim, num_tracks),
140-
nn.Softplus()
146+
output_activation
141147
)
142148

143149
def forward(
@@ -179,7 +185,8 @@ def __init__(
179185
bottleneck_num_memories = 256,
180186
bottleneck_num_codebooks = 4,
181187
bottleneck_decay = 0.9,
182-
auto_set_target_length = True
188+
auto_set_target_length = True,
189+
output_activation: Optional[nn.Module] = nn.Softplus()
183190
):
184191
super().__init__()
185192
assert isinstance(enformer, Enformer)
@@ -204,6 +211,8 @@ def __init__(
204211
self.to_context_weights = nn.Parameter(torch.randn(context_dim, enformer_hidden_dim))
205212
self.to_context_bias = nn.Parameter(torch.randn(context_dim))
206213

214+
self.activation = default(output_activation, nn.Identity())
215+
207216
def forward(
208217
self,
209218
seq,
@@ -229,7 +238,7 @@ def forward(
229238

230239
pred = einsum('b n d, t d -> b n t', embeddings, weights) + bias
231240

232-
pred = F.softplus(pred)
241+
pred = self.activation(pred)
233242

234243
if not exists(target):
235244
return pred
@@ -250,7 +259,8 @@ def __init__(
250259
bottleneck_num_memories = 256,
251260
bottleneck_num_codebooks = 4,
252261
bottleneck_decay = 0.9,
253-
auto_set_target_length = True
262+
auto_set_target_length = True,
263+
output_activation: Optional[nn.Module] = None
254264
):
255265
super().__init__()
256266
assert isinstance(enformer, Enformer)
@@ -286,10 +296,10 @@ def __init__(
286296
self.to_key_values = nn.Linear(context_dim, inner_dim * 2, bias = False)
287297
self.to_out = nn.Linear(inner_dim, enformer_hidden_dim)
288298

289-
self.to_pred = nn.Sequential(
299+
self.to_pred = Sequential(
290300
nn.Linear(enformer_hidden_dim, 1),
291301
Rearrange('b c ... 1 -> b ... c'),
292-
nn.Softplus()
302+
output_activation
293303
)
294304

295305
def forward(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.6.4',
7+
version = '0.7.0',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)