Skip to content

Commit 0b09444

Browse files
committed
feat: implement flair sklearn vectorizer wrappers
1 parent 51f981f commit 0b09444

File tree

4 files changed

+53
-10
lines changed

4 files changed

+53
-10
lines changed
+5-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Dict, Optional
22

33
import pandas as pd
4+
import scipy
45
from sklearn.base import BaseEstimator as AnySklearnVectorizer
56

67
from embeddings.embedding.embedding import Embedding
@@ -9,21 +10,16 @@
910

1011
class SklearnEmbedding(Embedding[ArrayLike, pd.DataFrame]):
1112
def __init__(
12-
self,
13-
vectorizer: AnySklearnVectorizer,
14-
vectorizer_has_sparse_output: bool = True,
15-
vectorizer_kwargs: Optional[Dict[str, Any]] = None,
13+
self, vectorizer: AnySklearnVectorizer, vectorizer_kwargs: Optional[Dict[str, Any]] = None
1614
):
1715
super().__init__()
18-
self.vectorizer_kwargs = vectorizer_kwargs if vectorizer_kwargs else {}
19-
self.vectorizer_has_sparse_output = vectorizer_has_sparse_output
20-
self.vectorizer = vectorizer(**self.vectorizer_kwargs)
16+
self.vectorizer = vectorizer(**vectorizer_kwargs if vectorizer_kwargs else {})
2117

2218
def fit(self, data: ArrayLike) -> None:
2319
self.vectorizer.fit(data)
2420

2521
def embed(self, data: ArrayLike) -> pd.DataFrame:
2622
embedded = self.vectorizer.transform(data)
27-
if self.vectorizer_has_sparse_output:
23+
if scipy.sparse.issparse(embedded):
2824
embedded = embedded.A
29-
return pd.DataFrame(embedded, columns=self.vectorizer.get_feature_names_out())
25+
return pd.DataFrame(embedded)

embeddings/embedding/vectorizer/__init__.py

Whitespace-only changes.
+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import abc
2+
from typing import Any, Dict, Generic, List, Optional, TypeVar
3+
4+
import numpy as np
5+
from flair.data import Sentence
6+
from numpy import typing as nptyping
7+
from sklearn.base import BaseEstimator, TransformerMixin
8+
from sklearn.feature_extraction.text import _VectorizerMixin
9+
10+
from embeddings.embedding.flair_embedding import FlairEmbedding
11+
from embeddings.utils.array_like import ArrayLike
12+
13+
Output = TypeVar("Output")
14+
15+
16+
# ignoring the mypy error due to no types (Any) in TransformerMixin and BaseEstimator classes
17+
class FlairVectorizer(TransformerMixin, _VectorizerMixin, BaseEstimator, Generic[Output]): # type: ignore
18+
def __init__(self, flair_embedding: FlairEmbedding) -> None:
19+
self.embedder = flair_embedding
20+
21+
def fit(self, x: ArrayLike, y: Optional[ArrayLike] = None) -> None:
22+
pass
23+
24+
@abc.abstractmethod
25+
def transform(self, x: ArrayLike) -> Output:
26+
pass
27+
28+
def fit_transform(self, x: ArrayLike, y: Optional[ArrayLike] = None, **kwargs: Any) -> Output:
29+
return self.transform(x)
30+
31+
32+
class FlairDocumentVectorizer(FlairVectorizer[nptyping.NDArray[np.float_]]):
33+
def transform(self, x: ArrayLike) -> nptyping.NDArray[np.float_]:
34+
sentences = [Sentence(example) for example in x]
35+
embeddings = [sentence.embedding.numpy() for sentence in self.embedder.embed(sentences)]
36+
return np.vstack(embeddings)
37+
38+
39+
class FlairWordVectorizer(FlairVectorizer[List[List[Dict[int, float]]]]):
40+
def transform(self, x: ArrayLike) -> List[List[Dict[int, float]]]:
41+
sentences = [Sentence(example) for example in x]
42+
embeddings = [sentence for sentence in self.embedder.embed(sentences)]
43+
return [
44+
[{i: value for i, value in enumerate(word.embedding.numpy())} for word in sent]
45+
for sent in embeddings
46+
]

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ module = [
119119
"spacy",
120120
"appdirs",
121121
"dataset.arrow_dataset",
122-
"seqeval.*"
122+
"seqeval.*",
123+
"scipy"
123124
]
124125
ignore_missing_imports = true
125126

0 commit comments

Comments
 (0)