Skip to content

Commit 9f0d0bf

Browse files
committed
fix(data profile): refine the image content; remove assert; remove api_key
1 parent 6f825fb commit 9f0d0bf

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

alias/src/alias/agent/agents/data_source/_data_profiler_factory.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,6 @@ def _load_prompt_and_model(source_type: Any = None, api_key: str = None):
102102
"IRREGULAR": MODEL_CONFIG_NAME,
103103
}
104104

105-
if not api_key:
106-
api_key = os.environ.get("DASHSCOPE_API_KEY")
107-
108105
models_2_model_and_formatter = {
109106
MODEL_CONFIG_NAME: [
110107
DashScopeChatModel(
@@ -360,16 +357,19 @@ def _wrap_data_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
360357
# they contain tables. Each table contains columns and description
361358
if "tables" in self.data and "tables" in response:
362359
new_schema["tables"] = []
363-
for i, table in enumerate(self.data["tables"]):
364-
# Ensure alignment between schema tables and resp tables
365-
# TODO: It matches by order, by name would be more robust.
366-
if i >= len(response["tables"]):
367-
# LLM returns less tables than the original schema
368-
break
369-
assert response["tables"][i]["name"] == table["name"]
360+
# Build a map for response tables and descriptions
361+
res_des_map = {
362+
table["name"]: table["description"]
363+
for table in response["tables"]
364+
}
365+
for table in self.data["tables"]:
366+
table_name = table["name"]
367+
if table_name not in res_des_map:
368+
continue
370369
new_table = {}
371-
new_table["name"] = table["name"]
372-
new_table["description"] = response["tables"][i]["description"]
370+
new_table["name"] = table_name
371+
# Retain the desrciption from the LLM response
372+
new_table["description"] = res_des_map[table_name]
373373
if "columns" in table:
374374
new_table["columns"] = table["columns"]
375375
if "irregular_judgment" in table:
@@ -687,15 +687,20 @@ def _generate_content(self, prompt, data):
687687
Returns:
688688
List containing image and text components for the LLM call
689689
"""
690-
contents = []
691690
# Convert image paths according to the model requirements
692-
contents.append(
691+
contents = [
693692
{
694-
"image": data,
693+
"text": prompt,
694+
"type": "text",
695695
},
696-
)
697-
# append text
698-
contents.append({"text": prompt})
696+
{
697+
"source": {
698+
"url": data,
699+
"type": "url",
700+
},
701+
"type": "image",
702+
},
703+
]
699704
return contents
700705

701706
def _wrap_data_response(self, response: Dict[str, Any]) -> Dict[str, Any]:

alias/src/alias/agent/agents/data_source/data_profile.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ def _get_binary_buffer(
3232
return buffer
3333

3434

35-
def _copy_file_from_sandbox(sandbox: AliasSandbox, file_path: str) -> str:
35+
def _copy_file_from_sandbox_with_original_name(
36+
sandbox: AliasSandbox,
37+
file_path: str,
38+
) -> str:
3639
"""
3740
Copies a file from the sandbox environment
3841
or a URL to a local temporary file.
@@ -61,7 +64,7 @@ def _copy_file_from_sandbox(sandbox: AliasSandbox, file_path: str) -> str:
6164
with open(full_path, "wb") as f:
6265
f.write(file_buffer.getvalue())
6366
file_source = full_path
64-
return file_source
67+
return str(file_source)
6568

6669

6770
async def data_profile(
@@ -87,7 +90,10 @@ async def data_profile(
8790
"""
8891

8992
if source_type in [SourceType.CSV, SourceType.EXCEL, SourceType.IMAGE]:
90-
local_path = _copy_file_from_sandbox(sandbox, sandbox_path)
93+
local_path = _copy_file_from_sandbox_with_original_name(
94+
sandbox,
95+
sandbox_path,
96+
)
9197
elif source_type == SourceType.RELATIONAL_DB:
9298
local_path = sandbox_path
9399
else:

0 commit comments

Comments
 (0)