99from loguru import logger
1010import pandas as pd
1111from sqlalchemy import inspect , text , create_engine
12- from agentscope .message import Msg
13- from agentscope .model import DashScopeChatModel
14- from agentscope .formatter import DashScopeChatFormatter
15-
1612from alias .agent .agents .data_source ._typing import SourceType
17- from alias .agent .agents .ds_agent_utils .ds_config import (
18- MODEL_CONFIG_NAME ,
19- VL_MODEL_NAME ,
20- )
2113from alias .agent .agents .ds_agent_utils import (
2214 get_prompt_from_file ,
23- model_call_with_retry ,
15+ )
16+ from alias .agent .utils .unified_model_call_interface import (
17+ UnifiedModelCallInterface ,
2418)
2519
2620
@@ -29,112 +23,84 @@ class BaseDataProfiler(ABC):
2923 sources like csv, excel, db, etc.
3024 """
3125
32- def __init__ (self , api_key : str , path : str , source_type : SourceType ):
26+ _PROFILE_PROMPT_BASE_PATH = os .path .join (
27+ os .path .dirname (__file__ ),
28+ "built_in_prompt" ,
29+ )
30+
31+ def __init__ (
32+ self ,
33+ path : str ,
34+ source_type : SourceType ,
35+ model_interface : UnifiedModelCallInterface ,
36+ ):
3337 """Initialize the data profiler with API key, data path and type.
3438
3539 Args:
3640 api_key: Authentication key for LLM service
3741 path: Path to the data source file or connection string
3842 source_type: Enum indicating the type of data source
3943 """
40- self .api_key = api_key
4144 self .path = path
45+ self .file_name = os .path .basename (path )
4246 self .source_type = source_type
43- (
44- self .prompt ,
45- self .model ,
46- self .formatter ,
47- ) = BaseDataProfiler ._load_prompt_and_model (source_type )
48-
49- async def generate_profile (self ) -> Dict [str , Any ]:
50- """Generate a complete data profile
51- by reading data, generating content,
52- calling the LLM, and wrapping the response.
47+ self .model_interface = model_interface
5348
54- Returns:
55- Dictionary containing the complete data profile
56- """
57- try :
58- self .data = await self ._read_data ()
59- content = self ._generate_content (self .prompt , self .data )
60- # content = self.prompt.format(data=self.data)
61- res = await self ._call_model (content )
62- self .profile = self ._wrap_data_response (res )
63- except Exception as e :
64- logger .warning (f"Error generating profile: { e } " )
65- self .profile = {}
66- return self .profile
67-
68- @staticmethod
69- def _load_prompt_and_model (source_type : Any = None , api_key : str = None ):
70- """Load the appropriate prompt template, model name and LLM API
71- class
72- based on the source type.
73-
74- Args:
75- source_type: Type of data source (CSV, EXCEL, IMAGE, etc.)
76-
77- Returns:
78- Tuple of (prompt_template, model, formatter)
79-
80- Raises:
81- ValueError: If source_type is unsupported
82- """
83- PROFILE_PROMPT_BASE_PATH = os .path .join (
84- os .path .dirname (__file__ ),
85- "built_in_prompt" ,
86- )
87- source_types_2_prompts = {
49+ self .source_types_2_prompts = {
8850 SourceType .CSV : "_profile_csv_prompt.md" ,
8951 SourceType .EXCEL : "_profile_xlsx_prompt.md" ,
9052 SourceType .IMAGE : "_profile_image_prompt.md" ,
9153 SourceType .RELATIONAL_DB : "_profile_relationdb_prompt.md" ,
9254 "IRREGULAR" : "_profile_irregular_xlsx_prompt.md" ,
9355 }
94-
95- # For irregular excel files, load the prompt
96- if source_type not in source_types_2_prompts :
56+ if source_type not in self .source_types_2_prompts :
9757 raise ValueError (f"Unsupported source type: { source_type } " )
58+ self .prompt = self ._load_prompt (source_type )
9859
99- source_types_2_models = {
100- SourceType .CSV : MODEL_CONFIG_NAME ,
101- SourceType .EXCEL : MODEL_CONFIG_NAME ,
102- SourceType .IMAGE : VL_MODEL_NAME ,
103- SourceType .RELATIONAL_DB : MODEL_CONFIG_NAME ,
104- "IRREGULAR" : MODEL_CONFIG_NAME ,
105- }
60+ base_model_name = self .model_interface .get_base_model_name ()
61+ vl_model_name = self .model_interface .get_vl_model_name ()
10662
107- models_2_model_and_formatter = {
108- MODEL_CONFIG_NAME : [
109- DashScopeChatModel (
110- api_key = api_key ,
111- model_name = "qwen3-max-preview" ,
112- stream = True ,
113- ),
114- DashScopeChatFormatter (),
115- ],
116- VL_MODEL_NAME : [
117- DashScopeChatModel (
118- api_key = api_key ,
119- model_name = "qwen-vl-max-latest" ,
120- stream = True ,
121- ),
122- DashScopeChatFormatter (),
123- ],
63+ self .source_types_2_models = {
64+ SourceType .CSV : base_model_name ,
65+ SourceType .EXCEL : base_model_name ,
66+ SourceType .IMAGE : vl_model_name ,
67+ SourceType .RELATIONAL_DB : base_model_name ,
12468 }
69+ self .model_name = self .source_types_2_models [source_type ]
12570
126- prompt_file_name = source_types_2_prompts [source_type ]
71+ def _load_prompt (self , source_type : Any = None ):
72+ prompt_file_name = self .source_types_2_prompts [source_type ]
12773 prompt = get_prompt_from_file (
12874 os .path .join (
129- PROFILE_PROMPT_BASE_PATH ,
75+ self . _PROFILE_PROMPT_BASE_PATH ,
13076 prompt_file_name ,
13177 ),
13278 False ,
13379 )
80+ return prompt
81+
82+ async def generate_profile (self ) -> Dict [str , Any ]:
83+ """Generate a complete data profile
84+ by reading data, generating content,
85+ calling the LLM, and wrapping the response.
13486
135- model_name = source_types_2_models [source_type ]
136- model , formatter = models_2_model_and_formatter [model_name ]
137- return prompt , model , formatter
87+ Returns:
88+ Dictionary containing the complete data profile
89+ """
90+ try :
91+ self .data = await self ._read_data ()
92+ # different source types have different data building methods
93+ content = self ._build_content_with_prompt_and_data (
94+ self .prompt ,
95+ self .data ,
96+ )
97+ # content = self.prompt.format(data=self.data)
98+ res = await self ._call_model (content )
99+ self .profile = self ._wrap_data_response (res )
100+ except Exception as e :
101+ logger .warning (f"Error generating profile: { e } " )
102+ self .profile = {}
103+ return self .profile
138104
139105 @staticmethod
140106 def tool_clean_json (raw_response : str ):
@@ -157,8 +123,12 @@ def tool_clean_json(raw_response: str):
157123 return json .loads (cleaned_response )
158124
159125 @abstractmethod
160- def _generate_content (self , prompt : str , data : Any ) -> str :
161- """Abstract method to generate content for LLM based on prompt
126+ def _build_content_with_prompt_and_data (
127+ self ,
128+ prompt : str ,
129+ data : Any ,
130+ ) -> str :
131+ """Abstract method to build content for LLM based on prompt
162132 and data.
163133
164134 This method should be implemented by subclasses to format
@@ -188,46 +158,11 @@ async def _read_data(self):
188158 async def _call_model (
189159 self ,
190160 content : Any ,
191- model : DashScopeChatModel = None ,
192- formatter : DashScopeChatFormatter = None ,
193161 ) -> Dict [str , Any ]:
194- """Uses LLM to generate profile based on the content.
195- Makes multiple attempts to call the LLM service,
196- with retry logic in case of failures.
197- Handles both regular text and multimodal inputs.
198-
199- Args:
200- content: Content to send to the LLM (text or multimodal)
201- model: ChatModel to use
202- formatter: ChatFormatter to use for LLM input
203-
204- Returns:
205- Dictionary response parsed from LLM output
206-
207- Raises:
208- Exception: If all retry attempts fail
209- """
210- system_prompt = (
211- "You are a helpful AI assistant for database management."
162+ response = await self .model_interface .unified_model_call_interface (
163+ model_name = self .model_name ,
164+ user_content = content ,
212165 )
213- msgs = [
214- Msg ("system" , system_prompt , "system" ),
215- Msg ("user" , content , "user" ),
216- ]
217- if model is None :
218- raw_response = await model_call_with_retry (
219- model = self .model ,
220- formatter = self .formatter ,
221- msgs = msgs ,
222- )
223- else :
224- raw_response = await model_call_with_retry (
225- model = model ,
226- formatter = formatter ,
227- msgs = msgs ,
228- )
229- response = raw_response .content [0 ]["text" ]
230- # Clean and parse the JSON response from the LLM
231166 response = BaseDataProfiler .tool_clean_json (response )
232167 return response
233168
@@ -325,7 +260,11 @@ def _extract_schema_from_table(df: pd.DataFrame, df_name: str) -> dict:
325260 }
326261 return table_schema
327262
328- def _generate_content (self , prompt : str , data : Any ) -> str :
263+ def _build_content_with_prompt_and_data (
264+ self ,
265+ prompt : str ,
266+ data : Any ,
267+ ) -> str :
329268 """Format the prompt with data for structured data sources.
330269
331270 Args:
@@ -383,10 +322,6 @@ def _wrap_data_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
383322
384323
385324class ExcelProfiler (StructuredDataProfiler ):
386- def __init__ (self , api_key : str , path : str , source_type : SourceType ):
387- super ().__init__ (api_key , path , source_type )
388- self .file_name = os .path .basename (self .path )
389-
390325 async def _extract_irregular_table (
391326 self ,
392327 path : str ,
@@ -405,13 +340,9 @@ async def _extract_irregular_table(
405340 Returns:
406341 Schema dictionary for the irregular table structure
407342 """
408- prompt , model , formatter = self ._load_prompt_and_model ("IRREGULAR" )
343+ prompt = self ._load_prompt ("IRREGULAR" )
409344 content = prompt .format (raw_snippet_data = raw_data_snippet )
410- res = await self ._call_model (
411- content = content ,
412- model = model ,
413- formatter = formatter ,
414- )
345+ res = await self ._call_model (content = content )
415346
416347 if "is_extractable_table" in res and res ["is_extractable_table" ]:
417348 logger .debug (res ["reasoning" ])
@@ -634,10 +565,6 @@ async def _read_data(self):
634565
635566
636567class CsvProfiler (ExcelProfiler ):
637- def __init__ (self , api_key : str , path : str , source_type : SourceType ):
638- super ().__init__ (api_key , path , source_type )
639- self .file_name = os .path .basename (path )
640-
641568 async def _read_data (self ):
642569 """Handles schema extraction for CSV as single-table sources.
643570
@@ -664,10 +591,6 @@ async def _read_data(self):
664591class ImageProfiler (BaseDataProfiler ):
665592 """Profiler for image data sources that uses multimodal LLMs."""
666593
667- def __init__ (self , api_key : str , path : str , source_type : SourceType ):
668- super ().__init__ (api_key , path , source_type )
669- self .file_name = os .path .basename (self .path )
670-
671594 async def _read_data (self ):
672595 """
673596 For images, this simply returns the path since the LLM API
@@ -678,8 +601,8 @@ async def _read_data(self):
678601 """
679602 return self .path
680603
681- def _generate_content (self , prompt , data ):
682- """Generate multimodal content for image analysis.
604+ def _build_content_with_prompt_and_data (self , prompt , data ):
605+ """build multimodal content for image analysis.
683606
684607 Creates content in the format required by multimodal LLM APIs
685608 with both image and text components.
@@ -731,7 +654,7 @@ class DataProfilerFactory:
731654
732655 @staticmethod
733656 def get_profiler (
734- api_key : str ,
657+ model_interface : UnifiedModelCallInterface ,
735658 path : str ,
736659 source_type : SourceType ,
737660 ) -> BaseDataProfiler :
@@ -750,12 +673,28 @@ def get_profiler(
750673 ValueError: If the source_type is unsupported
751674 """
752675 if source_type == SourceType .IMAGE :
753- return ImageProfiler (api_key , path , source_type )
676+ return ImageProfiler (
677+ path = path ,
678+ source_type = source_type ,
679+ model_interface = model_interface ,
680+ )
754681 elif source_type == SourceType .CSV :
755- return CsvProfiler (api_key , path , source_type )
682+ return CsvProfiler (
683+ path = path ,
684+ source_type = source_type ,
685+ model_interface = model_interface ,
686+ )
756687 elif source_type == SourceType .EXCEL :
757- return ExcelProfiler (api_key , path , source_type )
688+ return ExcelProfiler (
689+ path = path ,
690+ source_type = source_type ,
691+ model_interface = model_interface ,
692+ )
758693 elif source_type == SourceType .RELATIONAL_DB :
759- return RelationalDatabaseProfiler (api_key , path , source_type )
694+ return RelationalDatabaseProfiler (
695+ path = path ,
696+ source_type = source_type ,
697+ model_interface = model_interface ,
698+ )
760699 else :
761700 raise ValueError (f"Unsupported source type: { source_type } " )
0 commit comments