Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lazyllm/tools/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
CharacterSplitter, RecursiveSplitter, MarkdownSplitter, CodeSplitter,
JSONSplitter, YAMLSplitter, HTMLSplitter, XMLSplitter, GeneralCodeSplitter, JSONLSplitter)
from .similarity import register_similarity
from .doc_node import DocNode
from .doc_node import DocNode, RichDocNode
from .readers import (PDFReader, DocxReader, HWPReader, PPTXReader, ImageReader, IPYNBReader, EpubReader,
MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader,
MineruPDFReader)
Expand Down Expand Up @@ -46,6 +46,7 @@
'register_similarity',
'register_reranker',
'DocNode',
'RichDocNode',
'PDFReader',
'DocxReader',
'HWPReader',
Expand Down
8 changes: 4 additions & 4 deletions lazyllm/tools/rag/doc_to_db/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from lazyllm import LOG, ThreadPoolExecutor, once_wrapper
from lazyllm.components import JsonFormatter
from lazyllm.module import LLMBase
from lazyllm.module import LLMBase, ModuleBase

from ...sql.sql_manager import DBStatus, SqlManager
from ..doc_node import DocNode
Expand All @@ -33,7 +33,7 @@
ONE_DOC_LENGTH_LIMIT = 102400


class SchemaExtractor:
class SchemaExtractor(ModuleBase):
'''Schema aware extractor that materializes BaseModel schemas into database tables.'''

TABLE_PREFIX = 'lazyllm_schema'
Expand Down Expand Up @@ -724,8 +724,8 @@ def _get_extract_data(self, algo_id: str, doc_ids: List[str], # noqa: C901
results.append(ExtractResult(data=row_data, metadata=meta))
return results

def __call__(self, data: Union[str, List[DocNode]],
algo_id: str = DocListManager.DEFAULT_GROUP_NAME) -> ExtractResult:
def forward(self, data: Union[str, List[DocNode]],
algo_id: str = DocListManager.DEFAULT_GROUP_NAME) -> ExtractResult:
# NOTE: data should be from single file source (kb_id, doc_id should be the same)
self._lazy_init()
res = self.extract_and_store(data=data, algo_id=algo_id)
Expand Down
36 changes: 19 additions & 17 deletions lazyllm/tools/rag/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable,
self._dataset_path = dataset_path
self._embed = self._get_embeds(embed)
self._processor = processor
self._schema_extractor = self._register_submodules(schema_extractor)
name = name or DocListManager.DEFAULT_GROUP_NAME
if not display_name: display_name = name

Expand Down Expand Up @@ -90,21 +91,22 @@ def web_url(self):

def _get_embeds(self, embed):
embeds = embed if isinstance(embed, dict) else {EMBED_DEFAULT_KEY: embed} if embed else {}
for embed in embeds.values():
if isinstance(embed, ModuleBase):
self._submodules.append(embed)
return embeds

def add_kb_group(self, name, doc_fields: Optional[Dict[str, DocField]] = None,
store_conf: Optional[Dict] = None,
embed: Optional[Union[Callable, Dict[str, Callable]]] = None):
return self._register_submodules(embeds)

def _register_submodules(self, m):
if not m: return m
for embed in (m.values() if isinstance(m, dict) else m if isinstance(m, (tuple, list)) else [m]):
if isinstance(embed, ModuleBase): self._submodules.append(embed)
return m

def add_kb_group(self, name, doc_fields: Optional[Dict[str, DocField]] = None, store_conf: Optional[Dict] = None,
embed: Optional[Union[Callable, Dict[str, Callable]]] = None,
schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None):
embed = self._get_embeds(embed) if embed else self._embed
if isinstance(self._kbs, ServerModule):
self._kbs._impl._m[name] = DocImpl(dlm=self._dlm, embed=embed, kb_group_name=name,
global_metadata_desc=doc_fields, store=store_conf)
else:
self._kbs[name] = DocImpl(dlm=self._dlm, embed=self._embed, kb_group_name=name,
global_metadata_desc=doc_fields, store=store_conf)
schema_extractor = self._register_submodules(schema_extractor) or self._schema_extractor
impl = DocImpl(dlm=self._dlm, embed=embed, kb_group_name=name, global_metadata_desc=doc_fields,
store=store_conf, schema_extractor=schema_extractor)
(self._kbs._impl._m if isinstance(self._kbs, ServerModule) else self._kbs)[name] = impl
self._dlm.add_kb_group(name=name)

def get_doc_by_kb_group(self, name):
Expand Down Expand Up @@ -147,7 +149,6 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal
'Only map store is supported for Document with temp-files')

