|
5 | 5 | import json |
6 | 6 | from abc import ABC, abstractmethod |
7 | 7 | from typing import Any, Dict |
| 8 | + |
| 9 | +from loguru import logger |
8 | 10 | import pandas as pd |
9 | 11 | from sqlalchemy import inspect, text, create_engine |
10 | 12 | from agentscope.message import Msg |
@@ -59,7 +61,7 @@ async def generate_profile(self) -> Dict[str, Any]: |
59 | 61 | res = await self._call_model(content) |
60 | 62 | self.profile = self._wrap_data_response(res) |
61 | 63 | except Exception as e: |
62 | | - print(f"Error generating profile: {e}") |
| 64 | + logger.warning(f"Error generating profile: {e}") |
63 | 65 | self.profile = {} |
64 | 66 | return self.profile |
65 | 67 |
|
@@ -102,9 +104,6 @@ def _load_prompt_and_model(source_type: Any = None, api_key: str = None): |
102 | 104 | "IRREGULAR": MODEL_CONFIG_NAME, |
103 | 105 | } |
104 | 106 |
|
105 | | - if not api_key: |
106 | | - api_key = os.environ.get("DASHSCOPE_API_KEY") |
107 | | - |
108 | 107 | models_2_model_and_formatter = { |
109 | 108 | MODEL_CONFIG_NAME: [ |
110 | 109 | DashScopeChatModel( |
@@ -360,16 +359,19 @@ def _wrap_data_response(self, response: Dict[str, Any]) -> Dict[str, Any]: |
360 | 359 | # they contain tables. Each table contains columns and description |
361 | 360 | if "tables" in self.data and "tables" in response: |
362 | 361 | new_schema["tables"] = [] |
363 | | - for i, table in enumerate(self.data["tables"]): |
364 | | - # Ensure alignment between schema tables and resp tables |
365 | | - # TODO: It matches by order, by name would be more robust. |
366 | | - if i >= len(response["tables"]): |
367 | | - # LLM returns less tables than the original schema |
368 | | - break |
369 | | - assert response["tables"][i]["name"] == table["name"] |
| 362 | + # Build a map for response tables and descriptions |
| 363 | + res_des_map = { |
| 364 | + table["name"]: table["description"] |
| 365 | + for table in response["tables"] |
| 366 | + } |
| 367 | + for table in self.data["tables"]: |
| 368 | + table_name = table["name"] |
| 369 | + if table_name not in res_des_map: |
| 370 | + continue |
370 | 371 | new_table = {} |
371 | | - new_table["name"] = table["name"] |
372 | | - new_table["description"] = response["tables"][i]["description"] |
| 372 | + new_table["name"] = table_name |
| 373 | + # Retain the desrciption from the LLM response |
| 374 | + new_table["description"] = res_des_map[table_name] |
373 | 375 | if "columns" in table: |
374 | 376 | new_table["columns"] = table["columns"] |
375 | 377 | if "irregular_judgment" in table: |
@@ -412,7 +414,7 @@ async def _extract_irregular_table( |
412 | 414 | ) |
413 | 415 |
|
414 | 416 | if "is_extractable_table" in res and res["is_extractable_table"]: |
415 | | - print(res["reasoning"]) |
| 417 | + logger.debug(res["reasoning"]) |
416 | 418 | skiprows = res["row_start_index"] + 1 |
417 | 419 | cols_range = res["col_ranges"] |
418 | 420 | df = pd.read_excel( |
@@ -533,7 +535,7 @@ async def _read_data(self): |
533 | 535 | try: |
534 | 536 | connection = engine.connect() |
535 | 537 | except Exception as e: |
536 | | - print(f"Connection to {self.path} failed: {e}") |
| 538 | + logger.error(f"Connection to {self.path} failed: {e}") |
537 | 539 | raise ConnectionError(f"Failed to connect to database: {e}") from e |
538 | 540 |
|
539 | 541 | # Use DSN as the db identifier (can parsed cleaner) |
@@ -581,7 +583,9 @@ async def _read_data(self): |
581 | 583 | lines.append(", ".join(row_values)) |
582 | 584 | raw_data_snippet = "\n".join(lines) |
583 | 585 | except Exception as e: |
584 | | - print(f"Error fetching {table_name} data: {str(e)}") |
| 586 | + logger.warning( |
| 587 | + f"Error fetching {table_name} data: {str(e)}", |
| 588 | + ) |
585 | 589 | raw_data_snippet = None |
586 | 590 | # 4. detailed column info (types and samples) |
587 | 591 | column_details = [] |
@@ -618,8 +622,8 @@ async def _read_data(self): |
618 | 622 |
|
619 | 623 | except Exception as e: |
620 | 624 | # If one table fails, log it and continue to the next |
621 | | - print(f"Error processing {table_name}: {str(e)}") |
622 | | - return {} |
| 625 | + logger.warning(f"Error processing {table_name}: {str(e)}") |
| 626 | + continue |
623 | 627 | # Contruct the final schema |
624 | 628 | schema = { |
625 | 629 | "name": database_name, |
@@ -687,15 +691,20 @@ def _generate_content(self, prompt, data): |
687 | 691 | Returns: |
688 | 692 | List containing image and text components for the LLM call |
689 | 693 | """ |
690 | | - contents = [] |
691 | 694 | # Convert image paths according to the model requirements |
692 | | - contents.append( |
| 695 | + contents = [ |
693 | 696 | { |
694 | | - "image": data, |
| 697 | + "text": prompt, |
| 698 | + "type": "text", |
695 | 699 | }, |
696 | | - ) |
697 | | - # append text |
698 | | - contents.append({"text": prompt}) |
| 700 | + { |
| 701 | + "source": { |
| 702 | + "url": data, |
| 703 | + "type": "url", |
| 704 | + }, |
| 705 | + "type": "image", |
| 706 | + }, |
| 707 | + ] |
699 | 708 | return contents |
700 | 709 |
|
701 | 710 | def _wrap_data_response(self, response: Dict[str, Any]) -> Dict[str, Any]: |
|
0 commit comments