Skip to content

Commit b545577

Browse files
committed
fix(data profile): refine the image content; remove assert; remove api_key; use logger
1 parent 6f825fb commit b545577

File tree

2 files changed

+42
-27
lines changed

2 files changed

+42
-27
lines changed

alias/src/alias/agent/agents/data_source/_data_profiler_factory.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import json
66
from abc import ABC, abstractmethod
77
from typing import Any, Dict
8+
9+
from loguru import logger
810
import pandas as pd
911
from sqlalchemy import inspect, text, create_engine
1012
from agentscope.message import Msg
@@ -59,7 +61,7 @@ async def generate_profile(self) -> Dict[str, Any]:
5961
res = await self._call_model(content)
6062
self.profile = self._wrap_data_response(res)
6163
except Exception as e:
62-
print(f"Error generating profile: {e}")
64+
logger.warning(f"Error generating profile: {e}")
6365
self.profile = {}
6466
return self.profile
6567

@@ -102,9 +104,6 @@ def _load_prompt_and_model(source_type: Any = None, api_key: str = None):
102104
"IRREGULAR": MODEL_CONFIG_NAME,
103105
}
104106

105-
if not api_key:
106-
api_key = os.environ.get("DASHSCOPE_API_KEY")
107-
108107
models_2_model_and_formatter = {
109108
MODEL_CONFIG_NAME: [
110109
DashScopeChatModel(
@@ -360,16 +359,19 @@ def _wrap_data_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
360359
# they contain tables. Each table contains columns and description
361360
if "tables" in self.data and "tables" in response:
362361
new_schema["tables"] = []
363-
for i, table in enumerate(self.data["tables"]):
364-
# Ensure alignment between schema tables and resp tables
365-
# TODO: It matches by order, by name would be more robust.
366-
if i >= len(response["tables"]):
367-
# LLM returns less tables than the original schema
368-
break
369-
assert response["tables"][i]["name"] == table["name"]
362+
# Build a map for response tables and descriptions
363+
res_des_map = {
364+
table["name"]: table["description"]
365+
for table in response["tables"]
366+
}
367+
for table in self.data["tables"]:
368+
table_name = table["name"]
369+
if table_name not in res_des_map:
370+
continue
370371
new_table = {}
371-
new_table["name"] = table["name"]
372-
new_table["description"] = response["tables"][i]["description"]
372+
new_table["name"] = table_name
373+
# Retain the desrciption from the LLM response
374+
new_table["description"] = res_des_map[table_name]
373375
if "columns" in table:
374376
new_table["columns"] = table["columns"]
375377
if "irregular_judgment" in table:
@@ -412,7 +414,7 @@ async def _extract_irregular_table(
412414
)
413415

414416
if "is_extractable_table" in res and res["is_extractable_table"]:
415-
print(res["reasoning"])
417+
logger.debug(res["reasoning"])
416418
skiprows = res["row_start_index"] + 1
417419
cols_range = res["col_ranges"]
418420
df = pd.read_excel(
@@ -533,7 +535,7 @@ async def _read_data(self):
533535
try:
534536
connection = engine.connect()
535537
except Exception as e:
536-
print(f"Connection to {self.path} failed: {e}")
538+
logger.error(f"Connection to {self.path} failed: {e}")
537539
raise ConnectionError(f"Failed to connect to database: {e}") from e
538540

539541
# Use DSN as the db identifier (can parsed cleaner)
@@ -581,7 +583,9 @@ async def _read_data(self):
581583
lines.append(", ".join(row_values))
582584
raw_data_snippet = "\n".join(lines)
583585
except Exception as e:
584-
print(f"Error fetching {table_name} data: {str(e)}")
586+
logger.warning(
587+
f"Error fetching {table_name} data: {str(e)}",
588+
)
585589
raw_data_snippet = None
586590
# 4. detailed column info (types and samples)
587591
column_details = []
@@ -618,8 +622,8 @@ async def _read_data(self):
618622

619623
except Exception as e:
620624
# If one table fails, log it and continue to the next
621-
print(f"Error processing {table_name}: {str(e)}")
622-
return {}
625+
logger.warning(f"Error processing {table_name}: {str(e)}")
626+
continue
623627
# Contruct the final schema
624628
schema = {
625629
"name": database_name,
@@ -687,15 +691,20 @@ def _generate_content(self, prompt, data):
687691
Returns:
688692
List containing image and text components for the LLM call
689693
"""
690-
contents = []
691694
# Convert image paths according to the model requirements
692-
contents.append(
695+
contents = [
693696
{
694-
"image": data,
697+
"text": prompt,
698+
"type": "text",
695699
},
696-
)
697-
# append text
698-
contents.append({"text": prompt})
700+
{
701+
"source": {
702+
"url": data,
703+
"type": "url",
704+
},
705+
"type": "image",
706+
},
707+
]
699708
return contents
700709

701710
def _wrap_data_response(self, response: Dict[str, Any]) -> Dict[str, Any]:

alias/src/alias/agent/agents/data_source/data_profile.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ def _get_binary_buffer(
3232
return buffer
3333

3434

35-
def _copy_file_from_sandbox(sandbox: AliasSandbox, file_path: str) -> str:
35+
def _copy_file_from_sandbox_with_original_name(
36+
sandbox: AliasSandbox,
37+
file_path: str,
38+
) -> str:
3639
"""
3740
Copies a file from the sandbox environment
3841
or a URL to a local temporary file.
@@ -61,7 +64,7 @@ def _copy_file_from_sandbox(sandbox: AliasSandbox, file_path: str) -> str:
6164
with open(full_path, "wb") as f:
6265
f.write(file_buffer.getvalue())
6366
file_source = full_path
64-
return file_source
67+
return str(file_source)
6568

6669

6770
async def data_profile(
@@ -87,7 +90,10 @@ async def data_profile(
8790
"""
8891

8992
if source_type in [SourceType.CSV, SourceType.EXCEL, SourceType.IMAGE]:
90-
local_path = _copy_file_from_sandbox(sandbox, sandbox_path)
93+
local_path = _copy_file_from_sandbox_with_original_name(
94+
sandbox,
95+
sandbox_path,
96+
)
9197
elif source_type == SourceType.RELATIONAL_DB:
9298
local_path = sandbox_path
9399
else:

0 commit comments

Comments
 (0)