-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvector_stores.py
More file actions
187 lines (164 loc) · 7.16 KB
/
Copy pathvector_stores.py
File metadata and controls
187 lines (164 loc) · 7.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
import config_data as config
import os
class VectorStoreService(object):
def __init__(self, embedding):
self.embedding = embedding
self.vector_store = Chroma(
collection_name=config.collections_name,
embedding_function=self.embedding,
persist_directory=config.persist_directory
)
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
separators=config.separators,
)
def get_retriever(self):
return self.vector_store.as_retriever(search_kwargs={"k": config.similarity_threshold})
def add_documents(self, documents):
if not documents:
return 0
split_docs = self.text_splitter.split_documents(documents)
self.vector_store.add_documents(split_docs)
return len(split_docs)
def add_text(self, text, metadata=None):
docs = [Document(page_content=text, metadata=metadata or {})]
return self.add_documents(docs)
def _load_file(self, file_path):
_, ext = os.path.splitext(file_path)
ext = ext.lower()
if ext == ".pdf":
return self._load_pdf_langchain(file_path)
elif ext == ".docx":
return self._load_docx_langchain(file_path)
elif ext in [".xlsx", ".xls"]:
return self._load_excel_langchain(file_path)
elif ext in [".txt", ".md", ".json"]:
return self._load_text_file(file_path)
else:
raise ValueError(f"不支持的文件格式: {ext}")
def _load_pdf_langchain(self, file_path):
try:
from langchain_community.document_loaders import PyPDFLoader
loader = PyPDFLoader(file_path)
docs = loader.load()
return docs
except ImportError:
try:
from langchain_community.document_loaders import PDFMinerLoader
loader = PDFMinerLoader(file_path)
return loader.load()
except ImportError:
return self._load_pdf_fallback(file_path)
except Exception as e:
return self._load_pdf_fallback(file_path)
def _load_pdf_fallback(self, file_path):
text = ""
try:
with open(file_path, 'rb') as f:
import fitz
doc = fitz.open(file_path)
for page in doc:
text += page.get_text() + "\n"
return [Document(page_content=text.strip(), metadata={"source": file_path})]
except ImportError:
raise ValueError(
"无法解析PDF文件。请安装以下依赖之一:\n"
"pip install PyPDF2\n"
"pip install pdfminer.six\n"
"pip install pymupdf"
)
except Exception as e:
raise ValueError(f"PDF解析失败: {str(e)}")
def _load_docx_langchain(self, file_path):
try:
from langchain_community.document_loaders import Docx2txtLoader
loader = Docx2txtLoader(file_path)
return loader.load()
except ImportError:
try:
from langchain_community.document_loaders import UnstructuredWordDocumentLoader
loader = UnstructuredWordDocumentLoader(file_path)
return loader.load()
except ImportError:
return self._load_docx_fallback(file_path)
def _load_docx_fallback(self, file_path):
try:
from docx import Document
doc = Document(file_path)
text = "\n".join([para.text for para in doc.paragraphs])
return [Document(page_content=text.strip(), metadata={"source": file_path})]
except ImportError:
raise ValueError(
"无法解析Word文件。请安装以下依赖之一:\n"
"pip install docx2txt\n"
"pip install python-docx\n"
"pip install unstructured"
)
except Exception as e:
raise ValueError(f"Word文档解析失败: {str(e)}")
def _load_excel_langchain(self, file_path):
try:
from langchain_community.document_loaders import UnstructuredExcelLoader
loader = UnstructuredExcelLoader(file_path)
return loader.load()
except ImportError:
try:
from langchain_community.document_loaders import ExcelLoader
loader = ExcelLoader(file_path)
return loader.load()
except ImportError:
return self._load_excel_fallback(file_path)
def _load_excel_fallback(self, file_path):
try:
import pandas as pd
df = pd.read_excel(file_path)
text = df.to_string(index=False)
return [Document(page_content=text.strip(), metadata={"source": file_path})]
except ImportError:
raise ValueError(
"无法解析Excel文件。请安装:\n"
"pip install pandas openpyxl xlrd"
)
except Exception as e:
raise ValueError(f"Excel文档解析失败: {str(e)}")
def _load_text_file(self, file_path):
try:
from langchain_community.document_loaders import TextLoader
loader = TextLoader(file_path, encoding='utf-8')
return loader.load()
except Exception:
with open(file_path, 'r', encoding='utf-8') as f:
text = f.read()
return [Document(page_content=text, metadata={"source": file_path})]
def add_file(self, file_path, metadata=None):
if not os.path.exists(file_path):
return 0
try:
docs = self._load_file(file_path)
for doc in docs:
if metadata:
doc.metadata.update(metadata)
doc.metadata.setdefault('source', file_path)
doc.metadata.setdefault('file_name', os.path.basename(file_path))
return self.add_documents(docs)
except Exception as e:
raise e
def get_collection_stats(self):
return self.vector_store.get()
def clear_all_documents(self):
self.vector_store.delete_collection()
self.vector_store = Chroma(
collection_name=config.collections_name,
embedding_function=self.embedding,
persist_directory=config.persist_directory
)
@staticmethod
def get_supported_formats():
return ["txt", "md", "json", "pdf", "docx", "xlsx", "xls"]
if __name__ == '__main__':
from langchain_community.embeddings import DashScopeEmbeddings
retriever = VectorStoreService(DashScopeEmbeddings(model="text-embedding-v4", dashscope_api_key="sk-8f296b06c5cd4d10aa13d8eae463e6a8")).get_retriever()