Skip to content

Commit 635c1be

Browse files
committed
type(data profile): format the LLMCallManager
1 parent 35d5a8f commit 635c1be

File tree

6 files changed

+77
-81
lines changed

6 files changed

+77
-81
lines changed

alias/src/alias/agent/agents/data_source/_data_profiler_factory.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
from loguru import logger
1010
import pandas as pd
1111
from sqlalchemy import inspect, text, create_engine
12+
from agentscope.message import Msg
13+
1214
from alias.agent.agents.data_source._typing import SourceType
1315
from alias.agent.agents.ds_agent_utils import (
1416
get_prompt_from_file,
1517
)
16-
from alias.agent.utils.unified_model_call_interface import (
17-
UnifiedModelCallInterface,
18+
from alias.agent.utils.llm_call_manager import (
19+
LLMCallManager,
1820
)
1921

2022

@@ -32,7 +34,7 @@ def __init__(
3234
self,
3335
path: str,
3436
source_type: SourceType,
35-
model_interface: UnifiedModelCallInterface,
37+
llm_call_manager: LLMCallManager,
3638
):
3739
"""Initialize the data profiler with API key, data path and type.
3840
@@ -44,7 +46,7 @@ def __init__(
4446
self.path = path
4547
self.file_name = os.path.basename(path)
4648
self.source_type = source_type
47-
self.model_interface = model_interface
49+
self.llm_call_manager = llm_call_manager
4850

4951
self.source_types_2_prompts = {
5052
SourceType.CSV: "_profile_csv_prompt.md",
@@ -57,8 +59,8 @@ def __init__(
5759
raise ValueError(f"Unsupported source type: {source_type}")
5860
self.prompt = self._load_prompt(source_type)
5961

60-
base_model_name = self.model_interface.get_base_model_name()
61-
vl_model_name = self.model_interface.get_vl_model_name()
62+
base_model_name = self.llm_call_manager.get_base_model_name()
63+
vl_model_name = self.llm_call_manager.get_vl_model_name()
6264

6365
self.source_types_2_models = {
6466
SourceType.CSV: base_model_name,
@@ -159,9 +161,14 @@ async def _call_model(
159161
self,
160162
content: Any,
161163
) -> Dict[str, Any]:
162-
response = await self.model_interface.unified_model_call_interface(
164+
sys_prompt = "You are a helpful AI assistant for database management."
165+
msgs = [
166+
Msg("system", sys_prompt, "system"),
167+
Msg("user", content, "user"),
168+
]
169+
response = await self.llm_call_manager(
163170
model_name=self.model_name,
164-
user_content=content,
171+
messages=msgs,
165172
)
166173
response = BaseDataProfiler.tool_clean_json(response)
167174
return response
@@ -654,7 +661,7 @@ class DataProfilerFactory:
654661

655662
@staticmethod
656663
def get_profiler(
657-
model_interface: UnifiedModelCallInterface,
664+
llm_call_manager: LLMCallManager,
658665
path: str,
659666
source_type: SourceType,
660667
) -> BaseDataProfiler:
@@ -676,25 +683,25 @@ def get_profiler(
676683
return ImageProfiler(
677684
path=path,
678685
source_type=source_type,
679-
model_interface=model_interface,
686+
llm_call_manager=llm_call_manager,
680687
)
681688
elif source_type == SourceType.CSV:
682689
return CsvProfiler(
683690
path=path,
684691
source_type=source_type,
685-
model_interface=model_interface,
692+
llm_call_manager=llm_call_manager,
686693
)
687694
elif source_type == SourceType.EXCEL:
688695
return ExcelProfiler(
689696
path=path,
690697
source_type=source_type,
691-
model_interface=model_interface,
698+
llm_call_manager=llm_call_manager,
692699
)
693700
elif source_type == SourceType.RELATIONAL_DB:
694701
return RelationalDatabaseProfiler(
695702
path=path,
696703
source_type=source_type,
697-
model_interface=model_interface,
704+
llm_call_manager=llm_call_manager,
698705
)
699706
else:
700707
raise ValueError(f"Unsupported source type: {source_type}")

alias/src/alias/agent/agents/data_source/data_profile.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
get_workspace_file,
1616
)
1717
from alias.runtime.alias_sandbox.alias_sandbox import AliasSandbox
18-
from alias.agent.utils.unified_model_call_interface import (
19-
UnifiedModelCallInterface,
18+
from alias.agent.utils.llm_call_manager import (
19+
LLMCallManager,
2020
)
2121

2222

@@ -74,7 +74,7 @@ async def data_profile(
7474
sandbox: AliasSandbox,
7575
sandbox_path: str,
7676
source_type: SourceType,
77-
model_interface: UnifiedModelCallInterface,
77+
llm_call_manager: LLMCallManager,
7878
) -> Dict[str, Any]:
7979
"""
8080
Generates a detailed profile and summary for data source using LLMs.
@@ -104,7 +104,7 @@ async def data_profile(
104104
raise ValueError(f"Unsupported source type {source_type}")
105105

106106
profiler = DataProfilerFactory.get_profiler(
107-
model_interface=model_interface,
107+
llm_call_manager=llm_call_manager,
108108
path=local_path,
109109
source_type=source_type,
110110
)

alias/src/alias/agent/agents/data_source/data_source.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
from alias.agent.tools.sandbox_util import (
2525
copy_local_file_to_workspace,
2626
)
27-
from alias.agent.utils.unified_model_call_interface import (
28-
UnifiedModelCallInterface,
27+
from alias.agent.utils.llm_call_manager import (
28+
LLMCallManager,
2929
)
3030

3131

@@ -186,16 +186,16 @@ def get_coarse_desc(self):
186186
async def prepare_profile(
187187
self,
188188
sandbox: Sandbox,
189-
model_interface: UnifiedModelCallInterface,
189+
llm_call_manager: LLMCallManager,
190190
) -> Optional[Dict[str, Any]]:
191191
"""Run type-specific profiling."""
192-
if model_interface and not self.profile:
192+
if llm_call_manager and not self.profile:
193193
try:
194194
self.profile = await data_profile(
195195
sandbox=sandbox,
196196
sandbox_path=self.source_access,
197197
source_type=self.source_type,
198-
model_interface=model_interface,
198+
llm_call_manager=llm_call_manager,
199199
)
200200
logger.info(
201201
"Profiling successfully: "
@@ -252,7 +252,7 @@ class DataSourceManager:
252252
def __init__(
253253
self,
254254
sandbox: Sandbox,
255-
model_interface: UnifiedModelCallInterface,
255+
llm_call_manager: LLMCallManager,
256256
):
257257
"""Initialize an empty data source manager."""
258258
self._data_sources: Dict[str, DataSource] = {}
@@ -265,7 +265,7 @@ def __init__(
265265

266266
self.toolkit = AliasToolkit(sandbox=sandbox)
267267

268-
self.model_interface = model_interface
268+
self.llm_call_manager = llm_call_manager
269269

270270
def add_data_source(
271271
self,
@@ -338,7 +338,7 @@ async def prepare_data_sources(self) -> None:
338338
await data_source.prepare(self.toolkit)
339339
await data_source.prepare_profile(
340340
self.toolkit.sandbox,
341-
self.model_interface,
341+
self.llm_call_manager,
342342
)
343343

344344
def _generate_name(self, endpoint: str) -> str:

alias/src/alias/agent/run.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343
init_ds_toolkit,
4444
)
4545

46-
from alias.agent.utils.unified_model_call_interface import (
47-
UnifiedModelCallInterface,
46+
from alias.agent.utils.llm_call_manager import (
47+
LLMCallManager,
4848
)
4949

5050
MODEL_FORMATTER_MAPPING = {
@@ -116,15 +116,15 @@ async def arun_meta_planner(
116116
ds_toolkit = init_ds_toolkit(worker_full_toolkit)
117117

118118
# Initialize data source manager
119-
model_interface = UnifiedModelCallInterface(
119+
llm_call_manager = LLMCallManager(
120120
base_model_name=MODEL_CONFIG_NAME,
121121
vl_model_name=VL_MODEL_NAME,
122122
model_formatter_mapping=MODEL_FORMATTER_MAPPING,
123123
)
124124
data_manager = await prepare_data_sources(
125125
session_service=session_service,
126126
sandbox=sandbox,
127-
model_interface=model_interface,
127+
llm_call_manager=llm_call_manager,
128128
)
129129
add_data_source_tools(
130130
data_manager,
@@ -362,7 +362,7 @@ async def arun_datascience_agent(
362362

363363
global_toolkit = AliasToolkit(sandbox, add_all=True)
364364
worker_toolkit = init_ds_toolkit(global_toolkit)
365-
model_interface = UnifiedModelCallInterface(
365+
llm_call_manager = LLMCallManager(
366366
base_model_name=MODEL_CONFIG_NAME,
367367
vl_model_name=VL_MODEL_NAME,
368368
model_formatter_mapping=MODEL_FORMATTER_MAPPING,
@@ -371,7 +371,7 @@ async def arun_datascience_agent(
371371
session_service=session_service,
372372
sandbox=sandbox,
373373
binded_toolkit=worker_toolkit,
374-
model_interface=model_interface,
374+
llm_call_manager=llm_call_manager,
375375
)
376376

377377
try:
Lines changed: 33 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# -*- coding: utf-8 -*-
22
import asyncio
3-
from typing import Any, Dict, Literal
3+
from typing import Any, Dict, Literal, Type, AsyncGenerator
44
from agentscope.message import Msg
5-
from agentscope.model import DashScopeChatModel
6-
from agentscope.formatter import DashScopeChatFormatter
5+
from agentscope.model import DashScopeChatModel, ChatResponse
76

87
from tenacity import retry, stop_after_attempt, wait_fixed
8+
from pydantic import BaseModel
99

1010

1111
@retry(
@@ -14,28 +14,21 @@
1414
reraise=True,
1515
# before_sleep=_print_exc_on_retry
1616
)
17-
async def _model_call_with_retry(
17+
async def model_call_with_retry(
1818
model: DashScopeChatModel = None,
19-
formatter: DashScopeChatFormatter = None,
20-
sys_content: Any = None,
21-
user_content: Any = None,
19+
messages: list[dict[str, Any]] = None,
2220
tool_json_schemas: list[dict] | None = None,
2321
tool_choice: Literal["auto", "none", "required"] | str | None = None,
22+
structured_model: Type[BaseModel] | None = None,
2423
msg_name: str = "model_call",
25-
structured_model=None,
24+
**kwargs: Any,
2625
) -> Msg:
27-
msgs = [
28-
Msg("system", sys_content, "system"),
29-
Msg("user", user_content, "user"),
30-
]
31-
32-
format_msgs = await formatter.format(msgs=msgs)
33-
3426
res = await model(
35-
format_msgs,
27+
messages,
3628
tools=tool_json_schemas,
3729
tool_choice=tool_choice,
3830
structured_model=structured_model,
31+
kwargs=kwargs,
3932
)
4033

4134
if model.stream:
@@ -52,7 +45,7 @@ async def _model_call_with_retry(
5245
return msg
5346

5447

55-
class UnifiedModelCallInterface:
48+
class LLMCallManager:
5649
def __init__(
5750
self,
5851
base_model_name: str,
@@ -63,35 +56,31 @@ def __init__(
6356
self.vl_model_name = vl_model_name
6457
self.model_formatter_mapping = model_formatter_mapping
6558

66-
async def unified_model_call_interface(
67-
self,
68-
model_name: str = None,
69-
user_content: Any = None,
70-
sys_content: Any = None,
71-
) -> Msg:
72-
model, formatter = self._load_model_and_formatter(
73-
model_name=model_name,
74-
)
75-
if sys_content is None:
76-
sys_content = (
77-
"You are a helpful AI assistant for database management."
78-
)
79-
80-
raw_response = await _model_call_with_retry(
81-
model=model,
82-
formatter=formatter,
83-
sys_content=sys_content,
84-
user_content=user_content,
85-
)
86-
response = raw_response.content[0]["text"]
87-
return response
88-
89-
def _load_model_and_formatter(self, model_name: str):
90-
model, formatter = self.model_formatter_mapping[model_name]
91-
return model, formatter
92-
9359
def get_base_model_name(self) -> str:
9460
return self.base_model_name
9561

9662
def get_vl_model_name(self) -> str:
9763
return self.vl_model_name
64+
65+
async def __call__(
66+
self,
67+
model_name: str,
68+
messages: list[dict[str, Any]],
69+
tools: list[dict] | None = None,
70+
tool_choice: Literal["auto", "none", "required"] | str | None = None,
71+
structured_model: Type[BaseModel] | None = None,
72+
**kwargs: Any,
73+
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
74+
model, formatter = self.model_formatter_mapping[model_name]
75+
format_msgs = await formatter.format(msgs=messages)
76+
raw_response = await model_call_with_retry(
77+
model=model,
78+
messages=format_msgs,
79+
tool_json_schemas=tools,
80+
tool_choice=tool_choice,
81+
structured_model=structured_model,
82+
msg_name="model_call",
83+
kwargs=kwargs,
84+
)
85+
response = raw_response.content[0]["text"]
86+
return response

alias/src/alias/agent/utils/prepare_data_source.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
from alias.agent.agents.data_source.data_source import DataSourceManager
77
from alias.agent.tools import AliasToolkit, share_tools
8-
from alias.agent.utils.unified_model_call_interface import (
9-
UnifiedModelCallInterface,
8+
from alias.agent.utils.llm_call_manager import (
9+
LLMCallManager,
1010
)
1111

1212
if os.getenv("TEST_MODE") not in ["local", "runtime-test"]:
@@ -21,12 +21,12 @@ async def prepare_data_sources(
2121
session_service: SessionService,
2222
sandbox: Sandbox,
2323
binded_toolkit: AliasToolkit = None,
24-
model_interface: UnifiedModelCallInterface = None,
24+
llm_call_manager: LLMCallManager = None,
2525
):
2626
data_manager = await build_data_manager(
2727
session_service,
2828
sandbox,
29-
model_interface,
29+
llm_call_manager,
3030
)
3131
if len(data_manager):
3232
await add_user_data_message(session_service, data_manager)
@@ -40,9 +40,9 @@ async def prepare_data_sources(
4040
async def build_data_manager(
4141
session_service: SessionService,
4242
sandbox: Sandbox,
43-
model_interface: UnifiedModelCallInterface,
43+
llm_call_manager: LLMCallManager,
4444
):
45-
data_manager = DataSourceManager(sandbox, model_interface)
45+
data_manager = DataSourceManager(sandbox, llm_call_manager)
4646
if (
4747
hasattr(session_service.session_entity, "data_config")
4848
and session_service.session_entity.data_config

0 commit comments

Comments
 (0)