Skip to content

Commit f6c1ec7

Browse files
committed
Merge branch 'main' into ASReview2-nb
2 parents 089806a + 102dd59 commit f6c1ec7

14 files changed

Lines changed: 455 additions & 82 deletions

File tree

src/feature_matrix_scripts/bge-m3.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pandas as pd
55
import synergy_dataset as sd
66
from FlagEmbedding import BGEM3FlagModel
7+
from sklearn.preprocessing import normalize
78
from tqdm import tqdm
89

910
FORCE = False
@@ -14,20 +15,21 @@
1415
model = BGEM3FlagModel("BAAI/bge-m3", devices=["cuda:0"])
1516

1617
for dataset in tqdm(sd.iter_datasets(), total=26):
17-
# Load dataset
18-
if dataset.name == "Moran_2021_corrected":
18+
if dataset.name == "Chou_2004" or dataset.name == "Jeyaraman_2020":
19+
continue
20+
elif dataset.name == "Moran_2021":
1921
df = pd.read_csv("./datasets/Moran_2021_corrected_shuffled_raw.csv")
20-
elif dataset.name == "Muthu_2021_corrected":
22+
dataset_name = "Moran_2021_corrected"
23+
elif dataset.name == "Muthu_2021":
2124
df = pd.read_csv("./datasets/Muthu_2021_corrected_shuffled_raw.csv")
25+
dataset_name = "Muthu_2021_corrected"
2226
else:
2327
df = dataset.to_frame().reset_index()
28+
dataset_name = dataset.name
2429

2530
# Combine 'title' and 'abstract' text
2631
combined_texts = (df["title"].fillna("") + " " + df["abstract"].fillna("")).tolist()
2732

28-
dataset_name = (
29-
dataset.name if dataset.name != "Moran_2021" else "Moran_2021_corrected"
30-
)
3133
pickle_file_path = folder_pickle_files / f"{dataset_name}.pkl"
3234

3335
# Check if the pickle file already exists
@@ -45,9 +47,8 @@
4547
return_colbert_vecs=False,
4648
)
4749

