File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -258,9 +258,9 @@ def __init__(
258258 heads = 8 ,
259259 output_heads = dict (human = 5313 , mouse = 1643 ),
260260 target_length = TARGET_LENGTH ,
261- dropout_rate = 0.4 ,
262261 num_alphabet = 4 ,
263262 attn_dim_key = 64 ,
263+ dropout_rate = 0.4 ,
264264 attn_dropout = 0.05 ,
265265 pos_dropout = 0.01
266266 ):
@@ -359,6 +359,10 @@ def __init__(
359359 nn .Softplus ()
360360 ), output_heads ))
361361
362+ def set_target_length (self , target_length ):
363+ crop_module = self ._trunk [- 2 ]
364+ crop_module .target_length = target_length
365+
362366 @property
363367 def trunk (self ):
364368 return self ._trunk
Original file line number Diff line number Diff line change @@ -35,9 +35,8 @@ def remove_nones(d):
3535def load_pretrained_model (
3636 slug ,
3737 force = False ,
38- target_length = None ,
39- dropout_rate = None ,
40- model = None
38+ model = None ,
39+ ** kwargs
4140):
4241 if slug not in CONFIG :
4342 print (f'model { slug } not found among available choices: [{ ", " .join (CONFIG .keys ())} ]' )
@@ -58,7 +57,7 @@ def load_pretrained_model(
5857
5958 # load
6059
61- override_params = remove_nones ({ 'target_length' : target_length , 'dropout_rate' : dropout_rate } )
60+ override_params = remove_nones (kwargs )
6261 params = {** config ['params' ], ** override_params }
6362
6463 if not exists (model ):
Original file line number Diff line number Diff line change 44 name = 'enformer-pytorch' ,
55 packages = find_packages (exclude = []),
66 include_package_data = True ,
7- version = '0.1.19 ' ,
7+ version = '0.1.20 ' ,
88 license = 'MIT' ,
99 description = 'Enformer - Pytorch' ,
1010 author = 'Phil Wang' ,
You can’t perform that action at this time.
0 commit comments