@@ -241,13 +241,17 @@ class ModelTraining(dj.Computed):
241241 # https://github.com/DeepLabCut/DeepLabCut/issues/70
242242
243243 def make (self , key ):
244- from deeplabcut import train_network # isort:skip
244+ from deeplabcut import train_network # isort:skip
245+
245246 try :
246- from deeplabcut .utils .auxiliaryfunctions import get_model_folder # isort:skip
247+ from deeplabcut .utils .auxiliaryfunctions import (
248+ get_model_folder ,
249+ edit_config ,
250+ ) # isort:skip
247251 except ImportError :
248252 from deeplabcut .utils .auxiliaryfunctions import (
249- GetModelFolder as get_model_folder
250- ) # isort:skip
253+ GetModelFolder as get_model_folder ,
254+ ) # isort:skip
251255
252256 """Launch training for each train.TrainingTask training_id via `.populate()`."""
253257 project_path , model_prefix = (TrainingTask & key ).fetch1 (
@@ -275,11 +279,26 @@ def make(self, key):
275279 # Write dlc config file to base project folder
276280 dlc_cfg_filepath = dlc_reader .save_yaml (project_path , dlc_config )
277281
282+ # ---- Update the project path in the DLC pose configuration (yaml) files ----
283+ model_folder = get_model_folder (
284+ trainFraction = dlc_config ["train_fraction" ],
285+ shuffle = dlc_config ["shuffle" ],
286+ cfg = dlc_config ,
287+ modelprefix = dlc_config ["modelprefix" ],
288+ )
289+ model_train_folder = project_path / model_folder / "train"
290+
291+ edit_config (
292+ model_train_folder / "pose_cfg.yaml" ,
293+ {"project_path" : project_path .as_posix ()},
294+ )
295+
278296 # ---- Trigger DLC model training job ----
279297 train_network_input_args = list (inspect .signature (train_network ).parameters )
280298 train_network_kwargs = {
281- k : int (v ) if k in ("shuffle" , "trainingsetindex" , "maxiters" ) else v
282- for k , v in dlc_config .items () if k in train_network_input_args
299+ k : int (v ) if k in ("shuffle" , "trainingsetindex" , "maxiters" ) else v
300+ for k , v in dlc_config .items ()
301+ if k in train_network_input_args
283302 }
284303 for k in ["shuffle" , "trainingsetindex" , "maxiters" ]:
285304 train_network_kwargs [k ] = int (train_network_kwargs [k ])
@@ -289,18 +308,7 @@ def make(self, key):
289308 except KeyboardInterrupt : # Instructions indicate to train until interrupt
290309 print ("DLC training stopped via Keyboard Interrupt" )
291310
292- snapshots = list (
293- (
294- project_path
295- / get_model_folder (
296- trainFraction = dlc_config ["train_fraction" ],
297- shuffle = dlc_config ["shuffle" ],
298- cfg = dlc_config ,
299- modelprefix = dlc_config ["modelprefix" ],
300- )
301- / "train"
302- ).glob ("*index*" )
303- )
311+ snapshots = list (model_train_folder .glob ("*index*" ))
304312 max_modified_time = 0
305313 # DLC goes by snapshot magnitude when judging 'latest' for evaluation
306314 # Here, we mean most recently generated
0 commit comments