8080 default = SEGMENTATION_HYPER_PARAMS ['min_delta' ],
8181 type = click .FLOAT ,
8282 help = 'Minimum improvement between epochs to reset early stopping. By default it scales the delta by the best loss' )
83- @click .option ('-d' , '--device' , show_default = True , default = 'cpu' , help = 'Select device to use (cpu, cuda:0, cuda:1, ...)' )
8483@click .option ('--optimizer' ,
8584 show_default = True ,
8685 default = SEGMENTATION_HYPER_PARAMS ['optimizer' ],
127126@click .option ('-e' , '--evaluation-files' , show_default = True , default = None , multiple = True ,
128127 callback = _validate_manifests , type = click .File (mode = 'r' , lazy = True ),
129128 help = 'File(s) with paths to evaluation data. Overrides the `-p` parameter' )
130- @click .option ('--workers' , show_default = True , default = 1 , type = click .IntRange (0 ), help = 'Number of data loading worker processes.' )
131129@click .option ('--load-hyper-parameters/--no-load-hyper-parameters' , show_default = True , default = False ,
132130 help = 'When loading an existing model, retrieve hyper-parameters from the model' )
133131@click .option ('--force-binarization/--no-binarization' , show_default = True ,
195193 help = 'Path to directory where the logger will store the logs. If not set, a directory will be created in the current working directory.' )
196194@click .argument ('ground_truth' , nargs = - 1 , callback = _expand_gt , type = click .Path (exists = False , dir_okay = False ))
197195def segtrain (ctx , output , spec , line_width , pad , load , freq , quit , epochs ,
198- min_epochs , lag , min_delta , device , optimizer , lrate ,
199- momentum , weight_decay , warmup , schedule , gamma , step_size ,
200- sched_patience , cos_max , cos_min_lr , partition , training_files ,
201- evaluation_files , workers , load_hyper_parameters ,
202- force_binarization , format_type , suppress_regions ,
203- suppress_baselines , valid_regions , valid_baselines , merge_regions ,
204- merge_baselines , merge_all_baselines , merge_all_regions ,
205- bounding_regions , augment , resize , topline , pl_logger , log_dir ,
206- ground_truth ):
196+ min_epochs , lag , min_delta , optimizer , lrate , momentum ,
197+ weight_decay , warmup , schedule , gamma , step_size , sched_patience ,
198+ cos_max , cos_min_lr , partition , training_files , evaluation_files ,
199+ load_hyper_parameters , force_binarization , format_type ,
200+ suppress_regions , suppress_baselines , valid_regions ,
201+ valid_baselines , merge_regions , merge_baselines ,
202+ merge_all_baselines , merge_all_regions , bounding_regions , augment ,
203+ resize , topline , pl_logger , log_dir , ground_truth ):
207204 """
208205 Trains a baseline labeling model for layout analysis
209206 """
@@ -279,7 +276,7 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs,
279276 topline = loc [topline ]
280277
281278 try :
282- accelerator , device = to_ptl_device (device )
279+ accelerator , device = to_ptl_device (ctx . meta [ ' device' ] )
283280 except Exception as e :
284281 raise click .BadOptionUsage ('device' , str (e ))
285282
@@ -295,7 +292,7 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs,
295292 training_data = ground_truth ,
296293 evaluation_data = evaluation_files ,
297294 partition = partition ,
298- num_workers = workers ,
295+ num_workers = ctx . meta [ ' workers' ] ,
299296 load_hyper_parameters = load_hyper_parameters ,
300297 force_binarization = force_binarization ,
301298 format_type = format_type ,
0 commit comments