-
Notifications
You must be signed in to change notification settings - Fork 364
refactor: replace DocToDbProcessor with SchemaExtractor #1026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
df6cf9f
79d5882
ad14e6d
14323da
7a26eeb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,8 @@ | |
| from .doc_manager import DocManager | ||
| from .doc_impl import DocImpl, StorePlaceholder, EmbedPlaceholder, BuiltinGroups, DocumentProcessor, NodeGroupType | ||
| from .doc_node import DocNode | ||
| from .doc_to_db import DocInfoSchema, DocToDbProcessor, extract_db_schema_from_files, SchemaExtractor | ||
| from .doc_to_db import SchemaExtractor | ||
| from lazyllm.tools.rag.doc_to_db.model import SchemaSetInfo, Table_ALGO_KB_SCHEMA | ||
| from .store import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY | ||
| from .store.store_base import DEFAULT_KB_ID | ||
| from .index_base import IndexBase | ||
|
|
@@ -175,7 +176,6 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal | |
| display_name=display_name, description=description, | ||
| schema_extractor=self._schema_extractor) | ||
| self._curr_group = name | ||
| self._doc_to_db_processor: DocToDbProcessor = None | ||
| self._graph_document: weakref.ref = None | ||
|
|
||
| @staticmethod | ||
|
|
@@ -230,60 +230,70 @@ def url(self): | |
| def connect_sql_manager( | ||
| self, | ||
| sql_manager: SqlManager, | ||
| schma: Optional[DocInfoSchema] = None, | ||
| schma: Optional[BaseModel] = None, #basemodel | ||
| force_refresh: bool = True, | ||
| ): | ||
| def format_schema_to_dict(schema: DocInfoSchema): | ||
| if schema is None: | ||
| return None, None | ||
| desc_dict = {ele['key']: ele['desc'] for ele in schema} | ||
| type_dict = {ele['key']: ele['type'] for ele in schema} | ||
| return desc_dict, type_dict | ||
|
|
||
| def compare_schema(old_schema: DocInfoSchema, new_schema: DocInfoSchema): | ||
| old_desc_dict, old_type_dict = format_schema_to_dict(old_schema) | ||
| new_desc_dict, new_type_dict = format_schema_to_dict(new_schema) | ||
| return old_desc_dict == new_desc_dict and old_type_dict == new_type_dict | ||
|
|
||
| def compare_schema(rows: Union[List, object, None], schma: BaseModel, extractor: SchemaExtractor): | ||
| if schma is None: | ||
| return False | ||
| sid = extractor.register_schema_set(schma) | ||
|
|
||
| if not rows: | ||
| return False | ||
| if not isinstance(rows, (list, tuple, set)): | ||
| rows = [rows] | ||
|
|
||
| return any(getattr(row, 'schema_set_id', row) == sid for row in rows) | ||
|
|
||
| # 1. Check valid arguments | ||
| if sql_manager.check_connection().status != DBStatus.SUCCESS: | ||
| raise RuntimeError(f'Failed to connect to sql manager: {sql_manager._gen_conn_url()}') | ||
| pre_doc_table_schema = None | ||
| if self._doc_to_db_processor: | ||
| pre_doc_table_schema = self._doc_to_db_processor.doc_info_schema | ||
| assert pre_doc_table_schema or schma, 'doc_table_schma must be given' | ||
|
|
||
| schema_equal = compare_schema(pre_doc_table_schema, schma) | ||
|
|
||
| rows: List = [] | ||
| if self._schema_extractor: | ||
| mgr = self._schema_extractor.sql_manager | ||
| Bind = mgr.get_table_orm_class(Table_ALGO_KB_SCHEMA['name']) | ||
| with mgr.get_session() as s: | ||
| rows = s.query(Bind).filter_by(algo_id=self._impl._algo_name).all() | ||
| # algoid + kbid (self._impl.algo_name) | ||
|
|
||
| assert rows or schma, 'doc_table_schma must be given' | ||
|
|
||
| extractor = self._schema_extractor or SchemaExtractor(sql_manager) | ||
|
||
| schema_equal = compare_schema(rows, schma, extractor) | ||
| assert ( | ||
| schema_equal or force_refresh is True | ||
| ), 'When changing doc_table_schema, force_refresh should be set to True' | ||
|
|
||
| # 2. Init handler if needed | ||
| need_init_processor = False | ||
| if self._doc_to_db_processor is None: | ||
| if self._schema_extractor is None: | ||
| need_init_processor = True | ||
| else: | ||
| # avoid reinit for the same db | ||
| if sql_manager != self._doc_to_db_processor.sql_manager: | ||
| if sql_manager != self._schema_extractor.sql_manager: | ||
| need_init_processor = True | ||
| if need_init_processor: | ||
| self._doc_to_db_processor = DocToDbProcessor(sql_manager) | ||
| # reuse the extractor instance used for schema comparison/registration | ||
| self._schema_extractor = extractor | ||
|
|
||
| # 3. Reset doc_table_schema if needed | ||
| if schma and not schema_equal: | ||
| # This api call will clear existing db table 'lazyllm_doc_elements' | ||
| self._doc_to_db_processor._reset_doc_info_schema(schma) | ||
| schema_set_id = self._schema_extractor.register_schema_set(schma) | ||
| self._schema_extractor.register_schema_set_to_kb( | ||
| algo_id = self._impl._algo_name, schema_set_id= schema_set_id, force_refresh=True) | ||
|
|
||
| def get_sql_manager(self): | ||
| if self._doc_to_db_processor is None: | ||
| if self._schema_extractor is None: | ||
| raise ValueError('Please call connect_sql_manager to init handler first') | ||
| return self._doc_to_db_processor.sql_manager | ||
| return self._schema_extractor.sql_manager | ||
|
|
||
| def extract_db_schema( | ||
| self, llm: Union[OnlineChatModule, TrainableModule], print_schema: bool = False | ||
| ) -> DocInfoSchema: | ||
| file_paths = self._list_all_files_in_dataset() | ||
| schema = extract_db_schema_from_files(file_paths, llm) | ||
| ) -> SchemaSetInfo: | ||
|
Comment on lines
305
to
+307
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| schema = self._forward('_analyze_schema_by_llm') | ||
| if print_schema: | ||
| lazyllm.LOG.info(f'Extracted Schema:\n\t{schema}\n') | ||
| return schema | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,56 +1,185 @@ | ||
| import unittest | ||
| import json | ||
| import os | ||
| import tempfile | ||
|
|
||
| import lazyllm | ||
| from lazyllm.tools import SqlManager | ||
| import pytest | ||
| import os | ||
| from pydantic import BaseModel, Field | ||
|
|
||
| from lazyllm.tools.rag import SchemaExtractor | ||
| from lazyllm.tools.rag.doc_to_db.model import Table_ALGO_KB_SCHEMA | ||
|
|
||
|
|
||
| class ReadingReportSchema(BaseModel): | ||
| reading_time: str = Field(description="The date or time period when the book was read.", default="") | ||
| document_title: str = Field(description="The title of the book being reviewed.", default="") | ||
| author_name: str = Field(description="The name of the author of the book.", default="") | ||
| publication_type: str = Field(description="The type of publication (e.g., book, journal, etc.).", default="") | ||
| publisher_name: str = Field(description="The name of the publisher of the book.", default="") | ||
| publication_date: str = Field(description="The date when the book was published.", default="") | ||
| keywords: str = Field(description="Key terms or themes discussed in the book.", default="") | ||
| content_summary: str = Field(description="A brief summary of the book's main content or arguments.", default="") | ||
| insights: str = Field(description="The reader's insights on the book's content.", default="") | ||
| reflections: str = Field(description="The reader's reflections on the book's content.", default="") | ||
|
|
||
|
|
||
| EXPECTED_FIELDS = { | ||
| "reading_time", | ||
| "document_title", | ||
| "author_name", | ||
| "publication_type", | ||
| "publisher_name", | ||
| "publication_date", | ||
| "keywords", | ||
| "content_summary", | ||
| "insights", | ||
| "reflections", | ||
| } | ||
|
|
||
|
|
||
| def _fetch_bind_row(sql_manager, algo_id): | ||
| bind_table = Table_ALGO_KB_SCHEMA["name"] | ||
| bind_rows = json.loads( | ||
| sql_manager.execute_query( | ||
| f"select * from {bind_table} where algo_id='{algo_id}' limit 1" | ||
| ) | ||
| ) | ||
| assert isinstance(bind_rows, list) | ||
| return bind_rows[0] if bind_rows else None | ||
|
|
||
| class DocToDbTester(unittest.TestCase): | ||
|
|
||
| def _get_table_name(schema_extractor, schema_set_id): | ||
| return schema_extractor._table_name(schema_set_id) | ||
|
|
||
|
|
||
| def _get_count(sql_manager, table_name): | ||
| count_result = json.loads( | ||
| sql_manager.execute_query(f"select count(*) as cnt from {table_name}") | ||
| ) | ||
| return count_result[0]["cnt"] if count_result else 0 | ||
|
|
||
|
|
||
| def _connect_and_get_table(documents, schema_extractor, algo_id, *, force_refresh): | ||
| documents.connect_sql_manager( | ||
| sql_manager=schema_extractor.sql_manager, | ||
| schma=ReadingReportSchema, | ||
| force_refresh=force_refresh, | ||
| ) | ||
| sql_manager = schema_extractor.sql_manager | ||
| bind_row = _fetch_bind_row(sql_manager, algo_id) | ||
| assert bind_row is not None | ||
| table_name = _get_table_name(schema_extractor, bind_row["schema_set_id"]) | ||
| return sql_manager, table_name | ||
|
|
||
|
|
||
| class TestDocToDb: | ||
| @classmethod | ||
| def setUpClass(cls): | ||
| cls.llm = lazyllm.OnlineChatModule(source='qwen') | ||
| data_root_dir = os.getenv('LAZYLLM_DATA_PATH') | ||
| def setup_class(cls): | ||
| cls.llm = lazyllm.OnlineChatModule(source="qwen") | ||
| data_root_dir = os.getenv("LAZYLLM_DATA_PATH") | ||
| assert data_root_dir | ||
| cls.pdf_root = os.path.join(data_root_dir, 'rag_master/default/__data/pdfs') | ||
| cls.pdf_root = os.path.join(data_root_dir, "rag_master/default/__data/pdfs") | ||
| fd, cls.db_path = tempfile.mkstemp(suffix=".db") | ||
| os.close(fd) | ||
| cls.db_config = { | ||
| "db_type": "sqlite", | ||
| "user": None, | ||
| "password": None, | ||
| "host": None, | ||
| "port": None, | ||
| "db_name": cls.db_path, | ||
| } | ||
| cls.schema_extractor = SchemaExtractor( | ||
| db_config=cls.db_config, | ||
| llm=cls.llm, | ||
| force_refresh=True, | ||
| ) | ||
|
|
||
| @pytest.mark.skip(reason='Skip for now, will be fixed in v0.6') | ||
| def test_doc_to_db_sop(self): | ||
| sql_manager = SqlManager('SQLite', None, None, None, None, db_name=':memory:') | ||
| documents = lazyllm.Document(dataset_path=self.pdf_root, create_ui=False) | ||
| @classmethod | ||
| def teardown_class(cls): | ||
| if os.path.exists(cls.db_path): | ||
| os.remove(cls.db_path) | ||
|
|
||
| # Test-1: Use llm to extract schema | ||
| schema_by_llm = documents.extract_db_schema(llm=self.llm, print_schema=True) | ||
| assert schema_by_llm | ||
| def setup_method(self, method): | ||
| self.algo_id = f"doc_to_db_test_{method.__name__}" | ||
| self.documents = lazyllm.Document( | ||
| dataset_path=self.pdf_root, | ||
| name=self.algo_id, | ||
| schema_extractor=self.schema_extractor, | ||
| ) | ||
|
|
||
| # Test-2: set without schema, assert failed | ||
| def test_connect_sql_manager_requires_schema(self): | ||
| # 未提供 schema 时应抛出错误 | ||
| with pytest.raises(AssertionError) as excinfo: | ||
| documents.connect_sql_manager(sql_manager=sql_manager, schma=None) | ||
| assert 'doc_table_schma must be given' in str(excinfo.value) | ||
|
|
||
| refined_schema = [ | ||
| {'key': 'reading_time', 'desc': 'The date or time period when the book was read.', 'type': 'text'}, | ||
| {'key': 'document_title', 'desc': 'The title of the book being reviewed.', 'type': 'text'}, | ||
| {'key': 'author_name', 'desc': 'The name of the author of the book.', 'type': 'text'}, | ||
| {'key': 'publication_type', 'desc': 'The type of publication (e.g., book, journal, etc.).', 'type': 'text'}, | ||
| {'key': 'publisher_name', 'desc': 'The name of the publisher of the book.', 'type': 'text'}, | ||
| {'key': 'publication_date', 'desc': 'The date when the book was published.', 'type': 'text'}, | ||
| {'key': 'keywords', 'desc': 'Key terms or themes discussed in the book.', 'type': 'text'}, | ||
| { | ||
| 'key': 'content_summary', | ||
| 'desc': "A brief summary of the book's main content or arguments.", | ||
| 'type': 'text', | ||
| }, | ||
| {'key': 'insights', 'desc': "The reader's insights on the book's content.", 'type': 'text'}, | ||
| {'key': 'reflections', 'desc': "The reader's reflections on the book's content.", 'type': 'text'}, | ||
| ] | ||
| # Test-3: set sqlmanager, llm, with schema | ||
| documents.connect_sql_manager( | ||
| sql_manager=sql_manager, | ||
| schma=refined_schema, | ||
| self.documents.connect_sql_manager( | ||
| sql_manager=self.schema_extractor.sql_manager, | ||
| schma=None, | ||
| ) | ||
| assert "doc_table_schma must be given" in str(excinfo.value) | ||
|
|
||
| def test_connect_sql_manager_creates_bind_and_table(self): | ||
| # 提供 schema 时应写入绑定映射并创建表 | ||
| sql_manager, table_name = _connect_and_get_table( | ||
| self.documents, | ||
| self.schema_extractor, | ||
| self.algo_id, | ||
| force_refresh=True, | ||
| ) | ||
|
|
||
| table_cls = sql_manager.get_table_orm_class(table_name) | ||
| assert table_cls is not None | ||
| column_names = {col.name for col in table_cls.__table__.columns} | ||
| assert EXPECTED_FIELDS.issubset(column_names) | ||
|
|
||
| def test_start_triggers_extraction_and_writes_rows(self): | ||
| # 启动流程后应抽取并写入结构化数据 | ||
| sql_manager, table_name = _connect_and_get_table( | ||
| self.documents, | ||
| self.schema_extractor, | ||
| self.algo_id, | ||
| force_refresh=True, | ||
| ) | ||
|
|
||
| self.documents.start() | ||
| self.documents.extract_db_schema(llm=self.llm, print_schema=True) | ||
|
|
||
| count_before = _get_count(sql_manager, table_name) | ||
| assert count_before > 0 | ||
|
|
||
| sample_row_str = sql_manager.execute_query(f"select * from {table_name} limit 1") | ||
| print(f"sample_row: {sample_row_str}") | ||
|
|
||
| def test_extract_db_schema_returns_schema_info(self): | ||
| # extract_db_schema 应返回带 schema_set_id 的结果 | ||
| _connect_and_get_table( | ||
| self.documents, | ||
| self.schema_extractor, | ||
| self.algo_id, | ||
| force_refresh=True, | ||
| ) | ||
|
|
||
| self.documents.start() | ||
| schema_info = self.documents.extract_db_schema(llm=self.llm, print_schema=True) | ||
| assert schema_info is not None | ||
| assert getattr(schema_info, "schema_set_id", None) | ||
|
|
||
| def test_connect_same_schema_no_refresh(self): | ||
| # 相同 schema 再次绑定且 force_refresh=False 时不应刷新数据 | ||
| sql_manager, table_name = _connect_and_get_table( | ||
| self.documents, | ||
| self.schema_extractor, | ||
| self.algo_id, | ||
| force_refresh=True, | ||
| ) | ||
| # Test-4: check update run success (The extracted row exists in db means it definitely fits schema) | ||
| documents.update_database(llm=self.llm) | ||
| str_result = sql_manager.execute_query(f'select * from {documents._doc_to_db_processor.doc_table_name}') | ||
| print(f'str_result: {str_result}') | ||
| assert 'reading_report_p1' in str_result | ||
|
|
||
| self.documents.start() | ||
| self.documents.extract_db_schema(llm=self.llm, print_schema=True) | ||
| count_before = _get_count(sql_manager, table_name) | ||
|
|
||
| self.documents.connect_sql_manager( | ||
| sql_manager=sql_manager, | ||
| schma=ReadingReportSchema, | ||
| force_refresh=False, | ||
| ) | ||
| count_after = _get_count(sql_manager, table_name) | ||
| assert count_after == count_before |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llm参数现在是可选的,默认值为None,但是这里的检查if not isinstance(llm, LLMBase):会在llm为None时引发TypeError。这会阻止在没有 LLM 的情况下初始化SchemaExtractor,而这似乎是本次重构的一个预期用例。这个检查应该被更新以处理None的情况。