Skip to content

Commit fc70d51

Browse files
committed
make conformer able to do things autoregressively, to save issues with variable lengths in soundstorm
1 parent a37a2ad commit fc70d51

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

conformer/conformer.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def __init__(
149149
causal = False,
150150
expansion_factor = 2,
151151
kernel_size = 31,
152-
dropout = 0.):
152+
dropout = 0.
153+
):
153154
super().__init__()
154155

155156
inner_dim = dim * expansion_factor
@@ -185,12 +186,13 @@ def __init__(
185186
conv_kernel_size = 31,
186187
attn_dropout = 0.,
187188
ff_dropout = 0.,
188-
conv_dropout = 0.
189+
conv_dropout = 0.,
190+
conv_causal = False
189191
):
190192
super().__init__()
191193
self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
192194
self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
193-
self.conv = ConformerConvModule(dim = dim, causal = False, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
195+
self.conv = ConformerConvModule(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
194196
self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
195197

196198
self.attn = PreNorm(dim, self.attn)
@@ -222,7 +224,8 @@ def __init__(
222224
conv_kernel_size = 31,
223225
attn_dropout = 0.,
224226
ff_dropout = 0.,
225-
conv_dropout = 0.
227+
conv_dropout = 0.,
228+
conv_causal = False
226229
):
227230
super().__init__()
228231
self.dim = dim
@@ -236,6 +239,7 @@ def __init__(
236239
ff_mult = ff_mult,
237240
conv_expansion_factor = conv_expansion_factor,
238241
conv_kernel_size = conv_kernel_size,
242+
conv_causal = conv_causal
239243

240244
))
241245

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'conformer',
55
packages = find_packages(),
6-
version = '0.3.1',
6+
version = '0.3.2',
77
license='MIT',
88
description = 'The convolutional module from the Conformer paper',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)