Skip to content

Commit 5c0bee2

Browse files
committed
fix(flair_pipelines): Fix flair pipelines
1 parent 362feb0 commit 5c0bee2

File tree

3 files changed

+49
-10
lines changed

3 files changed

+49
-10
lines changed

embeddings/pipeline/flair_classification.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
2424
SampleSplitsFlairCorpusTransformation,
2525
)
26-
from embeddings.transformation.transformation import Transformation
26+
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
27+
ClassEncodeColumnTransformation,
28+
)
29+
from embeddings.transformation.transformation import DummyTransformation, Transformation
2730
from embeddings.utils.json_dict_persister import JsonPersister
2831

2932

@@ -45,14 +48,24 @@ def __init__(
4548
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
4649
seed: int = 441,
4750
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
51+
encode_classes: bool = False,
4852
):
4953
output_path = Path(output_path)
5054
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
5155
data_loader = HuggingFaceDataLoader()
5256
transformation: Union[
53-
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
57+
Transformation[datasets.DatasetDict, datasets.DatasetDict],
58+
Transformation[datasets.DatasetDict, Corpus],
59+
Transformation[Corpus, Corpus],
5460
]
55-
transformation = ClassificationCorpusTransformation(input_column_name, target_column_name)
61+
transformation = DummyTransformation()
62+
if encode_classes:
63+
transformation = transformation.then(
64+
ClassEncodeColumnTransformation(column=target_column_name)
65+
)
66+
transformation = transformation.then(
67+
ClassificationCorpusTransformation(input_column_name, target_column_name)
68+
)
5669
if sample_missing_splits:
5770
transformation = transformation.then(
5871
SampleSplitsFlairCorpusTransformation(*sample_missing_splits, seed=seed)

embeddings/pipeline/flair_pair_classification.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
2424
SampleSplitsFlairCorpusTransformation,
2525
)
26-
from embeddings.transformation.transformation import Transformation
26+
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
27+
ClassEncodeColumnTransformation,
28+
)
29+
from embeddings.transformation.transformation import DummyTransformation, Transformation
2730
from embeddings.utils.json_dict_persister import JsonPersister
2831

2932

@@ -45,15 +48,23 @@ def __init__(
4548
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
4649
seed: int = 441,
4750
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
51+
encode_classes: bool = False,
4852
):
4953
output_path = Path(output_path)
5054
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
5155
data_loader = HuggingFaceDataLoader()
5256
transformation: Union[
53-
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
57+
Transformation[datasets.DatasetDict, datasets.DatasetDict],
58+
Transformation[datasets.DatasetDict, Corpus],
59+
Transformation[Corpus, Corpus],
5460
]
55-
transformation = PairClassificationCorpusTransformation(
56-
input_columns_names_pair, target_column_name
61+
transformation = DummyTransformation()
62+
if encode_classes:
63+
transformation = transformation.then(
64+
ClassEncodeColumnTransformation(column=target_column_name)
65+
)
66+
transformation = transformation.then(
67+
PairClassificationCorpusTransformation(input_columns_names_pair, target_column_name)
5768
)
5869
if sample_missing_splits:
5970
transformation = transformation.then(

embeddings/pipeline/flair_sequence_labeling.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
2424
SampleSplitsFlairCorpusTransformation,
2525
)
26-
from embeddings.transformation.transformation import Transformation
26+
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
27+
ClassEncodeColumnTransformation,
28+
)
29+
from embeddings.transformation.transformation import DummyTransformation, Transformation
2730
from embeddings.utils.json_dict_persister import JsonPersister
2831

2932

@@ -47,14 +50,26 @@ def __init__(
4750
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
4851
seed: int = 441,
4952
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
53+
encode_classes: bool = True,
5054
):
5155
output_path = Path(output_path)
5256
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
5357
data_loader = HuggingFaceDataLoader()
58+
5459
transformation: Union[
55-
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
60+
Transformation[datasets.DatasetDict, datasets.DatasetDict],
61+
Transformation[datasets.DatasetDict, Corpus],
62+
Transformation[Corpus, Corpus],
5663
]
57-
transformation = ColumnCorpusTransformation(input_column_name, target_column_name)
64+
transformation = DummyTransformation()
65+
if encode_classes:
66+
transformation = transformation.then(
67+
ClassEncodeColumnTransformation(column=target_column_name)
68+
)
69+
transformation = transformation.then(
70+
ColumnCorpusTransformation(input_column_name, target_column_name)
71+
)
72+
5873
if sample_missing_splits:
5974
transformation = transformation.then(
6075
SampleSplitsFlairCorpusTransformation(*sample_missing_splits, seed=seed)

0 commit comments

Comments
 (0)