Skip to content

Commit 59d8599

Browse files
committed
proper typing in ketos CLI parameters
1 parent dce4a6d commit 59d8599

File tree

7 files changed

+350
-797
lines changed

7 files changed

+350
-797
lines changed

kraken/ketos/__init__.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from kraken.registry import PRECISIONS
3030

3131
from .dataset import compile
32+
from .util import _load_config
3233
from .pretrain import pretrain
3334
from .recognition import test, train
3435
from .repo import publish
@@ -45,41 +46,61 @@
4546
# raise default max image size to 20k * 20k pixels
4647
Image.MAX_IMAGE_PIXELS = 20000 ** 2
4748

49+
from kraken.configs import (Config,
50+
TrainingDataConfig,
51+
VGSLRecognitionTrainingConfig,
52+
VGSLRecognitionTrainingDataConfig,
53+
BLLASegmentationTrainingConfig,
54+
BLLASegmentationTrainingDataConfig)
55+
56+
@click.group(context_settings=dict(show_default=True,
57+
default_map={**Config().__dict__,
58+
**TrainingDataConfig().__dict__,
59+
'train': {**VGSLRecognitionTrainingConfig().__dict__, **VGSLRecognitionTrainingDataConfig().__dict__},
60+
'test': VGSLRecognitionTrainingDataConfig().__dict__,
61+
'segtrain': {**BLLASegmentationTrainingConfig().__dict__, **BLLASegmentationTrainingDataConfig().__dict__},
62+
'segtest': {**BLLASegmentationTrainingConfig().__dict__, **BLLASegmentationTrainingDataConfig().__dict__}}))
4863

