Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions lazyllm/tools/rag/doc_to_db/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class SchemaExtractor:

def __init__(self, db_config: Dict[str, Any], llm: LLMBase, *, table_prefix: Optional[str] = None,
force_refresh: bool = False, extraction_mode: ExtractionMode = ExtractionMode.TEXT,
max_len: int = ONE_DOC_LENGTH_LIMIT, num_workers: int = 4):
max_len: int = ONE_DOC_LENGTH_LIMIT, num_workers: int = 4, sql_manager: Optional[SqlManager] = None):
if not isinstance(llm, LLMBase):
raise TypeError('llm must be an instance of LLMBase')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

llm 参数现在是可选的,默认值为 None,但是这里的检查 if not isinstance(llm, LLMBase): 会在 llmNone 时引发 TypeError。这会阻止在没有 LLM 的情况下初始化 SchemaExtractor,而这似乎是本次重构的一个预期用例。这个检查应该被更新以处理 None 的情况。

Suggested change
if not isinstance(llm, LLMBase):
raise TypeError('llm must be an instance of LLMBase')
if llm is not None and not isinstance(llm, LLMBase):

self._llm = llm
Expand All @@ -80,6 +80,7 @@ def __init__(self, db_config: Dict[str, Any], llm: LLMBase, *, table_prefix: Opt
self._extraction_mode = extraction_mode
self._max_len = max_len
self._num_workers = num_workers
self._sql_manager = sql_manager

@property
def sql_manager(self) -> SqlManager:
Expand Down Expand Up @@ -179,7 +180,8 @@ def _schema_table_desc(model: Type[BaseModel]) -> str:

@once_wrapper
def _lazy_init(self):
self._sql_manager = self._init_sql_manager(self._db_config) if self._db_config else None
if self._sql_manager is None:
self._sql_manager = self._init_sql_manager(self._db_config) if self._db_config else None
if self._sql_manager:
self._ensure_management_tables()

Expand Down
68 changes: 39 additions & 29 deletions lazyllm/tools/rag/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from .doc_manager import DocManager
from .doc_impl import DocImpl, StorePlaceholder, EmbedPlaceholder, BuiltinGroups, DocumentProcessor, NodeGroupType
from .doc_node import DocNode
from .doc_to_db import DocInfoSchema, DocToDbProcessor, extract_db_schema_from_files, SchemaExtractor
from .doc_to_db import SchemaExtractor
from lazyllm.tools.rag.doc_to_db.model import SchemaSetInfo, Table_ALGO_KB_SCHEMA
from .store import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY
from .store.store_base import DEFAULT_KB_ID
from .index_base import IndexBase
Expand Down Expand Up @@ -175,7 +176,6 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal
display_name=display_name, description=description,
schema_extractor=self._schema_extractor)
self._curr_group = name
self._doc_to_db_processor: DocToDbProcessor = None
self._graph_document: weakref.ref = None

@staticmethod
Expand Down Expand Up @@ -230,60 +230,70 @@ def url(self):
def connect_sql_manager(
self,
sql_manager: SqlManager,
schma: Optional[DocInfoSchema] = None,
schma: Optional[BaseModel] = None, #basemodel
force_refresh: bool = True,
):
def format_schema_to_dict(schema: DocInfoSchema):
if schema is None:
return None, None
desc_dict = {ele['key']: ele['desc'] for ele in schema}
type_dict = {ele['key']: ele['type'] for ele in schema}
return desc_dict, type_dict

def compare_schema(old_schema: DocInfoSchema, new_schema: DocInfoSchema):
old_desc_dict, old_type_dict = format_schema_to_dict(old_schema)
new_desc_dict, new_type_dict = format_schema_to_dict(new_schema)
return old_desc_dict == new_desc_dict and old_type_dict == new_type_dict

def compare_schema(rows: Union[List, object, None], schma: BaseModel, extractor: SchemaExtractor):
if schma is None:
return False
sid = extractor.register_schema_set(schma)

if not rows:
return False
if not isinstance(rows, (list, tuple, set)):
rows = [rows]

return any(getattr(row, 'schema_set_id', row) == sid for row in rows)

