|
1 | 1 | import torch
|
2 |
| -from torch import nn |
| 2 | +from torch import nn, einsum |
3 | 3 | import torch.nn.functional as F
|
4 | 4 |
|
| 5 | +from einops import rearrange |
| 6 | + |
5 | 7 | # helper functions
|
6 | 8 |
|
| 9 | +def exists(val): |
| 10 | + return val is not None |
| 11 | + |
| 12 | +def default(val, d): |
| 13 | + return val if exists(val) else d |
| 14 | + |
7 | 15 | def calc_same_padding(kernel_size):
|
8 | 16 | pad = kernel_size // 2
|
9 | 17 | return (pad, pad - (kernel_size + 1) % 2)
|
@@ -42,7 +50,86 @@ def forward(self, x):
|
42 | 50 | x = F.pad(x, self.padding)
|
43 | 51 | return self.conv(x)
|
44 | 52 |
|
45 |
| -# main class |
| 53 | +# attention, feedforward, and conv module |
| 54 | + |
| 55 | +class Scale(nn.Module): |
| 56 | + def __init__(self, scale, fn): |
| 57 | + super().__init__() |
| 58 | + self.fn = fn |
| 59 | + self.scale = scale |
| 60 | + |
| 61 | + def forward(self, x, **kwargs): |
| 62 | + return self.fn(x, **kwargs) * self.scale |
| 63 | + |
| 64 | +class PreNorm(nn.Module): |
| 65 | + def __init__(self, dim, fn): |
| 66 | + super().__init__() |
| 67 | + self.fn = fn |
| 68 | + self.norm = nn.LayerNorm(dim) |
| 69 | + |
| 70 | + def forward(self, x, **kwargs): |
| 71 | + x = self.norm(x) |
| 72 | + return self.fn(x, **kwargs) |
| 73 | + |
| 74 | +class Attention(nn.Module): |
| 75 | + def __init__( |
| 76 | + self, |
| 77 | + dim, |
| 78 | + heads = 8, |
| 79 | + dim_head = 64, |
| 80 | + dropout = 0. |
| 81 | + ): |
| 82 | + super().__init__() |
| 83 | + inner_dim = dim_head * heads |
| 84 | + self.heads= heads |
| 85 | + self.scale = dim_head ** -0.5 |
| 86 | + self.to_q = nn.Linear(dim, inner_dim, bias = False) |
| 87 | + self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) |
| 88 | + self.to_out = nn.Linear(inner_dim, dim) |
| 89 | + |
| 90 | + self.dropout = nn.Dropout(dropout) |
| 91 | + |
| 92 | + def forward(self, x, context = None, mask = None, context_mask = None): |
| 93 | + device, h, has_context = x.device, self.heads, exists(context) |
| 94 | + context = default(context, x) |
| 95 | + |
| 96 | + q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) |
| 97 | + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) |
| 98 | + |
| 99 | + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale |
| 100 | + |
| 101 | + if exists(mask) or exists(context_mask): |
| 102 | + mask = default(mask, lambda: torch.ones(*x.shape[:2], device = device)) |
| 103 | + context_mask = default(context_mask, mask) if not has_context else default(context_mask, lambda: torch.ones(*context.shape[:2], device = device)) |
| 104 | + mask_value = -torch.finfo(dots.dtype).max |
| 105 | + mask = mask[:, None, :, None] * context_mask[:, None, None, :] |
| 106 | + dots.masked_fill_(~mask, mask_value) |
| 107 | + |
| 108 | + attn = dots.softmax(dim = -1) |
| 109 | + attn = self.dropout(attn) |
| 110 | + |
| 111 | + out = einsum('b h i j, b h j d -> b h i d', attn, v) |
| 112 | + out = rearrange(out, 'b h n d -> b n (h d)') |
| 113 | + return self.to_out(out) |
| 114 | + |
| 115 | +class FeedForward(nn.Module): |
| 116 | + def __init__( |
| 117 | + self, |
| 118 | + dim, |
| 119 | + mult = 4, |
| 120 | + dropout = 0. |
| 121 | + ): |
| 122 | + super().__init__() |
| 123 | + self.net = nn.Sequential( |
| 124 | + nn.Linear(dim, dim * mult), |
| 125 | + Swish(), |
| 126 | + nn.Dropout(dropout), |
| 127 | + nn.Linear(dim * mult, dim), |
| 128 | + nn.Dropout(dropout) |
| 129 | + ) |
| 130 | + |
| 131 | + def forward(self, x): |
| 132 | + return self.net(x) |
46 | 133 |
|
47 | 134 | class ConformerConvModule(nn.Module):
|
48 | 135 | def __init__(
|
@@ -72,3 +159,39 @@ def __init__(
|
72 | 159 |
|
73 | 160 | def forward(self, x):
|
74 | 161 | return self.net(x)
|
| 162 | + |
| 163 | +# Conformer Block |
| 164 | + |
| 165 | +class ConformerBlock(nn.Module): |
| 166 | + def __init__( |
| 167 | + self, |
| 168 | + *, |
| 169 | + dim, |
| 170 | + dim_head = 64, |
| 171 | + heads = 8, |
| 172 | + ff_mult = 4, |
| 173 | + conv_expansion_factor = 2, |
| 174 | + conv_kernel_size = 31, |
| 175 | + attn_dropout = 0., |
| 176 | + ff_dropout = 0., |
| 177 | + conv_dropout = 0. |
| 178 | + ): |
| 179 | + super().__init__() |
| 180 | + self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) |
| 181 | + self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) |
| 182 | + self.conv = ConformerConvModule(dim = dim, causal = False, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout) |
| 183 | + self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) |
| 184 | + |
| 185 | + self.attn = PreNorm(dim, self.attn) |
| 186 | + self.ff1 = Scale(0.5, PreNorm(dim, self.ff1)) |
| 187 | + self.ff2 = Scale(0.5, PreNorm(dim, self.ff2)) |
| 188 | + |
| 189 | + self.post_norm = nn.LayerNorm(dim) |
| 190 | + |
| 191 | + def forward(self, x, mask = None): |
| 192 | + x = self.ff1(x) + x |
| 193 | + x = self.attn(x, mask = mask) + x |
| 194 | + x = self.conv(x) + x |
| 195 | + x = self.ff2(x) + x |
| 196 | + x = self.post_norm(x) |
| 197 | + return x |
0 commit comments