Skip to content

Commit 87cc4c0

Browse files
committed
address variable sequence lengths while using tf gamma #32
1 parent d2dbc21 commit 87cc4c0

2 files changed

Lines changed: 11 additions & 2 deletions

File tree

enformer_pytorch/modeling_enformer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@
2727
DIR = Path(__file__).parents[0]
2828
TF_GAMMAS = torch.load(str(DIR / "precomputed"/ "tf_gammas.pt"))
2929

30+
def get_tf_gamma(seq_len, device):
31+
tf_gammas = TF_GAMMAS.to(device)
32+
pad = 1536 - seq_len
33+
34+
if pad == 0:
35+
return tf_gammas
36+
37+
return tf_gammas[pad:-pad]
38+
3039
# helpers
3140

3241
def exists(val):
@@ -106,7 +115,7 @@ def get_positional_embed(seq_len, feature_size, device, use_tf_gamma):
106115
feature_functions = [
107116
get_positional_features_exponential,
108117
get_positional_features_central_mask,
109-
get_positional_features_gamma if not use_tf_gamma else always(TF_GAMMAS.to(device))
118+
get_positional_features_gamma if not use_tf_gamma else always(get_tf_gamma(seq_len, device))
110119
]
111120

112121
num_components = len(feature_functions) * 2

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.8.3',
7+
version = '0.8.4',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)