Skip to content

Commit 35d5a8f

Browse files
committed
fix(data profile): add unified model interface(init at run.py) for data profile
1 parent b6ff393 commit 35d5a8f

File tree

6 files changed

+253
-178
lines changed

6 files changed

+253
-178
lines changed

alias/src/alias/agent/agents/data_source/_data_profiler_factory.py

Lines changed: 92 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,12 @@
99
from loguru import logger
1010
import pandas as pd
1111
from sqlalchemy import inspect, text, create_engine
12-
from agentscope.message import Msg
13-
from agentscope.model import DashScopeChatModel
14-
from agentscope.formatter import DashScopeChatFormatter
15-
1612
from alias.agent.agents.data_source._typing import SourceType
17-
from alias.agent.agents.ds_agent_utils.ds_config import (
18-
MODEL_CONFIG_NAME,
19-
VL_MODEL_NAME,
20-
)
2113
from alias.agent.agents.ds_agent_utils import (
2214
get_prompt_from_file,
23-
model_call_with_retry,
15+
)
16+
from alias.agent.utils.unified_model_call_interface import (
17+
UnifiedModelCallInterface,
2418
)
2519

2620

@@ -29,112 +23,84 @@ class BaseDataProfiler(ABC):
2923
sources like csv, excel, db, etc.
3024
"""
3125

32-
def __init__(self, api_key: str, path: str, source_type: SourceType):
26+
_PROFILE_PROMPT_BASE_PATH = os.path.join(
27+
os.path.dirname(__file__),
28+
"built_in_prompt",
29+
)
30+
31+
def __init__(
32+
self,
33+
path: str,
34+
source_type: SourceType,
35+
model_interface: UnifiedModelCallInterface,
36+
):
3337
"""Initialize the data profiler with API key, data path and type.
3438
3539
Args:
3640
api_key: Authentication key for LLM service
3741
path: Path to the data source file or connection string
3842
source_type: Enum indicating the type of data source
3943
"""
40-
self.api_key = api_key
4144
self.path = path
45+
self.file_name = os.path.basename(path)
4246
self.source_type = source_type
43-
(
44-
self.prompt,
45-
self.model,
46-
self.formatter,
47-
) = BaseDataProfiler._load_prompt_and_model(source_type)
48-
49-
async def generate_profile(self) -> Dict[str, Any]:
50-
"""Generate a complete data profile
51-
by reading data, generating content,
52-
calling the LLM, and wrapping the response.
47+
self.model_interface = model_interface
5348