# 1. Check valid arguments
if sql_manager.check_connection().status != DBStatus.SUCCESS:
raise RuntimeError(f'Failed to connect to sql manager: {sql_manager._gen_conn_url()}')
pre_doc_table_schema = None
if self._doc_to_db_processor:
pre_doc_table_schema = self._doc_to_db_processor.doc_info_schema
assert pre_doc_table_schema or schma, 'doc_table_schma must be given'

schema_equal = compare_schema(pre_doc_table_schema, schma)

rows: List = []
if self._schema_extractor:
mgr = self._schema_extractor.sql_manager
Bind = mgr.get_table_orm_class(Table_ALGO_KB_SCHEMA['name'])
with mgr.get_session() as s:
rows = s.query(Bind).filter_by(algo_id=self._impl._algo_name).all()
# algoid + kbid (self._impl.algo_name)

assert rows or schma, 'doc_table_schma must be given'

extractor = self._schema_extractor or SchemaExtractor(sql_manager)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line introduces a critical issue with two failure modes:

  1. If self._schema_extractor is None, the code attempts to instantiate SchemaExtractor(sql_manager). This will fail with a TypeError because the SchemaExtractor constructor requires db_config (a dictionary) and llm as its first two positional arguments. The current call incorrectly passes an SqlManager object for db_config and omits the required llm.

  2. If self._schema_extractor was initialized with an LLMBase instance, extractor here would be that LLMBase instance. The subsequent call to compare_schema would then fail with an AttributeError because it expects a SchemaExtractor instance and calls methods like register_schema_set.

To fix this, the function must ensure it has a valid SchemaExtractor instance. This might involve ensuring self._impl._create_schema_extractor() is called beforehand and then consistently using the resulting extractor instance. The state management of _schema_extractor between Document and DocImpl may also need to be reviewed for consistency.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

SchemaExtractor(sql_manager) 这个调用是错误的。在 SchemaExtractor.__init__ 中,sql_manager 是一个仅关键字参数(keyword-only argument),所以这个调用会引发 TypeError。即使它是一个位置参数,它也会被错误地赋给 db_config,从而在后续导致错误。这个调用应该使用关键字参数,例如 SchemaExtractor(sql_manager=sql_manager, ...)。此外,SchemaExtractor 初始化可能需要的 llm 实例没有被提供,这个问题也需要解决。

schema_equal = compare_schema(rows, schma, extractor)
assert (
schema_equal or force_refresh is True
), 'When changing doc_table_schema, force_refresh should be set to True'

# 2. Init handler if needed
need_init_processor = False
if self._doc_to_db_processor is None:
if self._schema_extractor is None:
need_init_processor = True
else:
# avoid reinit for the same db
if sql_manager != self._doc_to_db_processor.sql_manager:
if sql_manager != self._schema_extractor.sql_manager:
need_init_processor = True
if need_init_processor:
self._doc_to_db_processor = DocToDbProcessor(sql_manager)
# reuse the extractor instance used for schema comparison/registration
self._schema_extractor = extractor

# 3. Reset doc_table_schema if needed
if schma and not schema_equal:
# This api call will clear existing db table 'lazyllm_doc_elements'
self._doc_to_db_processor._reset_doc_info_schema(schma)
schema_set_id = self._schema_extractor.register_schema_set(schma)
self._schema_extractor.register_schema_set_to_kb(
algo_id = self._impl._algo_name, schema_set_id= schema_set_id, force_refresh=True)

def get_sql_manager(self):
if self._doc_to_db_processor is None:
if self._schema_extractor is None:
raise ValueError('Please call connect_sql_manager to init handler first')
return self._doc_to_db_processor.sql_manager
return self._schema_extractor.sql_manager

def extract_db_schema(
self, llm: Union[OnlineChatModule, TrainableModule], print_schema: bool = False
) -> DocInfoSchema:
file_paths = self._list_all_files_in_dataset()
schema = extract_db_schema_from_files(file_paths, llm)
) -> SchemaSetInfo:
Comment on lines 305 to +307

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

重构后,extract_db_schema 方法中的 llm 参数不再被使用。该方法现在调用 _analyze_schema_by_llm,它依赖于在 Document 初始化时在 SchemaExtractor 中配置的 llm 实例。这对 API 用户来说是有误导性的,并且相比之前直接使用传入 llm 的行为是一种功能退步。请从方法签名中移除 llm 参数,或者更新实现来使用它。