50+
X["dense_vecs_norm"] = normalize(X["dense_vecs"], norm="l2")
51+
4852
# Save embeddings and labels as a pickle file
4953
with open(folder_pickle_files / f"{dataset_name}.pkl", "wb") as f:
50-
pickle.dump((X["dense_vecs"], df["label_included"].tolist()), f)
51-
52-
with open(folder_pickle_files / f"sparse-{dataset_name}.pkl", "wb") as f:
53-
pickle.dump((X["lexical_weights"], df["label_included"].tolist()), f)
54+
pickle.dump((X["dense_vecs_norm"], df["label_included"].tolist()), f)
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import pickle
2+
from pathlib import Path
3+
4+
import numpy as np
5+
import pandas as pd
6+
import synergy_dataset as sd
7+
from gensim.models.doc2vec import Doc2Vec as GenSimDoc2Vec
8+
from gensim.models.doc2vec import TaggedDocument
9+
from gensim.utils import simple_preprocess
10+
from sklearn.preprocessing import normalize as SKNormalize
11+
from tqdm import tqdm
12+
13+
14+
class Doc2Vec:
15+
"""
16+
Doc2Vec feature extraction technique (``doc2vec``).
17+
18+
Feature extraction technique provided by the `gensim
19+
<https://radimrehurek.com/gensim/>`__ package. It trains a model to generate
20+
document embeddings, which can reduce dimensionality and accelerate modeling.
21+
22+
.. note::
23+
24+
For fully reproducible runs, limit the model to a single worker thread
25+
(`n_jobs=1`) to eliminate potential variability due to thread scheduling.
26+
27+
Parameters
28+
----------
29+
vector_size : int, optional
30+
Dimensionality of the feature vectors. Default: 40
31+
epochs : int, optional
32+
Number of epochs to train the model. Default: 33
33+
min_count : int, optional
34+
Ignores all words with total frequency lower than this. Default: 1
35+
n_jobs : int, optional
36+
Number of threads to use during training. Default: 1
37+
window : int, optional
38+
Maximum distance between the current and predicted word. Default: 7
39+
dm_concat : bool, optional
40+
If True, concatenate word vectors. Default: False
41+
dm : int, optional
42+
Training model:
43+
- 0: Distributed Bag of Words (DBOW)
44+
- 1: Distributed Memory (DM)
45+
- 2: Both DBOW and DM (concatenated embeddings). Default: 2
46+
dbow_words : bool, optional
47+
Train word vectors alongside DBOW. Default: False
48+
normalize : bool, optional
49+
Normalize embeddings using min-max scaling. Default: True
50+
verbose : bool, optional
51+
Print progress and status updates. Default: True
52+
"""
53+
54+
def __init__(
55+
self,
56+
vector_size=40,
57+
epochs=33,
58+
min_count=1,
59+
n_jobs=1,
60+
window=7,
61+
dm_concat=False,
62+
dm=2,
63+
dbow_words=False,
64+
normalize=True,
65+
norm="l2",
66+
verbose=True,
67+
):
68+
self.vector_size = int(vector_size)
69+
self.epochs = int(epochs)
70+
self.min_count = int(min_count)
71+
self.n_jobs = int(n_jobs)
72+
self.window = int(window)
73+
self.dm_concat = 1 if dm_concat else 0
74+
self.dm = int(dm)
75+
self.dbow_words = 1 if dbow_words else 0
76+
self.normalize = normalize
77+
self.norm = norm
78+
self.verbose = verbose
79+
self._model_instance = None
80+
81+
self._tagged_document = TaggedDocument
82+
self._simple_preprocess = simple_preprocess
83+
self._model = GenSimDoc2Vec
84+
85+
def fit(self, X, y=None):
86+
if self.verbose:
87+
print("Preparing corpus...")
88+
corpus = [
89+
self._tagged_document(self._simple_preprocess(text), [i])
90+
for i, text in enumerate(X)
91+
]
92+
93+
model_param = {
94+
"vector_size": self.vector_size,
95+
"epochs": self.epochs,
96+
"min_count": self.min_count,
97+
"workers": self.n_jobs,
98+
"window": self.window,
99+
"dm_concat": self.dm_concat,
100+
"dbow_words": self.dbow_words,
101+
}
102+
103+
if self.dm == 2:
104+
# Train both DM and DBOW models
105+
model_param["vector_size"] = int(self.vector_size / 2)
106+
if self.verbose:
107+
print("Training DM model...")
108+
self._model_dm = self._train_model(corpus, **model_param, dm=1)
109+
if self.verbose:
110+
print("Training DBOW model...")
111+
self._model_dbow = self._train_model(corpus, **model_param, dm=0)
112+
else:
113+
if self.verbose:
114+
print(f"Training single model with dm={self.dm}...")
115+
self._model_instance = self._train_model(corpus, **model_param, dm=self.dm)
116+
117+
def transform(self, texts):
118+
if self.verbose:
119+
print("Preparing corpus for transformation...")
120+
corpus = [
121+
self._tagged_document(self._simple_preprocess(text), [i])
122+
for i, text in enumerate(texts)
123+
]
124+
125+
if self.dm == 2:
126+
X_dm = self._infer_vectors(self._model_dm, corpus)
127+
X_dbow = self._infer_vectors(self._model_dbow, corpus)
128+
X = np.concatenate((X_dm, X_dbow), axis=1)
129+
else:
130+
X = self._infer_vectors(self._model_instance, corpus)
131+
132+
if self.verbose:
133+
print("Finished transforming texts to vectors.")
134+
135+
if self.normalize:
136+
if self.verbose:
137+
print("Normalizing embeddings.")
138+
X = SKNormalize(X, norm=self.norm)
139+
140+
return X
141+
142+
def fit_transform(self, X, y):
143+
self.fit(X, y)
144+
return self.transform(X)
145+
146+
def _train_model(self, corpus, *args, **kwargs):
147+
model = self._model(*args, **kwargs)
148+
if self.verbose:
149+
print("Building vocabulary...")
150+
model.build_vocab(corpus)
151+
if self.verbose:
152+
print("Training model...")
153+
model.train(corpus, total_examples=model.corpus_count, epochs=model.epochs)
154+
if self.verbose:
155+
print("Model training complete.")
156+
return model
157+
158+
def _infer_vectors(self, model, corpus):
159+
if self.verbose:
160+
print("Inferring vectors for documents...")
161+
X = [model.infer_vector(doc.words) for doc in corpus]
162+
if self.verbose:
163+
print("Vector inference complete.")
164+
return np.array(X)
165+
166+
167+
FORCE = True
168+
169+
# Folder to save embeddings
170+
folder_pickle_files = Path("synergy-dataset", "pickles_doc2vec")
171+
folder_pickle_files.mkdir(parents=True, exist_ok=True)
172+
173+
model = Doc2Vec(n_jobs=10)
174+
175+
# Loop through datasets
176+
for dataset in tqdm(sd.iter_datasets(), total=26):
177+
if dataset.name == "Chou_2004" or dataset.name == "Jeyaraman_2020":
178+
continue
179+
elif dataset.name == "Moran_2021":
180+
df = pd.read_csv("./datasets/Moran_2021_corrected_shuffled_raw.csv")
181+
dataset_name = "Moran_2021_corrected"
182+
elif dataset.name == "Muthu_2021":
183+
df = pd.read_csv("./datasets/Muthu_2021_corrected_shuffled_raw.csv")
184+
dataset_name = "Muthu_2021_corrected"
185+
else:
186+
df = dataset.to_frame().reset_index()
187+
dataset_name = dataset.name
188+
189+
# Combine 'title' and 'abstract' text
190+
combined_texts = (df["title"].fillna("") + " " + df["abstract"].fillna("")).tolist()
191+
192+
pickle_file_path = folder_pickle_files / f"{dataset_name}.pkl"
193+
194+
# Check if the pickle file already exists
195+
if not FORCE and pickle_file_path.exists():
196+
print(f"Skipping {dataset_name}, pickle file already exists.")
197+
continue
198+
199+
# Generate embeddings
200+
X = model.fit_transform(combined_texts, [])
201+
202+
# Save embeddings and labels as a pickle file
203+
with open(folder_pickle_files / f"{dataset_name}.pkl", "wb") as f:
204+
pickle.dump(
205+
(
206+
X,
207+
df["label_included"].tolist(),
208+
),
209+
f,
210+
)

