Skip to content

Commit 9ac9d69

Browse files
committed
overrideable dropout rate for fine-tuning
1 parent 2195260 commit 9ac9d69

3 files changed

Lines changed: 5 additions & 4 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ You can also load, with overriding of the `target_length` parameter, if you are
138138
```python
139139
from enformer_pytorch import load_pretrained_model
140140

141-
model = load_pretrained_model('preview', target_length = 128)
141+
model = load_pretrained_model('preview', target_length = 128, dropout_rate = 0.1)
142142

143143
# do your fine-tuning
144144
```
@@ -148,7 +148,7 @@ You can also define the model externally, and then load the pretrained weights b
148148
```python
149149
from enformer_pytorch import Enformer, load_pretrained_model
150150

151-
enformer = Enformer(dim = 1536, depth = 11, target_length = 128)
151+
enformer = Enformer(dim = 1536, depth = 11, target_length = 128, dropout_rate = 0.1)
152152

153153
load_pretrained_model('preview', model = enformer)
154154

enformer_pytorch/model_loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def load_pretrained_model(
3636
slug,
3737
force = False,
3838
target_length = None,
39+
dropout_rate = None,
3940
model = None
4041
):
4142
if slug not in CONFIG:
@@ -57,7 +58,7 @@ def load_pretrained_model(
5758

5859
# load
5960

60-
override_params = remove_nones({'target_length': target_length})
61+
override_params = remove_nones({'target_length': target_length, 'dropout_rate': dropout_rate})
6162
params = {**config['params'], **override_params}
6263

6364
if not exists(model):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.1.17',
7+
version = '0.1.18',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)