@@ -14,10 +14,18 @@ def exists(val):
1414def null_context ():
1515 yield
1616
17+ # controlling freezing of layers
18+
1719def set_module_requires_grad_ (module , requires_grad ):
1820 for param in module .parameters ():
1921 param .requires_grad = requires_grad
2022
23+ def freeze_all_layers_ (module ):
24+ set_module_requires_grad_ (module , False )
25+
26+ def unfreeze_all_layers_ (module ):
27+ set_module_requires_grad_ (module , True )
28+
2129def freeze_batchnorms_ (model ):
2230 bns = [m for m in model .modules () if isinstance (m , nn .BatchNorm1d )]
2331
@@ -30,6 +38,24 @@ def freeze_all_but_layernorms_(model):
3038 for m in model .modules ():
3139 set_module_requires_grad_ (m , isinstance (m , nn .LayerNorm ))
3240
41+ def unfreeze_last_n_layers_ (enformer , n ):
42+ assert isinstance (enformer , Enformer )
43+ transformer_blocks = enformer .transformer [1 :]
44+
45+ for module in transformer_blocks [- n :]:
46+ set_module_requires_grad_ (module , True )
47+
48+ def freeze_all_but_last_n_layers_ (enformer , n ):
49+ assert isinstance (enformer , Enformer )
50+ freeze_all_layers_ (enformer )
51+
52+ transformer_blocks = enformer .transformer [1 :]
53+
54+ for module in transformer_blocks [- n :]:
55+ set_module_requires_grad_ (module , False )
56+
57+ # get enformer embeddings
58+
3359def get_enformer_embeddings (
3460 model ,
3561 seq ,
0 commit comments