Skip to content

Commit 0cb3a55

Browse files
committed
add shaws relative positional embedding
1 parent 7321090 commit 0cb3a55

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

conformer/conformer.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def __init__(
7777
dim,
7878
heads = 8,
7979
dim_head = 64,
80-
dropout = 0.
80+
dropout = 0.,
81+
max_pos_emb = 512
8182
):
8283
super().__init__()
8384
inner_dim = dim_head * heads
@@ -87,17 +88,28 @@ def __init__(
8788
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
8889
self.to_out = nn.Linear(inner_dim, dim)
8990

91+
self.max_pos_emb = max_pos_emb
92+
self.rel_pos_emb = nn.Embedding(2 * max_pos_emb + 1, dim_head)
93+
9094
self.dropout = nn.Dropout(dropout)
9195

9296
def forward(self, x, context = None, mask = None, context_mask = None):
93-
device, h, has_context = x.device, self.heads, exists(context)
97+
n, device, h, max_pos_emb, has_context = x.shape[-2], x.device, self.heads, self.max_pos_emb, exists(context)
9498
context = default(context, x)
9599

96100
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
97101
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
98102

99103
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
100104

105+
# shaw's relative positional embedding
106+
seq = torch.arange(n, device = device)
107+
dist = seq[:, None] - seq[None, :]
108+
dist = dist.clip(-max_pos_emb, max_pos_emb) + max_pos_emb
109+
rel_pos_emb = self.rel_pos_emb(dist).to(q)
110+
pos_attn = einsum('b h n d, n r d -> b h n r', q, rel_pos_emb) * self.scale
111+
dots = dots + pos_attn
112+
101113
if exists(mask) or exists(context_mask):
102114
mask = default(mask, lambda: torch.ones(*x.shape[:2], device = device))
103115
context_mask = default(context_mask, mask) if not has_context else default(context_mask, lambda: torch.ones(*context.shape[:2], device = device))

0 commit comments

Comments
 (0)