39
39
from embeddings .transformation .flair_transformation .split_sample_corpus_transformation import (
40
40
SampleSplitsFlairCorpusTransformation ,
41
41
)
42
+ from embeddings .transformation .hf_transformation .class_encode_column_transformation import (
43
+ ClassEncodeColumnTransformation ,
44
+ )
42
45
from embeddings .transformation .transformation import DummyTransformation
43
46
from embeddings .utils .flair_corpus_persister import FlairConllPersister , FlairPicklePersister
44
47
@@ -62,6 +65,7 @@ class FlairPreprocessingPipeline(
62
65
ignore_test_subset : bool = False
63
66
seed : int = 441
64
67
load_dataset_kwargs : Optional [Dict [str , Any ]] = None
68
+ encode_labels : bool = False
65
69
66
70
def __post_init__ (self ) -> None :
67
71
self .persister = self ._get_persister ()
@@ -71,7 +75,7 @@ def __post_init__(self) -> None:
71
75
super (FlairPreprocessingPipeline , self ).__init__ (dataset , data_loader , transformation )
72
76
73
77
@abc .abstractmethod
74
- def _get_base_dataset_transformation (self ) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE :
78
+ def _get_to_flair_dataset_transformation (self ) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE :
75
79
pass
76
80
77
81
@abc .abstractmethod
@@ -89,17 +93,25 @@ def _get_dataloader(self, dataset: Dataset) -> FLAIR_DATALOADERS:
89
93
90
94
def _get_dataset_transformation (
91
95
self , data_loader : FLAIR_DATALOADERS
92
- ) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE :
96
+ ) -> Optional [ FLAIR_DATASET_TRANSFORMATIONS_TYPE ] :
93
97
if isinstance (data_loader , (ConllFlairCorpusDataLoader , PickleFlairCorpusDataLoader )):
94
- return DummyTransformation ()
98
+ return None
95
99
96
- return self ._get_base_dataset_transformation ()
100
+ return self ._get_to_flair_dataset_transformation ()
97
101
98
102
def _get_transformations (
99
103
self , data_loader : FLAIR_DATALOADERS
100
104
) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE :
101
105
102
- transformation = self ._get_dataset_transformation (data_loader )
106
+ transformation : FLAIR_DATASET_TRANSFORMATIONS_TYPE = DummyTransformation ()
107
+ if self .encode_labels :
108
+ transformation = transformation .then (
109
+ ClassEncodeColumnTransformation (column = self .target_column_name )
110
+ )
111
+
112
+ to_flair_dataset_transformation = self ._get_dataset_transformation (data_loader )
113
+ if to_flair_dataset_transformation :
114
+ transformation = transformation .then (to_flair_dataset_transformation )
103
115
104
116
if self .sample_missing_splits :
105
117
transformation = transformation .then (
@@ -126,7 +138,7 @@ class FlairTextClassificationPreprocessingPipeline(FlairPreprocessingPipeline):
126
138
def _get_persister (self ) -> FLAIR_PERSISTERS_TYPE :
127
139
return FlairPicklePersister (self .persist_path )
128
140
129
- def _get_base_dataset_transformation (self ) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE :
141
+ def _get_to_flair_dataset_transformation (self ) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE :
130
142
assert isinstance (self .input_column_name , str )
131
143
return ClassificationCorpusTransformation (
132
144
input_column_name = self .input_column_name ,
@@ -138,7 +150,7 @@ def _get_base_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE
138
150
class FlairTextPairClassificationPreprocessingPipeline (
139
151
FlairTextClassificationPreprocessingPipeline
140
152
):
141
- def _get_base_dataset_transformation (self ) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE :
153
+ def _get_to_flair_dataset_transformation (self ) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE :
142
154
assert isinstance (self .input_column_name , (tuple , list ))
143
155
return PairClassificationCorpusTransformation (
144
156
input_columns_names_pair = self .input_column_name ,
@@ -151,7 +163,7 @@ class FlairSequenceLabelingPreprocessingPipeline(FlairPreprocessingPipeline):
151
163
def _get_persister (self ) -> FLAIR_PERSISTERS_TYPE :
152
164
return FlairConllPersister (self .persist_path )
153
165
154
- def _get_base_dataset_transformation (self ) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE :
166
+ def _get_to_flair_dataset_transformation (self ) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE :
155
167
assert isinstance (self .input_column_name , str )
156
168
return ColumnCorpusTransformation (
157
169
input_column_name = self .input_column_name ,
0 commit comments