schema = self._forward('_analyze_schema_by_llm')
if print_schema:
lazyllm.LOG.info(f'Extracted Schema:\n\t{schema}\n')
return schema
Expand Down
217 changes: 173 additions & 44 deletions tests/charge_tests/Tools/test_doc_to_db.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,185 @@
import unittest
import json
import os
import tempfile

import lazyllm
from lazyllm.tools import SqlManager
import pytest
import os
from pydantic import BaseModel, Field

from lazyllm.tools.rag import SchemaExtractor
from lazyllm.tools.rag.doc_to_db.model import Table_ALGO_KB_SCHEMA


class ReadingReportSchema(BaseModel):
reading_time: str = Field(description="The date or time period when the book was read.", default="")
document_title: str = Field(description="The title of the book being reviewed.", default="")
author_name: str = Field(description="The name of the author of the book.", default="")
publication_type: str = Field(description="The type of publication (e.g., book, journal, etc.).", default="")
publisher_name: str = Field(description="The name of the publisher of the book.", default="")
publication_date: str = Field(description="The date when the book was published.", default="")
keywords: str = Field(description="Key terms or themes discussed in the book.", default="")
content_summary: str = Field(description="A brief summary of the book's main content or arguments.", default="")
insights: str = Field(description="The reader's insights on the book's content.", default="")
reflections: str = Field(description="The reader's reflections on the book's content.", default="")


EXPECTED_FIELDS = {
"reading_time",
"document_title",
"author_name",
"publication_type",
"publisher_name",
"publication_date",
"keywords",
"content_summary",
"insights",
"reflections",
}


def _fetch_bind_row(sql_manager, algo_id):
bind_table = Table_ALGO_KB_SCHEMA["name"]
bind_rows = json.loads(
sql_manager.execute_query(
f"select * from {bind_table} where algo_id='{algo_id}' limit 1"
)
)
assert isinstance(bind_rows, list)
return bind_rows[0] if bind_rows else None

class DocToDbTester(unittest.TestCase):

def _get_table_name(schema_extractor, schema_set_id):
return schema_extractor._table_name(schema_set_id)


def _get_count(sql_manager, table_name):
count_result = json.loads(
sql_manager.execute_query(f"select count(*) as cnt from {table_name}")
)
return count_result[0]["cnt"] if count_result else 0


def _connect_and_get_table(documents, schema_extractor, algo_id, *, force_refresh):
documents.connect_sql_manager(
sql_manager=schema_extractor.sql_manager,
schma=ReadingReportSchema,
force_refresh=force_refresh,
)
sql_manager = schema_extractor.sql_manager
bind_row = _fetch_bind_row(sql_manager, algo_id)
assert bind_row is not None
table_name = _get_table_name(schema_extractor, bind_row["schema_set_id"])
return sql_manager, table_name


class TestDocToDb:
@classmethod
def setUpClass(cls):
cls.llm = lazyllm.OnlineChatModule(source='qwen')
data_root_dir = os.getenv('LAZYLLM_DATA_PATH')
def setup_class(cls):
cls.llm = lazyllm.OnlineChatModule(source="qwen")
data_root_dir = os.getenv("LAZYLLM_DATA_PATH")
assert data_root_dir
cls.pdf_root = os.path.join(data_root_dir, 'rag_master/default/__data/pdfs')
cls.pdf_root = os.path.join(data_root_dir, "rag_master/default/__data/pdfs")
fd, cls.db_path = tempfile.mkstemp(suffix=".db")
os.close(fd)
cls.db_config = {
"db_type": "sqlite",
"user": None,
"password": None,
"host": None,
"port": None,
"db_name": cls.db_path,
}
cls.schema_extractor = SchemaExtractor(
db_config=cls.db_config,
llm=cls.llm,
force_refresh=True,
)

@pytest.mark.skip(reason='Skip for now, will be fixed in v0.6')
def test_doc_to_db_sop(self):
sql_manager = SqlManager('SQLite', None, None, None, None, db_name=':memory:')
documents = lazyllm.Document(dataset_path=self.pdf_root, create_ui=False)
@classmethod
def teardown_class(cls):
if os.path.exists(cls.db_path):
os.remove(cls.db_path)

