Skip to content

Commit 7a3ba7c

Browse files
committed
complete the conformer, just layers of conformer blocks, to ready for soundstorm
1 parent 0e91f03 commit 7a3ba7c

File tree

4 files changed

+86
-5
lines changed

4 files changed

+86
-5
lines changed

README.md

+31
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,39 @@ block = ConformerBlock(
4949
)
5050

5151
x = torch.randn(1, 1024, 512)
52+
5253
block(x) # (1, 1024, 512)
5354
```
55+
56+
Conformer - just multiple `ConformerBlock` from above
57+
58+
```python
59+
import torch
60+
from conformer import Conformer
61+
62+
conformer = Conformer(
63+
dim = 512,
64+
depth = 12, # 12 blocks
65+
dim_head = 64,
66+
heads = 8,
67+
ff_mult = 4,
68+
conv_expansion_factor = 2,
69+
conv_kernel_size = 31,
70+
attn_dropout = 0.,
71+
ff_dropout = 0.,
72+
conv_dropout = 0.
73+
)
74+
75+
x = torch.randn(1, 1024, 512)
76+
77+
conformer(x) # (1, 1024, 512)
78+
```
79+
80+
## Todo
81+
82+
- [ ] switch to a better relative positional encoding. shaw's is dated
83+
- [ ] flash attention with a better RPE
84+
5485
## Citations
5586

5687
```bibtex

conformer/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from conformer.conformer import ConformerConvModule, ConformerBlock
1+
from conformer.conformer import ConformerConvModule, ConformerBlock, Conformer

conformer/conformer.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,13 @@ def __init__(
8585

8686
self.dropout = nn.Dropout(dropout)
8787

88-
def forward(self, x, context = None, mask = None, context_mask = None):
88+
def forward(
89+
self,
90+
x,
91+
context = None,
92+
mask = None,
93+
context_mask = None
94+
):
8995
n, device, h, max_pos_emb, has_context = x.shape[-2], x.device, self.heads, self.max_pos_emb, exists(context)
9096
context = default(context, x)
9197

@@ -95,6 +101,7 @@ def forward(self, x, context = None, mask = None, context_mask = None):
95101
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
96102

97103
# shaw's relative positional embedding
104+
98105
seq = torch.arange(n, device = device)
99106
dist = rearrange(seq, 'i -> i ()') - rearrange(seq, 'j -> () j')
100107
dist = dist.clamp(-max_pos_emb, max_pos_emb) + max_pos_emb
@@ -199,3 +206,41 @@ def forward(self, x, mask = None):
199206
x = self.ff2(x) + x
200207
x = self.post_norm(x)
201208
return x
209+
210+
# Conformer
211+
212+
class Conformer(nn.Module):
213+
def __init__(
214+
self,
215+
dim,
216+
*,
217+
depth,
218+
dim_head = 64,
219+
heads = 8,
220+
ff_mult = 4,
221+
conv_expansion_factor = 2,
222+
conv_kernel_size = 31,
223+
attn_dropout = 0.,
224+
ff_dropout = 0.,
225+
conv_dropout = 0.
226+
):
227+
super().__init__()
228+
self.layers = nn.ModuleList([])
229+
230+
for _ in range(depth):
231+
self.layers.append(ConformerBlock(
232+
dim = dim,
233+
dim_head = dim_head,
234+
heads = heads,
235+
ff_mult = ff_mult,
236+
conv_expansion_factor = conv_expansion_factor,
237+
conv_kernel_size = conv_kernel_size,
238+
239+
))
240+
241+
def forward(self, x):
242+
243+
for block in self.layers:
244+
x = block(x)
245+
246+
return x

setup.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,20 @@
33
setup(
44
name = 'conformer',
55
packages = find_packages(),
6-
version = '0.2.5',
6+
version = '0.3.0',
77
license='MIT',
88
description = 'The convolutional module from the Conformer paper',
99
author = 'Phil Wang',
1010
author_email = '[email protected]',
1111
url = 'https://github.com/lucidrains/conformer',
12-
keywords = ['transformers', 'artificial intelligence', 'transformer'],
12+
keywords = [
13+
'artificial intelligence',
14+
'deep learning',
15+
'transformers',
16+
'audio'
17+
],
1318
install_requires=[
14-
'einops',
19+
'einops>=0.6.1',
1520
'torch'
1621
],
1722
classifiers=[

0 commit comments

Comments
 (0)