@@ -21,7 +21,8 @@ def batch(command, jobs_cfg):
2121 jobs_cfg ["model" ] = [jobs_cfg ["model" ]]
2222
2323 # Same for dataset
24- if not isinstance (jobs_cfg ["dataset" ], list ):
24+ has_dataset = "dataset" in jobs_cfg
25+ if has_dataset and not isinstance (jobs_cfg ["dataset" ], list ):
2526 jobs_cfg ["dataset" ] = [jobs_cfg ["dataset" ]]
2627
2728 # Build list of model configurations
@@ -54,11 +55,16 @@ def batch(command, jobs_cfg):
5455
5556 # Build list of jobs (IDs must be unique)
5657 all_jobs = {}
57- for model_cfg , dataset_cfg in product (model_cfgs , jobs_cfg ["dataset" ]):
58- job_id = f"{ model_cfg ['id' ]} -{ dataset_cfg ['id' ]} "
58+ job_iter = product (model_cfgs , jobs_cfg ["dataset" ]) if has_dataset else model_cfgs
59+ for job_components in job_iter :
60+ if not isinstance (job_components , tuple ):
61+ job_components = (job_components ,)
62+
63+ job_id = "-" .join ([str (jc ["id" ]) for jc in job_components ])
5964 if job_id in all_jobs :
6065 raise ValueError (f"Job ID { job_id } is not unique" )
61- all_jobs [job_id ] = (model_cfg , dataset_cfg )
66+
67+ all_jobs [job_id ] = job_components
6268
6369 print ("\n " + "-" * 80 )
6470 print (f"{ len (all_jobs )} job(s) will be executed:" )
@@ -69,7 +75,7 @@ def batch(command, jobs_cfg):
6975 # Start processing jobs
7076 pbar = tqdm (all_jobs .items (), total = len (all_jobs ), leave = True )
7177 preds_outdir = None
72- for job_id , ( model_cfg , dataset_cfg ) in pbar :
78+ for job_id , job_components in pbar :
7379 job_out_fname = os .path .join (jobs_cfg ["outdir" ], f"{ job_id } .csv" )
7480 if jobs_cfg .get ("store_results_per_sample" , False ):
7581 preds_outdir = os .path .join (jobs_cfg ["outdir" ], f"preds-{ job_id } " )
@@ -84,31 +90,50 @@ def batch(command, jobs_cfg):
8490
8591 ctx = click .get_current_context ()
8692 try :
87- result = ctx .invoke (
88- cli_registry [command ],
89- task = jobs_cfg ["task" ],
90- input_type = jobs_cfg ["input_type" ],
91- model_format = model_cfg ["format" ],
92- model = model_cfg ["path" ],
93- model_ontology = model_cfg ["ontology" ],
94- model_cfg = model_cfg ["cfg" ],
95- dataset_format = dataset_cfg ["format" ],
96- dataset_fname = dataset_cfg .get ("fname" , None ),
97- dataset_dir = dataset_cfg .get ("dir" , None ),
98- split_dir = dataset_cfg .get ("split_dir" , None ),
99- train_dataset_dir = dataset_cfg .get ("train_dir" , None ),
100- val_dataset_dir = dataset_cfg .get ("val_dir" , None ),
101- test_dataset_dir = dataset_cfg .get ("test_dir" , None ),
102- images_dir = dataset_cfg .get ("data_dir" , None ),
103- labels_dir = dataset_cfg .get ("labels_dir" , None ),
104- data_suffix = dataset_cfg .get ("data_suffix" , None ),
105- label_suffix = dataset_cfg .get ("label_suffix" , None ),
106- dataset_ontology = dataset_cfg .get ("ontology" , None ),
107- split = dataset_cfg ["split" ],
108- ontology_translation = jobs_cfg .get ("ontology_translation" , None ),
109- out_fname = job_out_fname ,
110- predictions_outdir = preds_outdir ,
93+ params = {
94+ "task" : jobs_cfg ["task" ],
95+ "input_type" : jobs_cfg ["input_type" ],
96+ }
97+
98+ model_cfg = job_components [0 ]
99+ params .update (
100+ {
101+ "model_format" : model_cfg ["format" ],
102+ "model" : model_cfg ["path" ],
103+ "model_ontology" : model_cfg ["ontology" ],
104+ "model_cfg" : model_cfg ["cfg" ],
105+ # "image_size": model_cfg.get("image_size", None),
106+ }
111107 )
108+ if has_dataset :
109+ dataset_cfg = job_components [1 ]
110+ params .update (
111+ {
112+ "dataset_format" : dataset_cfg .get ("format" , None ),
113+ "dataset_fname" : dataset_cfg .get ("fname" , None ),
114+ "dataset_dir" : dataset_cfg .get ("dir" , None ),
115+ "split_dir" : dataset_cfg .get ("split_dir" , None ),
116+ "train_dataset_dir" : dataset_cfg .get ("train_dir" , None ),
117+ "val_dataset_dir" : dataset_cfg .get ("val_dir" , None ),
118+ "test_dataset_dir" : dataset_cfg .get ("test_dir" , None ),
119+ "images_dir" : dataset_cfg .get ("data_dir" , None ),
120+ "labels_dir" : dataset_cfg .get ("labels_dir" , None ),
121+ "data_suffix" : dataset_cfg .get ("data_suffix" , None ),
122+ "label_suffix" : dataset_cfg .get ("label_suffix" , None ),
123+ "dataset_ontology" : dataset_cfg .get ("ontology" , None ),
124+ "split" : dataset_cfg ["split" ],
125+ "ontology_translation" : jobs_cfg .get (
126+ "ontology_translation" , None
127+ ),
128+ }
129+ )
130+
131+ params .update ({"out_fname" : job_out_fname })
132+ if preds_outdir is not None :
133+ params .update ({"predictions_outdir" : preds_outdir })
134+
135+ result = ctx .invoke (cli_registry [command ], ** params )
136+
112137 except Exception as e :
113138 print (f"Error processing job { job_id } : { e } " )
114139 continue
0 commit comments