|
| 1 | +import sys |
| 2 | +import asyncio |
| 3 | +import json |
| 4 | +import os |
| 5 | +print(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) |
| 6 | +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) |
| 7 | + |
| 8 | +from typing import Any, Hashable |
| 9 | + |
| 10 | +import pandas as pd |
| 11 | +from pydantic import Field, model_validator |
| 12 | + |
| 13 | +from app.config import config |
| 14 | +from app.llm import LLM |
| 15 | +from app.logger import logger |
| 16 | +from app.tool.base import BaseTool |
| 17 | + |
| 18 | + |
| 19 | +class AddInsights(BaseTool): |
| 20 | + name: str = "add_insights" |
| 21 | + description: str = ( |
| 22 | + "Enhances charts by adding insights markers and annotations " |
| 23 | + "using JSON data generated by the insights_selection tool. " |
| 24 | + "This creates the final annotated visualization output." |
| 25 | + ) |
| 26 | + |
| 27 | + parameters: dict = { |
| 28 | + "type": "object", |
| 29 | + "properties": { |
| 30 | + "json_path": { |
| 31 | + "type": "string", |
| 32 | + "description": """Path to the JSON file generated by insights_selection tool. |
| 33 | +Contains chart insights data in format: |
| 34 | +{ |
| 35 | + "chartPath": string, |
| 36 | + "insights_id": number[] |
| 37 | +}""", |
| 38 | + }, |
| 39 | + "output_type": { |
| 40 | + "type": "string", |
| 41 | + "description": "Visualization output format selection", |
| 42 | + "default": "html", |
| 43 | + "enum": [ |
| 44 | + "png", # Static image format |
| 45 | + "html" # Interactive web format (recommended) |
| 46 | + ], |
| 47 | + }, |
| 48 | + }, |
| 49 | + "required": ["json_path"], |
| 50 | + } |
| 51 | + llm: LLM = Field(default_factory=LLM, description="Language model instance") |
| 52 | + |
| 53 | + @model_validator(mode="after") |
| 54 | + def initialize_llm(self): |
| 55 | + """Initialize llm with default settings if not provided.""" |
| 56 | + if self.llm is None or not isinstance(self.llm, LLM): |
| 57 | + self.llm = LLM(config_name=self.name.lower()) |
| 58 | + return self |
| 59 | + |
| 60 | + def load_chart_with_css(self, chart_path): |
| 61 | + # 读取 HTML 文件 |
| 62 | + with open(chart_path, 'r', encoding='utf-8') as f: |
| 63 | + html_content = f.read() |
| 64 | + html_content = html_content.replace('`', "'") |
| 65 | + |
| 66 | + # 在 <head> 里插入 CSS |
| 67 | + css = """ |
| 68 | + <style> |
| 69 | + body, html { |
| 70 | + margin: 0; |
| 71 | + padding: 0; |
| 72 | + height: 100%; |
| 73 | + overflow: hidden; |
| 74 | + } |
| 75 | + #chart-container { |
| 76 | + width: 100%; |
| 77 | + height: 100%; |
| 78 | + } |
| 79 | + </style> |
| 80 | + """ |
| 81 | + |
| 82 | + # 如果原文件没有 <head>,直接插入到最前面 |
| 83 | + if "<head>" in html_content: |
| 84 | + html_content = html_content.replace("<head>", "<head>" + css) |
| 85 | + else: |
| 86 | + html_content = css + html_content |
| 87 | + |
| 88 | + with open(chart_path, 'w', encoding='utf-8') as f: |
| 89 | + f.write(html_content) |
| 90 | + |
| 91 | + def get_file_path( |
| 92 | + self, |
| 93 | + json_info: list[dict[str, str]], |
| 94 | + path_str: str, |
| 95 | + directory: str = None, |
| 96 | + ) -> list[str]: |
| 97 | + res = [] |
| 98 | + for item in json_info: |
| 99 | + if os.path.exists(item[path_str]): |
| 100 | + res.append(item[path_str]) |
| 101 | + elif os.path.exists( |
| 102 | + os.path.join(f"{directory or config.workspace_root}", item[path_str]) |
| 103 | + ): |
| 104 | + res.append( |
| 105 | + os.path.join( |
| 106 | + f"{directory or config.workspace_root}", item[path_str] |
| 107 | + ) |
| 108 | + ) |
| 109 | + else: |
| 110 | + raise Exception(f"No such file or directory: {item[path_str]}") |
| 111 | + return res |
| 112 | + |
| 113 | + async def add_insights( |
| 114 | + self, json_info: list[dict[str, str]], output_type: str |
| 115 | + ) -> str: |
| 116 | + data_list = [] |
| 117 | + chart_file_path = self.get_file_path( |
| 118 | + json_info, "chartPath", os.path.join(config.workspace_root, "visualization") |
| 119 | + ) |
| 120 | + for index, item in enumerate(json_info): |
| 121 | + if "insights_id" in item: |
| 122 | + data_list.append( |
| 123 | + { |
| 124 | + "file_name": os.path.basename(chart_file_path[index]).replace( |
| 125 | + f".{output_type}", "" |
| 126 | + ), |
| 127 | + "insights_id": item["insights_id"], |
| 128 | + } |
| 129 | + ) |
| 130 | + tasks = [ |
| 131 | + self.invoke_vmind( |
| 132 | + insights_id=item["insights_id"], |
| 133 | + file_name=item["file_name"], |
| 134 | + output_type=output_type, |
| 135 | + task_type="insight", |
| 136 | + ) |
| 137 | + for item in data_list |
| 138 | + ] |
| 139 | + results = await asyncio.gather(*tasks) |
| 140 | + error_list = [] |
| 141 | + success_list = [] |
| 142 | + for index, result in enumerate(results): |
| 143 | + chart_path = chart_file_path[index] |
| 144 | + if "error" in result and "chart_path" not in result: |
| 145 | + error_list.append(f"Error in {chart_path}: {result['error']}") |
| 146 | + else: |
| 147 | + success_list.append(chart_path) |
| 148 | + self.load_chart_with_css(chart_path) |
| 149 | + |
| 150 | + success_template = ( |
| 151 | + f"# Charts Update with Insights\n{','.join(success_list)}" |
| 152 | + if len(success_list) > 0 |
| 153 | + else "" |
| 154 | + ) |
| 155 | + if len(error_list) > 0: |
| 156 | + return { |
| 157 | + "observation": f"# Error in chart insights:{'\n'.join(error_list)}\n{success_template}", |
| 158 | + "success": False, |
| 159 | + } |
| 160 | + else: |
| 161 | + return {"observation": f"{success_template}"} |
| 162 | + |
| 163 | + async def execute( |
| 164 | + self, |
| 165 | + json_path: str, |
| 166 | + output_type: str | None = "html", |
| 167 | + tool_type: str | None = "visualization", |
| 168 | + language: str | None = "en", |
| 169 | + ) -> str: |
| 170 | + try: |
| 171 | + logger.info(f"📈 data_visualization with {json_path} in: {tool_type} ") |
| 172 | + with open(json_path, "r", encoding="utf-8") as file: |
| 173 | + json_info = json.load(file) |
| 174 | + return await self.add_insights(json_info, output_type) |
| 175 | + except Exception as e: |
| 176 | + return { |
| 177 | + "observation": f"Error: {e}", |
| 178 | + "success": False, |
| 179 | + } |
| 180 | + |
| 181 | + async def invoke_vmind( |
| 182 | + self, |
| 183 | + file_name: str, |
| 184 | + output_type: str, |
| 185 | + task_type: str, |
| 186 | + insights_id: list[str] = None, |
| 187 | + dict_data: list[dict[Hashable, Any]] = None, |
| 188 | + chart_description: str = None, |
| 189 | + language: str = "en", |
| 190 | + ): |
| 191 | + llm_config = { |
| 192 | + "base_url": self.llm.base_url, |
| 193 | + "model": self.llm.model, |
| 194 | + "api_key": self.llm.api_key, |
| 195 | + } |
| 196 | + vmind_params = { |
| 197 | + "llm_config": llm_config, |
| 198 | + "user_prompt": chart_description, |
| 199 | + "dataset": dict_data, |
| 200 | + "file_name": file_name, |
| 201 | + "output_type": output_type, |
| 202 | + "insights_id": insights_id, |
| 203 | + "task_type": task_type, |
| 204 | + "directory": str(config.workspace_root), |
| 205 | + "language": language, |
| 206 | + } |
| 207 | + |
| 208 | + process = await asyncio.create_subprocess_exec( |
| 209 | + "npx", |
| 210 | + "ts-node", |
| 211 | + "src/chartVisualize.ts", |
| 212 | + stdin=asyncio.subprocess.PIPE, |
| 213 | + stdout=asyncio.subprocess.PIPE, |
| 214 | + stderr=asyncio.subprocess.PIPE, |
| 215 | + cwd=os.path.dirname(__file__), |
| 216 | + ) |
| 217 | + input_json = json.dumps(vmind_params, ensure_ascii=False).encode("utf-8") |
| 218 | + try: |
| 219 | + stdout, stderr = await process.communicate(input_json) |
| 220 | + stdout_str = stdout.decode("utf-8") |
| 221 | + stderr_str = stderr.decode("utf-8") |
| 222 | + if process.returncode == 0: |
| 223 | + return json.loads(stdout_str) |
| 224 | + else: |
| 225 | + return {"error": f"Node.js Error: {stderr_str}"} |
| 226 | + except Exception as e: |
| 227 | + return {"error": f"Subprocess Error: {str(e)}"} |
| 228 | + |
0 commit comments