Skip to content

Commit 811a746

Browse files
committed
allow for passing externally defined enformer model into load_pretrained_model method
1 parent 362a57c commit 811a746

3 files changed

Lines changed: 18 additions & 3 deletions

File tree

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,18 @@ model = load_pretrained_model('preview', target_length = 128)
143143
# do your fine-tuning
144144
```
145145

146+
You can also define the model externally, and then load the pretrained weights by passing it into `load_pretrained_model`
147+
148+
```python
149+
from enformer_pytorch import Enformer, load_pretrained_model
150+
151+
enformer = Enformer(dim = 1536, depth = 11, target_length = 128)
152+
153+
load_pretrained_model('preview', model = enformer)
154+
155+
# use enformer
156+
```
157+
146158
## Fine-tuning (wip)
147159

148160
This repository will also allow for easy fine-tuning of Enformer.

enformer_pytorch/model_loader.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def remove_nones(d):
3535
def load_pretrained_model(
3636
slug,
3737
force = False,
38-
target_length = None
38+
target_length = None,
39+
model = None
3940
):
4041
if slug not in CONFIG:
4142
print(f'model {slug} not found among available choices: [{", ".join(CONFIG.keys())}]')
@@ -59,7 +60,9 @@ def load_pretrained_model(
5960
override_params = remove_nones({'target_length': target_length})
6061
params = {**config['params'], **override_params}
6162

62-
model = Enformer(**config['params'])
63+
if not exists(model):
64+
model = Enformer(**config['params'])
65+
6366
model.load_state_dict(torch.load(str(save_path)))
6467

6568
print(f'loaded {slug} successfully')

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.1.9',
7+
version = '0.1.10',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)