Skip to content

Commit d902d26

Browse files
committed
refactor: Refactor code after review
1 parent 9e69742 commit d902d26

File tree

2 files changed

+12
-18
lines changed

2 files changed

+12
-18
lines changed

embeddings/pipeline/flair_hps_pipeline.py

+10-16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC
2+
from copy import deepcopy
23
from dataclasses import dataclass, field
34
from pathlib import Path
45
from tempfile import TemporaryDirectory
@@ -73,6 +74,15 @@ class _OptimizedFlairPipelineDefaultsBase(_HuggingFaceOptimizedPipelineDefaultsB
7374
init=False, default_factory=TemporaryDirectory
7475
)
7576

77+
@staticmethod
78+
def _revert_default_hps_task_train_kwargs(
79+
task_train_kwargs: Dict[str, ParameterValues]
80+
) -> Dict[str, ParameterValues]:
81+
out = deepcopy(task_train_kwargs)
82+
out["param_selection_mode"] = False
83+
out["save_final_model"] = True
84+
return out
85+
7686

7787
# Mypy currently properly don't handle dataclasses with abstract methods https://github.com/python/mypy/issues/5374
7888
@dataclass # type: ignore
@@ -104,14 +114,6 @@ def _pop_sampled_parameters(
104114
assert isinstance(load_model_kwargs, dict)
105115
return embedding_name, document_embedding, task_train_kwargs, load_model_kwargs
106116

107-
@staticmethod
108-
def _revert_default_hps_task_train_kwargs(
109-
task_train_kwargs: Dict[str, ParameterValues]
110-
) -> Dict[str, ParameterValues]:
111-
task_train_kwargs["param_selection_mode"] = False
112-
task_train_kwargs["save_final_model"] = True
113-
return task_train_kwargs
114-
115117

116118
@dataclass
117119
class OptimizedFlairClassificationPipeline(
@@ -396,14 +398,6 @@ def _pop_sampled_parameters(
396398
assert isinstance(task_model_kwargs, dict)
397399
return embedding_name, hidden_size, task_train_kwargs, task_model_kwargs
398400

399-
@staticmethod
400-
def _revert_default_hps_task_train_kwargs(
401-
task_train_kwargs: Dict[str, ParameterValues]
402-
) -> Dict[str, ParameterValues]:
403-
task_train_kwargs["param_selection_mode"] = False
404-
task_train_kwargs["save_final_model"] = True
405-
return task_train_kwargs
406-
407401
def _get_metadata(self, parameters: SampledParameters) -> FlairSequenceLabelingPipelineMetadata:
408402
(
409403
embedding_name,

embeddings/pipeline/flair_preprocessing_pipeline.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _get_persister(self) -> FLAIR_PERSISTERS_TYPE:
8585
def _get_dataset(self) -> Dataset:
8686
return Dataset(
8787
self.dataset_name_or_path,
88-
**self.load_dataset_kwargs if self.load_dataset_kwargs else {}
88+
**self.load_dataset_kwargs if self.load_dataset_kwargs else {},
8989
)
9090

9191
def _get_dataloader(self, dataset: Dataset) -> FLAIR_DATALOADERS:
@@ -130,7 +130,7 @@ def _get_transformations(
130130
DownsampleFlairCorpusTransformation(
131131
*self.downsample_splits,
132132
stratify=self.downsample_splits_stratification,
133-
seed=self.seed
133+
seed=self.seed,
134134
)
135135
)
136136

0 commit comments

Comments
 (0)