Skip to content

Commit a70f5d9

Browse files
committed
release 1 conformer block
1 parent 9453375 commit a70f5d9

File tree

4 files changed

+149
-4
lines changed

4 files changed

+149
-4
lines changed

README.md

+21
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,27 @@ x = torch.randn(1, 1024, 512)
2828
x = layer(x) + x
2929
```
3030

31+
1 Conformer Block
32+
33+
```python
34+
import torch
35+
from conformer import ConformerBlock
36+
37+
block = ConformerBlock(
38+
dim = 512,
39+
dim_head = 64,
40+
heads = 8,
41+
ff_mult = 4,
42+
conv_expansion_factor = 2,
43+
conv_kernel_size = 31,
44+
attn_dropout = 0.,
45+
ff_dropout = 0.,
46+
conv_dropout = 0.
47+
)
48+
49+
x = torch.randn(1, 1024, 512)
50+
block(x) # (1, 1024, 512)
51+
```
3152
## Citations
3253

3354
```bibtex

conformer/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from conformer.conformer import ConformerConvModule
1+
from conformer.conformer import ConformerConvModule, ConformerBlock

conformer/conformer.py

+125-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
import torch
2-
from torch import nn
2+
from torch import nn, einsum
33
import torch.nn.functional as F
44

5+
from einops import rearrange
6+
57
# helper functions
68

9+
def exists(val):
10+
return val is not None
11+
12+
def default(val, d):
13+
return val if exists(val) else d
14+
715
def calc_same_padding(kernel_size):
816
pad = kernel_size // 2
917
return (pad, pad - (kernel_size + 1) % 2)
@@ -42,7 +50,86 @@ def forward(self, x):
4250
x = F.pad(x, self.padding)
4351
return self.conv(x)
4452

45-
# main class
53+
# attention, feedforward, and conv module
54+
55+
class Scale(nn.Module):
56+
def __init__(self, scale, fn):
57+
super().__init__()
58+
self.fn = fn
59+
self.scale = scale
60+
61+
def forward(self, x, **kwargs):
62+
return self.fn(x, **kwargs) * self.scale
63+
64+
class PreNorm(nn.Module):
65+
def __init__(self, dim, fn):
66+
super().__init__()
67+
self.fn = fn
68+
self.norm = nn.LayerNorm(dim)
69+
70+
def forward(self, x, **kwargs):
71+
x = self.norm(x)
72+
return self.fn(x, **kwargs)
73+
74+
class Attention(nn.Module):
75+
def __init__(
76+
self,
77+
dim,
78+
heads = 8,
79+
dim_head = 64,
80+
dropout = 0.
81+
):
82+
super().__init__()
83+
inner_dim = dim_head * heads
84+
self.heads= heads
85+
self.scale = dim_head ** -0.5
86+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
87+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
88+
self.to_out = nn.Linear(inner_dim, dim)
89+
90+
self.dropout = nn.Dropout(dropout)
91+
92+
def forward(self, x, context = None, mask = None, context_mask = None):
93+
device, h, has_context = x.device, self.heads, exists(context)
94+
context = default(context, x)
95+
96+
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
97+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
98+
99+
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
100+
101+
if exists(mask) or exists(context_mask):
102+
mask = default(mask, lambda: torch.ones(*x.shape[:2], device = device))
103+
context_mask = default(context_mask, mask) if not has_context else default(context_mask, lambda: torch.ones(*context.shape[:2], device = device))
104+
mask_value = -torch.finfo(dots.dtype).max
105+
mask = mask[:, None, :, None] * context_mask[:, None, None, :]
106+
dots.masked_fill_(~mask, mask_value)
107+
108+
attn = dots.softmax(dim = -1)
109+
attn = self.dropout(attn)
110+
111+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
112+
out = rearrange(out, 'b h n d -> b n (h d)')
113+
return self.to_out(out)
114+
115+
class FeedForward(nn.Module):
116+
def __init__(
117+
self,
118+
dim,
119+
mult = 4,
120+
dropout = 0.
121+
):
122+
super().__init__()
123+
self.net = nn.Sequential(
124+
nn.Linear(dim, dim * mult),
125+
Swish(),
126+
nn.Dropout(dropout),
127+
nn.Linear(dim * mult, dim),
128+
nn.Dropout(dropout)
129+
)
130+
131+
def forward(self, x):
132+
return self.net(x)
46133

47134
class ConformerConvModule(nn.Module):
48135
def __init__(
@@ -72,3 +159,39 @@ def __init__(
72159

73160
def forward(self, x):
74161
return self.net(x)
162+
163+
# Conformer Block
164+
165+
class ConformerBlock(nn.Module):
166+
def __init__(
167+
self,
168+
*,
169+
dim,
170+
dim_head = 64,
171+
heads = 8,
172+
ff_mult = 4,
173+
conv_expansion_factor = 2,
174+
conv_kernel_size = 31,
175+
attn_dropout = 0.,
176+
ff_dropout = 0.,
177+
conv_dropout = 0.
178+
):
179+
super().__init__()
180+
self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
181+
self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
182+
self.conv = ConformerConvModule(dim = dim, causal = False, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
183+
self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
184+
185+
self.attn = PreNorm(dim, self.attn)
186+
self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
187+
self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))
188+
189+
self.post_norm = nn.LayerNorm(dim)
190+
191+
def forward(self, x, mask = None):
192+
x = self.ff1(x) + x
193+
x = self.attn(x, mask = mask) + x
194+
x = self.conv(x) + x
195+
x = self.ff2(x) + x
196+
x = self.post_norm(x)
197+
return x

setup.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
setup(
44
name = 'conformer',
55
packages = find_packages(),
6-
version = '0.1.0',
6+
version = '0.2.0',
77
license='MIT',
88
description = 'The convolutional module from the Conformer paper',
99
author = 'Phil Wang',
1010
author_email = '[email protected]',
1111
url = 'https://github.com/lucidrains/conformer',
1212
keywords = ['transformers', 'artificial intelligence', 'transformer'],
1313
install_requires=[
14+
'einops',
1415
'torch'
1516
],
1617
classifiers=[

0 commit comments

Comments
 (0)