Skip to content

Commit 0f8fbd3

Browse files
authored
incorporate some einops suggestions from @arogozhnikov
1 parent 86b8316 commit 0f8fbd3

File tree

1 file changed

+5
-13
lines changed

1 file changed

+5
-13
lines changed

conformer/conformer.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch.nn.functional as F
44

55
from einops import rearrange
6+
from einops.layers.torch import Rearrange
67

78
# helper functions
89

@@ -22,15 +23,6 @@ class Swish(nn.Module):
2223
def forward(self, x):
2324
return x * x.sigmoid()
2425

25-
class Transpose(nn.Module):
26-
def __init__(self, dims):
27-
super().__init__()
28-
assert len(dims) == 2, 'dims must be a tuple of two dimensions'
29-
self.dims = dims
30-
31-
def forward(self, x):
32-
return x.transpose(*self.dims)
33-
3426
class GLU(nn.Module):
3527
def __init__(self, dim):
3628
super().__init__()
@@ -104,7 +96,7 @@ def forward(self, x, context = None, mask = None, context_mask = None):
10496

10597
# shaw's relative positional embedding
10698
seq = torch.arange(n, device = device)
107-
dist = seq[:, None] - seq[None, :]
99+
dist = rearrange(seq, 'i -> i ()') - rearrange(seq, 'j -> () j')
108100
dist = dist.clip(-max_pos_emb, max_pos_emb) + max_pos_emb
109101
rel_pos_emb = self.rel_pos_emb(dist).to(q)
110102
pos_attn = einsum('b h n d, n r d -> b h n r', q, rel_pos_emb) * self.scale
@@ -114,7 +106,7 @@ def forward(self, x, context = None, mask = None, context_mask = None):
114106
mask = default(mask, lambda: torch.ones(*x.shape[:2], device = device))
115107
context_mask = default(context_mask, mask) if not has_context else default(context_mask, lambda: torch.ones(*context.shape[:2], device = device))
116108
mask_value = -torch.finfo(dots.dtype).max
117-
mask = mask[:, None, :, None] * context_mask[:, None, None, :]
109+
mask = rearrange(mask, 'b i -> b () i ()') * rearrange(context_mask, 'b j -> b () () j')
118110
dots.masked_fill_(~mask, mask_value)
119111

120112
attn = dots.softmax(dim = -1)
@@ -158,14 +150,14 @@ def __init__(
158150

159151
self.net = nn.Sequential(
160152
nn.LayerNorm(dim),
161-
Transpose((1, 2)),
153+
Rearrange('b n c -> b c n'),
162154
nn.Conv1d(dim, inner_dim * 2, 1),
163155
GLU(dim=1),
164156
DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),
165157
nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
166158
Swish(),
167159
nn.Conv1d(inner_dim, dim, 1),
168-
Transpose((1, 2)),
160+
Rearrange('b c n -> b n c'),
169161
nn.Dropout(dropout)
170162
)
171163

0 commit comments

Comments
 (0)