File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -373,6 +373,7 @@ def forward(
373373 target = None ,
374374 return_corr_coef = False ,
375375 return_embeddings = False ,
376+ return_only_embeddings = False ,
376377 head = None
377378 ):
378379 dtype = x .dtype
@@ -386,12 +387,15 @@ def forward(
386387 x = rearrange (x , '... -> () ...' )
387388
388389 x = self ._trunk (x )
389- out = map_values (lambda fn : fn (x ), self ._heads )
390390
391391 if no_batch :
392- out = map_values (lambda t : rearrange (t , '() ... -> ...' ), out )
393392 x = rearrange (x , '() ... -> ...' )
394393
394+ if return_only_embeddings :
395+ return x
396+
397+ out = map_values (lambda fn : fn (x ), self ._heads )
398+
395399 if exists (head ):
396400 assert head in self ._heads , f'head { head } not found'
397401 out = out [head ]
Original file line number Diff line number Diff line change @@ -53,7 +53,7 @@ def forward(
5353 enformer_context = freeze_batchnorm_context (self .enformer ) if not freeze_enformer else torch .no_grad ()
5454
5555 with enformer_context :
56- _ , embeddings = self .enformer (seq , return_embeddings = True )
56+ embeddings = self .enformer (seq , return_only_embeddings = True )
5757
5858 if freeze_enformer :
5959 embeddings .detach_ ()
@@ -90,7 +90,7 @@ def forward(
9090 enformer_context = freeze_batchnorm_context (self .enformer ) if not freeze_enformer else torch .no_grad ()
9191
9292 with enformer_context :
93- _ , embeddings = self .enformer (seq , return_embeddings = True )
93+ embeddings = self .enformer (seq , return_only_embeddings = True )
9494
9595 if freeze_enformer :
9696 embeddings .detach_ ()
Original file line number Diff line number Diff line change 33setup (
44 name = 'enformer-pytorch' ,
55 packages = find_packages (exclude = []),
6- version = '0.1.5 ' ,
6+ version = '0.1.6 ' ,
77 license = 'MIT' ,
88 description = 'Enformer - Pytorch' ,
99 author = 'Phil Wang' ,
You can’t perform that action at this time.
0 commit comments