@@ -82,32 +82,37 @@ def get_model(
8282 # Init representation related variables
8383 sr , hop_len , patch_size = None , None , None
8484
85- config_file = Path (config_file )
86-
87- # Read previous config bindings
88- bindings = []
89- cfg_str = gin .config_str ()
90-
91- # If these are not empty, this model is part of a larger setup
92- # Do not finish the configuration now
9385 finalize_config = False
94- if cfg_str == "" :
95- finalize_config = True
96-
97- lines = cfg_str .split ("\n " )
98- bindings .extend (lines )
9986
100- if encodec_weights_path is not None :
101- bindings .append (f"nets.encodec.EnCodec.weights_path = '{ encodec_weights_path } '" )
102- bindings .append ("nets.encodec.EnCodec.stats_path = None" )
103-
104- # Parse the gin config
105- gin .parse_config_files_and_bindings (
106- [config_file ],
107- bindings ,
108- skip_unknown = True ,
109- finalize_config = finalize_config ,
110- )
87+ # When no config file is provided, it is assumed that an external
88+ # gin-config file with all the required fileds has already been parsed.
89+ # Don't try to moddify the gin configuration nor load a checkpoint.
90+ if config_file != "" :
91+ # Read previous config bindings
92+ bindings = []
93+ cfg_str = gin .config_str ()
94+
95+ # If these are not empty, this model is part of a larger setup
96+ # Do not finish the configuration now
97+ if cfg_str == "" :
98+ finalize_config = True
99+
100+ lines = cfg_str .split ("\n " )
101+ bindings .extend (lines )
102+
103+ if encodec_weights_path is not None :
104+ bindings .append (
105+ f"nets.encodec.EnCodec.weights_path = '{ encodec_weights_path } '"
106+ )
107+ bindings .append ("nets.encodec.EnCodec.stats_path = None" )
108+
109+ # Parse the gin config
110+ gin .parse_config_files_and_bindings (
111+ [str (config_file )],
112+ bindings ,
113+ skip_unknown = True ,
114+ finalize_config = finalize_config ,
115+ )
111116
112117 gin_config = gin .get_bindings (build_module )
113118
@@ -116,11 +121,6 @@ def get_model(
116121 representation = gin_config ["representation" ]
117122 module = gin_config ["module" ]
118123
119- # Make the checkpoint path relative to the config file location
120- # insted of taking the absolute path
121- ckpt_path = Path (gin_config ["ckpt_path" ])
122- ckpt_path = config_file .parent / ckpt_path .name
123-
124124 # Instantiate the classes
125125 net = net ()
126126
@@ -135,13 +135,21 @@ def get_model(
135135 sr = representation .sr
136136 hop_len = representation .hop_len
137137
138- module = module .load_from_checkpoint (
139- ckpt_path ,
140- net = net ,
141- representation = representation ,
142- strict = False ,
143- map_location = device ,
144- )
138+ if config_file != "" :
139+ # Make the checkpoint path relative to the config file location
140+ # insted of taking the absolute path
141+ ckpt_path = Path (gin_config ["ckpt_path" ])
142+ ckpt_path = Path (config_file ).parent / ckpt_path .name
143+
144+ module = module .load_from_checkpoint (
145+ ckpt_path ,
146+ net = net ,
147+ representation = representation ,
148+ strict = False ,
149+ map_location = device ,
150+ )
151+ else :
152+ module = module (net = net , representation = representation )
145153
146154 # Set the model to eval mode
147155 module .eval ()
0 commit comments