11import torch
2+ from typing import Optional
3+
24from copy import deepcopy
35from contextlib import contextmanager
46import torch .nn .functional as F
1315def exists (val ):
1416 return val is not None
1517
18+ def default (val , d ):
19+ return val if exists (val ) else d
20+
1621@contextmanager
1722def null_context ():
1823 yield
@@ -101,6 +106,7 @@ def __init__(
101106 bottleneck_num_codebooks = 4 ,
102107 bottleneck_decay = 0.9 ,
103108 transformer_embed_fn : nn .Module = nn .Identity (),
109+ output_activation : Optional [nn .Module ] = nn .Softplus (),
104110 auto_set_target_length = True
105111 ):
106112 super ().__init__ ()
@@ -135,9 +141,9 @@ def __init__(
135141 nn .LayerNorm (enformer_hidden_dim ) if post_transformer_embed else None
136142 )
137143
138- self .to_tracks = nn . Sequential (
144+ self .to_tracks = Sequential (
139145 nn .Linear (enformer_hidden_dim , num_tracks ),
140- nn . Softplus ()
146+ output_activation
141147 )
142148
143149 def forward (
@@ -179,7 +185,8 @@ def __init__(
179185 bottleneck_num_memories = 256 ,
180186 bottleneck_num_codebooks = 4 ,
181187 bottleneck_decay = 0.9 ,
182- auto_set_target_length = True
188+ auto_set_target_length = True ,
189+ output_activation : Optional [nn .Module ] = nn .Softplus ()
183190 ):
184191 super ().__init__ ()
185192 assert isinstance (enformer , Enformer )
@@ -204,6 +211,8 @@ def __init__(
204211 self .to_context_weights = nn .Parameter (torch .randn (context_dim , enformer_hidden_dim ))
205212 self .to_context_bias = nn .Parameter (torch .randn (context_dim ))
206213
214+ self .activation = default (output_activation , nn .Identity ())
215+
207216 def forward (
208217 self ,
209218 seq ,
@@ -229,7 +238,7 @@ def forward(
229238
230239 pred = einsum ('b n d, t d -> b n t' , embeddings , weights ) + bias
231240
232- pred = F . softplus (pred )
241+ pred = self . activation (pred )
233242
234243 if not exists (target ):
235244 return pred
@@ -250,7 +259,8 @@ def __init__(
250259 bottleneck_num_memories = 256 ,
251260 bottleneck_num_codebooks = 4 ,
252261 bottleneck_decay = 0.9 ,
253- auto_set_target_length = True
262+ auto_set_target_length = True ,
263+ output_activation : Optional [nn .Module ] = None
254264 ):
255265 super ().__init__ ()
256266 assert isinstance (enformer , Enformer )
@@ -286,10 +296,10 @@ def __init__(
286296 self .to_key_values = nn .Linear (context_dim , inner_dim * 2 , bias = False )
287297 self .to_out = nn .Linear (inner_dim , enformer_hidden_dim )
288298
289- self .to_pred = nn . Sequential (
299+ self .to_pred = Sequential (
290300 nn .Linear (enformer_hidden_dim , 1 ),
291301 Rearrange ('b c ... 1 -> b ... c' ),
292- nn . Softplus ()
302+ output_activation
293303 )
294304
295305 def forward (
0 commit comments