diff --git a/lazyllm/tools/rag/doc_to_db/extractor.py b/lazyllm/tools/rag/doc_to_db/extractor.py index 089692250..6f58e5fcb 100644 --- a/lazyllm/tools/rag/doc_to_db/extractor.py +++ b/lazyllm/tools/rag/doc_to_db/extractor.py @@ -65,15 +65,30 @@ class SchemaExtractor: 'map': dict, } - def __init__(self, db_config: Dict[str, Any], llm: LLMBase, *, table_prefix: Optional[str] = None, - force_refresh: bool = False, extraction_mode: ExtractionMode = ExtractionMode.TEXT, - max_len: int = ONE_DOC_LENGTH_LIMIT, num_workers: int = 4): - if not isinstance(llm, LLMBase): + def __init__(self, db_config: Optional[Dict[str, Any]] = None, llm: LLMBase = None, *, + table_prefix: Optional[str] = None, force_refresh: bool = False, + extraction_mode: ExtractionMode = ExtractionMode.TEXT, max_len: int = ONE_DOC_LENGTH_LIMIT, + num_workers: int = 4, sql_manager: Optional[SqlManager] = None): + if (db_config is None) == (sql_manager is None): + raise ValueError('Exactly one of db_config or sql_manager must be provided') + if llm is None or not isinstance(llm, LLMBase): raise TypeError('llm must be an instance of LLMBase') self._llm = llm self._table_prefix = table_prefix or self.TABLE_PREFIX self._sql_manager = None - self._db_config = db_config + if sql_manager is not None: + self._sql_manager = sql_manager + self._db_config = { + 'db_type': sql_manager._db_type, + 'user': getattr(sql_manager, '_user', None), + 'password': getattr(sql_manager, '_password', None), + 'host': getattr(sql_manager, '_host', None), + 'port': getattr(sql_manager, '_port', None), + 'db_name': getattr(sql_manager, '_db_name', None), + 'options_str': getattr(sql_manager, '_options_str', None), + } + else: + self._db_config = db_config self._table_cache: Dict[str, Type[_TableBase]] = {} self._schema_registry: Dict[str, Type[BaseModel]] = {} self._force_refresh = force_refresh @@ -179,7 +194,8 @@ def _schema_table_desc(model: Type[BaseModel]) -> str: @once_wrapper def _lazy_init(self): - self._sql_manager = self._init_sql_manager(self._db_config) if self._db_config else None + if self._sql_manager is None: + self._sql_manager = self._init_sql_manager(self._db_config) if self._db_config else None if self._sql_manager: self._ensure_management_tables() diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index 9102400bf..1221388f3 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -12,11 +12,12 @@ from .doc_manager import DocManager from .doc_impl import DocImpl, StorePlaceholder, EmbedPlaceholder, BuiltinGroups, DocumentProcessor, NodeGroupType from .doc_node import DocNode -from .doc_to_db import DocInfoSchema, DocToDbProcessor, extract_db_schema_from_files, SchemaExtractor +from .doc_to_db import SchemaExtractor +from lazyllm.tools.rag.doc_to_db.model import SchemaSetInfo, Table_ALGO_KB_SCHEMA from .store import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY from .store.store_base import DEFAULT_KB_ID from .index_base import IndexBase -from .utils import DocListManager, ensure_call_endpoint +from .utils import DocListManager, ensure_call_endpoint, _get_default_db_config from .global_metadata import GlobalMetadataDesc as DocField from .web import DocWebModule import copy @@ -147,6 +148,17 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal 'Only map store is supported for Document with temp-files') name = name or DocListManager.DEFAULT_GROUP_NAME + if schema_extractor is not None and not isinstance(schema_extractor, SchemaExtractor): + if isinstance(schema_extractor, LLMBase): + metadata_store_config = None + if isinstance(store_conf, dict): + metadata_store_config = store_conf.get('metadata_store') + metadata_store_config = metadata_store_config or _get_default_db_config( + db_name=f'{name}_metadata' + ) + schema_extractor = SchemaExtractor(db_config=metadata_store_config, llm=schema_extractor) + else: + raise ValueError(f'Invalid type for schema extractor: {type(schema_extractor)}') self._schema_extractor: SchemaExtractor = schema_extractor if isinstance(manager, Document._Manager): @@ -175,7 +187,6 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal display_name=display_name, description=description, schema_extractor=self._schema_extractor) self._curr_group = name - self._doc_to_db_processor: DocToDbProcessor = None self._graph_document: weakref.ref = None @staticmethod @@ -227,63 +238,75 @@ def url(self): assert isinstance(self._manager._kbs, ServerModule), 'Document is not a service, please set `manager` to `True`' return self._manager._kbs._url + def _resolve_schema_extractor(self, sql_manager: SqlManager) -> SchemaExtractor: + if self._schema_extractor is None: + raise ValueError('schema_extractor is required to connect sql manager') + if not isinstance(self._schema_extractor, SchemaExtractor): + raise ValueError(f'Invalid type for schema extractor: {type(self._schema_extractor)}') + if sql_manager == self._schema_extractor.sql_manager: + return self._schema_extractor + return SchemaExtractor(sql_manager=sql_manager, llm=self._schema_extractor._llm) + + @staticmethod + def _compare_schema_rows(rows: Union[List, object, None], schma: BaseModel, + extractor: SchemaExtractor) -> bool: + if schma is None: + return False + sid = extractor.register_schema_set(schma) + if not rows: + return False + if not isinstance(rows, (list, tuple, set)): + rows = [rows] + return any(getattr(row, 'schema_set_id', row) == sid for row in rows) + + def _get_schema_bind_rows(self, extractor: SchemaExtractor) -> List: + mgr = extractor.sql_manager + bind_cls = mgr.get_table_orm_class(Table_ALGO_KB_SCHEMA['name']) + if bind_cls is None: + return [] + with mgr.get_session() as s: + return s.query(bind_cls).filter_by(algo_id=self._impl.algo_name).all() + def connect_sql_manager( - self, - sql_manager: SqlManager, - schma: Optional[DocInfoSchema] = None, + self, sql_manager: SqlManager, schma: Optional[BaseModel] = None, force_refresh: bool = True, ): - def format_schema_to_dict(schema: DocInfoSchema): - if schema is None: - return None, None - desc_dict = {ele['key']: ele['desc'] for ele in schema} - type_dict = {ele['key']: ele['type'] for ele in schema} - return desc_dict, type_dict - - def compare_schema(old_schema: DocInfoSchema, new_schema: DocInfoSchema): - old_desc_dict, old_type_dict = format_schema_to_dict(old_schema) - new_desc_dict, new_type_dict = format_schema_to_dict(new_schema) - return old_desc_dict == new_desc_dict and old_type_dict == new_type_dict - # 1. Check valid arguments if sql_manager.check_connection().status != DBStatus.SUCCESS: raise RuntimeError(f'Failed to connect to sql manager: {sql_manager._gen_conn_url()}') - pre_doc_table_schema = None - if self._doc_to_db_processor: - pre_doc_table_schema = self._doc_to_db_processor.doc_info_schema - assert pre_doc_table_schema or schma, 'doc_table_schma must be given' - schema_equal = compare_schema(pre_doc_table_schema, schma) + extractor = self._resolve_schema_extractor(sql_manager) + rows = self._get_schema_bind_rows(extractor) + assert rows or schma, 'doc_table_schma must be given' + + schema_equal = self._compare_schema_rows(rows, schma, extractor) assert ( schema_equal or force_refresh is True ), 'When changing doc_table_schema, force_refresh should be set to True' # 2. Init handler if needed - need_init_processor = False - if self._doc_to_db_processor is None: - need_init_processor = True - else: - # avoid reinit for the same db - if sql_manager != self._doc_to_db_processor.sql_manager: - need_init_processor = True - if need_init_processor: - self._doc_to_db_processor = DocToDbProcessor(sql_manager) + if extractor is not self._schema_extractor: + # reuse the extractor instance used for schema comparison/registration + self._schema_extractor = extractor + self._impl._schema_extractor = extractor # 3. Reset doc_table_schema if needed if schma and not schema_equal: # This api call will clear existing db table 'lazyllm_doc_elements' - self._doc_to_db_processor._reset_doc_info_schema(schma) + schema_set_id = self._schema_extractor.register_schema_set(schma) + self._schema_extractor.register_schema_set_to_kb( + algo_id=self._impl._algo_name, schema_set_id=schema_set_id, force_refresh=True) def get_sql_manager(self): - if self._doc_to_db_processor is None: + if self._schema_extractor is None: raise ValueError('Please call connect_sql_manager to init handler first') - return self._doc_to_db_processor.sql_manager + return self._schema_extractor.sql_manager def extract_db_schema( self, llm: Union[OnlineChatModule, TrainableModule], print_schema: bool = False - ) -> DocInfoSchema: - file_paths = self._list_all_files_in_dataset() - schema = extract_db_schema_from_files(file_paths, llm) + ) -> SchemaSetInfo: + + schema = self._forward('_analyze_schema_by_llm') if print_schema: lazyllm.LOG.info(f'Extracted Schema:\n\t{schema}\n') return schema diff --git a/tests/charge_tests/Tools/test_doc_to_db.py b/tests/charge_tests/Tools/test_doc_to_db.py deleted file mode 100644 index 8ea9136e5..000000000 --- a/tests/charge_tests/Tools/test_doc_to_db.py +++ /dev/null @@ -1,56 +0,0 @@ -import unittest -import lazyllm -from lazyllm.tools import SqlManager -import pytest -import os - -class DocToDbTester(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.llm = lazyllm.OnlineChatModule(source='qwen') - data_root_dir = os.getenv('LAZYLLM_DATA_PATH') - assert data_root_dir - cls.pdf_root = os.path.join(data_root_dir, 'rag_master/default/__data/pdfs') - - @pytest.mark.skip(reason='Skip for now, will be fixed in v0.6') - def test_doc_to_db_sop(self): - sql_manager = SqlManager('SQLite', None, None, None, None, db_name=':memory:') - documents = lazyllm.Document(dataset_path=self.pdf_root, create_ui=False) - - # Test-1: Use llm to extract schema - schema_by_llm = documents.extract_db_schema(llm=self.llm, print_schema=True) - assert schema_by_llm - - # Test-2: set without schema, assert failed - with pytest.raises(AssertionError) as excinfo: - documents.connect_sql_manager(sql_manager=sql_manager, schma=None) - assert 'doc_table_schma must be given' in str(excinfo.value) - - refined_schema = [ - {'key': 'reading_time', 'desc': 'The date or time period when the book was read.', 'type': 'text'}, - {'key': 'document_title', 'desc': 'The title of the book being reviewed.', 'type': 'text'}, - {'key': 'author_name', 'desc': 'The name of the author of the book.', 'type': 'text'}, - {'key': 'publication_type', 'desc': 'The type of publication (e.g., book, journal, etc.).', 'type': 'text'}, - {'key': 'publisher_name', 'desc': 'The name of the publisher of the book.', 'type': 'text'}, - {'key': 'publication_date', 'desc': 'The date when the book was published.', 'type': 'text'}, - {'key': 'keywords', 'desc': 'Key terms or themes discussed in the book.', 'type': 'text'}, - { - 'key': 'content_summary', - 'desc': "A brief summary of the book's main content or arguments.", - 'type': 'text', - }, - {'key': 'insights', 'desc': "The reader's insights on the book's content.", 'type': 'text'}, - {'key': 'reflections', 'desc': "The reader's reflections on the book's content.", 'type': 'text'}, - ] - # Test-3: set sqlmanager, llm, with schema - documents.connect_sql_manager( - sql_manager=sql_manager, - schma=refined_schema, - force_refresh=True, - ) - # Test-4: check update run success (The extracted row exists in db means it definitely fits schema) - documents.update_database(llm=self.llm) - str_result = sql_manager.execute_query(f'select * from {documents._doc_to_db_processor.doc_table_name}') - print(f'str_result: {str_result}') - assert 'reading_report_p1' in str_result diff --git a/tests/charge_tests/Tools/test_schema_extractor.py b/tests/charge_tests/Tools/test_schema_extractor.py new file mode 100644 index 000000000..c0dca419e --- /dev/null +++ b/tests/charge_tests/Tools/test_schema_extractor.py @@ -0,0 +1,185 @@ +import json +import os +import tempfile + +import lazyllm +import pytest +from pydantic import BaseModel, Field + +from lazyllm.tools.rag import SchemaExtractor +from lazyllm.tools.rag.doc_to_db.model import Table_ALGO_KB_SCHEMA + + +class ReadingReportSchema(BaseModel): + reading_time: str = Field(description='The date or time period when the book was read.', default='') + document_title: str = Field(description='The title of the book being reviewed.', default='') + author_name: str = Field(description='The name of the author of the book.', default='') + publication_type: str = Field(description='The type of publication (e.g., book, journal, etc.).', default='') + publisher_name: str = Field(description='The name of the publisher of the book.', default='') + publication_date: str = Field(description='The date when the book was published.', default='') + keywords: str = Field(description='Key terms or themes discussed in the book.', default='') + content_summary: str = Field(description='A brief summary of the book\'s main content or arguments.', default='') + insights: str = Field(description='The reader\'s insights on the book\'s content.', default='') + reflections: str = Field(description='The reader\'s reflections on the book\'s content.', default='') + + +EXPECTED_FIELDS = { + 'reading_time', + 'document_title', + 'author_name', + 'publication_type', + 'publisher_name', + 'publication_date', + 'keywords', + 'content_summary', + 'insights', + 'reflections', +} + + +def _fetch_bind_row(sql_manager, algo_id): + bind_table = Table_ALGO_KB_SCHEMA['name'] + bind_rows = json.loads( + sql_manager.execute_query( + f'select * from {bind_table} where algo_id=\'{algo_id}\' limit 1' + ) + ) + assert isinstance(bind_rows, list) + return bind_rows[0] if bind_rows else None + + +def _get_table_name(schema_extractor, schema_set_id): + return schema_extractor._table_name(schema_set_id) + + +def _get_count(sql_manager, table_name): + count_result = json.loads( + sql_manager.execute_query(f'select count(*) as cnt from {table_name}') + ) + return count_result[0]['cnt'] if count_result else 0 + + +def _connect_and_get_table(documents, schema_extractor, algo_id, *, force_refresh): + documents.connect_sql_manager( + sql_manager=schema_extractor.sql_manager, + schma=ReadingReportSchema, + force_refresh=force_refresh, + ) + sql_manager = schema_extractor.sql_manager + bind_row = _fetch_bind_row(sql_manager, algo_id) + assert bind_row is not None + table_name = _get_table_name(schema_extractor, bind_row['schema_set_id']) + return sql_manager, table_name + + +class TestSchemaExtractor: + @classmethod + def setup_class(cls): + cls.llm = lazyllm.OnlineChatModule(source='qwen') + data_root_dir = os.getenv('LAZYLLM_DATA_PATH') + assert data_root_dir + cls.pdf_root = os.path.join(data_root_dir, 'rag_master/default/__data/pdfs') + fd, cls.db_path = tempfile.mkstemp(suffix='.db') + os.close(fd) + cls.db_config = { + 'db_type': 'sqlite', + 'user': None, + 'password': None, + 'host': None, + 'port': None, + 'db_name': cls.db_path, + } + cls.schema_extractor = SchemaExtractor( + db_config=cls.db_config, + llm=cls.llm, + force_refresh=True, + ) + + @classmethod + def teardown_class(cls): + if os.path.exists(cls.db_path): + os.remove(cls.db_path) + + def setup_method(self, method): + self.algo_id = f'doc_to_db_test_{method.__name__}' + self.documents = lazyllm.Document( + dataset_path=self.pdf_root, + name=self.algo_id, + schema_extractor=self.schema_extractor, + ) + + def test_connect_sql_manager_requires_schema(self): + # 未提供 schema 时应抛出错误 + with pytest.raises(AssertionError) as excinfo: + self.documents.connect_sql_manager( + sql_manager=self.schema_extractor.sql_manager, + schma=None, + ) + assert 'doc_table_schma must be given' in str(excinfo.value) + + def test_connect_sql_manager_creates_bind_and_table(self): + # 提供 schema 时应写入绑定映射并创建表 + sql_manager, table_name = _connect_and_get_table( + self.documents, + self.schema_extractor, + self.algo_id, + force_refresh=True, + ) + + table_cls = sql_manager.get_table_orm_class(table_name) + assert table_cls is not None + column_names = {col.name for col in table_cls.__table__.columns} + assert EXPECTED_FIELDS.issubset(column_names) + + def test_start_triggers_extraction_and_writes_rows(self): + # 启动流程后应抽取并写入结构化数据 + sql_manager, table_name = _connect_and_get_table( + self.documents, + self.schema_extractor, + self.algo_id, + force_refresh=True, + ) + + self.documents.start() + self.documents.extract_db_schema(llm=self.llm, print_schema=True) + + count_before = _get_count(sql_manager, table_name) + assert count_before > 0 + + sample_row_str = sql_manager.execute_query(f'select * from {table_name} limit 1') + print(f'sample_row: {sample_row_str}') + + def test_extract_db_schema_returns_schema_info(self): + # extract_db_schema 应返回带 schema_set_id 的结果 + _connect_and_get_table( + self.documents, + self.schema_extractor, + self.algo_id, + force_refresh=True, + ) + + self.documents.start() + schema_info = self.documents.extract_db_schema(llm=self.llm, print_schema=True) + assert schema_info is not None + assert getattr(schema_info, 'schema_set_id', None) + + def test_connect_same_schema_no_refresh(self): + # 相同 schema 再次绑定且 force_refresh=False 时不应刷新数据 + sql_manager, table_name = _connect_and_get_table( + self.documents, + self.schema_extractor, + self.algo_id, + force_refresh=True, + ) + + self.documents.start() + self.documents.extract_db_schema(llm=self.llm, print_schema=True) + count_before = _get_count(sql_manager, table_name) + + self.documents.connect_sql_manager( + sql_manager=sql_manager, + schma=ReadingReportSchema, + force_refresh=False, + ) + count_after = _get_count(sql_manager, table_name) + assert count_after == count_before