src/feature_matrix_scripts/e5-large.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sentence_transformers import SentenceTransformer
88
from tqdm import tqdm
99

10-
FORCE = False
10+
FORCE = True
1111

1212
# Folder to save embeddings
1313
folder_pickle_files = Path("synergy-dataset", "pickles_e5")
@@ -22,20 +22,21 @@
2222

2323
# Loop through datasets
2424
for dataset in tqdm(sd.iter_datasets(), total=26):
25-
# Load dataset
26-
if dataset.name == "Moran_2021_corrected":
25+
if dataset.name == "Chou_2004" or dataset.name == "Jeyaraman_2020":
26+
continue
27+
elif dataset.name == "Moran_2021":
2728
df = pd.read_csv("./datasets/Moran_2021_corrected_shuffled_raw.csv")
28-
elif dataset.name == "Muthu_2021_corrected":
29+
dataset_name = "Moran_2021_corrected"
30+
elif dataset.name == "Muthu_2021":
2931
df = pd.read_csv("./datasets/Muthu_2021_corrected_shuffled_raw.csv")
32+
dataset_name = "Muthu_2021_corrected"
3033
else:
3134
df = dataset.to_frame().reset_index()
35+
dataset_name = dataset.name
3236

3337
# Combine 'title' and 'abstract' text
3438
combined_texts = (df["title"].fillna("") + " " + df["abstract"].fillna("")).tolist()
3539

