Skip to content

Commit 51f84c9

Browse files
committed
pull device and worker parameters directly from base context.
Fixes #743
1 parent bdc0aa8 commit 51f84c9

File tree

2 files changed

+15
-21
lines changed

2 files changed

+15
-21
lines changed

kraken/ketos/ro.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,6 @@
118118
@click.option('-e', '--evaluation-files', show_default=True, default=None, multiple=True,
119119
callback=_validate_manifests, type=click.File(mode='r', lazy=True),
120120
help='File(s) with paths to evaluation data. Overrides the `-p` parameter')
121-
@click.option('--workers', show_default=True, default=1, type=click.IntRange(0), help='Number of worker proesses.')
122-
@click.option('--threads', show_default=True, default=1, type=click.IntRange(1), help='Maximum size of OpenMP/BLAS thread pool.')
123121
@click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False,
124122
help='When loading an existing model, retrieve hyper-parameters from the model')
125123
@click.option('-f', '--format-type', type=click.Choice(['xml', 'alto', 'page']), default='xml',
@@ -151,10 +149,9 @@
151149
def rotrain(ctx, batch_size, output, load, freq, quit, epochs, min_epochs, lag,
152150
min_delta, optimizer, lrate, momentum, weight_decay, warmup,
153151
schedule, gamma, step_size, sched_patience, cos_max, cos_min_lr,
154-
partition, training_files, evaluation_files, workers, threads,
155-
load_hyper_parameters, format_type, valid_entities, merge_entities,
156-
merge_all_entities, pl_logger, log_dir, level, reading_order,
157-
ground_truth):
152+
partition, training_files, evaluation_files, load_hyper_parameters,
153+
format_type, valid_entities, merge_entities, merge_all_entities,
154+
pl_logger, log_dir, level, reading_order, ground_truth):
158155
"""
159156
Trains a baseline labeling model for layout analysis
160157
"""
@@ -235,7 +232,7 @@ def rotrain(ctx, batch_size, output, load, freq, quit, epochs, min_epochs, lag,
235232
training_data=ground_truth,
236233
evaluation_data=evaluation_files,
237234
partition=partition,
238-
num_workers=workers,
235+
num_workers=ctx.meta['workers'],
239236
format_type=format_type,
240237
class_mapping=class_mapping,
241238
valid_entities=valid_entities,
@@ -269,7 +266,7 @@ def rotrain(ctx, batch_size, output, load, freq, quit, epochs, min_epochs, lag,
269266
log_dir=log_dir,
270267
**val_check_interval)
271268

272-
with threadpool_limits(limits=threads):
269+
with threadpool_limits(limits=ctx.meta['threads']):
273270
trainer.fit(model, dm)
274271

275272
if model.best_epoch == -1:

kraken/ketos/segmentation.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@
8080
default=SEGMENTATION_HYPER_PARAMS['min_delta'],
8181
type=click.FLOAT,
8282
help='Minimum improvement between epochs to reset early stopping. By default it scales the delta by the best loss')
83-
@click.option('-d', '--device', show_default=True, default='cpu', help='Select device to use (cpu, cuda:0, cuda:1, ...)')
8483
@click.option('--optimizer',
8584
show_default=True,
8685
default=SEGMENTATION_HYPER_PARAMS['optimizer'],
@@ -127,7 +126,6 @@
127126
@click.option('-e', '--evaluation-files', show_default=True, default=None, multiple=True,
128127
callback=_validate_manifests, type=click.File(mode='r', lazy=True),
129128
help='File(s) with paths to evaluation data. Overrides the `-p` parameter')
130-
@click.option('--workers', show_default=True, default=1, type=click.IntRange(0), help='Number of data loading worker processes.')
131129
@click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False,
132130
help='When loading an existing model, retrieve hyper-parameters from the model')
133131
@click.option('--force-binarization/--no-binarization', show_default=True,
@@ -195,15 +193,14 @@
195193
help='Path to directory where the logger will store the logs. If not set, a directory will be created in the current working directory.')
196194
@click.argument('ground_truth', nargs=-1, callback=_expand_gt, type=click.Path(exists=False, dir_okay=False))
197195
def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs,
198-
min_epochs, lag, min_delta, device, optimizer, lrate,
199-
momentum, weight_decay, warmup, schedule, gamma, step_size,
200-
sched_patience, cos_max, cos_min_lr, partition, training_files,
201-
evaluation_files, workers, load_hyper_parameters,
202-
force_binarization, format_type, suppress_regions,
203-
suppress_baselines, valid_regions, valid_baselines, merge_regions,
204-
merge_baselines, merge_all_baselines, merge_all_regions,
205-
bounding_regions, augment, resize, topline, pl_logger, log_dir,
206-
ground_truth):
196+
min_epochs, lag, min_delta, optimizer, lrate, momentum,
197+
weight_decay, warmup, schedule, gamma, step_size, sched_patience,
198+
cos_max, cos_min_lr, partition, training_files, evaluation_files,
199+
load_hyper_parameters, force_binarization, format_type,
200+
suppress_regions, suppress_baselines, valid_regions,
201+
valid_baselines, merge_regions, merge_baselines,
202+
merge_all_baselines, merge_all_regions, bounding_regions, augment,
203+
resize, topline, pl_logger, log_dir, ground_truth):
207204
"""
208205
Trains a baseline labeling model for layout analysis
209206
"""
@@ -279,7 +276,7 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs,
279276
topline = loc[topline]
280277

281278
try:
282-
accelerator, device = to_ptl_device(device)
279+
accelerator, device = to_ptl_device(ctx.meta['device'])
283280
except Exception as e:
284281
raise click.BadOptionUsage('device', str(e))
285282

@@ -295,7 +292,7 @@ def segtrain(ctx, output, spec, line_width, pad, load, freq, quit, epochs,
295292
training_data=ground_truth,
296293
evaluation_data=evaluation_files,
297294
partition=partition,
298-
num_workers=workers,
295+
num_workers=ctx.meta['workers'],
299296
load_hyper_parameters=load_hyper_parameters,
300297
force_binarization=force_binarization,
301298
format_type=format_type,

0 commit comments

Comments
 (0)