Skip to content

Commit 9572f26

Browse files
committed
fix(flair_pipelines): Fix flair pipelines
1 parent 362feb0 commit 9572f26

5 files changed

+69
-28
lines changed

embeddings/pipeline/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
DOWNSAMPLE_SPLITS_TYPE = Tuple[Optional[float], Optional[float], Optional[float]]
1010
SAMPLE_MISSING_SPLITS_TYPE = Optional[Tuple[Optional[float], Optional[float]]]
1111
FLAIR_DATASET_TRANSFORMATIONS_TYPE = Union[
12-
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
12+
Transformation[datasets.DatasetDict, datasets.DatasetDict],
13+
Transformation[datasets.DatasetDict, Corpus],
14+
Transformation[Corpus, Corpus],
1315
]
1416
FLAIR_PERSISTERS_TYPE = Union[FlairConllPersister[Corpus], FlairPicklePersister[Corpus, Corpus]]

embeddings/pipeline/flair_classification.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Any, Dict, Optional, Tuple, Union
2+
from typing import Any, Dict, Optional, Tuple
33

44
import datasets
55
from flair.data import Corpus
@@ -15,6 +15,7 @@
1515
from embeddings.embedding.flair_loader import FlairDocumentPoolEmbeddingLoader
1616
from embeddings.evaluator.text_classification_evaluator import TextClassificationEvaluator
1717
from embeddings.model.flair_model import FlairModel
18+
from embeddings.pipeline import FLAIR_DATASET_TRANSFORMATIONS_TYPE
1819
from embeddings.pipeline.standard_pipeline import StandardPipeline
1920
from embeddings.task.flair_task.text_classification import TextClassification
2021
from embeddings.transformation.flair_transformation.classification_corpus_transformation import (
@@ -23,7 +24,10 @@
2324
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
2425
SampleSplitsFlairCorpusTransformation,
2526
)
26-
from embeddings.transformation.transformation import Transformation
27+
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
28+
ClassEncodeColumnTransformation,
29+
)
30+
from embeddings.transformation.transformation import DummyTransformation
2731
from embeddings.utils.json_dict_persister import JsonPersister
2832

2933

@@ -45,14 +49,19 @@ def __init__(
4549
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
4650
seed: int = 441,
4751
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
52+
encode_classes: bool = False,
4853
):
4954
output_path = Path(output_path)
5055
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
5156
data_loader = HuggingFaceDataLoader()
52-
transformation: Union[
53-
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
54-
]
55-
transformation = ClassificationCorpusTransformation(input_column_name, target_column_name)
57+
transformation: FLAIR_DATASET_TRANSFORMATIONS_TYPE = DummyTransformation()
58+
if encode_classes:
59+
transformation = transformation.then(
60+
ClassEncodeColumnTransformation(column=target_column_name)
61+
)
62+
transformation = transformation.then(
63+
ClassificationCorpusTransformation(input_column_name, target_column_name)
64+
)
5665
if sample_missing_splits:
5766
transformation = transformation.then(
5867
SampleSplitsFlairCorpusTransformation(*sample_missing_splits, seed=seed)

embeddings/pipeline/flair_pair_classification.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Any, Dict, Optional, Tuple, Union
2+
from typing import Any, Dict, Optional, Tuple
33

44
import datasets
55
from flair.data import Corpus
@@ -15,6 +15,7 @@
1515
from embeddings.embedding.flair_loader import FlairDocumentPoolEmbeddingLoader
1616
from embeddings.evaluator.text_classification_evaluator import TextClassificationEvaluator
1717
from embeddings.model.flair_model import FlairModel
18+
from embeddings.pipeline import FLAIR_DATASET_TRANSFORMATIONS_TYPE
1819
from embeddings.pipeline.standard_pipeline import StandardPipeline
1920
from embeddings.task.flair_task.text_pair_classification import TextPairClassification
2021
from embeddings.transformation.flair_transformation.pair_classification_corpus_transformation import (
@@ -23,7 +24,10 @@
2324
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
2425
SampleSplitsFlairCorpusTransformation,
2526
)
26-
from embeddings.transformation.transformation import Transformation
27+
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
28+
ClassEncodeColumnTransformation,
29+
)
30+
from embeddings.transformation.transformation import DummyTransformation
2731
from embeddings.utils.json_dict_persister import JsonPersister
2832

2933

@@ -45,15 +49,18 @@ def __init__(
4549
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
4650
seed: int = 441,
4751
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
52+
encode_classes: bool = False,
4853
):
4954
output_path = Path(output_path)
5055
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
5156
data_loader = HuggingFaceDataLoader()
52-
transformation: Union[
53-
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
54-
]
55-
transformation = PairClassificationCorpusTransformation(
56-
input_columns_names_pair, target_column_name
57+
transformation: FLAIR_DATASET_TRANSFORMATIONS_TYPE = DummyTransformation()
58+
if encode_classes:
59+
transformation = transformation.then(
60+
ClassEncodeColumnTransformation(column=target_column_name)
61+
)
62+
transformation = transformation.then(
63+
PairClassificationCorpusTransformation(input_columns_names_pair, target_column_name)
5764
)
5865
if sample_missing_splits:
5966
transformation = transformation.then(

embeddings/pipeline/flair_preprocessing_pipeline.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
4040
SampleSplitsFlairCorpusTransformation,
4141
)
42+
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
43+
ClassEncodeColumnTransformation,
44+
)
4245
from embeddings.transformation.transformation import DummyTransformation
4346
from embeddings.utils.flair_corpus_persister import FlairConllPersister, FlairPicklePersister
4447

