2424from PIL import Image
2525
2626from kraken .registry import OPTIMIZERS , SCHEDULERS , STOPPERS
27- from kraken .lib .default_specs import (RECOGNITION_PRETRAIN_HYPER_PARAMS ,
28- RECOGNITION_SPEC )
2927
3028from .util import _expand_gt , _validate_manifests , message , to_ptl_device
3129
3836
3937@click .command ('pretrain' )
4038@click .pass_context
41- @click .option ('-B' , '--batch-size' , show_default = True , type = click .INT ,
42- default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['batch_size' ], help = 'batch sample size' )
43- @click .option ('--pad' , show_default = True , type = click .INT ,
44- default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['pad' ],
45- help = 'Left and right padding around lines' )
46- @click .option ('-o' , '--output' , show_default = True , type = click .Path (), default = 'model' , help = 'Output model file' )
47- @click .option ('-s' , '--spec' , show_default = True , default = RECOGNITION_SPEC ,
48- help = 'VGSL spec of the network to train.' )
49- @click .option ('-i' , '--load' , show_default = True , type = click .Path (exists = True ,
50- readable = True ), help = 'Load existing file to continue training' )
51- @click .option ('-F' , '--freq' , show_default = True , default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['freq' ], type = click .FLOAT ,
39+ @click .option ('-B' , '--batch-size' , type = int , help = 'batch sample size' )
40+ @click .option ('--pad' , 'padding' , type = int , help = 'Left and right padding around lines' )
41+ @click .option ('-o' , '--output' , 'checkpoint_path' , type = click .Path (), help = 'Output checkpoint path' )
42+ @click .option ('-s' , '--spec' , help = 'VGSL spec of the network to train.' )
43+ @click .option ('-i' , '--load' , type = click .Path (exists = True , readable = True ),
44+ help = 'Load existing file to continue training' )
45+ @click .option ('-F' ,
46+ '--freq' ,
47+ type = float ,
5248 help = 'Model saving and report generation frequency in epochs '
5349 'during training. If frequency is >1 it must be an integer, '
5450 'i.e. running validation every n-th epoch.' )
5551@click .option ('-q' ,
5652 '--quit' ,
57- show_default = True ,
58- default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['quit' ],
5953 type = click .Choice (STOPPERS ),
6054 help = 'Stop condition for training. Set to `early` for early stooping or `fixed` for fixed number of epochs' )
6155@click .option ('-N' ,
6256 '--epochs' ,
63- show_default = True ,
64- default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['epochs' ],
57+ type = int ,
6558 help = 'Number of epochs to train for' )
6659@click .option ('--min-epochs' ,
67- show_default = True ,
68- default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['min_epochs' ],
60+ type = int ,
6961 help = 'Minimal number of epochs to train for when using early stopping.' )
7062@click .option ('--lag' ,
71- show_default = True ,
72- default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['lag' ],
63+ type = int ,
7364 help = 'Number of evaluations (--report frequency) to wait before stopping training without improvement' )
7465@click .option ('--min-delta' ,
75- show_default = True ,
76- default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['min_delta' ],
77- type = click .FLOAT ,
66+ type = float ,
7867 help = 'Minimum improvement between epochs to reset early stopping. Default is scales the delta by the best loss' )
7968@click .option ('--optimizer' ,
80- show_default = True ,
81- default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['optimizer' ],
8269 type = click .Choice (OPTIMIZERS ),
8370 help = 'Select optimizer' )
84- @click .option ('-r' , '--lrate' , show_default = True , default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['lrate' ], help = 'Learning rate' )
85- @click .option ('-m' , '--momentum' , show_default = True , default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['momentum' ], help = 'Momentum' )
86- @click .option ('-w' , '--weight-decay' , show_default = True , type = float ,
87- default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['weight_decay' ], help = 'Weight decay' )
88- @click .option ('--warmup' , show_default = True , type = float ,
89- default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['warmup' ], help = 'Number of samples to ramp up to `lrate` initial learning rate.' )
71+ @click .option ('-r' ,
72+ '--lrate' ,
73+ type = float ,
74+ help = 'Learning rate' )
75+ @click .option ('-m' ,
76+ '--momentum' ,
77+ type = float ,
78+ help = 'Momentum' )
79+ @click .option ('-w' ,
80+ '--weight-decay' ,
81+ type = float , help = 'Weight decay' )
82+ @click .option ('--warmup' ,
83+ type = float ,
84+ help = 'Number of samples to ramp up to `lrate` initial learning rate.' )
9085@click .option ('--schedule' ,
91- show_default = True ,
9286 type = click .Choice (SCHEDULERS ),
93- default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['schedule' ],
9487 help = 'Set learning rate scheduler. For 1cycle, cycle length is determined by the `--epoch` option.' )
9588@click .option ('-g' ,
9689 '--gamma' ,
97- show_default = True ,
98- default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['gamma' ],
90+ type = float ,
9991 help = 'Decay factor for exponential, step, and reduceonplateau learning rate schedules' )
10092@click .option ('-ss' ,
10193 '--step-size' ,
102- show_default = True ,
103- default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['step_size' ],
94+ type = int ,
10495 help = 'Number of validation runs between learning rate decay for exponential and step LR schedules' )
10596@click .option ('--sched-patience' ,
106- show_default = True ,
107- default = RECOGNITION_PRETRAIN_HYPER_PARAMS [ 'rop_patience' ] ,
97+ 'rop_patience' ,
98+ type = int ,
10899 help = 'Minimal number of validation runs between LR reduction for reduceonplateau LR schedule.' )
109100@click .option ('--cos-max' ,
110- show_default = True ,
111- default = RECOGNITION_PRETRAIN_HYPER_PARAMS [ 'cos_t_max' ] ,
101+ 'cos_max_t' ,
102+ type = int ,
112103 help = 'Epoch of minimal learning rate for cosine LR scheduler.' )
113104@click .option ('--cos-min-lr' ,
114- show_default = True ,
115- default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['cos_min_lr' ],
105+ type = float ,
116106 help = 'Minimal final learning rate for cosine LR scheduler.' )
117- @click .option ('-p' , '--partition' , show_default = True , default = 0.9 ,
107+ @click .option ('-p' ,
108+ '--partition' ,
109+ type = float ,
118110 help = 'Ground truth data partition ratio between train/validation set' )
119- @click .option ('--fixed-splits/--ignore-fixed-splits' , show_default = True , default = False ,
111+ @click .option ('--fixed-splits/--ignore-fixed-splits' , default = False ,
120112 help = 'Whether to honor fixed splits in binary datasets.' )
121- @click .option ('-t' , '--training-files' , show_default = True , default = None , multiple = True ,
113+ @click .option ('-t' , '--training-files' , default = None , multiple = True ,
122114 callback = _validate_manifests , type = click .File (mode = 'r' , lazy = True ),
123115 help = 'File(s) with additional paths to training data' )
124- @click .option ('-e' , '--evaluation-files' , show_default = True , default = None , multiple = True ,
116+ @click .option ('-e' , '--evaluation-files' , default = None , multiple = True ,
125117 callback = _validate_manifests , type = click .File (mode = 'r' , lazy = True ),
126118 help = 'File(s) with paths to evaluation data. Overrides the `-p` parameter' )
127- @click .option ('--load-hyper-parameters/--no-load-hyper-parameters' , show_default = True , default = False ,
119+ @click .option ('--load-hyper-parameters/--no-load-hyper-parameters' , default = False ,
128120 help = 'When loading an existing model, retrieve hyperparameters from the model' )
129- @click .option ('--force-binarization/--no-binarization' , show_default = True ,
130- default = False , help = 'Forces input images to be binary, otherwise '
131- 'the appropriate color format will be auto-determined through the '
132- 'network specification. Will be ignored in `path` mode.' )
133- @click .option ('-f' , '--format-type' , type = click .Choice (['path' , 'xml' , 'alto' , 'page' , 'binary' ]), default = 'path' ,
121+ @click .option ('-f' , '--format-type' , type = click .Choice (['path' , 'xml' , 'alto' , 'page' , 'binary' ]),
134122 help = 'Sets the training data format. In ALTO and PageXML mode all '
135123 'data is extracted from xml files containing both line definitions and a '
136124 'link to source images. In `path` mode arguments are image files '
137125 'sharing a prefix up to the last extension with `.gt.txt` text files '
138126 'containing the transcription. In binary mode files are datasets '
139127 'files containing pre-extracted text lines.' )
140128@click .option ('--augment/--no-augment' ,
141- show_default = True ,
142- default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['augment' ],
143129 help = 'Enable image augmentation' )
144- @click .option ('-mw' , '--mask-width' , show_default = True ,
145- default = RECOGNITION_PRETRAIN_HYPER_PARAMS ['mask_width' ],
130+ @click .option ('-mw' ,
131+ '--mask-width' ,
132+ type = int ,
146133 help = 'Width of sampled masks at scale of the sampled tensor, e.g. '
147134 '4X subsampling in convolutional layers with mask width 3 results '
148135 'in an effective mask width of 12.' )
149- @click .option ('-mp' , '--mask-probability' ,
150- show_default = True ,
151- default = RECOGNITION_PRETRAIN_HYPER_PARAMS [ 'mask_prob' ] ,
136+ @click .option ('-mp' ,
137+ '--mask-probability' ,
138+ type = float ,
152139 help = 'Probability of a particular position being the start position of a mask.' )
153- @click .option ('-nn' , '--num-negatives' ,
154- show_default = True ,
155- default = RECOGNITION_PRETRAIN_HYPER_PARAMS [ 'num_negatives' ] ,
140+ @click .option ('-nn' ,
141+ '--num-negatives' ,
142+ type = int ,
156143 help = 'Number of negative samples for the contrastive loss.' )
157- @click .option ('-lt' , '--logit-temp' ,
158- show_default = True ,
159- default = RECOGNITION_PRETRAIN_HYPER_PARAMS [ 'logit_temp' ] ,
144+ @click .option ('-lt' ,
145+ '--logit-temp' ,
146+ type = float ,
160147 help = 'Multiplicative factor for the logits used in contrastive loss.' )
148+ @click .option ('--legacy-polygons' , default = False , is_flag = True , help = 'Use the legacy polygon extractor.' )
161149@click .argument ('ground_truth' , nargs = - 1 , callback = _expand_gt , type = click .Path (exists = False , dir_okay = False ))
162- @click .option ('--legacy-polygons' , show_default = True , default = False , is_flag = True , help = 'Use the legacy polygon extractor.' )
163- def pretrain (ctx , batch_size , pad , output , spec , load , freq , quit , epochs ,
164- min_epochs , lag , min_delta , optimizer , lrate , momentum ,
165- weight_decay , warmup , schedule , gamma , step_size , sched_patience ,
166- cos_max , cos_min_lr , partition , fixed_splits , training_files ,
167- evaluation_files , load_hyper_parameters ,
168- force_binarization , format_type , augment , mask_probability ,
169- mask_width , num_negatives , logit_temp , ground_truth ,
170- legacy_polygons ):
150+ def pretrain (ctx , ** kwargs ):
171151 """
172152 Trains a model from image-text pairs.
173153 """
@@ -188,32 +168,6 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs,
188168 RecognitionPretrainModel )
189169 from kraken .lib .train import KrakenTrainer
190170
191- hyper_params = RECOGNITION_PRETRAIN_HYPER_PARAMS .copy ()
192- hyper_params .update ({'freq' : freq ,
193- 'pad' : pad ,
194- 'batch_size' : batch_size ,
195- 'quit' : quit ,
196- 'epochs' : epochs ,
197- 'min_epochs' : min_epochs ,
198- 'lag' : lag ,
199- 'min_delta' : min_delta ,
200- 'optimizer' : optimizer ,
201- 'lrate' : lrate ,
202- 'momentum' : momentum ,
203- 'weight_decay' : weight_decay ,
204- 'warmup' : warmup ,
205- 'schedule' : schedule ,
206- 'gamma' : gamma ,
207- 'step_size' : step_size ,
208- 'rop_patience' : sched_patience ,
209- 'cos_t_max' : cos_max ,
210- 'cos_min_lr' : cos_min_lr ,
211- 'augment' : augment ,
212- 'mask_prob' : mask_probability ,
213- 'mask_width' : mask_width ,
214- 'num_negatives' : num_negatives ,
215- 'logit_temp' : logit_temp })
216-
217171 # disable automatic partition when given evaluation set explicitly
218172 if evaluation_files :
219173 partition = 1
@@ -250,7 +204,7 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs,
250204 evaluation_data = evaluation_files ,
251205 partition = partition ,
252206 binary_dataset_split = fixed_splits ,
253- num_workers = ctx .meta ['workers ' ],
207+ num_workers = ctx .meta ['num_workers ' ],
254208 height = model .height ,
255209 width = model .width ,
256210 channels = model .channels ,
0 commit comments