Skip to content

Commit dce5709

Browse files
committed
only allow seq length of 1536 if using tf gamma
1 parent 87cc4c0 commit dce5709

3 files changed

Lines changed: 5 additions & 12 deletions

File tree

enformer_pytorch/modeling_enformer.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,6 @@
2727
DIR = Path(__file__).parents[0]
2828
TF_GAMMAS = torch.load(str(DIR / "precomputed"/ "tf_gammas.pt"))
2929

30-
def get_tf_gamma(seq_len, device):
31-
tf_gammas = TF_GAMMAS.to(device)
32-
pad = 1536 - seq_len
33-
34-
if pad == 0:
35-
return tf_gammas
36-
37-
return tf_gammas[pad:-pad]
38-
3930
# helpers
4031

4132
def exists(val):
@@ -112,10 +103,12 @@ def get_positional_features_gamma(positions, features, seq_len, stddev = None, s
112103
def get_positional_embed(seq_len, feature_size, device, use_tf_gamma):
113104
distances = torch.arange(-seq_len + 1, seq_len, device = device)
114105

106+
assert not use_tf_gamma or seq_len == 1536, 'if using tf gamma, only sequence length of 1536 allowed for now'
107+
115108
feature_functions = [
116109
get_positional_features_exponential,
117110
get_positional_features_central_mask,
118-
get_positional_features_gamma if not use_tf_gamma else always(get_tf_gamma(seq_len, device))
111+
get_positional_features_gamma if not use_tf_gamma else always(TF_GAMMAS.to(device))
119112
]
120113

121114
num_components = len(feature_functions) * 2

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.8.4',
7+
version = '0.8.5',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

test_pretrained.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from enformer_pytorch import from_pretrained
33

4-
enformer = from_pretrained('EleutherAI/enformer-official-rough').cuda()
4+
enformer = from_pretrained('EleutherAI/enformer-official-rough', use_tf_gamma = False).cuda()
55
enformer.eval()
66

77
data = torch.load('./data/test-sample.pt')

0 commit comments

Comments
 (0)