@@ -104,6 +104,14 @@ def _pop_sampled_parameters(
104
104
assert isinstance (load_model_kwargs , dict )
105
105
return embedding_name , document_embedding , task_train_kwargs , load_model_kwargs
106
106
107
+ @staticmethod
108
+ def _revert_default_hps_task_train_kwargs (
109
+ task_train_kwargs : Dict [str , ParameterValues ]
110
+ ) -> Dict [str , ParameterValues ]:
111
+ task_train_kwargs ["param_selection_mode" ] = False
112
+ task_train_kwargs ["save_final_model" ] = True
113
+ return task_train_kwargs
114
+
107
115
108
116
@dataclass
109
117
class OptimizedFlairClassificationPipeline (
@@ -157,6 +165,12 @@ def _get_metadata(self, parameters: SampledParameters) -> FlairClassificationPip
157
165
task_train_kwargs ,
158
166
load_model_kwargs ,
159
167
) = self ._pop_sampled_parameters (parameters = parameters )
168
+
169
+ task_train_kwargs = (
170
+ OptimizedFlairClassificationPipeline ._revert_default_hps_task_train_kwargs (
171
+ task_train_kwargs
172
+ )
173
+ )
160
174
metadata : FlairClassificationPipelineMetadata = {
161
175
"embedding_name" : embedding_name ,
162
176
"dataset_name" : str (self .dataset_name_or_path ),
@@ -257,6 +271,11 @@ def _get_metadata(
257
271
task_train_kwargs ,
258
272
load_model_kwargs ,
259
273
) = self ._pop_sampled_parameters (parameters = parameters )
274
+ task_train_kwargs = (
275
+ OptimizedFlairPairClassificationPipeline ._revert_default_hps_task_train_kwargs (
276
+ task_train_kwargs
277
+ )
278
+ )
260
279
metadata : FlairPairClassificationPipelineMetadata = {
261
280
"embedding_name" : embedding_name ,
262
281
"dataset_name" : str (self .dataset_name_or_path ),
@@ -377,13 +396,26 @@ def _pop_sampled_parameters(
377
396
assert isinstance (task_model_kwargs , dict )
378
397
return embedding_name , hidden_size , task_train_kwargs , task_model_kwargs
379
398
399
+ @staticmethod
400
+ def _revert_default_hps_task_train_kwargs (
401
+ task_train_kwargs : Dict [str , ParameterValues ]
402
+ ) -> Dict [str , ParameterValues ]:
403
+ task_train_kwargs ["param_selection_mode" ] = False
404
+ task_train_kwargs ["save_final_model" ] = True
405
+ return task_train_kwargs
406
+
380
407
def _get_metadata (self , parameters : SampledParameters ) -> FlairSequenceLabelingPipelineMetadata :
381
408
(
382
409
embedding_name ,
383
410
hidden_size ,
384
411
task_train_kwargs ,
385
412
task_model_kwargs ,
386
413
) = self ._pop_sampled_parameters (parameters )
414
+ task_train_kwargs = (
415
+ OptimizedFlairSequenceLabelingPipeline ._revert_default_hps_task_train_kwargs (
416
+ task_train_kwargs
417
+ )
418
+ )
387
419
metadata : FlairSequenceLabelingPipelineMetadata = {
388
420
"embedding_name" : embedding_name ,
389
421
"dataset_name" : str (self .dataset_name_or_path ),
0 commit comments