Skip to content

Commit 7928be1

Browse files
committed
allow variable lengthed contexts
1 parent f019818 commit 7928be1

3 files changed

Lines changed: 12 additions & 1 deletion

File tree

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,12 @@ seq = torch.randint(0, 4, (1, 196_608 // 2,)).cuda()
244244
target = torch.randn(1, 200, 4).cuda() # 4 tracks
245245
context = torch.randn(4, 16, 1024).cuda() # 4 contexts for the different 'tracks', each with 16 tokens
246246

247+
context_mask = torch.ones(4, 16).bool().cuda() # optional context mask, in example, include all context tokens
248+
247249
loss = model(
248250
seq,
249251
context = context,
252+
context_mask = context_mask,
250253
target = target
251254
)
252255

enformer_pytorch/finetune.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def forward(
149149
seq,
150150
*,
151151
context,
152+
context_mask = None,
152153
target = None,
153154
freeze_enformer = False
154155
):
@@ -183,6 +184,13 @@ def forward(
183184
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
184185
sim = einsum('b h i d, c h j d -> b c h i j', q, k) * self.scale
185186

187+
# masking
188+
189+
if exists(context_mask):
190+
context_mask = F.pad(context_mask, (1, 0), value = True)
191+
context_mask =rearrange(context_mask, 'b j -> b 1 1 1 j')
192+
sim = sim.masked_fill(~context_mask, -torch.finfo(sim.dtype).max)
193+
186194
# attention
187195

188196
attn = sim.softmax(dim = -1)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.1.15',
7+
version = '0.1.16',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)