name = name or DocListManager.DEFAULT_GROUP_NAME
self._schema_extractor: SchemaExtractor = schema_extractor

if isinstance(manager, Document._Manager):
assert not server, 'Server infomation is already set to by manager'
Expand All @@ -157,7 +158,8 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal
if dataset_path != manager._dataset_path and dataset_path != manager._origin_path:
raise RuntimeError(f'Document path mismatch, expected `{manager._dataset_path}`'
f'while received `{dataset_path}`')
manager.add_kb_group(name=name, doc_fields=doc_fields, store_conf=store_conf, embed=embed)
manager.add_kb_group(name=name, doc_fields=doc_fields, store_conf=store_conf, embed=embed,
schema_extractor=schema_extractor)
self._manager = manager
self._curr_group = name
else:
Expand All @@ -173,7 +175,7 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal
self._manager = Document._Manager(dataset_path, embed, manager, server, name, launcher, store_conf,
doc_fields, cloud=cloud, doc_files=doc_files, processor=processor,
display_name=display_name, description=description,
schema_extractor=self._schema_extractor)
schema_extractor=schema_extractor)
self._curr_group = name
self._doc_to_db_processor: DocToDbProcessor = None
self._graph_document: weakref.ref = None
Expand Down
10 changes: 4 additions & 6 deletions lazyllm/tools/rag/parsing_service/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import defaultdict, deque
from concurrent.futures import ThreadPoolExecutor
from functools import cached_property
from itertools import repeat

from lazyllm import LOG

Expand Down Expand Up @@ -95,12 +96,9 @@ def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # no
try:
if not input_files: return
if not ids: ids = [gen_docid(path) for path in input_files]
if metadatas is None:
metadatas = [{} for _ in input_files]
for metadata, doc_id, path in zip(metadatas, ids, input_files):
metadata.setdefault(RAG_DOC_ID, doc_id)
metadata.setdefault(RAG_DOC_PATH, path)
metadata.setdefault(RAG_KB_ID, kb_id or DEFAULT_KB_ID)
temp_metas = [{RAG_DOC_ID: doc_id, RAG_DOC_PATH: path, RAG_KB_ID: kb_id or DEFAULT_KB_ID}
for doc_id, path in zip(ids, input_files)]
metadatas = [{**temp, **(metadata)} for metadata, temp in zip(metadatas or repeat({}), temp_metas)]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current metadata merging logic might raise a TypeError if individual items within the metadatas list are None. For example, if metadatas is [None, {'user_key': 'user_value'}], the dictionary unpacking **(metadata) will fail for the None item. It's safer to ensure that metadata is always a dictionary before unpacking it.

Suggested change
metadatas = [{**temp, **(metadata)} for metadata, temp in zip(metadatas or repeat({}), temp_metas)]
metadatas = [{**temp, **(metadata or {})} for metadata, temp in zip(metadatas or repeat({}), temp_metas)]

kb_id = metadatas[0].get(RAG_KB_ID, DEFAULT_KB_ID) if kb_id is None else kb_id
root_nodes = self._reader.load_data(input_files, metadatas, split_nodes_by_type=True)
schema_futures = []
Expand Down