-
Notifications
You must be signed in to change notification settings - Fork 364
Expand file tree
/
Copy pathdocument.py
More file actions
442 lines (367 loc) · 21.2 KB
/
document.py
File metadata and controls
442 lines (367 loc) · 21.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
import os
from typing import Callable, Optional, Dict, Union, List, Type, Set
from functools import cached_property
from pydantic import BaseModel
import lazyllm
from lazyllm import ModuleBase, ServerModule, DynamicDescriptor, deprecated, OnlineChatModule, TrainableModule
from lazyllm.module import LLMBase
from lazyllm.launcher import LazyLLMLaunchersBase as Launcher
from lazyllm.tools.sql.sql_manager import SqlManager, DBStatus
from lazyllm.common.bind import _MetaBind
from .doc_manager import DocManager
from .doc_impl import DocImpl, StorePlaceholder, EmbedPlaceholder, BuiltinGroups, DocumentProcessor, NodeGroupType
from .doc_node import DocNode
from .doc_to_db import DocInfoSchema, DocToDbProcessor, extract_db_schema_from_files, SchemaExtractor
from .store import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY
from .store.store_base import DEFAULT_KB_ID
from .index_base import IndexBase
from .utils import DocListManager, ensure_call_endpoint
from .global_metadata import GlobalMetadataDesc as DocField
from .web import DocWebModule
import copy
import functools
import weakref
class CallableDict(dict):
def __call__(self, cls, *args, **kw):
return self[cls](*args, **kw)
class _MetaDocument(_MetaBind):
def __instancecheck__(self, __instance):
if isinstance(__instance, UrlDocument): return True
return super().__instancecheck__(__instance)
class Document(ModuleBase, BuiltinGroups, metaclass=_MetaDocument):
class _Manager(ModuleBase):
def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, Dict[str, Callable]]] = None,
manager: Union[bool, str] = False, server: Union[bool, int] = False, name: Optional[str] = None,
launcher: Optional[Launcher] = None, store_conf: Optional[Dict] = None,
doc_fields: Optional[Dict[str, DocField]] = None, cloud: bool = False,
doc_files: Optional[List[str]] = None, processor: Optional[DocumentProcessor] = None,
display_name: Optional[str] = '', description: Optional[str] = 'algorithm description',
schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None):
super().__init__()
self._origin_path, self._doc_files, self._cloud = dataset_path, doc_files, cloud
if dataset_path and not os.path.exists(dataset_path):
defatult_path = os.path.join(lazyllm.config['data_path'], dataset_path)
if os.path.exists(defatult_path):
dataset_path = defatult_path
elif dataset_path:
dataset_path = os.path.join(os.getcwd(), dataset_path)
self._launcher: Launcher = launcher if launcher else lazyllm.launchers.remote(sync=False)
self._dataset_path = dataset_path
self._embed = self._get_embeds(embed)
self._processor = processor
self._schema_extractor = self._register_submodules(schema_extractor)
name = name or DocListManager.DEFAULT_GROUP_NAME
if not display_name: display_name = name
self._dlm = None if (self._cloud or self._doc_files is not None) else DocListManager(
dataset_path, name, enable_path_monitoring=False if manager else True)
self._kbs = CallableDict({name: DocImpl(
embed=self._embed, dlm=self._dlm, doc_files=doc_files, global_metadata_desc=doc_fields,
store=store_conf, processor=processor, algo_name=name, display_name=display_name,
description=description, schema_extractor=schema_extractor)})
if manager: self._manager = ServerModule(DocManager(self._dlm), launcher=self._launcher)
if manager == 'ui': self._docweb = DocWebModule(doc_server=self._manager)
if server: self._kbs = ServerModule(self._kbs, port=(None if isinstance(server, bool) else int(server)))
self._global_metadata_desc = doc_fields
@property
def url(self):
if hasattr(self, '_manager'): return self._manager._url
return None
@property
@deprecated('Document.manager.url')
def _url(self):
return self.url
@property
def web_url(self):
if hasattr(self, '_docweb'): return self._docweb.url
return None
def _get_embeds(self, embed):
embeds = embed if isinstance(embed, dict) else {EMBED_DEFAULT_KEY: embed} if embed else {}
return self._register_submodules(embeds)
def _register_submodules(self, m):
if not m: return m
for embed in (m.values() if isinstance(m, dict) else m if isinstance(m, (tuple, list)) else [m]):
if isinstance(embed, ModuleBase): self._submodules.append(embed)
return m
def add_kb_group(self, name, doc_fields: Optional[Dict[str, DocField]] = None, store_conf: Optional[Dict] = None,
embed: Optional[Union[Callable, Dict[str, Callable]]] = None,
schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None):
embed = self._get_embeds(embed) if embed else self._embed
schema_extractor = self._register_submodules(schema_extractor) or self._schema_extractor
impl = DocImpl(dlm=self._dlm, embed=embed, kb_group_name=name, global_metadata_desc=doc_fields,
store=store_conf, schema_extractor=schema_extractor)
(self._kbs._impl._m if isinstance(self._kbs, ServerModule) else self._kbs)[name] = impl
self._dlm.add_kb_group(name=name)
def get_doc_by_kb_group(self, name):
return self._kbs._impl._m[name] if isinstance(self._kbs, ServerModule) else self._kbs[name]
def stop(self):
if hasattr(self, '_docweb'):
self._docweb.stop()
self._launcher.cleanup()
def __call__(self, *args, **kw):
return self._kbs(*args, **kw)
def __new__(cls, *args, **kw):
if url := kw.pop('url', None):
name = kw.pop('name', None)
assert not args and not kw, 'Only `name` is supported with `url`'
return UrlDocument(url, name)
else:
return super().__new__(cls)
def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Callable, Dict[str, Callable]]] = None,
create_ui: bool = False, manager: Union[bool, str, 'Document._Manager', DocumentProcessor] = False,
server: Union[bool, int] = False, name: Optional[str] = None, launcher: Optional[Launcher] = None,
doc_files: Optional[List[str]] = None, doc_fields: Dict[str, DocField] = None,
store_conf: Optional[Dict] = None, display_name: Optional[str] = '',
description: Optional[str] = 'algorithm description',
schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None):
super().__init__()
if create_ui:
lazyllm.LOG.warning('`create_ui` for Document is deprecated, use `manager` instead')
manager = create_ui
if isinstance(dataset_path, (tuple, list)):
doc_fields = dataset_path
dataset_path = None
if doc_files is not None:
assert dataset_path is None and not manager, (
'Manager and dataset_path are not supported for Document with temp-files')
assert store_conf is None or store_conf['type'] == 'map', (
'Only map store is supported for Document with temp-files')
name = name or DocListManager.DEFAULT_GROUP_NAME
if isinstance(manager, Document._Manager):
assert not server, 'Server infomation is already set to by manager'
assert not launcher, 'Launcher infomation is already set to by manager'
assert not manager._cloud, 'manager is not allowed to share in cloud mode'
assert manager._doc_files is None, 'manager is not allowed to share with temp files'
if dataset_path != manager._dataset_path and dataset_path != manager._origin_path:
raise RuntimeError(f'Document path mismatch, expected `{manager._dataset_path}`'
f'while received `{dataset_path}`')
manager.add_kb_group(name=name, doc_fields=doc_fields, store_conf=store_conf, embed=embed,
schema_extractor=schema_extractor)
self._manager = manager
self._curr_group = name
else:
if isinstance(manager, DocumentProcessor):
processor, cloud = manager, True
processor.start()
manager = False
assert name, '`Name` of Document is necessary when using cloud service'
assert store_conf.get('type') != 'map', 'Cloud manager is not supported when using map store'
assert not dataset_path, 'Cloud manager is not supported with local dataset path'
else:
cloud, processor = False, None
self._manager = Document._Manager(dataset_path, embed, manager, server, name, launcher, store_conf,
doc_fields, cloud=cloud, doc_files=doc_files, processor=processor,
display_name=display_name, description=description,
schema_extractor=schema_extractor)
self._curr_group = name
self._doc_to_db_processor: DocToDbProcessor = None
self._graph_document: weakref.ref = None
@staticmethod
def list_all_files_in_directory(
dataset_path: str, skip_hidden_path: bool = True, recursive: bool = True
) -> List[str]:
files_list = []
if not os.path.exists(dataset_path):
return files_list
if not os.path.isdir(dataset_path):
return [dataset_path] if os.path.isfile(dataset_path) else files_list
if recursive:
for root, dirs, files in os.walk(os.path.abspath(dataset_path)):
# Skip hidden directories
if skip_hidden_path:
path_parts = root.split(os.sep)
if any(part.startswith('.') for part in path_parts if part):
continue
# Filter out hidden directories
dirs[:] = [d for d in dirs if not d.startswith('.')]
# Skip hidden files
if skip_hidden_path:
files = [file_path for file_path in files if not file_path.startswith('.')]
files = [os.path.join(root, file_path) for file_path in files]
files_list.extend(files)
else:
items = os.listdir(dataset_path)
for item in items:
item_path = os.path.join(dataset_path, item)
# Skip hidden files/directories
if skip_hidden_path and item.startswith('.'):
continue
# Only add files, not directories
if os.path.isfile(item_path):
files_list.append(item_path)
return files_list
def _list_all_files_in_dataset(self, skip_hidden_path: bool = True) -> List[str]:
return self.list_all_files_in_directory(self._manager._dataset_path, skip_hidden_path)
@property
def url(self):
assert isinstance(self._manager._kbs, ServerModule), 'Document is not a service, please set `manager` to `True`'
return self._manager._kbs._url
def connect_sql_manager(
self,
sql_manager: SqlManager,
schma: Optional[DocInfoSchema] = None,
force_refresh: bool = True,
):
def format_schema_to_dict(schema: DocInfoSchema):
if schema is None:
return None, None
desc_dict = {ele['key']: ele['desc'] for ele in schema}
type_dict = {ele['key']: ele['type'] for ele in schema}
return desc_dict, type_dict
def compare_schema(old_schema: DocInfoSchema, new_schema: DocInfoSchema):
old_desc_dict, old_type_dict = format_schema_to_dict(old_schema)
new_desc_dict, new_type_dict = format_schema_to_dict(new_schema)
return old_desc_dict == new_desc_dict and old_type_dict == new_type_dict
# 1. Check valid arguments
if sql_manager.check_connection().status != DBStatus.SUCCESS:
raise RuntimeError(f'Failed to connect to sql manager: {sql_manager._gen_conn_url()}')
pre_doc_table_schema = None
if self._doc_to_db_processor:
pre_doc_table_schema = self._doc_to_db_processor.doc_info_schema
assert pre_doc_table_schema or schma, 'doc_table_schma must be given'
schema_equal = compare_schema(pre_doc_table_schema, schma)
assert (
schema_equal or force_refresh is True
), 'When changing doc_table_schema, force_refresh should be set to True'
# 2. Init handler if needed
need_init_processor = False
if self._doc_to_db_processor is None:
need_init_processor = True
else:
# avoid reinit for the same db
if sql_manager != self._doc_to_db_processor.sql_manager:
need_init_processor = True
if need_init_processor:
self._doc_to_db_processor = DocToDbProcessor(sql_manager)
# 3. Reset doc_table_schema if needed
if schma and not schema_equal:
# This api call will clear existing db table 'lazyllm_doc_elements'
self._doc_to_db_processor._reset_doc_info_schema(schma)
def get_sql_manager(self):
if self._doc_to_db_processor is None:
raise ValueError('Please call connect_sql_manager to init handler first')
return self._doc_to_db_processor.sql_manager
def extract_db_schema(
self, llm: Union[OnlineChatModule, TrainableModule], print_schema: bool = False
) -> DocInfoSchema:
file_paths = self._list_all_files_in_dataset()
schema = extract_db_schema_from_files(file_paths, llm)
if print_schema:
lazyllm.LOG.info(f'Extracted Schema:\n\t{schema}\n')
return schema
def update_database(self, llm: Union[OnlineChatModule, TrainableModule]):
assert self._doc_to_db_processor, 'Please call connect_db to init handler first'
file_paths = self._list_all_files_in_dataset()
info_dicts = self._doc_to_db_processor.extract_info_from_docs(llm, file_paths)
self._doc_to_db_processor.export_info_to_db(info_dicts)
@deprecated('Document(dataset_path, manager=doc.manager, name=xx, doc_fields=xx, store_conf=xx)')
def create_kb_group(self, name: str, doc_fields: Optional[Dict[str, DocField]] = None,
store_conf: Optional[Dict] = None) -> 'Document':
self._manager.add_kb_group(name=name, doc_fields=doc_fields, store_conf=store_conf)
doc = copy.copy(self)
doc._curr_group = name
return doc
@property
@deprecated('Document._manager')
def _impls(self): return self._manager
@property
def _impl(self) -> DocImpl: return self._manager.get_doc_by_kb_group(self._curr_group)
@property
def manager(self): return self._manager._processor or self._manager
def activate_group(self, group_name: str, embed_keys: Optional[Union[str, List[str]]] = None,
enable_embed: bool = True):
if embed_keys and not enable_embed:
raise ValueError('`enable_embed` must be set to True when `embed_keys` is provided')
# if embed_keys is None, use default embed keys
if (enable_embed and not embed_keys) and self._manager._embed:
embed_keys = self._manager._embed.keys()
if isinstance(embed_keys, str): embed_keys = [embed_keys]
self._impl.activate_group(group_name, embed_keys, enable_embed)
def activate_groups(self, groups: Union[str, List[str]], **kwargs):
if isinstance(groups, str): groups = [groups]
for group in groups:
self.activate_group(group, **kwargs)
@DynamicDescriptor
def create_node_group(self, name: str = None, *, transform: Callable, parent: str = LAZY_ROOT_NAME,
trans_node: bool = None, num_workers: int = 0, display_name: str = None,
ref: str = None, group_type: NodeGroupType = NodeGroupType.CHUNK, **kwargs) -> None:
assert ref is None or parent != ref, 'parent and ref must be different'
if isinstance(self, type):
DocImpl.create_global_node_group(name, transform=transform, parent=parent, trans_node=trans_node,
num_workers=num_workers, display_name=display_name,
group_type=group_type, ref=ref, **kwargs)
else:
self._impl.create_node_group(name, transform=transform, parent=parent, trans_node=trans_node,
num_workers=num_workers, display_name=display_name, group_type=group_type,
ref=ref, **kwargs)
@DynamicDescriptor
def add_reader(self, pattern: str, func: Optional[Callable] = None):
if isinstance(self, type):
return DocImpl.register_global_reader(pattern=pattern, func=func)
else:
self._impl.add_reader(pattern, func)
@classmethod
def register_global_reader(cls, pattern: str, func: Optional[Callable] = None):
return cls.add_reader(pattern, func)
def get_store(self):
return StorePlaceholder()
def get_embed(self):
return EmbedPlaceholder()
def register_index(self, index_type: str, index_cls: IndexBase, *args, **kwargs) -> None:
self._impl.register_index(index_type, index_cls, *args, **kwargs)
def _forward(self, func_name: str, *args, **kw):
return self._manager(self._curr_group, func_name, *args, **kw)
def find_parent(self, target) -> Callable:
return functools.partial(self._forward, 'find_parent', group=target)
def find_children(self, target) -> Callable:
return functools.partial(self._forward, 'find_children', group=target)
def find(self, target) -> Callable:
return functools.partial(self._forward, 'find', group=target)
def forward(self, *args, **kw) -> List[DocNode]:
return self._forward('retrieve', *args, **kw)
def clear_cache(self, group_names: Optional[List[str]] = None) -> None:
return self._forward('clear_cache', group_names)
def drop_algorithm(self):
return self._forward('drop_algorithm')
def analyze_schema_by_llm(self, kb_id: Optional[str] = None, doc_ids: Optional[List[str]] = None):
return self._forward('_analyze_schema_by_llm', kb_id, doc_ids)
def register_schema_set(self, schema_set: Type[BaseModel], kb_id: Optional[str] = DEFAULT_KB_ID,
force_refresh: bool = False) -> str:
return self._forward('_register_schema_set', schema_set, kb_id, force_refresh)
def get_nodes(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] = None,
group: Optional[str] = None, kb_id: Optional[str] = None, numbers: Optional[Set] = None
) -> List[DocNode]:
return self._forward('_get_nodes', uids, doc_ids, group, kb_id, numbers)
def get_window_nodes(self, node: DocNode, span: tuple[int, int] = (-5, 5),
merge: bool = False) -> Union[List[DocNode], DocNode]:
return self._forward('_get_window_nodes', node, span, merge)
def _get_post_process_tasks(self):
return lazyllm.pipeline(lambda *a: self._forward('_lazy_init'))
def __repr__(self):
return lazyllm.make_repr('Module', 'Document', manager=hasattr(self._manager, '_manager'),
server=isinstance(self._manager._kbs, ServerModule))
class UrlDocument(ModuleBase):
def __init__(self, url: str, name: str = None):
super().__init__()
self._missing_keys = set(dir(Document)) - set(dir(UrlDocument))
self._manager = lazyllm.UrlModule(url=ensure_call_endpoint(url))
self._curr_group = name or DocListManager.DEFAULT_GROUP_NAME
def _forward(self, func_name: str, *args, **kwargs):
args = (self._curr_group, func_name, *args)
return self._manager._call('__call__', *args, **kwargs)
def find(self, target) -> Callable:
return functools.partial(self._forward, 'find', group=target)
def forward(self, *args, **kw):
return self._forward('retrieve', *args, **kw)
def get_nodes(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] = None,
group: Optional[str] = None, kb_id: Optional[str] = None, numbers: Optional[Set] = None
) -> List[DocNode]:
return self._forward('_get_nodes', uids, doc_ids, group, kb_id, numbers)
def get_window_nodes(self, node: DocNode, span: tuple[int, int] = (-5, 5),
merge: bool = False) -> Union[List[DocNode], DocNode]:
return self._forward('_get_window_nodes', node, span, merge)
@cached_property
def active_node_groups(self):
return self._forward('active_node_groups')
def __getattr__(self, name):
if name in self._missing_keys:
raise RuntimeError(f'Document generated with url and name has no attribute `{name}`')