54-
Returns:
55-
Dictionary containing the complete data profile
56-
"""
57-
try:
58-
self.data = await self._read_data()
59-
content = self._generate_content(self.prompt, self.data)
60-
# content = self.prompt.format(data=self.data)
61-
res = await self._call_model(content)
62-
self.profile = self._wrap_data_response(res)
63-
except Exception as e:
64-
logger.warning(f"Error generating profile: {e}")
65-
self.profile = {}
66-
return self.profile
67-
68-
@staticmethod
69-
def _load_prompt_and_model(source_type: Any = None, api_key: str = None):
70-
"""Load the appropriate prompt template, model name and LLM API
71-
class
72-
based on the source type.
73-
74-
Args:
75-
source_type: Type of data source (CSV, EXCEL, IMAGE, etc.)
76-
77-
Returns:
78-
Tuple of (prompt_template, model, formatter)
79-
80-
Raises:
81-
ValueError: If source_type is unsupported
82-
"""
83-
PROFILE_PROMPT_BASE_PATH = os.path.join(
84-
os.path.dirname(__file__),
85-
"built_in_prompt",
86-
)
87-
source_types_2_prompts = {
49+
self.source_types_2_prompts = {
8850
SourceType.CSV: "_profile_csv_prompt.md",
8951
SourceType.EXCEL: "_profile_xlsx_prompt.md",
9052
SourceType.IMAGE: "_profile_image_prompt.md",
9153
SourceType.RELATIONAL_DB: "_profile_relationdb_prompt.md",
9254
"IRREGULAR": "_profile_irregular_xlsx_prompt.md",
9355
}
94-
95-
# For irregular excel files, load the prompt
96-
if source_type not in source_types_2_prompts:
56+
if source_type not in self.source_types_2_prompts:
9757
raise ValueError(f"Unsupported source type: {source_type}")
58+
self.prompt = self._load_prompt(source_type)
9859

99-
source_types_2_models = {
100-
SourceType.CSV: MODEL_CONFIG_NAME,
101-
SourceType.EXCEL: MODEL_CONFIG_NAME,
102-
SourceType.IMAGE: VL_MODEL_NAME,
103-
SourceType.RELATIONAL_DB: MODEL_CONFIG_NAME,
104-
"IRREGULAR": MODEL_CONFIG_NAME,
105-
}
60+
base_model_name = self.model_interface.get_base_model_name()
61+
vl_model_name = self.model_interface.get_vl_model_name()
10662

107-
models_2_model_and_formatter = {
108-
MODEL_CONFIG_NAME: [
109-
DashScopeChatModel(
110-
api_key=api_key,
111-
model_name="qwen3-max-preview",
112-
stream=True,
113-
),
114-
DashScopeChatFormatter(),
115-
],
116-
VL_MODEL_NAME: [
117-
DashScopeChatModel(
118-
api_key=api_key,
119-
model_name="qwen-vl-max-latest",
120-
stream=True,
121-
),
122-
DashScopeChatFormatter(),
123-
],
63+
self.source_types_2_models = {
64+
SourceType.CSV: base_model_name,
65+
SourceType.EXCEL: base_model_name,
66+
SourceType.IMAGE: vl_model_name,
67+
SourceType.RELATIONAL_DB: base_model_name,
12468
}
69+
self.model_name = self.source_types_2_models[source_type]
12570

126-
prompt_file_name = source_types_2_prompts[source_type]
71+
def _load_prompt(self, source_type: Any = None):
72+
prompt_file_name = self.source_types_2_prompts[source_type]
12773
prompt = get_prompt_from_file(
12874
os.path.join(
129-
PROFILE_PROMPT_BASE_PATH,
75+
self._PROFILE_PROMPT_BASE_PATH,
13076
prompt_file_name,
13177
),
13278
False,
13379
)
80+
return prompt
81+
82+
async def generate_profile(self) -> Dict[str, Any]:
83+
"""Generate a complete data profile
84+
by reading data, generating content,
85+
calling the LLM, and wrapping the response.
13486
135-
model_name = source_types_2_models[source_type]
136-
model, formatter = models_2_model_and_formatter[model_name]
137-
return prompt, model, formatter
87+
Returns:
88+
Dictionary containing the complete data profile
89+
"""
90+
try:
91+
self.data = await self._read_data()
92+
# different source types have different data building methods
93+
content = self._build_content_with_prompt_and_data(
94+
self.prompt,
95+
self.data,
96+
)
97+
# content = self.prompt.format(data=self.data)
98+
res = await self._call_model(content)
99+
self.profile = self._wrap_data_response(res)
100+
except Exception as e:
101+
logger.warning(f"Error generating profile: {e}")
102+
self.profile = {}
103+
return self.profile
138104

139105
@staticmethod
140106
def tool_clean_json(raw_response: str):
@@ -157,8 +123,12 @@ def tool_clean_json(raw_response: str):
157123
return json.loads(cleaned_response)
158124

159125
@abstractmethod
160-
def _generate_content(self, prompt: str, data: Any) -> str:
161-
"""Abstract method to generate content for LLM based on prompt
126+
def _build_content_with_prompt_and_data(
127+
self,
128+
prompt: str,
129+
data: Any,
130+
) -> str:
131+
"""Abstract method to build content for LLM based on prompt
162132
and data.
163133
164134
This method should be implemented by subclasses to format
@@ -188,46 +158,11 @@ async def _read_data(self):
188158
async def _call_model(
189159
self,
190160
content: Any,
191-
model: DashScopeChatModel = None,
192-
formatter: DashScopeChatFormatter = None,
193161
) -> Dict[str, Any]:
194-
"""Uses LLM to generate profile based on the content.
195-
Makes multiple attempts to call the LLM service,
196-
with retry logic in case of failures.
197-
Handles both regular text and multimodal inputs.
198-
199-
Args:
200-
content: Content to send to the LLM (text or multimodal)
201-
model: ChatModel to use
202-
formatter: ChatFormatter to use for LLM input
203-
204-
Returns:
205-
Dictionary response parsed from LLM output
206-
207-
Raises:
208-
Exception: If all retry attempts fail
209-
"""
210-
system_prompt = (
211-
"You are a helpful AI assistant for database management."
162+
response = await self.model_interface.unified_model_call_interface(
163+
model_name=self.model_name,
164+
user_content=content,
212165
)
213-
msgs = [
214-
Msg("system", system_prompt, "system"),
215-
Msg("user", content, "user"),
216-
]
217-
if model is None:
218-
raw_response = await model_call_with_retry(
219-
model=self.model,
220-
formatter=self.formatter,
221-
msgs=msgs,
222-
)
223-
else:
224-
raw_response = await model_call_with_retry(
225-
model=model,
226-
formatter=formatter,
227-
msgs=msgs,
228-
)
229-
response = raw_response.content[0]["text"]
230-
# Clean and parse the JSON response from the LLM
231166
response = BaseDataProfiler.tool_clean_json(response)
232167
return response
233168

@@ -325,7 +260,11 @@ def _extract_schema_from_table(df: pd.DataFrame, df_name: str) -> dict:
325260
}
326261
return table_schema
327262

328-
def _generate_content(self, prompt: str, data: Any) -> str:
263+
def _build_content_with_prompt_and_data(
264+
self,
265+
prompt: str,
266+
data: Any,
267+
) -> str:
329268
"""Format the prompt with data for structured data sources.
330269
331270
Args:
@@ -383,10 +322,6 @@ def _wrap_data_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
383322

384323