36-
dataset_name = (
37-
dataset.name if dataset.name != "Moran_2021" else "Moran_2021_corrected"
38-
)
3940
pickle_file_path = folder_pickle_files / f"{dataset_name}.pkl"
4041

4142
# Check if the pickle file already exists
@@ -46,7 +47,7 @@
4647
# Generate embeddings
4748
X = model.encode(
4849
combined_texts,
49-
batch_size=128,
50+
batch_size=512,
5051
show_progress_bar=False,
5152
device=device,
5253
normalize_embeddings=True,

src/feature_matrix_scripts/gist.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,21 @@
2222

2323
# Loop through datasets
2424
for dataset in tqdm(sd.iter_datasets(), total=26):
25-
# Load dataset
26-
if dataset.name == "Moran_2021_corrected":
25+
if dataset.name == "Chou_2004" or dataset.name == "Jeyaraman_2020":
26+
continue
27+
elif dataset.name == "Moran_2021":
2728
df = pd.read_csv("./datasets/Moran_2021_corrected_shuffled_raw.csv")
28-
elif dataset.name == "Muthu_2021_corrected":
29+
dataset_name = "Moran_2021_corrected"
30+
elif dataset.name == "Muthu_2021":
2931
df = pd.read_csv("./datasets/Muthu_2021_corrected_shuffled_raw.csv")
32+
dataset_name = "Muthu_2021_corrected"
3033
else:
3134
df = dataset.to_frame().reset_index()
35+
dataset_name = dataset.name
3236

3337
# Combine 'title' and 'abstract' text
3438
combined_texts = (df["title"].fillna("") + " " + df["abstract"].fillna("")).tolist()
3539

36-
dataset_name = (
37-
dataset.name if dataset.name != "Moran_2021" else "Moran_2021_corrected"
38-
)
3940
pickle_file_path = folder_pickle_files / f"{dataset_name}.pkl"
4041

4142
# Check if the pickle file already exists
@@ -45,7 +46,11 @@
4546

4647
# Generate embeddings
4748
X = model.encode(
48-
combined_texts, batch_size=128, show_progress_bar=False, device=device
49+
combined_texts,
50+
batch_size=512,
51+
show_progress_bar=False,
52+
device=device,
53+
normalize_embeddings=True,
4954
)
5055

5156
# Save embeddings and labels as a pickle file

src/feature_matrix_scripts/gte.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,21 @@
2222

2323
# Loop through datasets
2424
for dataset in tqdm(sd.iter_datasets(), total=26):
25-
# Load dataset
26-
if dataset.name == "Moran_2021_corrected":
25+
if dataset.name == "Chou_2004" or dataset.name == "Jeyaraman_2020":
26+
continue
27+
elif dataset.name == "Moran_2021":
2728
df = pd.read_csv("./datasets/Moran_2021_corrected_shuffled_raw.csv")
28-
elif dataset.name == "Muthu_2021_corrected":
29+
dataset_name = "Moran_2021_corrected"
30+
elif dataset.name == "Muthu_2021":
2931
df = pd.read_csv("./datasets/Muthu_2021_corrected_shuffled_raw.csv")
32+
dataset_name = "Muthu_2021_corrected"
3033
else:
3134
df = dataset.to_frame().reset_index()
35+
dataset_name = dataset.name
3236

3337
# Combine 'title' and 'abstract' text
3438
combined_texts = (df["title"].fillna("") + " " + df["abstract"].fillna("")).tolist()
3539

36-
dataset_name = (
37-
dataset.name if dataset.name != "Moran_2021" else "Moran_2021_corrected"
38-
)
3940
pickle_file_path = folder_pickle_files / f"{dataset_name}.pkl"
4041

4142
# Check if the pickle file already exists
@@ -46,7 +47,7 @@
4647
# Generate embeddings
4748
X = model.encode(
4849
combined_texts,
49-
batch_size=128,
50+
batch_size=512,
5051
show_progress_bar=False,
5152
device=device,
5253
normalize_embeddings=True,

0 commit comments

Comments
 (0)