-
Notifications
You must be signed in to change notification settings - Fork 364
refactor: replace DocToDbProcessor with SchemaExtractor #1026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
df6cf9f
79d5882
ad14e6d
14323da
7a26eeb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,15 +65,30 @@ class SchemaExtractor: | |
| 'map': dict, | ||
| } | ||
|
|
||
| def __init__(self, db_config: Dict[str, Any], llm: LLMBase, *, table_prefix: Optional[str] = None, | ||
| force_refresh: bool = False, extraction_mode: ExtractionMode = ExtractionMode.TEXT, | ||
| max_len: int = ONE_DOC_LENGTH_LIMIT, num_workers: int = 4): | ||
| def __init__(self, db_config: Optional[Dict[str, Any]] = None, llm: LLMBase = None, *, | ||
| table_prefix: Optional[str] = None, force_refresh: bool = False, | ||
| extraction_mode: ExtractionMode = ExtractionMode.TEXT, max_len: int = ONE_DOC_LENGTH_LIMIT, | ||
| num_workers: int = 4, sql_manager: Optional[SqlManager] = None): | ||
| if (db_config is None) == (sql_manager is None): | ||
| raise ValueError('Exactly one of db_config or sql_manager must be provided') | ||
| if not isinstance(llm, LLMBase): | ||
| raise TypeError('llm must be an instance of LLMBase') | ||
| self._llm = llm | ||
| self._table_prefix = table_prefix or self.TABLE_PREFIX | ||
| self._sql_manager = None | ||
| self._db_config = db_config | ||
| if sql_manager is not None: | ||
| self._sql_manager = sql_manager | ||
| self._db_config = { | ||
| 'db_type': sql_manager._db_type, | ||
| 'user': getattr(sql_manager, '_user', None), | ||
| 'password': getattr(sql_manager, '_password', None), | ||
| 'host': getattr(sql_manager, '_host', None), | ||
| 'port': getattr(sql_manager, '_port', None), | ||
| 'db_name': getattr(sql_manager, '_db_name', None), | ||
| 'options_str': getattr(sql_manager, '_options_str', None), | ||
| } | ||
|
Comment on lines
+81
to
+89
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Comment on lines
+81
to
+89
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Accessing private attributes of Consider adding a public method to |
||
| else: | ||
| self._db_config = db_config | ||
| self._table_cache: Dict[str, Type[_TableBase]] = {} | ||
| self._schema_registry: Dict[str, Type[BaseModel]] = {} | ||
| self._force_refresh = force_refresh | ||
|
|
@@ -179,7 +194,8 @@ def _schema_table_desc(model: Type[BaseModel]) -> str: | |
|
|
||
| @once_wrapper | ||
| def _lazy_init(self): | ||
| self._sql_manager = self._init_sql_manager(self._db_config) if self._db_config else None | ||
| if self._sql_manager is None: | ||
| self._sql_manager = self._init_sql_manager(self._db_config) if self._db_config else None | ||
| if self._sql_manager: | ||
| self._ensure_management_tables() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,8 @@ | |
| from .doc_manager import DocManager | ||
| from .doc_impl import DocImpl, StorePlaceholder, EmbedPlaceholder, BuiltinGroups, DocumentProcessor, NodeGroupType | ||
| from .doc_node import DocNode | ||
| from .doc_to_db import DocInfoSchema, DocToDbProcessor, extract_db_schema_from_files, SchemaExtractor | ||
| from .doc_to_db import SchemaExtractor | ||
| from lazyllm.tools.rag.doc_to_db.model import SchemaSetInfo, Table_ALGO_KB_SCHEMA | ||
| from .store import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY | ||
| from .store.store_base import DEFAULT_KB_ID | ||
| from .index_base import IndexBase | ||
|
|
@@ -175,7 +176,6 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal | |
| display_name=display_name, description=description, | ||
| schema_extractor=self._schema_extractor) | ||
| self._curr_group = name | ||
| self._doc_to_db_processor: DocToDbProcessor = None | ||
| self._graph_document: weakref.ref = None | ||
|
|
||
| @staticmethod | ||
|
|
@@ -230,60 +230,70 @@ def url(self): | |
| def connect_sql_manager( | ||
| self, | ||
| sql_manager: SqlManager, | ||
| schma: Optional[DocInfoSchema] = None, | ||
| schma: Optional[BaseModel] = None, #basemodel | ||
| force_refresh: bool = True, | ||
| ): | ||
| def format_schema_to_dict(schema: DocInfoSchema): | ||
| if schema is None: | ||
| return None, None | ||
| desc_dict = {ele['key']: ele['desc'] for ele in schema} | ||
| type_dict = {ele['key']: ele['type'] for ele in schema} | ||
| return desc_dict, type_dict | ||
|
|
||
| def compare_schema(old_schema: DocInfoSchema, new_schema: DocInfoSchema): | ||
| old_desc_dict, old_type_dict = format_schema_to_dict(old_schema) | ||
| new_desc_dict, new_type_dict = format_schema_to_dict(new_schema) | ||
| return old_desc_dict == new_desc_dict and old_type_dict == new_type_dict | ||
|
|
||
| def compare_schema(rows: Union[List, object, None], schma: BaseModel, extractor: SchemaExtractor): | ||
| if schma is None: | ||
| return False | ||
| sid = extractor.register_schema_set(schma) | ||
|
|
||
| if not rows: | ||
| return False | ||
| if not isinstance(rows, (list, tuple, set)): | ||
| rows = [rows] | ||
|
|
||
| return any(getattr(row, 'schema_set_id', row) == sid for row in rows) | ||
|
|
||
| # 1. Check valid arguments | ||
| if sql_manager.check_connection().status != DBStatus.SUCCESS: | ||
| raise RuntimeError(f'Failed to connect to sql manager: {sql_manager._gen_conn_url()}') | ||
| pre_doc_table_schema = None | ||
| if self._doc_to_db_processor: | ||
| pre_doc_table_schema = self._doc_to_db_processor.doc_info_schema | ||
| assert pre_doc_table_schema or schma, 'doc_table_schma must be given' | ||
|
|
||
| schema_equal = compare_schema(pre_doc_table_schema, schma) | ||
|
|
||
| rows: List = [] | ||
| if self._schema_extractor: | ||
| mgr = self._schema_extractor.sql_manager | ||
| Bind = mgr.get_table_orm_class(Table_ALGO_KB_SCHEMA['name']) | ||
| with mgr.get_session() as s: | ||
| rows = s.query(Bind).filter_by(algo_id=self._impl._algo_name).all() | ||
| # algoid + kbid (self._impl.algo_name) | ||
|
|
||
| assert rows or schma, 'doc_table_schma must be given' | ||
|
|
||
| extractor = self._schema_extractor or SchemaExtractor(sql_manager) | ||
|
||
| schema_equal = compare_schema(rows, schma, extractor) | ||
| assert ( | ||
| schema_equal or force_refresh is True | ||
| ), 'When changing doc_table_schema, force_refresh should be set to True' | ||
|
|
||
| # 2. Init handler if needed | ||
| need_init_processor = False | ||
| if self._doc_to_db_processor is None: | ||
| if self._schema_extractor is None: | ||
| need_init_processor = True | ||
| else: | ||
| # avoid reinit for the same db | ||
| if sql_manager != self._doc_to_db_processor.sql_manager: | ||
| if sql_manager != self._schema_extractor.sql_manager: | ||
| need_init_processor = True | ||
| if need_init_processor: | ||
| self._doc_to_db_processor = DocToDbProcessor(sql_manager) | ||
| # reuse the extractor instance used for schema comparison/registration | ||
| self._schema_extractor = extractor | ||
|
|
||
| # 3. Reset doc_table_schema if needed | ||
| if schma and not schema_equal: | ||
| # This api call will clear existing db table 'lazyllm_doc_elements' | ||
| self._doc_to_db_processor._reset_doc_info_schema(schma) | ||
| schema_set_id = self._schema_extractor.register_schema_set(schma) | ||
| self._schema_extractor.register_schema_set_to_kb( | ||
| algo_id = self._impl._algo_name, schema_set_id= schema_set_id, force_refresh=True) | ||
|
|
||
| def get_sql_manager(self): | ||
| if self._doc_to_db_processor is None: | ||
| if self._schema_extractor is None: | ||
| raise ValueError('Please call connect_sql_manager to init handler first') | ||
| return self._doc_to_db_processor.sql_manager | ||
| return self._schema_extractor.sql_manager | ||
|
|
||
| def extract_db_schema( | ||
| self, llm: Union[OnlineChatModule, TrainableModule], print_schema: bool = False | ||
| ) -> DocInfoSchema: | ||
| file_paths = self._list_all_files_in_dataset() | ||
| schema = extract_db_schema_from_files(file_paths, llm) | ||
| ) -> SchemaSetInfo: | ||
|
Comment on lines
305
to
+307
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| schema = self._forward('_analyze_schema_by_llm') | ||
| if print_schema: | ||
| lazyllm.LOG.info(f'Extracted Schema:\n\t{schema}\n') | ||
| return schema | ||
|
|
||
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llm参数现在是可选的,默认值为None,但是这里的检查if not isinstance(llm, LLMBase):会在llm为None时引发TypeError。这会阻止在没有 LLM 的情况下初始化SchemaExtractor,而这似乎是本次重构的一个预期用例。这个检查应该被更新以处理None的情况。