Skip to content

Commit 6f72ed4

Browse files
committed
add ability to only return embeddings, without human and mouse head calculation
1 parent f5072e1 commit 6f72ed4

3 files changed

Lines changed: 9 additions & 5 deletions

File tree

enformer_pytorch/enformer_pytorch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def forward(
373373
target = None,
374374
return_corr_coef = False,
375375
return_embeddings = False,
376+
return_only_embeddings = False,
376377
head = None
377378
):
378379
dtype = x.dtype
@@ -386,12 +387,15 @@ def forward(
386387
x = rearrange(x, '... -> () ...')
387388

388389
x = self._trunk(x)
389-
out = map_values(lambda fn: fn(x), self._heads)
390390

391391
if no_batch:
392-
out = map_values(lambda t: rearrange(t, '() ... -> ...'), out)
393392
x = rearrange(x, '() ... -> ...')
394393

394+
if return_only_embeddings:
395+
return x
396+
397+
out = map_values(lambda fn: fn(x), self._heads)
398+
395399
if exists(head):
396400
assert head in self._heads, f'head {head} not found'
397401
out = out[head]

enformer_pytorch/finetune.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def forward(
5353
enformer_context = freeze_batchnorm_context(self.enformer) if not freeze_enformer else torch.no_grad()
5454

5555
with enformer_context:
56-
_, embeddings = self.enformer(seq, return_embeddings = True)
56+
embeddings = self.enformer(seq, return_only_embeddings = True)
5757

5858
if freeze_enformer:
5959
embeddings.detach_()
@@ -90,7 +90,7 @@ def forward(
9090
enformer_context = freeze_batchnorm_context(self.enformer) if not freeze_enformer else torch.no_grad()
9191

9292
with enformer_context:
93-
_, embeddings = self.enformer(seq, return_embeddings = True)
93+
embeddings = self.enformer(seq, return_only_embeddings = True)
9494

9595
if freeze_enformer:
9696
embeddings.detach_()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.5',
6+
version = '0.1.6',
77
license='MIT',
88
description = 'Enformer - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)