49-
@click.group()
5064
@click.version_option()
5165
@click.pass_context
5266
@click.option('-v', '--verbose', default=0, count=True)
53-
@click.option('-d', '--device', default='cpu', show_default=True,
67+
@click.option('-d', '--device', show_default=True,
5468
help='Select device to use (cpu, cuda:0, cuda:1, ...)')
5569
@click.option('--precision',
56-
show_default=True,
57-
default='32-true',
5870
type=click.Choice(PRECISIONS),
5971
help='Numerical precision to use for training. Default is 32-bit single-point precision.')
60-
@click.option('--workers', show_default=True, default=1, type=click.IntRange(0), help='Number of data loading worker processes.')
61-
@click.option('--threads', show_default=True, default=1, type=click.IntRange(1), help='Maximum size of OpenMP/BLAS thread pool.')
72+
@click.option('--workers', 'num_workers', type=click.IntRange(0), help='Number of data loading worker processes.')
73+
@click.option('--threads', 'num_threads', type=click.IntRange(1), help='Maximum size of OpenMP/BLAS thread pool.')
6274
@click.option('-s', '--seed', default=None, type=click.INT,
6375
help='Seed for numpy\'s and torch\'s RNG. Set to a fixed value to '
6476
'ensure reproducible random splits of data')
65-
@click.option('-r', '--deterministic/--no-deterministic', default=False,
77+
@click.option('-r', '--deterministic/--no-deterministic',
6678
help="Enables deterministic training. If no seed is given and enabled the seed will be set to 42.")
67-
def cli(ctx, verbose, device, precision, workers, threads, seed, deterministic):
68-
ctx.meta['deterministic'] = False if not deterministic else 'warn'
69-
if seed:
79+
@click.option('--config',
80+
type=click.File(mode='r', lazy=True),
81+
help="Path to configuration file.",
82+
callback=_load_config,
83+
is_eager=True,
84+
expose_value=False,
85+
required=False)
86+
def cli(ctx, **kwargs):
87+
params = ctx.params
88+
89+
ctx.meta['deterministic'] = False if not params['deterministic'] else 'warn'
90+
if params['seed']:
7091
from lightning.pytorch import seed_everything
7192
seed_everything(seed, workers=True)
72-
elif deterministic:
93+
elif params['deterministic']:
7394
from lightning.pytorch import seed_everything
7495
seed_everything(42, workers=True)
7596

76-
ctx.meta['verbose'] = verbose
77-
ctx.meta['device'] = device
78-
ctx.meta['precision'] = precision
79-
ctx.meta['workers'] = workers
80-
ctx.meta['threads'] = threads
97+
ctx.meta['verbose'] = params.get('verbose')
98+
ctx.meta['device'] = params.get('device')
99+
ctx.meta['precision'] = params.get('precision')
100+
ctx.meta['num_workers'] = params.get('num_workers')
101+
ctx.meta['num_threads'] = params.get('num_threads')
81102

82-
log.set_logger(logger, level=30 - min(10 * verbose, 20))
103+
log.set_logger(logger, level=30 - min(10 * params['verbose'], 20))
83104

84105

85106
cli.add_command(compile)

kraken/ketos/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _update_bar(advance, total):
9292
arrow_dataset.build_binary_dataset(ground_truth,
9393
output,
9494
format_type,
95-
ctx.meta['workers'],
95+
ctx.meta['num_workers'],
9696
save_splits,
9797
random_split,
9898
force_type,

kraken/ketos/pretrain.py

Lines changed: 57 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
from PIL import Image
2525

2626
from kraken.registry import OPTIMIZERS, SCHEDULERS, STOPPERS
27-
from kraken.lib.default_specs import (RECOGNITION_PRETRAIN_HYPER_PARAMS,
28-
RECOGNITION_SPEC)
2927

3028
from .util import _expand_gt, _validate_manifests, message, to_ptl_device
3129

@@ -38,136 +36,118 @@
3836

3937
@click.command('pretrain')
4038
@click.pass_context
41-
@click.option('-B', '--batch-size', show_default=True, type=click.INT,
42-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['batch_size'], help='batch sample size')
43-
@click.option('--pad', show_default=True, type=click.INT,
44-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['pad'],
45-
help='Left and right padding around lines')
46-
@click.option('-o', '--output', show_default=True, type=click.Path(), default='model', help='Output model file')
47-
@click.option('-s', '--spec', show_default=True, default=RECOGNITION_SPEC,
48-
help='VGSL spec of the network to train.')
49-
@click.option('-i', '--load', show_default=True, type=click.Path(exists=True,
50-
readable=True), help='Load existing file to continue training')
51-
@click.option('-F', '--freq', show_default=True, default=RECOGNITION_PRETRAIN_HYPER_PARAMS['freq'], type=click.FLOAT,
39+
@click.option('-B', '--batch-size', type=int, help='batch sample size')
40+
@click.option('--pad', 'padding', type=int, help='Left and right padding around lines')
41+
@click.option('-o', '--output', 'checkpoint_path', type=click.Path(), help='Output checkpoint path')
42+
@click.option('-s', '--spec', help='VGSL spec of the network to train.')
43+
@click.option('-i', '--load', type=click.Path(exists=True, readable=True),
44+
help='Load existing file to continue training')
45+
@click.option('-F',
46+
'--freq',
47+
type=float,
5248
help='Model saving and report generation frequency in epochs '
5349
'during training. If frequency is >1 it must be an integer, '
5450
'i.e. running validation every n-th epoch.')
5551
@click.option('-q',
5652
'--quit',
57-
show_default=True,
58-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['quit'],
5953
type=click.Choice(STOPPERS),
6054
help='Stop condition for training. Set to `early` for early stooping or `fixed` for fixed number of epochs')
6155
@click.option('-N',
6256
'--epochs',
63-
show_default=True,
64-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['epochs'],
57+
type=int,
6558
help='Number of epochs to train for')
6659
@click.option('--min-epochs',
67-
show_default=True,
68-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['min_epochs'],
60+
type=int,
6961
help='Minimal number of epochs to train for when using early stopping.')
7062
@click.option('--lag',
71-
show_default=True,
72-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['lag'],
63+
type=int,
7364
help='Number of evaluations (--report frequency) to wait before stopping training without improvement')
7465
@click.option('--min-delta',
75-
show_default=True,
76-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['min_delta'],
77-
type=click.FLOAT,
66+
type=float,
7867
help='Minimum improvement between epochs to reset early stopping. Default is scales the delta by the best loss')
7968
@click.option('--optimizer',
80-
show_default=True,
81-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['optimizer'],
8269
type=click.Choice(OPTIMIZERS),
8370
help='Select optimizer')
84-
@click.option('-r', '--lrate', show_default=True, default=RECOGNITION_PRETRAIN_HYPER_PARAMS['lrate'], help='Learning rate')
85-
@click.option('-m', '--momentum', show_default=True, default=RECOGNITION_PRETRAIN_HYPER_PARAMS['momentum'], help='Momentum')
86-
@click.option('-w', '--weight-decay', show_default=True, type=float,
87-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['weight_decay'], help='Weight decay')
88-
@click.option('--warmup', show_default=True, type=float,
89-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['warmup'], help='Number of samples to ramp up to `lrate` initial learning rate.')
71+
@click.option('-r',
72+
'--lrate',
73+
type=float,
74+
help='Learning rate')
75+
@click.option('-m',
76+
'--momentum',
77+
type=float,
78+
help='Momentum')
79+
@click.option('-w',
80+
'--weight-decay',
81+
type=float, help='Weight decay')
82+
@click.option('--warmup',
83+
type=float,
84+
help='Number of samples to ramp up to `lrate` initial learning rate.')
9085
@click.option('--schedule',
91-
show_default=True,
9286
type=click.Choice(SCHEDULERS),
93-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['schedule'],
9487
help='Set learning rate scheduler. For 1cycle, cycle length is determined by the `--epoch` option.')
9588
@click.option('-g',
9689
'--gamma',
97-
show_default=True,
98-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['gamma'],
90+
type=float,
9991
help='Decay factor for exponential, step, and reduceonplateau learning rate schedules')
10092
@click.option('-ss',
10193
'--step-size',
102-
show_default=True,
103-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['step_size'],
94+
type=int,
10495
help='Number of validation runs between learning rate decay for exponential and step LR schedules')
10596
@click.option('--sched-patience',
106-
show_default=True,
107-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['rop_patience'],
97+
'rop_patience',
98+
type=int,
10899
help='Minimal number of validation runs between LR reduction for reduceonplateau LR schedule.')
109100
@click.option('--cos-max',
110-
show_default=True,
111-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['cos_t_max'],
101+
'cos_max_t',
102+
type=int,
112103
help='Epoch of minimal learning rate for cosine LR scheduler.')
113104
@click.option('--cos-min-lr',
114-
show_default=True,
115-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['cos_min_lr'],
105+
type=float,
116106
help='Minimal final learning rate for cosine LR scheduler.')
117-
@click.option('-p', '--partition', show_default=True, default=0.9,
107+
@click.option('-p',
108+
'--partition',
109+
type=float,
118110
help='Ground truth data partition ratio between train/validation set')
119-
@click.option('--fixed-splits/--ignore-fixed-splits', show_default=True, default=False,
111+
@click.option('--fixed-splits/--ignore-fixed-splits', default=False,
120112
help='Whether to honor fixed splits in binary datasets.')
121-
@click.option('-t', '--training-files', show_default=True, default=None, multiple=True,
113+
@click.option('-t', '--training-files', default=None, multiple=True,
122114
callback=_validate_manifests, type=click.File(mode='r', lazy=True),
123115
help='File(s) with additional paths to training data')
124-
@click.option('-e', '--evaluation-files', show_default=True, default=None, multiple=True,
116+
@click.option('-e', '--evaluation-files', default=None, multiple=True,
125117
callback=_validate_manifests, type=click.File(mode='r', lazy=True),
126118
help='File(s) with paths to evaluation data. Overrides the `-p` parameter')
127-
@click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False,
119+
@click.option('--load-hyper-parameters/--no-load-hyper-parameters', default=False,
128120
help='When loading an existing model, retrieve hyperparameters from the model')
129-
@click.option('--force-binarization/--no-binarization', show_default=True,
130-
default=False, help='Forces input images to be binary, otherwise '
131-
'the appropriate color format will be auto-determined through the '
132-
'network specification. Will be ignored in `path` mode.')
133-
@click.option('-f', '--format-type', type=click.Choice(['path', 'xml', 'alto', 'page', 'binary']), default='path',
121+
@click.option('-f', '--format-type', type=click.Choice(['path', 'xml', 'alto', 'page', 'binary']),
134122
help='Sets the training data format. In ALTO and PageXML mode all '
135123
'data is extracted from xml files containing both line definitions and a '
136124
'link to source images. In `path` mode arguments are image files '
137125
'sharing a prefix up to the last extension with `.gt.txt` text files '
138126
'containing the transcription. In binary mode files are datasets '
139127
'files containing pre-extracted text lines.')
140128
@click.option('--augment/--no-augment',
141-
show_default=True,
142-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['augment'],
143129
help='Enable image augmentation')
144-
@click.option('-mw', '--mask-width', show_default=True,
145-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['mask_width'],
130+
@click.option('-mw',
131+
'--mask-width',
132+
type=int,
146133
help='Width of sampled masks at scale of the sampled tensor, e.g. '
147134
'4X subsampling in convolutional layers with mask width 3 results '
148135
'in an effective mask width of 12.')
149-
@click.option('-mp', '--mask-probability',
150-
show_default=True,
151-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['mask_prob'],
136+
@click.option('-mp',
137+
'--mask-probability',
138+
type=float,
152139
help='Probability of a particular position being the start position of a mask.')
153-
@click.option('-nn', '--num-negatives',
154-
show_default=True,
155-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['num_negatives'],
140+
@click.option('-nn',
141+
'--num-negatives',
142+
type=int,
156143
help='Number of negative samples for the contrastive loss.')
157-
@click.option('-lt', '--logit-temp',
158-
show_default=True,
159-
default=RECOGNITION_PRETRAIN_HYPER_PARAMS['logit_temp'],
144+
@click.option('-lt',
145+
'--logit-temp',
146+
type=float,
160147
help='Multiplicative factor for the logits used in contrastive loss.')
148+
@click.option('--legacy-polygons', default=False, is_flag=True, help='Use the legacy polygon extractor.')
161149
@click.argument('ground_truth', nargs=-1, callback=_expand_gt, type=click.Path(exists=False, dir_okay=False))
162-
@click.option('--legacy-polygons', show_default=True, default=False, is_flag=True, help='Use the legacy polygon extractor.')
163-
def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs,
164-
min_epochs, lag, min_delta, optimizer, lrate, momentum,
165-
weight_decay, warmup, schedule, gamma, step_size, sched_patience,
166-
cos_max, cos_min_lr, partition, fixed_splits, training_files,
167-
evaluation_files, load_hyper_parameters,
168-
force_binarization, format_type, augment, mask_probability,
169-
mask_width, num_negatives, logit_temp, ground_truth,
170-
legacy_polygons):
150+
def pretrain(ctx, **kwargs):
171151
"""
172152
Trains a model from image-text pairs.
173153
"""
@@ -188,32 +168,6 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs,
188168
RecognitionPretrainModel)
189169
from kraken.lib.train import KrakenTrainer
190170

191-
hyper_params = RECOGNITION_PRETRAIN_HYPER_PARAMS.copy()
192-
hyper_params.update({'freq': freq,
193-
'pad': pad,
194-
'batch_size': batch_size,
195-
'quit': quit,
196-
'epochs': epochs,
197-
'min_epochs': min_epochs,
198-
'lag': lag,
199-
'min_delta': min_delta,
200-
'optimizer': optimizer,
201-
'lrate': lrate,
202-
'momentum': momentum,
203-
'weight_decay': weight_decay,
204-
'warmup': warmup,
205-
'schedule': schedule,
206-
'gamma': gamma,
207-
'step_size': step_size,
208-
'rop_patience': sched_patience,
209-
'cos_t_max': cos_max,
210-
'cos_min_lr': cos_min_lr,
211-
'augment': augment,
212-
'mask_prob': mask_probability,
213-
'mask_width': mask_width,
214-
'num_negatives': num_negatives,
215-
'logit_temp': logit_temp})
216-
217171
# disable automatic partition when given evaluation set explicitly
218172
if evaluation_files:
219173
partition = 1
@@ -250,7 +204,7 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs,
250204
evaluation_data=evaluation_files,
251205
partition=partition,
252206
binary_dataset_split=fixed_splits,
253-
num_workers=ctx.meta['workers'],
207+
num_workers=ctx.meta['num_workers'],
254208
height=model.height,
255209
width=model.width,
256210
channels=model.channels,

0 commit comments

Comments
 (0)