-
Notifications
You must be signed in to change notification settings - Fork 364
refactor: replace DocToDbProcessor with SchemaExtractor #1026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
df6cf9f
79d5882
ad14e6d
14323da
7a26eeb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,15 +65,30 @@ class SchemaExtractor: | |
| 'map': dict, | ||
| } | ||
|
|
||
| def __init__(self, db_config: Dict[str, Any], llm: LLMBase, *, table_prefix: Optional[str] = None, | ||
| force_refresh: bool = False, extraction_mode: ExtractionMode = ExtractionMode.TEXT, | ||
| max_len: int = ONE_DOC_LENGTH_LIMIT, num_workers: int = 4): | ||
| if not isinstance(llm, LLMBase): | ||
| def __init__(self, db_config: Optional[Dict[str, Any]] = None, llm: LLMBase = None, *, | ||
| table_prefix: Optional[str] = None, force_refresh: bool = False, | ||
| extraction_mode: ExtractionMode = ExtractionMode.TEXT, max_len: int = ONE_DOC_LENGTH_LIMIT, | ||
| num_workers: int = 4, sql_manager: Optional[SqlManager] = None): | ||
| if (db_config is None) == (sql_manager is None): | ||
| raise ValueError('Exactly one of db_config or sql_manager must be provided') | ||
| if llm is None or not isinstance(llm, LLMBase): | ||
| raise TypeError('llm must be an instance of LLMBase') | ||
| self._llm = llm | ||
| self._table_prefix = table_prefix or self.TABLE_PREFIX | ||
| self._sql_manager = None | ||
| self._db_config = db_config | ||
| if sql_manager is not None: | ||
| self._sql_manager = sql_manager | ||
| self._db_config = { | ||
| 'db_type': sql_manager._db_type, | ||
| 'user': getattr(sql_manager, '_user', None), | ||
| 'password': getattr(sql_manager, '_password', None), | ||
| 'host': getattr(sql_manager, '_host', None), | ||
| 'port': getattr(sql_manager, '_port', None), | ||
| 'db_name': getattr(sql_manager, '_db_name', None), | ||
| 'options_str': getattr(sql_manager, '_options_str', None), | ||
| } | ||
|
Comment on lines
+81
to
+89
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Accessing private attributes of Consider adding a public method to |
||
| else: | ||
| self._db_config = db_config | ||
| self._table_cache: Dict[str, Type[_TableBase]] = {} | ||
| self._schema_registry: Dict[str, Type[BaseModel]] = {} | ||
| self._force_refresh = force_refresh | ||
|
|
@@ -179,7 +194,8 @@ def _schema_table_desc(model: Type[BaseModel]) -> str: | |
|
|
||
| @once_wrapper | ||
| def _lazy_init(self): | ||
| self._sql_manager = self._init_sql_manager(self._db_config) if self._db_config else None | ||
| if self._sql_manager is None: | ||
| self._sql_manager = self._init_sql_manager(self._db_config) if self._db_config else None | ||
| if self._sql_manager: | ||
| self._ensure_management_tables() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -12,11 +12,12 @@ | |||||
| from .doc_manager import DocManager | ||||||
| from .doc_impl import DocImpl, StorePlaceholder, EmbedPlaceholder, BuiltinGroups, DocumentProcessor, NodeGroupType | ||||||
| from .doc_node import DocNode | ||||||
| from .doc_to_db import DocInfoSchema, DocToDbProcessor, extract_db_schema_from_files, SchemaExtractor | ||||||
| from .doc_to_db import SchemaExtractor | ||||||
| from lazyllm.tools.rag.doc_to_db.model import SchemaSetInfo, Table_ALGO_KB_SCHEMA | ||||||
| from .store import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY | ||||||
| from .store.store_base import DEFAULT_KB_ID | ||||||
| from .index_base import IndexBase | ||||||
| from .utils import DocListManager, ensure_call_endpoint | ||||||
| from .utils import DocListManager, ensure_call_endpoint, _get_default_db_config | ||||||
| from .global_metadata import GlobalMetadataDesc as DocField | ||||||
| from .web import DocWebModule | ||||||
| import copy | ||||||
|
|
@@ -147,6 +148,17 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal | |||||
| 'Only map store is supported for Document with temp-files') | ||||||
|
|
||||||
| name = name or DocListManager.DEFAULT_GROUP_NAME | ||||||
| if schema_extractor is not None and not isinstance(schema_extractor, SchemaExtractor): | ||||||
| if isinstance(schema_extractor, LLMBase): | ||||||
| metadata_store_config = None | ||||||
| if isinstance(store_conf, dict): | ||||||
| metadata_store_config = store_conf.get('metadata_store') | ||||||
| metadata_store_config = metadata_store_config or _get_default_db_config( | ||||||
| db_name=f'{name}_metadata' | ||||||
| ) | ||||||
| schema_extractor = SchemaExtractor(db_config=metadata_store_config, llm=schema_extractor) | ||||||
| else: | ||||||
| raise ValueError(f'Invalid type for schema extractor: {type(schema_extractor)}') | ||||||
| self._schema_extractor: SchemaExtractor = schema_extractor | ||||||
|
|
||||||
| if isinstance(manager, Document._Manager): | ||||||
|
|
@@ -175,7 +187,6 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal | |||||
| display_name=display_name, description=description, | ||||||
| schema_extractor=self._schema_extractor) | ||||||
| self._curr_group = name | ||||||
| self._doc_to_db_processor: DocToDbProcessor = None | ||||||
| self._graph_document: weakref.ref = None | ||||||
|
|
||||||
| @staticmethod | ||||||
|
|
@@ -227,63 +238,75 @@ def url(self): | |||||
| assert isinstance(self._manager._kbs, ServerModule), 'Document is not a service, please set `manager` to `True`' | ||||||
| return self._manager._kbs._url | ||||||
|
|
||||||
| def _resolve_schema_extractor(self, sql_manager: SqlManager) -> SchemaExtractor: | ||||||
| if self._schema_extractor is None: | ||||||
| raise ValueError('schema_extractor is required to connect sql manager') | ||||||
| if not isinstance(self._schema_extractor, SchemaExtractor): | ||||||
| raise ValueError(f'Invalid type for schema extractor: {type(self._schema_extractor)}') | ||||||
| if sql_manager == self._schema_extractor.sql_manager: | ||||||
| return self._schema_extractor | ||||||
| return SchemaExtractor(sql_manager=sql_manager, llm=self._schema_extractor._llm) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Accessing the private attribute For example, you could add
Suggested change
|
||||||
|
|
||||||
| @staticmethod | ||||||
| def _compare_schema_rows(rows: Union[List, object, None], schma: BaseModel, | ||||||
| extractor: SchemaExtractor) -> bool: | ||||||
| if schma is None: | ||||||
| return False | ||||||
| sid = extractor.register_schema_set(schma) | ||||||
| if not rows: | ||||||
| return False | ||||||
| if not isinstance(rows, (list, tuple, set)): | ||||||
| rows = [rows] | ||||||
| return any(getattr(row, 'schema_set_id', row) == sid for row in rows) | ||||||
|
|
||||||
| def _get_schema_bind_rows(self, extractor: SchemaExtractor) -> List: | ||||||
| mgr = extractor.sql_manager | ||||||
| bind_cls = mgr.get_table_orm_class(Table_ALGO_KB_SCHEMA['name']) | ||||||
| if bind_cls is None: | ||||||
| return [] | ||||||
| with mgr.get_session() as s: | ||||||
| return s.query(bind_cls).filter_by(algo_id=self._impl.algo_name).all() | ||||||
|
|
||||||
| def connect_sql_manager( | ||||||
| self, | ||||||
| sql_manager: SqlManager, | ||||||
| schma: Optional[DocInfoSchema] = None, | ||||||
| self, sql_manager: SqlManager, schma: Optional[BaseModel] = None, | ||||||
| force_refresh: bool = True, | ||||||
| ): | ||||||
| def format_schema_to_dict(schema: DocInfoSchema): | ||||||
| if schema is None: | ||||||
| return None, None | ||||||
| desc_dict = {ele['key']: ele['desc'] for ele in schema} | ||||||
| type_dict = {ele['key']: ele['type'] for ele in schema} | ||||||
| return desc_dict, type_dict | ||||||
|
|
||||||
| def compare_schema(old_schema: DocInfoSchema, new_schema: DocInfoSchema): | ||||||
| old_desc_dict, old_type_dict = format_schema_to_dict(old_schema) | ||||||
| new_desc_dict, new_type_dict = format_schema_to_dict(new_schema) | ||||||
| return old_desc_dict == new_desc_dict and old_type_dict == new_type_dict | ||||||
|
|
||||||
| # 1. Check valid arguments | ||||||
| if sql_manager.check_connection().status != DBStatus.SUCCESS: | ||||||
| raise RuntimeError(f'Failed to connect to sql manager: {sql_manager._gen_conn_url()}') | ||||||
| pre_doc_table_schema = None | ||||||
| if self._doc_to_db_processor: | ||||||
| pre_doc_table_schema = self._doc_to_db_processor.doc_info_schema | ||||||
| assert pre_doc_table_schema or schma, 'doc_table_schma must be given' | ||||||
|
|
||||||
| schema_equal = compare_schema(pre_doc_table_schema, schma) | ||||||
| extractor = self._resolve_schema_extractor(sql_manager) | ||||||
| rows = self._get_schema_bind_rows(extractor) | ||||||
| assert rows or schma, 'doc_table_schma must be given' | ||||||
|
|
||||||
| schema_equal = self._compare_schema_rows(rows, schma, extractor) | ||||||
| assert ( | ||||||
| schema_equal or force_refresh is True | ||||||
| ), 'When changing doc_table_schema, force_refresh should be set to True' | ||||||
|
|
||||||
| # 2. Init handler if needed | ||||||
| need_init_processor = False | ||||||
| if self._doc_to_db_processor is None: | ||||||
| need_init_processor = True | ||||||
| else: | ||||||
| # avoid reinit for the same db | ||||||
| if sql_manager != self._doc_to_db_processor.sql_manager: | ||||||
| need_init_processor = True | ||||||
| if need_init_processor: | ||||||
| self._doc_to_db_processor = DocToDbProcessor(sql_manager) | ||||||
| if extractor is not self._schema_extractor: | ||||||
| # reuse the extractor instance used for schema comparison/registration | ||||||
| self._schema_extractor = extractor | ||||||
| self._impl._schema_extractor = extractor | ||||||
|
|
||||||
| # 3. Reset doc_table_schema if needed | ||||||
| if schma and not schema_equal: | ||||||
| # This api call will clear existing db table 'lazyllm_doc_elements' | ||||||
| self._doc_to_db_processor._reset_doc_info_schema(schma) | ||||||
| schema_set_id = self._schema_extractor.register_schema_set(schma) | ||||||
| self._schema_extractor.register_schema_set_to_kb( | ||||||
| algo_id=self._impl._algo_name, schema_set_id=schema_set_id, force_refresh=True) | ||||||
|
|
||||||
| def get_sql_manager(self): | ||||||
| if self._doc_to_db_processor is None: | ||||||
| if self._schema_extractor is None: | ||||||
| raise ValueError('Please call connect_sql_manager to init handler first') | ||||||
| return self._doc_to_db_processor.sql_manager | ||||||
| return self._schema_extractor.sql_manager | ||||||
|
|
||||||
| def extract_db_schema( | ||||||
| self, llm: Union[OnlineChatModule, TrainableModule], print_schema: bool = False | ||||||
| ) -> DocInfoSchema: | ||||||
| file_paths = self._list_all_files_in_dataset() | ||||||
| schema = extract_db_schema_from_files(file_paths, llm) | ||||||
| ) -> SchemaSetInfo: | ||||||
|
Comment on lines
305
to
+307
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
|
|
||||||
| schema = self._forward('_analyze_schema_by_llm') | ||||||
| if print_schema: | ||||||
| lazyllm.LOG.info(f'Extracted Schema:\n\t{schema}\n') | ||||||
| return schema | ||||||
|
|
||||||
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接访问
sql_manager的私有成员(例如_db_type,_user)破坏了封装性,并使SchemaExtractor与SqlManager的内部实现紧密耦合。这将使得未来对SqlManager的修改变得困难且容易出错。更好的做法是让SqlManager提供一个公共方法(例如get_config())来暴露其配置细节。