Skip to content

Commit 18614f7

Browse files
committed
allow for overriding all dropouts, as well as convenience method for dynamically setting target crop length
1 parent 2bf8213 commit 18614f7

3 files changed

Lines changed: 9 additions & 6 deletions

File tree

enformer_pytorch/enformer_pytorch.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,9 @@ def __init__(
258258
heads = 8,
259259
output_heads = dict(human = 5313, mouse= 1643),
260260
target_length = TARGET_LENGTH,
261-
dropout_rate = 0.4,
262261
num_alphabet = 4,
263262
attn_dim_key = 64,
263+
dropout_rate = 0.4,
264264
attn_dropout = 0.05,
265265
pos_dropout = 0.01
266266
):
@@ -359,6 +359,10 @@ def __init__(
359359
nn.Softplus()
360360
), output_heads))
361361

362+
def set_target_length(self, target_length):
363+
crop_module = self._trunk[-2]
364+
crop_module.target_length = target_length
365+
362366
@property
363367
def trunk(self):
364368
return self._trunk

enformer_pytorch/model_loader.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@ def remove_nones(d):
3535
def load_pretrained_model(
3636
slug,
3737
force = False,
38-
target_length = None,
39-
dropout_rate = None,
40-
model = None
38+
model = None,
39+
**kwargs
4140
):
4241
if slug not in CONFIG:
4342
print(f'model {slug} not found among available choices: [{", ".join(CONFIG.keys())}]')
@@ -58,7 +57,7 @@ def load_pretrained_model(
5857

5958
# load
6059

61-
override_params = remove_nones({'target_length': target_length, 'dropout_rate': dropout_rate})
60+
override_params = remove_nones(kwargs)
6261
params = {**config['params'], **override_params}
6362

6463
if not exists(model):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.1.19',
7+
version = '0.1.20',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)