Skip to content

Commit b182ae7

Browse files
committed
fix(flair_pipelines): Fix flair pipelines
1 parent 362feb0 commit b182ae7

File tree

3 files changed

+49
-13
lines changed

3 files changed

+49
-13
lines changed

embeddings/pipeline/flair_classification.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
2424
SampleSplitsFlairCorpusTransformation,
2525
)
26-
from embeddings.transformation.transformation import Transformation
26+
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
27+
ClassEncodeColumnTransformation,
28+
)
29+
from embeddings.transformation.transformation import DummyTransformation, Transformation
2730
from embeddings.utils.json_dict_persister import JsonPersister
2831

2932

@@ -45,14 +48,23 @@ def __init__(
4548
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
4649
seed: int = 441,
4750
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
51+
encode_classes: bool = False,
4852
):
4953
output_path = Path(output_path)
5054
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
5155
data_loader = HuggingFaceDataLoader()
5256
transformation: Union[
53-
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
54-
]
55-
transformation = ClassificationCorpusTransformation(input_column_name, target_column_name)
57+
Transformation[datasets.DatasetDict, datasets.DatasetDict],
58+
Transformation[datasets.DatasetDict, Corpus],
59+
Transformation[Corpus, Corpus],
60+
] = DummyTransformation()
61+
if encode_classes:
62+
transformation = transformation.then(
63+
ClassEncodeColumnTransformation(column=target_column_name)
64+
)
65+
transformation = transformation.then(
66+
ClassificationCorpusTransformation(input_column_name, target_column_name)
67+
)
5668
if sample_missing_splits:
5769
transformation = transformation.then(
5870
SampleSplitsFlairCorpusTransformation(*sample_missing_splits, seed=seed)

embeddings/pipeline/flair_pair_classification.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
2424
SampleSplitsFlairCorpusTransformation,
2525
)
26-
from embeddings.transformation.transformation import Transformation
26+
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
27+
ClassEncodeColumnTransformation,
28+
)
29+
from embeddings.transformation.transformation import DummyTransformation, Transformation
2730
from embeddings.utils.json_dict_persister import JsonPersister
2831

2932

@@ -45,15 +48,22 @@ def __init__(
4548
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
4649
seed: int = 441,
4750
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
51+
encode_classes: bool = False,
4852
):
4953
output_path = Path(output_path)
5054
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
5155
data_loader = HuggingFaceDataLoader()
5256
transformation: Union[
53-
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
54-
]
55-
transformation = PairClassificationCorpusTransformation(
56-
input_columns_names_pair, target_column_name
57+
Transformation[datasets.DatasetDict, datasets.DatasetDict],
58+
Transformation[datasets.DatasetDict, Corpus],
59+
Transformation[Corpus, Corpus],
60+
] = DummyTransformation()
61+
if encode_classes:
62+
transformation = transformation.then(
63+
ClassEncodeColumnTransformation(column=target_column_name)
64+
)
65+
transformation = transformation.then(
66+
PairClassificationCorpusTransformation(input_columns_names_pair, target_column_name)
5767
)
5868
if sample_missing_splits:
5969
transformation = transformation.then(

embeddings/pipeline/flair_sequence_labeling.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
2424
SampleSplitsFlairCorpusTransformation,
2525
)
26-
from embeddings.transformation.transformation import Transformation
26+
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
27+
ClassEncodeColumnTransformation,
28+
)
29+
from embeddings.transformation.transformation import DummyTransformation, Transformation
2730
from embeddings.utils.json_dict_persister import JsonPersister
2831

2932

@@ -47,14 +50,25 @@ def __init__(
4750
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
4851
seed: int = 441,
4952
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
53+
encode_classes: bool = True,
5054
):
5155
output_path = Path(output_path)
5256
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
5357
data_loader = HuggingFaceDataLoader()
58+
5459
transformation: Union[
55-
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
56-
]
57-
transformation = ColumnCorpusTransformation(input_column_name, target_column_name)
60+
Transformation[datasets.DatasetDict, datasets.DatasetDict],
61+
Transformation[datasets.DatasetDict, Corpus],
62+
Transformation[Corpus, Corpus],
63+
] = DummyTransformation()
64+
if encode_classes:
65+
transformation = transformation.then(
66+
ClassEncodeColumnTransformation(column=target_column_name)
67+
)
68+
transformation = transformation.then(
69+
ColumnCorpusTransformation(input_column_name, target_column_name)
70+
)
71+
5872
if sample_missing_splits:
5973
transformation = transformation.then(
6074
SampleSplitsFlairCorpusTransformation(*sample_missing_splits, seed=seed)

0 commit comments

Comments
 (0)