|
1 | 1 | from abc import ABC
|
| 2 | +from copy import deepcopy |
2 | 3 | from dataclasses import dataclass, field
|
3 | 4 | from pathlib import Path
|
4 | 5 | from tempfile import TemporaryDirectory
|
@@ -73,6 +74,15 @@ class _OptimizedFlairPipelineDefaultsBase(_HuggingFaceOptimizedPipelineDefaultsB
|
73 | 74 | init=False, default_factory=TemporaryDirectory
|
74 | 75 | )
|
75 | 76 |
|
| 77 | + @staticmethod |
| 78 | + def _revert_default_hps_task_train_kwargs( |
| 79 | + task_train_kwargs: Dict[str, ParameterValues] |
| 80 | + ) -> Dict[str, ParameterValues]: |
| 81 | + out = deepcopy(task_train_kwargs) |
| 82 | + out["param_selection_mode"] = False |
| 83 | + out["save_final_model"] = True |
| 84 | + return out |
| 85 | + |
76 | 86 |
|
77 | 87 | # Mypy currently properly don't handle dataclasses with abstract methods https://github.com/python/mypy/issues/5374
|
78 | 88 | @dataclass # type: ignore
|
@@ -104,14 +114,6 @@ def _pop_sampled_parameters(
|
104 | 114 | assert isinstance(load_model_kwargs, dict)
|
105 | 115 | return embedding_name, document_embedding, task_train_kwargs, load_model_kwargs
|
106 | 116 |
|
107 |
| - @staticmethod |
108 |
| - def _revert_default_hps_task_train_kwargs( |
109 |
| - task_train_kwargs: Dict[str, ParameterValues] |
110 |
| - ) -> Dict[str, ParameterValues]: |
111 |
| - task_train_kwargs["param_selection_mode"] = False |
112 |
| - task_train_kwargs["save_final_model"] = True |
113 |
| - return task_train_kwargs |
114 |
| - |
115 | 117 |
|
116 | 118 | @dataclass
|
117 | 119 | class OptimizedFlairClassificationPipeline(
|
@@ -396,14 +398,6 @@ def _pop_sampled_parameters(
|
396 | 398 | assert isinstance(task_model_kwargs, dict)
|
397 | 399 | return embedding_name, hidden_size, task_train_kwargs, task_model_kwargs
|
398 | 400 |
|
399 |
| - @staticmethod |
400 |
| - def _revert_default_hps_task_train_kwargs( |
401 |
| - task_train_kwargs: Dict[str, ParameterValues] |
402 |
| - ) -> Dict[str, ParameterValues]: |
403 |
| - task_train_kwargs["param_selection_mode"] = False |
404 |
| - task_train_kwargs["save_final_model"] = True |
405 |
| - return task_train_kwargs |
406 |
| - |
407 | 401 | def _get_metadata(self, parameters: SampledParameters) -> FlairSequenceLabelingPipelineMetadata:
|
408 | 402 | (
|
409 | 403 | embedding_name,
|
|
0 commit comments