385324
class ExcelProfiler(StructuredDataProfiler):
386-
def __init__(self, api_key: str, path: str, source_type: SourceType):
387-
super().__init__(api_key, path, source_type)
388-
self.file_name = os.path.basename(self.path)
389-
390325
async def _extract_irregular_table(
391326
self,
392327
path: str,
@@ -405,13 +340,9 @@ async def _extract_irregular_table(
405340
Returns:
406341
Schema dictionary for the irregular table structure
407342
"""
408-
prompt, model, formatter = self._load_prompt_and_model("IRREGULAR")
343+
prompt = self._load_prompt("IRREGULAR")
409344
content = prompt.format(raw_snippet_data=raw_data_snippet)
410-
res = await self._call_model(
411-
content=content,
412-
model=model,
413-
formatter=formatter,
414-
)
345+
res = await self._call_model(content=content)
415346

416347
if "is_extractable_table" in res and res["is_extractable_table"]:
417348
logger.debug(res["reasoning"])
@@ -634,10 +565,6 @@ async def _read_data(self):
634565

635566

636567
class CsvProfiler(ExcelProfiler):
637-
def __init__(self, api_key: str, path: str, source_type: SourceType):
638-
super().__init__(api_key, path, source_type)
639-
self.file_name = os.path.basename(path)
640-
641568
async def _read_data(self):
642569
"""Handles schema extraction for CSV as single-table sources.
643570
@@ -664,10 +591,6 @@ async def _read_data(self):
664591
class ImageProfiler(BaseDataProfiler):
665592
"""Profiler for image data sources that uses multimodal LLMs."""
666593

667-
def __init__(self, api_key: str, path: str, source_type: SourceType):
668-
super().__init__(api_key, path, source_type)
669-
self.file_name = os.path.basename(self.path)
670-
671594
async def _read_data(self):
672595
"""
673596
For images, this simply returns the path since the LLM API
@@ -678,8 +601,8 @@ async def _read_data(self):
678601
"""
679602
return self.path
680603

681-
def _generate_content(self, prompt, data):
682-
"""Generate multimodal content for image analysis.
604+
def _build_content_with_prompt_and_data(self, prompt, data):
605+
"""build multimodal content for image analysis.
683606
684607
Creates content in the format required by multimodal LLM APIs
685608
with both image and text components.
@@ -731,7 +654,7 @@ class DataProfilerFactory:
731654

732655
@staticmethod
733656
def get_profiler(
734-
api_key: str,
657+
model_interface: UnifiedModelCallInterface,
735658
path: str,
736659
source_type: SourceType,
737660
) -> BaseDataProfiler:
@@ -750,12 +673,28 @@ def get_profiler(
750673
ValueError: If the source_type is unsupported
751674
"""
752675
if source_type == SourceType.IMAGE:
753-
return ImageProfiler(api_key, path, source_type)
676+
return ImageProfiler(
677+
path=path,
678+
source_type=source_type,
679+
model_interface=model_interface,
680+
)
754681
elif source_type == SourceType.CSV:
755-
return CsvProfiler(api_key, path, source_type)
682+
return CsvProfiler(
683+
path=path,
684+
source_type=source_type,
685+
model_interface=model_interface,
686+
)
756687
elif source_type == SourceType.EXCEL:
757-
return ExcelProfiler(api_key, path, source_type)
688+
return ExcelProfiler(
689+
path=path,
690+
source_type=source_type,
691+
model_interface=model_interface,
692+
)
758693
elif source_type == SourceType.RELATIONAL_DB:
759-
return RelationalDatabaseProfiler(api_key, path, source_type)
694+
return RelationalDatabaseProfiler(
695+
path=path,
696+
source_type=source_type,
697+
model_interface=model_interface,
698+
)
760699
else:
761700
raise ValueError(f"Unsupported source type: {source_type}")

alias/src/alias/agent/agents/data_source/data_profile.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
get_workspace_file,
1616
)
1717
from alias.runtime.alias_sandbox.alias_sandbox import AliasSandbox
18+
from alias.agent.utils.unified_model_call_interface import (
19+
UnifiedModelCallInterface,
20+
)
1821

1922

2023
def _get_binary_buffer(
@@ -71,6 +74,7 @@ async def data_profile(
7174
sandbox: AliasSandbox,
7275
sandbox_path: str,
7376
source_type: SourceType,
77+
model_interface: UnifiedModelCallInterface,
7478
) -> Dict[str, Any]:
7579
"""
7680
Generates a detailed profile and summary for data source using LLMs.
@@ -97,12 +101,10 @@ async def data_profile(
97101
elif source_type == SourceType.RELATIONAL_DB:
98102
local_path = sandbox_path
99103
else:
100-
raise ValueError(f"Unsupported source type: {source_type}")
101-
102-
dashscope_api_key = os.getenv("DASHSCOPE_API_KEY", "")
104+
raise ValueError(f"Unsupported source type {source_type}")
103105

104106
profiler = DataProfilerFactory.get_profiler(
105-
api_key=dashscope_api_key,
107+
model_interface=model_interface,
106108
path=local_path,
107109
source_type=source_type,
108110
)

0 commit comments

Comments
 (0)