-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcreate_cluster.py
159 lines (122 loc) · 4.67 KB
/
create_cluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
Code for creating clusters from embeddings.
Author: Peter Zhang, Ruiqi Zhong
"""
import glob
import json
import os
from argparse import ArgumentParser
from collections import defaultdict
from datetime import datetime
from typing import Tuple
import numpy as np
import sklearn.cluster
import sklearn.decomposition
import sklearn.mixture
import tqdm
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import euclidean_distances
def load_data(embed_dir: str, subset_size: int) -> Tuple[np.array, np.array]:
"""
Accepts a directory with embeddings and a sample size to use.
"""
f_prefixes = sorted(
[f.split(".")[0] for f in os.listdir(embed_dir) if f.endswith(".npy")],
key=lambda x: int(x),
)
all_embeddings, all_texts = [], []
for f in tqdm.tqdm(f_prefixes):
new_embeddings = np.load(os.path.join(embed_dir, f + ".npy"))
if len(new_embeddings.shape) == 2:
all_embeddings.extend(new_embeddings)
all_texts.extend(json.load(open(os.path.join(embed_dir, f + ".json"))))
if len(all_embeddings) >= subset_size:
break
all_embeddings = np.array(all_embeddings)[:subset_size]
all_texts = all_texts[:subset_size]
return all_embeddings, all_texts
def make_clusters(all_embeddings, first_pc, last_pc, cluster_method, k):
# loading the embeddings and texts
print(f"finished loading {len(all_embeddings)} embeddings")
# first run PCA
pca = sklearn.decomposition.PCA(n_components=1 + last_pc)
# fit the PCA model to the embeddings
all_embs = pca.fit_transform(all_embeddings)
all_embs = all_embs[:, first_pc : last_pc + 1]
print("finished PCA")
# GMM clustering
# defining the clustering model
if cluster_method == "gmm":
cluster = sklearn.mixture.GaussianMixture(
n_components=k, covariance_type="full"
)
elif cluster_method == "kmeans":
cluster = KMeans(n_clusters=k)
cluster.fit(all_embs)
if cluster_method == "gmm":
centers = cluster.means_
elif cluster_method == "kmeans":
centers = cluster.cluster_centers_
print("finished clustering")
cluster_idxes = cluster.predict(all_embs)
print("finished predicting probabilities")
center_pairwise_distances = euclidean_distances(centers, centers)
return cluster_idxes, center_pairwise_distances
def save_results(save_dir, cluster_idxes, all_texts, center_pairwise_distances):
"""
Save the results of the clustering.
"""
# saving the results
if not os.path.exists(save_dir):
os.makedirs(save_dir)
clusters = defaultdict(list)
for cluster, text in zip(cluster_idxes, all_texts):
clusters[int(cluster)].append(text)
json.dump(clusters, open(os.path.join(save_dir, "clusters.json"), "w"))
l2_distances = dict(enumerate(map(list, center_pairwise_distances.astype(float))))
json.dump(l2_distances, open(os.path.join(save_dir, "l2_distance.json"), "w"))
def main():
parser = ArgumentParser()
parser.add_argument("--make_all", action="store_true")
parser.add_argument("--dataset", type=str)
parser.add_argument("--first_pc", type=int, default=1)
parser.add_argument("--last_pc", type=int, default=30)
parser.add_argument("--subset_size", type=int, default=100000)
parser.add_argument("--sqrt_size", action="store_true")
parser.add_argument("--k", type=int, default=128)
parser.add_argument("--cluster_method", type=str, default="kmeans")
args = parser.parse_args()
make_all = args.make_all
dataset = args.dataset
first_pc = args.first_pc
last_pc = args.last_pc
subset_size = args.subset_size
sqrt_size = args.sqrt_size
k = args.k
cluster_method = args.cluster_method
if make_all:
datasets = glob.glob("results/*_embeddings")
else:
datasets = [dataset]
for dataset in datasets:
embed_dir = f"results/{dataset}_embeddings"
all_embeddings, all_texts = load_data(embed_dir, subset_size)
if sqrt_size:
k = int(np.sqrt(len(all_embeddings)) / 2)
print(f"using sqrt size for dataset {dataset}, k={k}")
cluster_idxes, center_pairwise_distances = make_clusters(
all_embeddings, first_pc, last_pc, cluster_method, k
)
time = datetime.now().strftime("%Y%d%m_%H%M%S")
if sqrt_size:
save_dir = f"results/{dataset}_{time}_clusters_sqrtsize"
else:
save_dir = f"results/{dataset}_{time}_clusters_{k}"
save_results(
save_dir,
cluster_idxes,
all_texts,
center_pairwise_distances,
)
if __name__ == "__main__":
main()