Skip to content

Commit 8771c29

Browse files
authored
feat(test_lda): add keywords (#392)
1 parent ce0d486 commit 8771c29

File tree

2 files changed

+294
-0
lines changed

2 files changed

+294
-0
lines changed

tests/test_lda/step0_preprocess.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import jieba
2+
import jieba.posseg as jp
3+
import pdb
4+
import json
5+
import os
6+
import re
7+
from multiprocessing import Process, cpu_count
8+
# https://blog.csdn.net/xyisv/article/details/104482818
9+
import hashlib
10+
import time
11+
image_name = re.compile(r'[0-9a-f]{18,64}')
12+
chapter2 = re.compile(r'[0-9]{1}\.[0-9]{1}')
13+
chapter3 = re.compile(r'[0-9]{1}\.[0-9]{1}\.[0-9]{1}')
14+
15+
def load_stopwords():
16+
sw = []
17+
with open('cn_en_stopwords.txt') as f:
18+
for line in f:
19+
if len(line.strip()) > 0:
20+
sw.append(line.strip())
21+
return sw
22+
23+
def load_documents(n:int = 1):
24+
basedir = '/home/data/khj/workspace/huixiangdou/repodir.lda'
25+
26+
docs = []
27+
for root, _, files in os.walk(basedir):
28+
for file in files:
29+
if file.endswith('.jpg') or file.endswith('.png') or file.endswith('.jpeg'):
30+
pdb.set_trace()
31+
else:
32+
docs.append((file, os.path.join(root, file)))
33+
34+
length = len(docs)
35+
step = length // n
36+
remainder = length % n
37+
38+
result = []
39+
start = 0
40+
for i in range(n):
41+
end = start + step + (1 if i < remainder else 0)
42+
result.append(docs[start:end])
43+
start = end
44+
45+
return result
46+
47+
def load_newwords():
48+
words = []
49+
basename = './newwords'
50+
files = os.listdir(basename)
51+
for filename in files:
52+
filepath = os.path.join(basename, filename)
53+
with open(filepath, encoding='utf8') as f:
54+
words += json.load(f)
55+
print('load {}'.format(filepath))
56+
return words
57+
58+
def content_hash(input_str:str):
59+
# 创建一个新的sha256 hash对象
60+
hash_object = hashlib.sha256()
61+
# 更新hash对象,参数是输入字符串的编码(bytes)
62+
hash_object.update(input_str.encode())
63+
# 获取十六进制的hash值
64+
hex_dig = hash_object.hexdigest()
65+
# 返回前6位
66+
return hex_dig[:6]
67+
68+
def process_data(documents: list, pid: int):
69+
# add newwords
70+
t0 = time.time()
71+
new_words = load_newwords()
72+
for w in new_words:
73+
jieba.add_word(w, tag='n')
74+
75+
stop_words = load_stopwords()
76+
print('{} start..'.format(pid))
77+
bad_patterns = [image_name, chapter2, chapter3]
78+
79+
for filename,filepath in documents:
80+
d = ''
81+
with open(filepath) as f:
82+
d = f.read()
83+
# use half content
84+
head_length = int(len(d) * 0.8)
85+
d = d[0:head_length]
86+
87+
cuts = [w.word for w in jp.cut(d)]
88+
89+
filtered = []
90+
for c in cuts:
91+
c = c.strip()
92+
if c in stop_words:
93+
continue
94+
95+
if 'images' == c:
96+
continue
97+
98+
skip = False
99+
for bad_pattern in bad_patterns:
100+
if bad_pattern.match(c):
101+
skip = True
102+
break
103+
if skip:
104+
continue
105+
106+
filtered.append(c)
107+
108+
if len(filtered) < 1:
109+
continue
110+
new_content = ' '.join(filtered)
111+
112+
if len(new_content) < 300:
113+
continue
114+
dirname = os.path.join('preprocess', str(pid))
115+
if not os.path.exists(dirname):
116+
os.makedirs(dirname)
117+
118+
hashname = content_hash(new_content)
119+
outfilepath = os.path.join(dirname, hashname + '.md')
120+
121+
with open('name_map.txt', 'a') as f:
122+
f.write('{}\t {}'.format(hashname, filepath))
123+
f.write('\n')
124+
125+
with open(outfilepath, 'w') as f:
126+
f.write(new_content)
127+
f.flush()
128+
print('{} finish, timecost {}'.format(pid, time.time() - t0))
129+
130+
def _get_num_processes():
131+
num_processes = cpu_count() - 1 # Good habit to leave 1 core.
132+
return num_processes
133+
134+
def main():
135+
debug_mode = False
136+
137+
processes = []
138+
split_documents = load_documents(n=_get_num_processes())
139+
for process_id, documents in enumerate(split_documents):
140+
print(f'Distributing to process[{process_id}]...')
141+
142+
if debug_mode:
143+
process_data(documents, process_id)
144+
else:
145+
# convert NDArray back to a list, easier.
146+
process = Process(
147+
target=process_data,
148+
args=(
149+
documents,
150+
process_id,
151+
),
152+
)
153+
process.start()
154+
print(f'Distributed to process[{process_id}].')
155+
processes.append(process)
156+
for process in processes:
157+
process.join()
158+
159+
if __name__ == '__main__':
160+
main()

tests/test_lda/step1_countvec.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Author: Olivier Grisel <[email protected]>
2+
# Lars Buitinck
3+
# Chyi-Kwei Yau <[email protected]>
4+
# License: BSD 3 clause
5+
6+
from time import time
7+
import shutil
8+
import matplotlib.pyplot as plt
9+
import pdb
10+
import os
11+
import numpy as np
12+
13+
from sklearn.datasets import fetch_20newsgroups
14+
from sklearn.decomposition import LatentDirichletAllocation
15+
from sklearn.feature_extraction.text import CountVectorizer
16+
import jieba
17+
import jieba.posseg as jp
18+
import json
19+
import re
20+
from multiprocessing import Process, cpu_count
21+
# https://blog.csdn.net/xyisv/article/details/104482818
22+
import pickle as pkl
23+
24+
n_features = 2048
25+
n_components = 100
26+
n_top_words = 100
27+
batch_size = 128
28+
29+
def files():
30+
basedir = '/home/data/khj/workspace/huixiangdou/lda/preprocess'
31+
32+
docs = []
33+
for root, _, files in os.walk(basedir):
34+
for file in files:
35+
if file.endswith('.jpg') or file.endswith('.png') or file.endswith('.jpeg'):
36+
pdb.set_trace()
37+
else:
38+
docs.append((file, os.path.join(root, file)))
39+
return docs
40+
41+
def filecontents(dirname:str):
42+
filepaths = files()
43+
for _, filepath in filepaths:
44+
with open(filepath) as f:
45+
content = f.read()
46+
if len(content) > 0:
47+
yield content
48+
49+
def load_namemap():
50+
namemap = dict()
51+
with open('name_map.txt') as f:
52+
for line in f:
53+
parts = line.split('\t')
54+
namemap[parts[0].strip()] = parts[1].strip()
55+
return namemap
56+
57+
# reference step https://blog.csdn.net/xyisv/article/details/104482818
58+
def plot_top_words(model, feature_names, n_top_words, title):
59+
fig, axes = plt.subplots(2, 5, figsize=(30, 15), sharex=True)
60+
axes = axes.flatten()
61+
for topic_idx, topic in enumerate(model.components_):
62+
top_features_ind = topic.argsort()[-n_top_words:]
63+
top_features = feature_names[top_features_ind]
64+
weights = topic[top_features_ind]
65+
66+
ax = axes[topic_idx]
67+
ax.barh(top_features, weights, height=0.7)
68+
ax.set_title(f"Topic {topic_idx +1}", fontdict={"fontsize": 30})
69+
ax.tick_params(axis="both", which="major", labelsize=20)
70+
for i in "top right left".split():
71+
ax.spines[i].set_visible(False)
72+
fig.suptitle(title, fontsize=40)
73+
74+
plt.subplots_adjust(top=0.90, bottom=0.05, wspace=0.90, hspace=0.3)
75+
plt.savefig('topic_centers.jpg')
76+
77+
def build_topic(dirname: str='preprocess'):
78+
namemap = load_namemap()
79+
pdb.set_trace()
80+
81+
tf_vectorizer = CountVectorizer(
82+
max_df=0.95, min_df=2, max_features=n_features, stop_words="english"
83+
)
84+
85+
t0 = time()
86+
tf = tf_vectorizer.fit_transform(filecontents(dirname))
87+
print("BoW in %0.3fs." % (time() - t0))
88+
89+
lda = LatentDirichletAllocation(
90+
n_components=n_components,
91+
max_iter=5,
92+
learning_method="online",
93+
learning_offset=50.0,
94+
random_state=0,
95+
)
96+
t0 = time()
97+
doc_types = lda.fit_transform(tf)
98+
99+
pdb.set_trace()
100+
print("lda train in %0.3fs." % (time() - t0))
101+
# transform(raw_documents)[source]
102+
feature_names = tf_vectorizer.get_feature_names_out()
103+
104+
models = {'CountVectorizer': tf_vectorizer, 'LatentDirichletAllocation': lda}
105+
with open('lda_models.pkl', 'wb') as model_file:
106+
pkl.dump(models, model_file)
107+
108+
top_features_list = []
109+
for _, topic in enumerate(lda.components_):
110+
top_features_ind = topic.argsort()[-n_top_words:]
111+
top_features = feature_names[top_features_ind]
112+
weights = topic[top_features_ind]
113+
top_features_list.append(top_features.tolist())
114+
115+
with open(os.path.join('cluster', 'desc.json'), 'w') as f:
116+
json_str = json.dumps(top_features_list, ensure_ascii=False)
117+
f.write(json_str)
118+
119+
filepaths = files()
120+
121+
pdb.set_trace()
122+
for file_id, doc_score in enumerate(doc_types):
123+
basename, input_filepath = filepaths[file_id]
124+
hashname = basename.split('.')[0]
125+
source_filepath = namemap[hashname]
126+
indices_np = np.where(doc_score > 0.1)[0]
127+
for topic_id in indices_np:
128+
target_dir = os.path.join('cluster', str(topic_id))
129+
if not os.path.exists(target_dir):
130+
os.makedirs(target_dir)
131+
shutil.copy(source_filepath, target_dir)
132+
133+
if __name__ == '__main__':
134+
build_topic()

0 commit comments

Comments
 (0)