Skip to content

Commit c785122

Browse files
committed
add loguru and improve embedding use or simple text #39
1 parent 4d7049f commit c785122

File tree

4 files changed

+39
-5
lines changed

4 files changed

+39
-5
lines changed

policy_analysis/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies = [
1313
"sentence-transformers>=5.1.2",
1414
"torch>=2.9.1",
1515
"matplotlib>=3.10.7",
16+
"loguru>=0.7.3",
1617
]
1718

1819
[project.optional-dependencies]

policy_analysis/src/policy_analysis/policies_clustering/clusterings.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from loguru import logger
12
from sklearn.base import BaseEstimator
23
from sklearn.pipeline import Pipeline
34
from sklearn.decomposition import TruncatedSVD
@@ -61,22 +62,28 @@ def build_hdbscan_pipeline(
6162
if __name__ == "__main__":
6263
import pandas as pd
6364
from pathlib import Path
65+
from sentence_transformers import SentenceTransformer
66+
from umap import UMAP
6467
root = Path().cwd()
6568
fp = root / "data/conclusions&pollitiques_synthetiques.jsonl"
69+
model_name = "all-MiniLM-L6-v2"
70+
model = SentenceTransformer(model_name)
6671
df = pd.read_json(fp, lines=True)
6772
texts = df["response"].tolist()
73+
embeddings = model.encode(texts, show_progress_bar=True)
74+
logger.info(f"Embeddings shape: {embeddings.shape}")
6875

6976
pipe = build_hdbscan_pipeline(
7077
embedding="sbert",
7178
n_components=5,
7279
min_cluster_size=2
7380
)
7481

75-
pipe.fit(texts)
82+
pipe.fit(embeddings)
7683

84+
umap_model = UMAP(n_components=2, n_neighbors=15, random_state=42,
85+
metric="cosine", verbose=True)
7786
labels = pipe.named_steps["cluster"].labels_
78-
X_2d = pipe.named_steps["umap"].transform(
79-
pipe.named_steps["embed"].transform(texts)
80-
)
87+
reduced_embeddings_2d = umap_model.fit_transform(embeddings)
8188

82-
plot_clusters_2d(X_2d, labels)
89+
plot_clusters_2d(reduced_embeddings_2d, labels)

policy_analysis/src/policy_analysis/policies_clustering/embeddings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def fit(self, X, y=None):
4040
return self
4141

4242
def transform(self, X):
43+
if isinstance(X, np.ndarray):
44+
return X
4345
embeddings = self.model.encode(
4446
X,
4547
batch_size=self.batch_size,

policy_analysis/uv.lock

Lines changed: 24 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)