Skip to content

Commit 1509960

Browse files
committed
add more freezing unfreezing functions for finetuning
1 parent ceef4ab commit 1509960

2 files changed

Lines changed: 27 additions & 1 deletion

File tree

enformer_pytorch/finetune.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,18 @@ def exists(val):
1414
def null_context():
1515
yield
1616

17+
# controlling freezing of layers
18+
1719
def set_module_requires_grad_(module, requires_grad):
1820
for param in module.parameters():
1921
param.requires_grad = requires_grad
2022

23+
def freeze_all_layers_(module):
24+
set_module_requires_grad_(module, False)
25+
26+
def unfreeze_all_layers_(module):
27+
set_module_requires_grad_(module, True)
28+
2129
def freeze_batchnorms_(model):
2230
bns = [m for m in model.modules() if isinstance(m, nn.BatchNorm1d)]
2331

@@ -30,6 +38,24 @@ def freeze_all_but_layernorms_(model):
3038
for m in model.modules():
3139
set_module_requires_grad_(m, isinstance(m, nn.LayerNorm))
3240

41+
def unfreeze_last_n_layers_(enformer, n):
42+
assert isinstance(enformer, Enformer)
43+
transformer_blocks = enformer.transformer[1:]
44+
45+
for module in transformer_blocks[-n:]:
46+
set_module_requires_grad_(module, True)
47+
48+
def freeze_all_but_last_n_layers_(enformer, n):
49+
assert isinstance(enformer, Enformer)
50+
freeze_all_layers_(enformer)
51+
52+
transformer_blocks = enformer.transformer[1:]
53+
54+
for module in transformer_blocks[-n:]:
55+
set_module_requires_grad_(module, False)
56+
57+
# get enformer embeddings
58+
3359
def get_enformer_embeddings(
3460
model,
3561
seq,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.3.5',
7+
version = '0.3.6',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)