Skip to content

Commit 4e957f2

Browse files
committed
fix(hps_metadata): Fix flair hps_metadata
1 parent c79029b commit 4e957f2

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

embeddings/pipeline/flair_hps_pipeline.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ def _pop_sampled_parameters(
104104
assert isinstance(load_model_kwargs, dict)
105105
return embedding_name, document_embedding, task_train_kwargs, load_model_kwargs
106106

107+
@staticmethod
108+
def _revert_default_hps_task_train_kwargs(
109+
task_train_kwargs: Dict[str, ParameterValues]
110+
) -> Dict[str, ParameterValues]:
111+
task_train_kwargs["param_selection_mode"] = False
112+
task_train_kwargs["save_final_model"] = True
113+
return task_train_kwargs
114+
107115

108116
@dataclass
109117
class OptimizedFlairClassificationPipeline(
@@ -157,6 +165,12 @@ def _get_metadata(self, parameters: SampledParameters) -> FlairClassificationPip
157165
task_train_kwargs,
158166
load_model_kwargs,
159167
) = self._pop_sampled_parameters(parameters=parameters)
168+
169+
task_train_kwargs = (
170+
OptimizedFlairClassificationPipeline._revert_default_hps_task_train_kwargs(
171+
task_train_kwargs
172+
)
173+
)
160174
metadata: FlairClassificationPipelineMetadata = {
161175
"embedding_name": embedding_name,
162176
"dataset_name": str(self.dataset_name_or_path),
@@ -257,6 +271,11 @@ def _get_metadata(
257271
task_train_kwargs,
258272
load_model_kwargs,
259273
) = self._pop_sampled_parameters(parameters=parameters)
274+
task_train_kwargs = (
275+
OptimizedFlairPairClassificationPipeline._revert_default_hps_task_train_kwargs(
276+
task_train_kwargs
277+
)
278+
)
260279
metadata: FlairPairClassificationPipelineMetadata = {
261280
"embedding_name": embedding_name,
262281
"dataset_name": str(self.dataset_name_or_path),
@@ -377,13 +396,26 @@ def _pop_sampled_parameters(
377396
assert isinstance(task_model_kwargs, dict)
378397
return embedding_name, hidden_size, task_train_kwargs, task_model_kwargs
379398

399+
@staticmethod
400+
def _revert_default_hps_task_train_kwargs(
401+
task_train_kwargs: Dict[str, ParameterValues]
402+
) -> Dict[str, ParameterValues]:
403+
task_train_kwargs["param_selection_mode"] = False
404+
task_train_kwargs["save_final_model"] = True
405+
return task_train_kwargs
406+
380407
def _get_metadata(self, parameters: SampledParameters) -> FlairSequenceLabelingPipelineMetadata:
381408
(
382409
embedding_name,
383410
hidden_size,
384411
task_train_kwargs,
385412
task_model_kwargs,
386413
) = self._pop_sampled_parameters(parameters)
414+
task_train_kwargs = (
415+
OptimizedFlairSequenceLabelingPipeline._revert_default_hps_task_train_kwargs(
416+
task_train_kwargs
417+
)
418+
)
387419
metadata: FlairSequenceLabelingPipelineMetadata = {
388420
"embedding_name": embedding_name,
389421
"dataset_name": str(self.dataset_name_or_path),

0 commit comments

Comments
 (0)