-
Notifications
You must be signed in to change notification settings - Fork 365
Expand file tree
/
Copy pathimpl.py
More file actions
255 lines (224 loc) · 12.1 KB
/
impl.py
File metadata and controls
255 lines (224 loc) · 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
import time
import traceback
from typing import Any, Dict, List, Optional
from graphlib import CycleError, TopologicalSorter
from collections import defaultdict, deque
from concurrent.futures import ThreadPoolExecutor
from functools import cached_property
from itertools import repeat
from lazyllm import LOG
from ..data_loaders import DirectoryReader
from ..doc_node import DocNode
from ..global_metadata import RAG_DOC_ID, RAG_DOC_PATH, RAG_KB_ID
from ..store import LAZY_IMAGE_GROUP, LAZY_ROOT_NAME
from ..store.document_store import _DocumentStore
from ..store.store_base import DEFAULT_KB_ID
from ..store.utils import fibonacci_backoff
from ..transform import AdaptiveTransform, make_transform
from ..utils import gen_docid
from ..doc_to_db import SchemaExtractor
class _NodeGroupDependencyGraph:
def __init__(self, node_groups: Dict[str, Dict], active: List[str]):
self._shortest_path_cache: Dict[tuple[str, str], List[str]] = {}
self._forward_graph = defaultdict(set)
self._dep_graph = {node: set() for node in active}
for group in active:
cfg = node_groups.get(group)
if not cfg:
raise ValueError(f'Node group "{group}" does not exist. Please check the group name '
'or add a new one through `create_node_group`.')
if parent := cfg['parent']:
self._forward_graph[parent].add(group)
self._dep_graph[group].add(parent)
if ref := cfg.get('ref'):
self._dep_graph[group].add(ref)
@cached_property
def topological_order(self) -> List[str]:
try:
return list(TopologicalSorter(self._dep_graph).static_order())
except CycleError as e:
raise ValueError(f'Detected node group cycle dependency: {e}')
def get_shortest_path(self, start: str, end: str) -> List[str]:
# NOTE: The path from start to end is guaranteed to exist.
# The returned list does not contain `start` itself, only intermediate nodes and `end`.
key = (start, end)
if key in self._shortest_path_cache:
return self._shortest_path_cache[key]
queue = deque([(start, [])])
visited = {start}
while queue:
current, path = queue.popleft()
for neighbor in self._forward_graph.get(current, []):
if neighbor == end:
result = path + [end]
self._shortest_path_cache[key] = result
return result
if neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, path + [neighbor]))
raise AssertionError(f'No path found from {start} to {end}, the dependency graph is not valid')
class _Processor:
def __init__(self, algo_id: str, store: _DocumentStore, reader: DirectoryReader, node_groups: Dict[str, Dict],
schema_extractor: Optional[SchemaExtractor] = None, display_name: Optional[str] = None,
description: Optional[str] = None, max_workers: int = 4):
self._algo_id = algo_id
self._store = store
self._reader = reader
self._node_groups = node_groups
self._schema_extractor = schema_extractor
self._display_name = display_name
self._description = description
self._max_workers = max_workers
self._thread_pool = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix=f'{self._algo_id}_processor')
self._dependency_graph: Optional[_NodeGroupDependencyGraph] = None
@property
def store(self) -> _DocumentStore:
return self._store
@property
def reader(self) -> DirectoryReader:
return self._reader
def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # noqa: C901
metadatas: Optional[List[Dict[str, Any]]] = None, kb_id: Optional[str] = None):
try:
if not input_files: return
if not ids: ids = [gen_docid(path) for path in input_files]
temp_metas = [{RAG_DOC_ID: doc_id, RAG_DOC_PATH: path, RAG_KB_ID: kb_id or DEFAULT_KB_ID}
for doc_id, path in zip(ids, input_files)]
metadatas = [{**temp, **(metadata)} for metadata, temp in zip(metadatas or repeat({}), temp_metas)]
kb_id = metadatas[0].get(RAG_KB_ID, DEFAULT_KB_ID) if kb_id is None else kb_id
root_nodes = self._reader.load_data(input_files, metadatas, split_nodes_by_type=True)
schema_futures = []
schema_errors: List[Exception] = []
if self._schema_extractor:
doc_to_root_nodes = defaultdict(list)
for n in root_nodes[LAZY_ROOT_NAME]:
doc_to_root_nodes[n.global_metadata.get(RAG_DOC_ID)].append(n)
if doc_to_root_nodes:
for nodes in doc_to_root_nodes.values():
schema_futures.append(
self._thread_pool.submit(self._schema_extractor, nodes, algo_id=self._algo_id)
)
for k, v in root_nodes.items():
if not v: continue
self._store.update_nodes(self._set_nodes_number(v))
self._create_nodes_recursive(v, k)
for future in schema_futures:
try:
future.result()
except Exception as exc: # pragma: no cover - defensive
LOG.error(f'Schema extraction failed: {exc}')
schema_errors.append(exc)
if schema_errors:
raise schema_errors[0]
LOG.info('Add documents done!')
except Exception as e:
LOG.error(f'Add documents failed: {e}, {traceback.format_exc()}')
raise e
def close(self):
self._thread_pool.shutdown(wait=True)
self._thread_pool = None
def _set_nodes_number(self, nodes: List[DocNode]) -> List[DocNode]:
doc_group_number = {}
for node in nodes:
doc_id = node.global_metadata.get(RAG_DOC_ID)
group_name = node.group
if doc_id not in doc_group_number:
doc_group_number[doc_id] = {}
if group_name not in doc_group_number[doc_id]:
doc_group_number[doc_id][group_name] = 1
node.metadata['lazyllm_store_num'] = doc_group_number[doc_id][group_name]
doc_group_number[doc_id][group_name] += 1
return nodes
def _get_dependency_graph(self) -> _NodeGroupDependencyGraph:
if self._dependency_graph is None:
self._dependency_graph = _NodeGroupDependencyGraph(self._node_groups, self._store.activated_groups())
return self._dependency_graph
def _create_nodes_recursive(self, p_nodes: List[DocNode], p_name: str):
graph = self._get_dependency_graph()
for group_name in graph.topological_order:
group = self._node_groups.get(group_name)
if group['parent'] == p_name:
ref_path = graph.get_shortest_path(group['parent'], group.get('ref')) if group.get('ref') else []
nodes = self._create_nodes_impl(p_nodes, group_name, ref_path=ref_path)
if nodes: self._create_nodes_recursive(nodes, group_name)
def _create_nodes_impl(self, p_nodes, group_name, ref_path=None):
# NOTE transform.batch_forward will set children for p_nodes, but when calling
# transform.batch_forward, p_nodes has been upsert in the store.
t = self._node_groups[group_name]['transform']
transform = AdaptiveTransform(t) if isinstance(t, list) or t.pattern else make_transform(t, group_name)
nodes = transform.batch_forward(p_nodes, group_name, ref_path=ref_path)
self._store.update_nodes(self._set_nodes_number(nodes))
return nodes
def _get_or_create_nodes(self, group_name, uids: Optional[List[str]] = None):
nodes = self._store.get_nodes(uids=uids, group=group_name) if self._store.is_group_active(group_name) else []
if not nodes and group_name not in (LAZY_IMAGE_GROUP, LAZY_ROOT_NAME):
p_nodes = self._get_or_create_nodes(self._node_groups[group_name]['parent'], uids)
nodes = self._create_nodes_impl(p_nodes, group_name)
return nodes
def reparse(self, group_name: str, uids: Optional[List[str]] = None, doc_ids: Optional[List[str]] = None,
kb_id: Optional[str] = None, **kwargs):
if doc_ids:
self._reparse_docs(group_name=group_name, doc_ids=doc_ids, kb_id=kb_id, **kwargs)
else:
self._get_or_create_nodes(group_name, uids)
def _reparse_docs(self, group_name: str, doc_ids: List[str], doc_paths: List[str], metadatas: List[Dict],
kb_id: str = None, **kwargs):
if not metadatas:
raise ValueError('metadatas is required for reparse')
kb_id = metadatas[0].get(RAG_KB_ID, None) if kb_id is None else kb_id
if group_name == 'all':
self._store.remove_nodes(doc_ids=doc_ids, kb_id=kb_id)
removed_flag = False
for wait_time in fibonacci_backoff():
nodes = self._store.get_nodes(group=LAZY_ROOT_NAME, kb_id=kb_id, doc_ids=doc_ids)
if not nodes:
removed_flag = True
break
time.sleep(wait_time)
if not removed_flag:
raise Exception(f'Failed to remove nodes for docs {doc_ids} from store')
self.add_doc(input_files=doc_paths, ids=doc_ids, metadatas=metadatas, kb_id=kb_id)
LOG.info(f'Reparse docs {doc_ids} from store done')
else:
p_nodes = self._store.get_nodes(group=self._node_groups[group_name]['parent'],
kb_id=kb_id, doc_ids=doc_ids)
self._reparse_group_recursive(p_nodes=p_nodes, cur_name=group_name,
doc_ids=doc_ids, kb_id=kb_id)
def _reparse_group_recursive(self, p_nodes: List[DocNode], cur_name: str, doc_ids: List[str], kb_id: str = None):
kb_id = p_nodes[0].global_metadata.get(RAG_KB_ID, None) if kb_id is None else kb_id
self._store.remove_nodes(group=cur_name, kb_id=kb_id, doc_ids=doc_ids)
removed_flag = False
for wait_time in fibonacci_backoff():
nodes = self._store.get_nodes(group=cur_name, kb_id=kb_id, doc_ids=doc_ids)
if not nodes:
removed_flag = True
break
time.sleep(wait_time)
if not removed_flag:
raise Exception(f'Failed to remove nodes for docs {doc_ids} group {cur_name} from store')
t = self._node_groups[cur_name]['transform']
transform = AdaptiveTransform(t) if isinstance(t, list) or t.pattern else make_transform(t, cur_name)
nodes = transform.batch_forward(p_nodes, cur_name)
# reparse need set global_metadata
self._store.update_nodes(self._set_nodes_number(nodes))
for group_name in self._store.activated_groups():
group = self._node_groups.get(group_name)
if group is None:
raise ValueError(f'Node group "{group_name}" does not exist. Please check the group name '
'or add a new one through `create_node_group`.')
if group['parent'] == cur_name:
self._reparse_group_recursive(p_nodes=nodes, cur_name=group_name, doc_ids=doc_ids, kb_id=kb_id)
def update_doc_meta(self, doc_id: str, metadata: dict, kb_id: str = None):
try:
self._store.update_doc_meta(doc_id=doc_id, metadata=metadata, kb_id=kb_id)
except Exception as e:
LOG.error(f'Failed to update doc meta: {e}, {traceback.format_exc()}')
raise e
def delete_doc(self, doc_ids: List[str] = None, kb_id: str = None) -> None:
try:
self._store.remove_nodes(kb_id=kb_id, doc_ids=doc_ids)
if self._schema_extractor:
self._schema_extractor._delete_extract_data(algo_id=self._algo_id, kb_id=kb_id, doc_ids=doc_ids)
except Exception as e:
LOG.error(f'Failed to delete doc: {e}, {traceback.format_exc()}')
raise e