@@ -62,6 +65,7 @@ class FlairPreprocessingPipeline(
6265
ignore_test_subset: bool = False
6366
seed: int = 441
6467
load_dataset_kwargs: Optional[Dict[str, Any]] = None
68+
encode_labels: bool = False
6569

6670
def __post_init__(self) -> None:
6771
self.persister = self._get_persister()
@@ -71,7 +75,7 @@ def __post_init__(self) -> None:
7175
super(FlairPreprocessingPipeline, self).__init__(dataset, data_loader, transformation)
7276

7377
@abc.abstractmethod
74-
def _get_base_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
78+
def _get_to_flair_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
7579
pass
7680

7781
@abc.abstractmethod
@@ -89,17 +93,25 @@ def _get_dataloader(self, dataset: Dataset) -> FLAIR_DATALOADERS:
8993

9094
def _get_dataset_transformation(
9195
self, data_loader: FLAIR_DATALOADERS
92-
) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
96+
) -> Optional[FLAIR_DATASET_TRANSFORMATIONS_TYPE]:
9397
if isinstance(data_loader, (ConllFlairCorpusDataLoader, PickleFlairCorpusDataLoader)):
94-
return DummyTransformation()
98+
return None
9599

96-
return self._get_base_dataset_transformation()
100+
return self._get_to_flair_dataset_transformation()
97101

98102
def _get_transformations(
99103
self, data_loader: FLAIR_DATALOADERS
100104
) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
101105

102-
transformation = self._get_dataset_transformation(data_loader)
106+
transformation: FLAIR_DATASET_TRANSFORMATIONS_TYPE = DummyTransformation()
107+
if self.encode_labels:
108+
transformation = transformation.then(
109+
ClassEncodeColumnTransformation(column=self.target_column_name)
110+
)
111+
112+
to_flair_dataset_transformation = self._get_dataset_transformation(data_loader)
113+
if to_flair_dataset_transformation:
114+
transformation = transformation.then(to_flair_dataset_transformation)
103115

104116
if self.sample_missing_splits:
105117
transformation = transformation.then(
@@ -126,7 +138,7 @@ class FlairTextClassificationPreprocessingPipeline(FlairPreprocessingPipeline):
126138
def _get_persister(self) -> FLAIR_PERSISTERS_TYPE:
127139
return FlairPicklePersister(self.persist_path)
128140

129-
def _get_base_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
141+
def _get_to_flair_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
130142
assert isinstance(self.input_column_name, str)
131143
return ClassificationCorpusTransformation(
132144
input_column_name=self.input_column_name,
@@ -138,7 +150,7 @@ def _get_base_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE
138150
class FlairTextPairClassificationPreprocessingPipeline(
139151
FlairTextClassificationPreprocessingPipeline
140152
):
141-
def _get_base_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
153+
def _get_to_flair_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
142154
assert isinstance(self.input_column_name, (tuple, list))
143155
return PairClassificationCorpusTransformation(
144156
input_columns_names_pair=self.input_column_name,
@@ -151,7 +163,7 @@ class FlairSequenceLabelingPreprocessingPipeline(FlairPreprocessingPipeline):
151163
def _get_persister(self) -> FLAIR_PERSISTERS_TYPE:
152164
return FlairConllPersister(self.persist_path)
153165

154-
def _get_base_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
166+
def _get_to_flair_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
155167
assert isinstance(self.input_column_name, str)
156168
return ColumnCorpusTransformation(
157169
input_column_name=self.input_column_name,

embeddings/pipeline/flair_sequence_labeling.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Any, Dict, Optional, Tuple, Union
2+
from typing import Any, Dict, Optional, Tuple
33

44
import datasets
55
from flair.data import Corpus
@@ -15,6 +15,7 @@
1515
from embeddings.embedding.flair_loader import FlairWordEmbeddingLoader
1616
from embeddings.evaluator.sequence_labeling_evaluator import SequenceLabelingEvaluator
1717
from embeddings.model.flair_model import FlairModel
18+
from embeddings.pipeline import FLAIR_DATASET_TRANSFORMATIONS_TYPE
1819
from embeddings.pipeline.standard_pipeline import StandardPipeline
1920
from embeddings.task.flair_task.sequence_labeling import SequenceLabeling
2021
from embeddings.transformation.flair_transformation.column_corpus_transformation import (
@@ -23,7 +24,10 @@
2324
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
2425
SampleSplitsFlairCorpusTransformation,
2526
)
26-
from embeddings.transformation.transformation import Transformation
27+
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
28+
ClassEncodeColumnTransformation,
29+
)
30+
from embeddings.transformation.transformation import DummyTransformation
2731
from embeddings.utils.json_dict_persister import JsonPersister
2832

2933

@@ -47,14 +51,21 @@ def __init__(
4751
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
4852
seed: int = 441,
4953
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
54+
encode_classes: bool = True,
5055
):
5156
output_path = Path(output_path)
5257
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
5358
data_loader = HuggingFaceDataLoader()
54-
transformation: Union[
55-
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
56-
]
57-
transformation = ColumnCorpusTransformation(input_column_name, target_column_name)
59+
60+
transformation: FLAIR_DATASET_TRANSFORMATIONS_TYPE = DummyTransformation()
61+
if encode_classes:
62+
transformation = transformation.then(
63+
ClassEncodeColumnTransformation(column=target_column_name)
64+
)
65+
transformation = transformation.then(
66+
ColumnCorpusTransformation(input_column_name, target_column_name)
67+
)
68+
5869
if sample_missing_splits:
5970
transformation = transformation.then(
6071
SampleSplitsFlairCorpusTransformation(*sample_missing_splits, seed=seed)

0 commit comments

Comments
 (0)