# Test-1: Use llm to extract schema
schema_by_llm = documents.extract_db_schema(llm=self.llm, print_schema=True)
assert schema_by_llm
def setup_method(self, method):
self.algo_id = f"doc_to_db_test_{method.__name__}"
self.documents = lazyllm.Document(
dataset_path=self.pdf_root,
name=self.algo_id,
schema_extractor=self.schema_extractor,
)

# Test-2: set without schema, assert failed
def test_connect_sql_manager_requires_schema(self):
# 未提供 schema 时应抛出错误
with pytest.raises(AssertionError) as excinfo:
documents.connect_sql_manager(sql_manager=sql_manager, schma=None)
assert 'doc_table_schma must be given' in str(excinfo.value)

refined_schema = [
{'key': 'reading_time', 'desc': 'The date or time period when the book was read.', 'type': 'text'},
{'key': 'document_title', 'desc': 'The title of the book being reviewed.', 'type': 'text'},
{'key': 'author_name', 'desc': 'The name of the author of the book.', 'type': 'text'},
{'key': 'publication_type', 'desc': 'The type of publication (e.g., book, journal, etc.).', 'type': 'text'},
{'key': 'publisher_name', 'desc': 'The name of the publisher of the book.', 'type': 'text'},
{'key': 'publication_date', 'desc': 'The date when the book was published.', 'type': 'text'},
{'key': 'keywords', 'desc': 'Key terms or themes discussed in the book.', 'type': 'text'},
{
'key': 'content_summary',
'desc': "A brief summary of the book's main content or arguments.",
'type': 'text',
},
{'key': 'insights', 'desc': "The reader's insights on the book's content.", 'type': 'text'},
{'key': 'reflections', 'desc': "The reader's reflections on the book's content.", 'type': 'text'},
]
# Test-3: set sqlmanager, llm, with schema
documents.connect_sql_manager(
sql_manager=sql_manager,
schma=refined_schema,
self.documents.connect_sql_manager(
sql_manager=self.schema_extractor.sql_manager,
schma=None,
)
assert "doc_table_schma must be given" in str(excinfo.value)

def test_connect_sql_manager_creates_bind_and_table(self):
# 提供 schema 时应写入绑定映射并创建表
sql_manager, table_name = _connect_and_get_table(
self.documents,
self.schema_extractor,
self.algo_id,
force_refresh=True,
)

table_cls = sql_manager.get_table_orm_class(table_name)
assert table_cls is not None
column_names = {col.name for col in table_cls.__table__.columns}
assert EXPECTED_FIELDS.issubset(column_names)

def test_start_triggers_extraction_and_writes_rows(self):
# 启动流程后应抽取并写入结构化数据
sql_manager, table_name = _connect_and_get_table(
self.documents,
self.schema_extractor,
self.algo_id,
force_refresh=True,
)

self.documents.start()
self.documents.extract_db_schema(llm=self.llm, print_schema=True)

count_before = _get_count(sql_manager, table_name)
assert count_before > 0

sample_row_str = sql_manager.execute_query(f"select * from {table_name} limit 1")
print(f"sample_row: {sample_row_str}")

def test_extract_db_schema_returns_schema_info(self):
# extract_db_schema 应返回带 schema_set_id 的结果
_connect_and_get_table(
self.documents,
self.schema_extractor,
self.algo_id,
force_refresh=True,
)

self.documents.start()
schema_info = self.documents.extract_db_schema(llm=self.llm, print_schema=True)
assert schema_info is not None
assert getattr(schema_info, "schema_set_id", None)

def test_connect_same_schema_no_refresh(self):
# 相同 schema 再次绑定且 force_refresh=False 时不应刷新数据
sql_manager, table_name = _connect_and_get_table(
self.documents,
self.schema_extractor,
self.algo_id,
force_refresh=True,
)
# Test-4: check update run success (The extracted row exists in db means it definitely fits schema)
documents.update_database(llm=self.llm)
str_result = sql_manager.execute_query(f'select * from {documents._doc_to_db_processor.doc_table_name}')
print(f'str_result: {str_result}')
assert 'reading_report_p1' in str_result

self.documents.start()
self.documents.extract_db_schema(llm=self.llm, print_schema=True)
count_before = _get_count(sql_manager, table_name)

self.documents.connect_sql_manager(
sql_manager=sql_manager,
schma=ReadingReportSchema,
force_refresh=False,
)
count_after = _get_count(sql_manager, table_name)
assert count_after == count_before
Loading