diff --git a/CHANGELOG.md b/CHANGELOG.md index 61be697..bf7d2f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +### Removed +- **Agent run_example tool**: Removed autonomous tool execution capability from agent. Agent now only recommends tools - all execution requires explicit user approval via approval buttons. This enforces consistent security/UX model where users maintain full control over tool execution. The underlying `gradio_space_tool.py` remains for UI-initiated demo execution. + ### Added - **New chat-based interface** (`ai_agent chat`) with conversational AI assistant - Chatbot component with rich media rendering (images, files, JSON, code blocks) @@ -34,6 +37,30 @@ All notable changes to this project will be documented in this file. - **YAML Model Configuration**: New `config.yaml` file for flexible model configuration supporting OpenAI, EPFL inference server, and any OpenAI-compatible API endpoints. - **Multi-Model Support**: Can now configure different models for agent (main reasoning & tool selection). - **Configuration Module**: New `utils/config.py` with Pydantic models for type-safe configuration loading and validation. +- **3D Lungs Segmentation Tool**: New MCP tool (`agent/tools/lungs_segmentation_tool.py`) that integrates with HuggingFace Space (https://qchapp-3d-lungs-segmentation.hf.space/) for 3D U-Net based lung segmentation in CT volumes. Supports DICOM, NIfTI, and TIFF stack inputs with robust file materialization strategy handling multiple Gradio output formats (FileData dict, URL string, local path, server path). +- **Tool Usage Analytics**: Real-time visualization in chat UI showing tool call frequency bar chart and timeline plot. Tracks all tool executions with timestamps and success/failure status. Provides users with transparency about which tools are being used during their session. +- **Downloadable Results Section**: New `download_files` component (gr.File with `file_count="multiple"`) positioned below chatbot for easy access to tool outputs. Files returned by tools (e.g., segmentation masks) are automatically extracted and presented as downloadable items separate from inline previews. +- **Inline Tool Approval Button**: Button-based tool execution approval appears dynamically in chat flow within a styled group box. Shows "šŸ¤– Tool Recommendation" header with contextual button label (e.g., "šŸš€ Run Lungs Segmentation") only when tool approval is pending. Replaces previous text-based approval pattern for better UX and extensibility. +- **Tool Registry System** (`agent/tools/mcp/registry.py`): Centralized tool registration pattern that eliminates tool-specific UI code + - `ToolConfig` dataclass for declarative tool configuration with field mappings + - `TOOL_REGISTRY` global dictionary for dynamic tool lookup + - `CATALOG_NAME_TO_TOOL` reverse mapping dict to handle catalog name → tool name resolution + - **Catalog name mapping**: Tools can specify `catalog_names` list to map dataset names (e.g., "lungs-segmentation") to internal tool names (e.g., "lungs_segmentation") + - **Clean registration pattern**: Single `ensure_tools_registered()` call in app.py replaces individual tool imports + - Generic extraction functions: `extract_preview()`, `extract_downloads()`, `extract_metadata()` + - Helper functions: `register_tool()`, `get_tool()`, `list_tools()`, `get_tool_display_name()`, `get_tool_icon()` + - `get_tool()` automatically resolves both registry names and catalog names for seamless integration with RAG recommendations + - Supports lazy loading to avoid loading heavy dependencies at import time +- **MCP Tools Subpackage** (`agent/tools/mcp/`): Organized separation of registered imaging tools (MCP protocol) from agent utilities. Base models, registry, and imaging tools (e.g., lungs_segmentation) now in dedicated subpackage for clarity. +- **Base Tool Models** (`agent/tools/mcp/base.py`): Standard Pydantic schemas for tool consistency + - `BaseToolOutput`: Standard fields across all tools (success, error, compute_time_seconds, result_preview, result_origin, metadata_text, notes) + - `BaseToolInput`: Minimal base class for tool inputs + - `ImageToolInput`: Common pattern for image-based tools with image_path and description fields +- **Tool Registration**: Lungs segmentation tool self-registers with complete field mappings + - Preview field: `result_preview` + - Download fields: `result_origin` + - Metadata field: `metadata_text` + - Notes field: `notes` ### Changed - CLI now supports `ai_agent chat` @@ -60,6 +87,31 @@ All notable changes to this project will be documented in this file. - **UI redesign**: File upload moved to dedicated right panel for cleaner workflow - **Visual hierarchy**: Header with gradient green banner and logo - **Button styling**: Primary actions use Imaging Plaza green theme colors +- **Tool Approval Workflow**: Replaced text-based approval (responding "yes"/"sure"/"ok" in chat) with explicit button-based approval. Tool execution now requires clicking a dedicated approval button that appears inline in the chat, improving clarity and preventing accidental tool execution. +- **Chat Output Structure**: Extended Gradio component outputs from 6 to 9 values to support new approval box and download files components. All event handlers (`submit_btn.click`, `msg_input.submit`, `approve_tool_btn.click`, `clear_btn.click`) now consistently yield/return all 9 outputs: chatbot history, state, 3 charts, state display, downloads, approval box visibility, and approval button label. +- **Tool Approval Button Position**: Moved approval button from standalone position below input controls to inline group box between chatbot and downloads section. Button now appears as part of "šŸ¤– Tool Recommendation" box with dynamic label showing tool name (e.g., "šŸš€ Run Lungs Segmentation"). +- **Generic Tool Execution** (`ui/handlers.py`): Replaced tool-specific `execute_lungs_segmentation()` with generic `execute_tool_with_approval()` + - Dynamic tool lookup via `get_tool(tool_name)` - NO hardcoded tool names + - Dynamic input construction: `tool_config.input_model(**params)` - works for any Pydantic schema + - Dynamic tool execution: `tool_config.executor(input_obj)` - calls registered executor + - Generic field extraction using registry field mappings - NO tool-specific code + - Works for ANY tool that registers in TOOL_REGISTRY + - Eliminates need for tool-specific if/else chains + - **Architectural benefit**: Adding 70+ tools requires ZERO changes to handlers.py +- **Dynamic Button Labels** (`ui/components.py`): Button text now uses registry helper functions + - `get_tool_display_name()` and `get_tool_icon()` provide consistent labels + - Replaces hardcoded string formatting: `tool_name.replace('_', ' ').title()` + - Button shows proper display name and icon from ToolConfig +- **Lazy Tool Loading** (`agent/tools/__init__.py`): Only export registry, not all tools + - Prevents loading heavy dependencies (nibabel, pydicom) at package import + - Tools imported explicitly where needed (e.g., in `ui/app.py`) + - Added `ensure_tools_registered()` function for explicit bulk loading + - Fixes import hangs caused by eager tool loading +- **Tool Import Location** (`ui/app.py`): Import lungs_segmentation_tool to trigger registration before UI launch +- **LungsSegmentationOutput Schema**: Enhanced with separate `result_origin` (original format file for download), `result_preview` (PNG preview for display), `metadata_text` (file metadata string), and `api_name` fields. Maintains backward compatibility with `result_path` field (now set to preview when available, else origin). +- **Download vs Display Separation**: Tool results now distinguish between files for download (`result_origin` - TIFF/NIfTI/DICOM) and inline display (`result_preview` - PNG). Downloads section shows original format files while chat shows converted previews for better compatibility. +- **HuggingFace Space Client Timeout**: Extended timeout to 300 seconds (5 minutes) via `httpx_kwargs={"timeout": 300.0}` parameter in `_make_gradio_client()` to handle slow Space cold starts and large medical imaging file uploads/downloads without timing out. Includes graceful fallback for older gradio_client versions without `httpx_kwargs` support. +- **Tool Execution Handler**: `execute_tool_with_approval()` in handlers.py now uses `result_preview` for inline images and `result_origin` for downloadable files, ensuring users get both viewable previews in chat and original format files for download. ### Removed - **VLMToolSelector**: Deleted unused `generator/generator.py` containing VLMToolSelector class. The pydantic-ai agent handles all tool selection directly. @@ -77,6 +129,10 @@ All notable changes to this project will be documented in this file. - **Clear Button**: Disabled during processing to prevent race conditions with ongoing requests. - **Alternative Tool Requests**: All recommended tools are now automatically added to the exclusion list (banlist) and properly passed to the agent through AgentState, ensuring follow-up requests like "I would like another tool" correctly return different tools. - **History Table**: Follow-up requests (without files) no longer create duplicate history entries. Only primary requests with files are logged to the History table. +- **Duplicate Function Definition**: Removed duplicate `clear_chat()` function definition in `components.py` that was causing syntax errors. +- **Text-Based Tool Approval Logic**: Removed legacy `_is_affirmative()` check in `handlers.py` that was conflicting with new button-based approval system. Tool execution now only triggered by explicit button click, preventing ambiguous user messages from unintentionally executing tools. +- **Gradio Component Compatibility**: Changed `gr.Box` to `gr.Group` for approval button container to ensure compatibility with Gradio 5.42.0 (gr.Box not available in this version). +- **Component Output Count**: Fixed inconsistent yield statements throughout `handle_chat()` generator function - all yields now consistently return 9 values (chatbot, state, 3 charts, state display, downloads, approval box visibility, button label) to match event handler output declarations. ## [0.1.3] - 2025-10-22 diff --git a/config.yaml b/config.yaml index f08138d..a6f8f0a 100644 --- a/config.yaml +++ b/config.yaml @@ -2,9 +2,12 @@ # Default/fallback model (used for CLI and initial startup) agent_model: - name: "gpt-5.1" - base_url: null # null for default OpenAI endpoint - api_key_env: "OPENAI_API_KEY" + # name: "gpt-5.1" + # base_url: null # null for default OpenAI endpoint + # api_key_env: "OPENAI_API_KEY" + name: "openai/gpt-oss-120b" + base_url: "https://inference-rcp.epfl.ch/v1" + api_key_env: "EPFL_API_KEY" # Available models for UI dropdown available_models: diff --git a/src/ai_agent/agent/agent.py b/src/ai_agent/agent/agent.py index 79e93e3..389308a 100644 --- a/src/ai_agent/agent/agent.py +++ b/src/ai_agent/agent/agent.py @@ -13,12 +13,11 @@ from ai_agent.generator.prompts import get_agent_system_prompt from ai_agent.generator.schema import ToolSelection, Conversation, ConversationStatus from ai_agent.utils.config import get_config -from .models import AgentToolSelection, ToolRunLog +from .models import AgentToolSelection, ToolRunLog, UsageStats from .tools.repo_info_tool import tool_repo_summary, RepoSummaryInput from ai_agent.agent.utils import coerce_github_url_or_none from .tools.search_tool import tool_search_tools, SearchToolsInput from .tools.search_alternative_tool import tool_search_alternative, SearchAlternativeInput -from .tools.gradio_space_tool import tool_run_example, RunExampleInput from .utils import AgentState, limit_tool_calls, cap_prepare from ai_agent.utils.image_meta import summarize_image_metadata, detect_ext_token @@ -216,39 +215,6 @@ async def repo_info(ctx: RunContext[AgentState], url: str, tool_name: str | None return out.model_dump(mode="python") -@agent.tool(retries=0, prepare=cap_prepare) -@limit_tool_calls("run_example", cap=1) -async def run_example( - ctx: RunContext[AgentState], - tool_name: str, - endpoint_url: str | None = None, - extra_text: str | None = None, -) -> dict: - """ - Run an example / demo for a given tool via its Gradio space. - - Thin wrapper around tools.gradio_space_tool.tool_run_example(). - """ - out = tool_run_example( - RunExampleInput( - tool_name=tool_name, - endpoint_url=endpoint_url, - extra_text=extra_text, - ) - ) - ctx.deps.tool_calls.append( - { - "tool": "run_example", - "tool_name": tool_name, - "ran": getattr(out, "ran", False), - "endpoint_url": getattr(out, "endpoint_url", endpoint_url), - "api_name": getattr(out, "api_name", None), - "timestamp": datetime.now().isoformat(), - } - ) - return out.model_dump(mode="python") - - # --------------------------------------------------------------------------- # High level entry point: run the agent on (text query + image) # --------------------------------------------------------------------------- @@ -387,7 +353,6 @@ def run_agent( agent_instance.tool(search_tools, retries=2, prepare=cap_prepare) agent_instance.tool(search_alternative, retries=2, prepare=cap_prepare) agent_instance.tool(repo_info, retries=2, prepare=cap_prepare) - agent_instance.tool(run_example, retries=0, prepare=cap_prepare) elif num_choices is not None and num_choices != 3: log.info( @@ -403,7 +368,6 @@ def run_agent( agent_instance.tool(search_tools, retries=2, prepare=cap_prepare) agent_instance.tool(search_alternative, retries=2, prepare=cap_prepare) agent_instance.tool(repo_info, retries=2, prepare=cap_prepare) - agent_instance.tool(run_example, retries=0, prepare=cap_prepare) else: log.info(f"ā™»ļø Using global agent (model: {effective_model}, num_choices: {effective_num_choices})") @@ -456,6 +420,7 @@ def run_agent( # Handle global tool quota limit (UsageLimitExceeded) and other errors gracefully error_msg = str(e) log.warning(f"āš ļø Agent execution encountered an error: {error_msg}") + run_result = None # Ensure run_result is defined for usage stats extraction # Check if this is a usage limit error (global tool quota) if "UsageLimitExceeded" in str(type(e).__name__) or "tool_calls_limit" in error_msg.lower(): @@ -490,13 +455,24 @@ def run_agent( ) ) - # ---- 8) Wrap into high-level AgentToolSelection ------------------------ + # ---- 8) Extract usage statistics if available ------------------------- + usage_stats = None + if run_result and hasattr(run_result, "usage") and run_result.usage: + usage = run_result.usage() + usage_stats = UsageStats( + total_tokens=usage.total_tokens, + input_tokens=usage.input_tokens, + output_tokens=usage.output_tokens, + ) + + # ---- 9) Wrap into high-level AgentToolSelection ------------------------ return AgentToolSelection( conversation=result.conversation, choices=result.choices, explanation=result.explanation, reason=result.reason, tool_calls=tool_logs, + usage=usage_stats, ) diff --git a/src/ai_agent/agent/models.py b/src/ai_agent/agent/models.py index 3e5d71c..80d7fed 100644 --- a/src/ai_agent/agent/models.py +++ b/src/ai_agent/agent/models.py @@ -11,8 +11,15 @@ class ToolRunLog(BaseModel): error: Optional[str] = None timestamp: Optional[str] = None +class UsageStats(BaseModel): + """Token usage statistics from the agent.""" + total_tokens: int = 0 + input_tokens: int = 0 + output_tokens: int = 0 + class AgentToolSelection(ToolSelection): tool_calls: List[ToolRunLog] = Field(default_factory=list) + usage: Optional[UsageStats] = None def to_legacy_dict(self) -> Dict[str, Any]: # Map to legacy pipeline result shape expected by UI (subset) @@ -22,10 +29,12 @@ def to_legacy_dict(self) -> Dict[str, Any]: "reason": self.reason, "explanation": self.explanation, "tool_calls": [c.model_dump(mode="python") for c in self.tool_calls], + "usage": self.usage.model_dump(mode="python") if self.usage else None, } __all__ = [ "AgentToolSelection", "ToolRunLog", + "UsageStats", "CandidateDoc", ] diff --git a/src/ai_agent/agent/tools/__init__.py b/src/ai_agent/agent/tools/__init__.py new file mode 100644 index 0000000..bdd8ff2 --- /dev/null +++ b/src/ai_agent/agent/tools/__init__.py @@ -0,0 +1,36 @@ +"""Agent tools package.""" + +# Only export registry - tools will self-register when imported explicitly +from .mcp import ( + TOOL_REGISTRY, + get_tool, + register_tool, + list_tools, + ensure_mcp_tools_registered, +) + +# Import tools lazily to avoid loading heavy dependencies at package import +# Tools should be imported explicitly where needed, e.g.: +# from ai_agent.agent.tools.mcp.lungs_segmentation_tool import tool_lungs_segmentation + +__all__ = [ + "TOOL_REGISTRY", + "get_tool", + "register_tool", + "list_tools", + "ensure_tools_registered", +] + + +def ensure_tools_registered(): + """ + Import all tools to trigger their registration. + Call this once at app startup. + """ + from .search_tool import tool_search_tools + from .search_alternative_tool import tool_search_alternative + from .repo_info_tool import tool_repo_summary + from .gradio_space_tool import tool_run_example + + # Import MCP tools + ensure_mcp_tools_registered() diff --git a/src/ai_agent/agent/tools/gradio_space_tool.py b/src/ai_agent/agent/tools/gradio_space_tool.py index 25c0a79..10fc902 100644 --- a/src/ai_agent/agent/tools/gradio_space_tool.py +++ b/src/ai_agent/agent/tools/gradio_space_tool.py @@ -6,6 +6,7 @@ from .utils import get_pipeline from ai_agent.utils.utils import _best_runnable_link from ai_agent.utils.previews import _build_preview_for_vlm +from ai_agent.utils.temp_file_manager import register_temp_file from gradio_client import Client, handle_file import tempfile from pathlib import Path @@ -69,7 +70,7 @@ def _download_to_temp(url: str) -> Optional[str]: with tempfile.NamedTemporaryFile(delete=False, prefix="demo_result_", suffix=ext) as fd: fd.write(r.content) fd.flush() - return fd.name + return register_temp_file(fd.name) except Exception: return None diff --git a/src/ai_agent/agent/tools/mcp/__init__.py b/src/ai_agent/agent/tools/mcp/__init__.py new file mode 100644 index 0000000..8551b17 --- /dev/null +++ b/src/ai_agent/agent/tools/mcp/__init__.py @@ -0,0 +1,51 @@ +""" +MCP (Model Context Protocol) tools package. + +This package contains registered imaging tools that require approval +and follow the tool registry pattern. +""" + +from .registry import ( + TOOL_REGISTRY, + CATALOG_NAME_TO_TOOL, + get_tool, + register_tool, + list_tools, + get_tool_display_name, + get_tool_icon, + extract_preview, + extract_downloads, + extract_metadata, + extract_output_field, + ToolConfig, +) + +from .base import BaseToolInput, BaseToolOutput, ImageToolInput + +__all__ = [ + # Registry + "TOOL_REGISTRY", + "CATALOG_NAME_TO_TOOL", + "get_tool", + "register_tool", + "list_tools", + "get_tool_display_name", + "get_tool_icon", + "extract_preview", + "extract_downloads", + "extract_metadata", + "extract_output_field", + "ToolConfig", + # Base models + "BaseToolInput", + "BaseToolOutput", + "ImageToolInput", +] + + +def ensure_mcp_tools_registered(): + """ + Import all MCP tools to trigger their registration. + Call this once at app startup. + """ + from . import lungs_segmentation_tool diff --git a/src/ai_agent/agent/tools/mcp/base.py b/src/ai_agent/agent/tools/mcp/base.py new file mode 100644 index 0000000..fcf356b --- /dev/null +++ b/src/ai_agent/agent/tools/mcp/base.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from typing import Optional, Dict, Any +from pydantic import BaseModel, Field + + +class BaseToolInput(BaseModel): + """ + Base input model that tools can extend. + + Common patterns: + - image_path: Path to uploaded image/volume + - description: Optional context from agent + """ + pass # Intentionally minimal - tools define their own inputs + + +class BaseToolOutput(BaseModel): + """ + Base output model that all tools should follow. + + This ensures consistent handling in the UI layer without + needing tool-specific code. + + Standard fields: + - success: bool - Whether execution succeeded + - error: Optional[str] - Error message if failed + - compute_time_seconds: float - Time taken by tool + - notes: Optional[str] - Additional info for user + + File outputs (at least one should be provided on success): + - result_preview: Optional[str] - PNG/GIF preview for inline display + - result_origin: Optional[str] - Original format file for download + - result_path: Optional[str] - Backward compat field + + Metadata: + - metadata_text: Optional[str] - Structured info about result + - metadata: Dict[str, Any] - Machine-readable metadata + + Tracking: + - endpoint_url: str - API endpoint used + - api_name: str - API method called + """ + # Core status + success: bool = False + error: Optional[str] = None + compute_time_seconds: float = 0.0 + + # File outputs (tools should provide these for UI to display/download) + result_preview: Optional[str] = Field( + default=None, + description="Path to preview image (PNG/GIF) for inline display in chat" + ) + result_origin: Optional[str] = Field( + default=None, + description="Path to original format file (TIFF/NIfTI/DICOM) for download" + ) + result_path: Optional[str] = Field( + default=None, + description="Backward compatibility: primary result path" + ) + + # Metadata + metadata_text: Optional[str] = Field( + default=None, + description="Human-readable metadata about the result" + ) + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="Machine-readable metadata" + ) + notes: Optional[str] = Field( + default=None, + description="Additional notes or context for user" + ) + + # Tracking/debugging + endpoint_url: str = "" + api_name: str = "" + + +class ImageToolInput(BaseToolInput): + """ + Common input pattern for image/volume processing tools. + """ + image_path: str = Field(description="Path to the image/volume file") + description: Optional[str] = Field( + default=None, + description="Optional context or notes from agent about the task" + ) diff --git a/src/ai_agent/agent/tools/mcp/lungs_segmentation_tool.py b/src/ai_agent/agent/tools/mcp/lungs_segmentation_tool.py new file mode 100644 index 0000000..c35f7e9 --- /dev/null +++ b/src/ai_agent/agent/tools/mcp/lungs_segmentation_tool.py @@ -0,0 +1,404 @@ +from __future__ import annotations + +from typing import Optional, Any, Dict, Tuple +from pydantic import BaseModel, Field +import os +import logging +import tempfile +from pathlib import Path +import time + +import requests +from gradio_client import Client, handle_file + +from ai_agent.utils.previews import _build_preview_for_vlm +from ai_agent.utils.temp_file_manager import register_temp_file +from ai_agent.agent.tools.mcp.registry import register_tool, ToolConfig +from ai_agent.agent.tools.mcp.base import BaseToolOutput, ImageToolInput + +log = logging.getLogger("agent.lungs_segmentation") + +# --------------------------------------------------------------------- +# Models +# --------------------------------------------------------------------- +class LungsSegmentationInput(ImageToolInput): + """Input for 3D lungs segmentation tool.""" + pass # Inherits image_path and description from ImageToolInput + + +class LungsSegmentationOutput(BaseToolOutput): + """Output from 3D lungs segmentation tool.""" + # All standard fields inherited from BaseToolOutput: + # - success, error, compute_time_seconds, notes + # - result_preview, result_origin, result_path + # - metadata_text, endpoint_url, api_name + pass + + +# --------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------- +LUNGS_SEGMENTATION_ENDPOINT = "https://qchapp-3d-lungs-segmentation.hf.space/" +LUNGS_SEGMENTATION_API_NAME = "/segment" + +# Maximum file size for downloads (1GB for medical imaging) +MAX_DOWNLOAD_SIZE = 1024 * 1024 * 1024 # 1GB in bytes + + +# --------------------------------------------------------------------- +# Public tool +# --------------------------------------------------------------------- +def tool_lungs_segmentation(inp: LungsSegmentationInput) -> LungsSegmentationOutput: + """ + Run 3D lungs segmentation on a CT scan image via a Gradio Space. + + Materialization strategy (robust): + 1) If Space returns dict FileData (url/path/etc) -> download via URL. + 2) If Space returns URL string -> download. + 3) If Space returns local file -> use it. + 4) If Space returns server path (/tmp/...) -> try /gradio_api/file=... (may 403). + """ + start_time = time.time() + + if not os.path.exists(inp.image_path): + return LungsSegmentationOutput( + success=False, + error=f"Image file not found: {inp.image_path}", + endpoint_url=LUNGS_SEGMENTATION_ENDPOINT, + api_name=LUNGS_SEGMENTATION_API_NAME, + ) + + hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") + + try: + log.info("Running lungs segmentation on %s (endpoint: %s)", inp.image_path, LUNGS_SEGMENTATION_ENDPOINT) + + client = _make_gradio_client(LUNGS_SEGMENTATION_ENDPOINT, hf_token) + + # Call API + try: + result = client.predict( + file_obj=handle_file(inp.image_path), + api_name=LUNGS_SEGMENTATION_API_NAME, + ) + log.info("API returned type=%s value=%r", type(result), result) + except Exception as e: + return LungsSegmentationOutput( + success=False, + error=f"API call failed: {e}", + compute_time_seconds=time.time() - start_time, + endpoint_url=LUNGS_SEGMENTATION_ENDPOINT, + api_name=LUNGS_SEGMENTATION_API_NAME, + ) + + # Materialize to local file + origin_path = _materialize_any(result, client=client, hf_token=hf_token) + + compute_time = time.time() - start_time + + if not origin_path or not os.path.exists(origin_path): + # This is the common case if the Space returns '/tmp/...' and Gradio blocks it (403). + return LungsSegmentationOutput( + success=False, + error="Could not materialize/download the result file.", + compute_time_seconds=compute_time, + endpoint_url=LUNGS_SEGMENTATION_ENDPOINT, + api_name=LUNGS_SEGMENTATION_API_NAME, + notes=( + f"API returned: {result!r}. If this is a '/tmp/...' path and you see HTTP 403, " + "the Space must return a FileData/url (recommended) or whitelist the output directory " + "via allowed_paths / GRADIO_TEMP_DIR." + ), + ) + + # Build preview + metadata using your shared function + preview_path, meta_text = _safe_build_preview(origin_path) + + # Back-compat: prefer preview in result_path + result_path = preview_path or origin_path + + return LungsSegmentationOutput( + success=True, + result_path=result_path, + result_origin=origin_path, + result_preview=preview_path, + metadata_text=meta_text, + compute_time_seconds=compute_time, + endpoint_url=LUNGS_SEGMENTATION_ENDPOINT, + api_name=LUNGS_SEGMENTATION_API_NAME, + notes=f"Successfully segmented lungs from {os.path.basename(inp.image_path)}", + ) + + except Exception as e: + log.exception("Lungs segmentation failed") + return LungsSegmentationOutput( + success=False, + error=str(e), + compute_time_seconds=time.time() - start_time, + endpoint_url=LUNGS_SEGMENTATION_ENDPOINT, + api_name=LUNGS_SEGMENTATION_API_NAME, + ) + + +# --------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------- +def _make_gradio_client(endpoint: str, hf_token: Optional[str]) -> Client: + """ + Create a gradio_client.Client with best compatibility across versions. + """ + # Set extended timeout for both connection and operations (5 minutes for large files) + httpx_kwargs = {"timeout": 300.0} + + # Newer versions use token=, older versions used hf_token= + if hf_token: + try: + return Client(endpoint, hf_token=hf_token, httpx_kwargs=httpx_kwargs) + except TypeError: + # Fallback for very old versions without httpx_kwargs support + try: + return Client(endpoint, hf_token=hf_token) + except TypeError: + return Client(endpoint) + + try: + return Client(endpoint, httpx_kwargs=httpx_kwargs) + except TypeError: + # Fallback for very old versions + return Client(endpoint) + + +def _safe_build_preview(origin_path: str) -> Tuple[Optional[str], Optional[str]]: + """ + Wrapper around _build_preview_for_vlm so preview failures never break the tool. + """ + try: + preview_path, meta_text = _build_preview_for_vlm([origin_path]) + return preview_path, meta_text + except Exception as e: + log.debug("Preview build failed for %s: %r", origin_path, e) + return None, None + + +def _materialize_any(obj: Any, client: Client, hf_token: Optional[str] = None, _depth: int = 0) -> Optional[str]: + """ + Convert common Gradio outputs into a local file path. + + Supported: + - local path string + - URL string + - dict (FileData-like) containing url/path/name/filepath + - list/tuple containing any of the above + - server path '/tmp/...' -> attempt Gradio file endpoint (may 403) + + Args: + obj: Object to materialize + client: Gradio client + hf_token: Optional HuggingFace token + _depth: Internal recursion depth counter (max 10) + """ + if obj is None or _depth > 10: + if _depth > 10: + log.warning("Recursion depth limit reached in _materialize_any, halting.") + return None + + # list/tuple: most Gradio outputs are single-element lists + if isinstance(obj, (list, tuple)) and obj: + return _materialize_any(obj[0], client=client, hf_token=hf_token, _depth=_depth + 1) + + # dict: FileData-like is best case (url provided) + if isinstance(obj, dict): + # Prefer URL if present + url = obj.get("url") + if isinstance(url, str) and url.startswith(("http://", "https://")): + log.info("Materialize: dict url=%s", url) + return _download_to_temp(url, hf_token=hf_token) + + # Fall back through common keys + for k in ("path", "filepath", "file", "name"): + v = obj.get(k) + if isinstance(v, str) and v: + return _materialize_any(v, client=client, hf_token=hf_token, _depth=_depth + 1) + + return None + + # string: local file, URL, or server path + if isinstance(obj, str): + s = obj.strip() + if not s: + return None + + # local file? + p = Path(s) + if p.exists() and p.is_file(): + log.info("Materialize: local file=%s", s) + return str(p) + + # URL? + if s.startswith(("http://", "https://")): + log.info("Materialize: url=%s", s) + return _download_to_temp(s, hf_token=hf_token) + + # server path? (e.g. /tmp/xxx_mask.tif) + if s.startswith("/"): + log.info("Materialize: server path=%s", s) + return _download_from_gradio_file_endpoint(client, s, hf_token=hf_token) + + return None + + +def _download_to_temp(url: str, hf_token: Optional[str] = None) -> Optional[str]: + """ + Download a URL to a temporary file (streaming) with size limit checks. + """ + headers: Dict[str, str] = {} + if hf_token: + headers["Authorization"] = f"Bearer {hf_token}" + + try: + with requests.get(url, headers=headers, timeout=120, stream=True, allow_redirects=True) as r: + if r.status_code != 200: + log.error("Download failed: url=%s status=%s", url, r.status_code) + return None + + # Check Content-Length if available + content_length = r.headers.get("content-length") + if content_length and int(content_length) > MAX_DOWNLOAD_SIZE: + log.error("File too large: %s bytes (max %s)", content_length, MAX_DOWNLOAD_SIZE) + return None + + ext = _guess_ext(url, r.headers.get("content-type", "")) + + with tempfile.NamedTemporaryFile(delete=False, prefix="lungs_seg_", suffix=ext) as f: + downloaded_size = 0 + for chunk in r.iter_content(chunk_size=1024 * 1024): + if chunk: + downloaded_size += len(chunk) + if downloaded_size > MAX_DOWNLOAD_SIZE: + log.error("Download exceeded size limit: %s bytes", downloaded_size) + f.close() + os.remove(f.name) + return None + f.write(chunk) + log.info("Downloaded %s bytes: %s -> %s", downloaded_size, url, f.name) + return register_temp_file(f.name) + except Exception as e: + log.error("Failed to download %s: %r", url, e) + return None + + +def _download_from_gradio_file_endpoint(client: Client, server_path: str, hf_token: Optional[str] = None) -> Optional[str]: + """ + Last-resort fallback when API returns '/tmp/...' but no URL. + Often blocked with 403 unless Space allows that directory or writes into Gradio temp/cache. + Includes size limit checks. + """ + base = (getattr(client, "src", None) or LUNGS_SEGMENTATION_ENDPOINT).rstrip("/") + file_url = f"{base}/gradio_api/file={server_path}" + + headers: Dict[str, str] = {} + if hf_token: + headers["Authorization"] = f"Bearer {hf_token}" + + params: Dict[str, str] = {} + session_hash = getattr(client, "session_hash", None) + if session_hash: + params["session_hash"] = session_hash + + try: + r = requests.get(file_url, headers=headers, params=params, timeout=60, stream=True) + if r.status_code == 403: + # Common: file exists but not allowed to be served + detail: Any + try: + detail = r.json() + except Exception: + detail = r.text[:200] + log.error("HTTP 403 from %s detail=%r", file_url, detail) + return None + + if r.status_code != 200: + log.error("HTTP %s from %s", r.status_code, file_url) + return None + + # Check Content-Length before downloading + content_length = r.headers.get("content-length") + if content_length and int(content_length) > MAX_DOWNLOAD_SIZE: + log.error("File too large: %s bytes (max %s)", content_length, MAX_DOWNLOAD_SIZE) + return None + + # Read content with size check + content = b"" + for chunk in r.iter_content(chunk_size=1024 * 1024): + if chunk: + content += chunk + if len(content) > MAX_DOWNLOAD_SIZE: + log.error("Download exceeded size limit: %s bytes", len(content)) + return None + + ct = r.headers.get("content-type", "") + if "html" in ct.lower() or content.startswith(b" %s", len(content), f.name) + return register_temp_file(f.name) + + except Exception as e: + log.error("Failed gradio file endpoint download: %r", e) + return None + + +def _guess_ext(url: str, content_type: str) -> str: + """ + Guess file extension from URL path or Content-Type. + """ + from urllib.parse import urlparse + path = urlparse(url).path.lower() + + if path.endswith(".nii.gz"): + return ".nii.gz" + + ext = os.path.splitext(path)[1] + if ext: + return ext + + ct = (content_type or "").lower() + if "tiff" in ct or "tif" in ct: + return ".tif" + if "png" in ct: + return ".png" + if "jpeg" in ct or "jpg" in ct: + return ".jpg" + if "gif" in ct: + return ".gif" + if "nifti" in ct or "nii" in ct: + return ".nii.gz" + return ".bin" + + +# --------------------------------------------------------------------- +# Tool Registration +# --------------------------------------------------------------------- +register_tool(ToolConfig( + name="lungs_segmentation", + display_name="3D Lungs Segmentation", + icon="🫁", + catalog_names=["lungs-segmentation"], # Catalog name from dataset/catalog.jsonl + input_model=LungsSegmentationInput, + output_model=LungsSegmentationOutput, + executor=tool_lungs_segmentation, + supports_images=True, + supports_files=True, + requires_approval=True, + preview_field="result_preview", + download_fields="result_origin", # Could also be ["result_origin", "other_file"] + metadata_field="metadata_text", + notes_field="notes", + success_field="success", + error_field="error", + compute_time_field="compute_time_seconds", +)) \ No newline at end of file diff --git a/src/ai_agent/agent/tools/mcp/registry.py b/src/ai_agent/agent/tools/mcp/registry.py new file mode 100644 index 0000000..b516c85 --- /dev/null +++ b/src/ai_agent/agent/tools/mcp/registry.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +from typing import Dict, Type, Callable, Optional, List, Any +from pydantic import BaseModel +from dataclasses import dataclass + + +@dataclass +class ToolConfig: + """ + Declarative configuration for a tool. + + Tools register themselves with this config, and the UI uses it + to generically handle execution, display, and file management. + """ + # Core identification + name: str # Internal name (e.g., "lungs_segmentation") + display_name: str # User-facing name (e.g., "3D Lungs Segmentation") + icon: str # Emoji for UI display + + # Type information + input_model: Type[BaseModel] # Pydantic model for inputs + output_model: Type[BaseModel] # Pydantic model for outputs + executor: Callable # Function that takes input_model and returns output_model + + # Capability flags + catalog_names: Optional[List[str]] = None # Catalog names for this tool (e.g., ["lungs-segmentation"]) + supports_images: bool = True + supports_files: bool = True + requires_approval: bool = True # Whether to show approval button + + # Output field mappings (how to extract results from output_model) + # These map generic concepts to tool-specific field names + preview_field: str = "result_preview" # Field containing preview image path + download_fields: List[str] | str = "result_origin" # Field(s) for downloadable files + metadata_field: Optional[str] = "metadata_text" # Optional metadata text + notes_field: str = "notes" # Field containing execution notes + + # Success detection + success_field: str = "success" # Field indicating success/failure + error_field: str = "error" # Field containing error message + compute_time_field: str = "compute_time_seconds" # Field with timing info + + +# Global tool registry +TOOL_REGISTRY: Dict[str, ToolConfig] = {} + +# Reverse mapping from catalog names to tool names +CATALOG_NAME_TO_TOOL: Dict[str, str] = {} + + +def register_tool(config: ToolConfig) -> None: + """ + Register a tool with the global registry. + + Args: + config: Tool configuration + + Raises: + ValueError: If tool name already registered or catalog name collision + """ + if config.name in TOOL_REGISTRY: + raise ValueError(f"Tool '{config.name}' is already registered") + + # Check for catalog name collisions before registering + if config.catalog_names: + for catalog_name in config.catalog_names: + if catalog_name in CATALOG_NAME_TO_TOOL and CATALOG_NAME_TO_TOOL[catalog_name] != config.name: + raise ValueError( + f"Catalog name '{catalog_name}' already registered to " + f"'{CATALOG_NAME_TO_TOOL[catalog_name]}'" + ) + + TOOL_REGISTRY[config.name] = config + + # Register catalog name mappings + if config.catalog_names: + for catalog_name in config.catalog_names: + CATALOG_NAME_TO_TOOL[catalog_name] = config.name + + +def get_tool(name: str) -> Optional[ToolConfig]: + """ + Get tool configuration by name. + + Args: + name: Tool name (registry name or catalog name) + + Returns: + ToolConfig if found, None otherwise + + Note: + This function checks both the tool registry name and catalog names. + """ + # First try direct registry lookup + config = TOOL_REGISTRY.get(name) + if config: + return config + + # Try catalog name mapping + tool_name = CATALOG_NAME_TO_TOOL.get(name) + if tool_name: + return TOOL_REGISTRY.get(tool_name) + + return None + + +def list_tools() -> List[str]: + """Get list of all registered tool names.""" + return list(TOOL_REGISTRY.keys()) + + +def get_tool_display_name(name: str) -> str: + """ + Get display name for a tool, with fallback to name. + + Args: + name: Tool name + + Returns: + Display name or formatted version of name + """ + tool = get_tool(name) + if tool: + return tool.display_name + # Fallback: format name nicely + return name.replace("_", " ").title() + + +def get_tool_icon(name: str) -> str: + """ + Get icon for a tool, with fallback. + + Args: + name: Tool name + + Returns: + Icon emoji or default + """ + tool = get_tool(name) + if tool: + return tool.icon + return "šŸ”§" # Default tool icon + + +def extract_output_field(output: BaseModel, field_name: str) -> Any: + """ + Safely extract a field from tool output. + + Args: + output: Tool output object + field_name: Field name to extract + + Returns: + Field value or None if not found + """ + return getattr(output, field_name, None) + + +def extract_preview(output: BaseModel, tool_name: str) -> Optional[str]: + """Extract preview image path from tool output.""" + tool = get_tool(tool_name) + if not tool: + return None + return extract_output_field(output, tool.preview_field) + + +def extract_downloads(output: BaseModel, tool_name: str) -> List[str]: + """Extract downloadable file paths from tool output.""" + tool = get_tool(tool_name) + if not tool: + return [] + + download_fields = tool.download_fields + if isinstance(download_fields, str): + download_fields = [download_fields] + + downloads = [] + for field in download_fields: + value = extract_output_field(output, field) + if value: + if isinstance(value, list): + downloads.extend([v for v in value if v]) + elif isinstance(value, str): + downloads.append(value) + + return [d for d in downloads if d] # Filter None/empty + + +def extract_metadata(output: BaseModel, tool_name: str) -> Optional[str]: + """Extract metadata text from tool output.""" + tool = get_tool(tool_name) + if not tool or not tool.metadata_field: + return None + return extract_output_field(output, tool.metadata_field) diff --git a/src/ai_agent/generator/prompts.py b/src/ai_agent/generator/prompts.py index 807ad01..18cbb15 100644 --- a/src/ai_agent/generator/prompts.py +++ b/src/ai_agent/generator/prompts.py @@ -87,13 +87,11 @@ - search_tools(query, excluded=[], top_k=...): Semantic search with automatic query expansion and reranking - search_alternative(alternative_query, excluded=[], top_k=...): Try different query formulation (up to 3 times) - repo_info(url): Fetch GitHub repository info for verification (required for finalists) -- run_example(tool_name, endpoint_url=None, extra_text=None): Test tool functionality (optional) USAGE PATTERN: 1. search_tools(query) → Get initial candidates 2. [Optional] search_alternative(alternative_query) → Try different terms if needed 3. repo_info(url) → Verify each finalist before recommending -4. [Optional] run_example(tool_name) → Test if needed """ ) diff --git a/src/ai_agent/ui/app.py b/src/ai_agent/ui/app.py index 807f4a3..6544ec9 100644 --- a/src/ai_agent/ui/app.py +++ b/src/ai_agent/ui/app.py @@ -52,6 +52,11 @@ from ai_agent.retriever.software_doc import SoftwareDoc from ai_agent.ui.components import create_chat_interface +# ============================================================================ +# Tool registration +# ============================================================================ +from ai_agent.agent.tools import ensure_tools_registered +ensure_tools_registered() # ============================================================================ # Pipeline initialization diff --git a/src/ai_agent/ui/components.py b/src/ai_agent/ui/components.py index 4dfeae6..b5ebd38 100644 --- a/src/ai_agent/ui/components.py +++ b/src/ai_agent/ui/components.py @@ -11,6 +11,7 @@ from .handlers import respond from .visualizations import create_tool_usage_chart, create_tool_timeline, create_disabled_tools_display from .utils import get_available_models, get_default_model_display_name +from .state import format_stats_markdown log = logging.getLogger("chat_components") @@ -165,6 +166,25 @@ def create_chat_interface(doc_index: Dict[str, SoftwareDoc]): avatar_images=("šŸ‘¤", "šŸ¤–"), ) + # Tool approval box (appears inline when approval needed) + with gr.Group(visible=False) as approval_box: + gr.Markdown("### šŸ¤– Tool Recommendation") + approve_tool_btn = gr.Button( + "šŸš€ Run Tool", + variant="primary", + size="lg", + scale=1, + ) + + # File downloads section + download_files = gr.File( + label="šŸ“„ Download Results", + file_count="multiple", + type="filepath", + visible=True, + height=100, + ) + with gr.Row(): with gr.Column(scale=8): msg_input = gr.Textbox( @@ -252,7 +272,7 @@ def handle_chat(message: str, history: List[dict], files: List, state_dict: dict user_msg["content"] = file_list history.append(user_msg) - yield history, state_dict, gr.update(), gr.update(), gr.update(), gr.update() + yield history, state_dict, gr.update(), gr.update(), gr.update(), gr.update(), None, gr.update(visible=False), gr.update() # If files were uploaded, build and show preview immediately if files: @@ -281,14 +301,14 @@ def handle_chat(message: str, history: List[dict], files: List, state_dict: dict "content": {"path": preview_path}, } ) - yield history, state_dict, gr.update(), gr.update(), gr.update(), gr.update() + yield history, state_dict, gr.update(), gr.update(), gr.update(), gr.update(), None, gr.update(visible=False), gr.update() except Exception as e: log.warning("Preview generation failed: %r", e) # Show "thinking" indicator for agent processing thinking_msg = {"role": "assistant", "content": "šŸ¤” Finding tools..."} history.append(thinking_msg) - yield history, state_dict, gr.update(), gr.update(), gr.update(), gr.update() + yield history, state_dict, gr.update(), gr.update(), gr.update(), gr.update(), None, gr.update(visible=False), gr.update() # Call respond function with settings try: @@ -310,6 +330,9 @@ def handle_chat(message: str, history: List[dict], files: List, state_dict: dict # Build text content first text_content = reply.text + # Add stats if available + text_content += format_stats_markdown(reply.stats) + # Add file links if reply.files: text_content += "\n\n" + "\n".join( @@ -346,6 +369,19 @@ def handle_chat(message: str, history: List[dict], files: List, state_dict: dict timeline_chart = create_tool_timeline(state_dict_updated.get("tool_calls", [])) disabled_text = create_disabled_tools_display(state_dict_updated.get("tool_calls", [])) + # Extract downloadable files + downloaded_files = [path for path, _label in reply.files] if reply.files else None + + # Determine button visibility and label using registry + box_visible = new_state.pending_tool_approval is not None + if box_visible and new_state.pending_tool_approval: + from ai_agent.agent.tools.mcp import get_tool_display_name, get_tool_icon + display_name = get_tool_display_name(new_state.pending_tool_approval) + icon = get_tool_icon(new_state.pending_tool_approval) + button_label = f"{icon} Run {display_name}" + else: + button_label = "šŸš€ Run Tool" + yield ( history, state_dict_updated, @@ -353,6 +389,9 @@ def handle_chat(message: str, history: List[dict], files: List, state_dict: dict gr.update(value=timeline_chart), gr.update(value=disabled_text), gr.update(value=state_dict_updated), + downloaded_files, + gr.update(visible=box_visible), # approval_box + gr.update(value=button_label), # approve_tool_btn ) except Exception as e: @@ -367,19 +406,56 @@ def handle_chat(message: str, history: List[dict], files: List, state_dict: dict ), } history.append(error_msg) - yield history, state_dict, gr.update(), gr.update(), gr.update(), gr.update() + yield history, state_dict, gr.update(), gr.update(), gr.update(), gr.update(), None, gr.update(visible=False), gr.update() def clear_chat(): """Reset everything.""" empty_chart = create_tool_usage_chart([]) empty_timeline = create_tool_timeline([]) - return [], {}, empty_chart, empty_timeline, "āœ… No tools disabled", gr.update(value={}) + return [], {}, empty_chart, empty_timeline, "āœ… No tools disabled", gr.update(value={}), None, gr.update(visible=False), gr.update() + + def handle_tool_approval(history: List[dict], state_dict: dict): + """Handle tool approval button click - executes the pending tool.""" + from .handlers import execute_tool_with_approval + from .state import ChatState + + state = ChatState.from_dict(state_dict) + + if not state.pending_tool_approval: + return history, state_dict, None, gr.update(visible=False), gr.update() + + # Execute the tool + reply, new_state = execute_tool_with_approval( + state.pending_tool_approval, + state.pending_tool_params, + state + ) + + # Build response text with stats + text_content = reply.text + text_content += format_stats_markdown(reply.stats) + + # Add text message + history.append({"role": "assistant", "content": text_content}) + + # Add images + for img_path in reply.images: + if os.path.exists(img_path): + history.append({"role": "assistant", "content": {"path": img_path}}) + + # Extract downloadable files + downloaded_files = [path for path, _label in reply.files] if reply.files else None + + # Update state and hide button + state_dict_updated = new_state.to_dict() + + return history, state_dict_updated, downloaded_files, gr.update(visible=False), gr.update() # Wire up events submit_btn.click( handle_chat, inputs=[msg_input, chatbot, file_input, chat_state, model_dropdown, top_k_slider, num_choices_slider], - outputs=[chatbot, chat_state, tool_usage_plot, tool_timeline_plot, disabled_tools_text, state_display], + outputs=[chatbot, chat_state, tool_usage_plot, tool_timeline_plot, disabled_tools_text, state_display, download_files, approval_box, approve_tool_btn], ).then( lambda: ("", None), # Clear inputs inputs=None, @@ -389,17 +465,23 @@ def clear_chat(): msg_input.submit( handle_chat, inputs=[msg_input, chatbot, file_input, chat_state, model_dropdown, top_k_slider, num_choices_slider], - outputs=[chatbot, chat_state, tool_usage_plot, tool_timeline_plot, disabled_tools_text, state_display], + outputs=[chatbot, chat_state, tool_usage_plot, tool_timeline_plot, disabled_tools_text, state_display, download_files, approval_box, approve_tool_btn], ).then( lambda: ("", None), # Clear inputs inputs=None, outputs=[msg_input, file_input], ) + approve_tool_btn.click( + handle_tool_approval, + inputs=[chatbot, chat_state], + outputs=[chatbot, chat_state, download_files, approval_box, approve_tool_btn], + ) + clear_btn.click( clear_chat, inputs=None, - outputs=[chatbot, chat_state, tool_usage_plot, tool_timeline_plot, disabled_tools_text, state_display], + outputs=[chatbot, chat_state, tool_usage_plot, tool_timeline_plot, disabled_tools_text, state_display, download_files, approval_box, approve_tool_btn], ) return demo diff --git a/src/ai_agent/ui/handlers.py b/src/ai_agent/ui/handlers.py index e9aaeb5..4407e65 100644 --- a/src/ai_agent/ui/handlers.py +++ b/src/ai_agent/ui/handlers.py @@ -1,11 +1,22 @@ import logging import os +import time from datetime import datetime from typing import List, Dict, Any, Tuple from pathlib import Path from ai_agent.agent.agent import run_agent from ai_agent.agent.tools.gradio_space_tool import tool_run_example, RunExampleInput +from ai_agent.agent.tools.mcp import ( + get_tool, + get_tool_display_name, + get_tool_icon, + extract_preview, + extract_downloads, + extract_metadata, + extract_output_field, + TOOL_REGISTRY, +) from ai_agent.retriever.software_doc import SoftwareDoc from ai_agent.utils.file_validator import FileValidator from ai_agent.utils.tags import strip_tags, parse_exclusions @@ -19,6 +30,117 @@ log = logging.getLogger("chat_handlers") +def execute_tool_with_approval( + tool_name: str, + tool_params: Dict[str, Any], + state: ChatState, +) -> Tuple[ChatMessage, ChatState]: + """ + Generic tool execution handler - works for ANY registered tool. + + Uses the tool registry to dynamically dispatch to the correct tool + and extract results in a standardized way. No tool-specific code needed! + + Args: + tool_name: Name of the tool to execute + tool_params: Parameters for the tool + state: Current chat state + + Returns: + (ChatMessage with result, updated ChatState) + """ + reply = ChatMessage() + start_time = time.time() + + # Get tool configuration from registry + tool_config = get_tool(tool_name) + if not tool_config: + log.error(f"Unknown tool: {tool_name}") + reply.text = f"āŒ Error: Unknown tool '{tool_name}'" + state.pending_tool_approval = None + state.pending_tool_params = {} + return reply, state + + log.info(f"Executing {tool_name} tool with params: {tool_params}") + reply.text = f"{tool_config.icon} Running {tool_config.display_name}...\n\n" + + try: + # Augment params with state data if needed (e.g., image_path from last upload) + if "image_path" in tool_params and not tool_params["image_path"]: + if state.last_files: + tool_params["image_path"] = state.last_files[0] + + # Build input object dynamically using the tool's input model + input_obj = tool_config.input_model(**tool_params) + + # Execute the tool + result = tool_config.executor(input_obj) + + compute_time = time.time() - start_time + + # Extract standard fields using registry configuration + success = extract_output_field(result, tool_config.success_field) + error = extract_output_field(result, tool_config.error_field) + compute_time_seconds = extract_output_field(result, tool_config.compute_time_field) or 0.0 + notes = extract_output_field(result, tool_config.notes_field) + + # Track execution in state (generic) + state.tool_calls.append({ + "tool": tool_name, + "success": success, + "compute_time_seconds": compute_time_seconds, + "error": error, + "timestamp": datetime.now().isoformat(), + **tool_params # Store all params for debugging + }) + + # Add stats to reply + reply.stats = { + "compute_time": compute_time_seconds, + "total_time": compute_time, + } + + if success: + reply.text += f"āœ… {tool_config.display_name} completed!\n\n" + + # Extract and add preview image (generic) + preview_path = extract_preview(result, tool_name) + if preview_path and os.path.exists(preview_path): + reply.images.append(preview_path) + + # Extract and add downloadable files (generic) + download_paths = extract_downloads(result, tool_name) + for download_path in download_paths: + if os.path.exists(download_path): + reply.files.append((download_path, f"Download {tool_config.display_name} result")) + + # Add metadata if available + metadata = extract_metadata(result, tool_name) + if metadata: + reply.text += f"_{metadata}_\n\n" + + # Add notes if available + if notes: + reply.text += f"_{notes}_\n\n" + else: + reply.text += f"āŒ {tool_config.display_name} failed.\n\n" + if error: + reply.text += f"**Error:** {error}\n\n" + + except Exception as e: + log.exception(f"Tool {tool_name} execution failed") + reply.text += f"āŒ Error: {e}\n\n" + compute_time = time.time() - start_time + reply.stats = {"total_time": compute_time} + + # Clear pending approval + state.pending_tool_approval = None + state.pending_tool_params = {} + state.conversation_history.append(f"Assistant: {reply.text}") + + return reply, state + + def respond( message: str, files: List[Any], @@ -99,8 +221,10 @@ def respond( reply.text += f"āœ… Demo completed!\n\n" reply.images.append(preview_path) + # Add original result file for download if available if demo_result.result_origin: reply.files.append((demo_result.result_origin, "Download result")) + else: note = demo_result.notes or "No output image returned" reply.text += f"ā„¹ļø Demo ran but {note}" @@ -267,6 +391,17 @@ def respond( result_dict = agent_result.to_legacy_dict() + # Extract usage stats if available + usage_info = result_dict.get("usage") + if usage_info: + reply.stats = { + "tokens": { + "total": usage_info.get("total_tokens", 0), + "input": usage_info.get("input_tokens", 0), + "output": usage_info.get("output_tokens", 0), + } + } + # Record tool calls if "tool_calls" in result_dict: state.tool_calls.extend(result_dict["tool_calls"]) @@ -321,14 +456,30 @@ def respond( else: reply.text += f"**{i}. {tool_name}** — {accuracy:.1f}%\n\n_{why}_\n\n" - # Offer demo for top tool - demo_url = top_tool.get("demo_link", "") - if demo_url: + # Check if top tool is registered in registry and requires approval + tool_config = get_tool(top_tool["name"]) + demo_url = top_tool.get("demo_link") or "" + + if tool_config and tool_config.requires_approval: + # Tool is registered and requires approval - use registry-based execution + image_path = effective_paths[0] if effective_paths else None + state.pending_tool_approval = tool_config.name + state.pending_tool_params = { + "image_path": image_path, + "description": f"Recommended by agent: {top_tool.get('why', '')}", + } + reply.text += f"\nšŸš€ **Ready to run {tool_config.display_name}?**\n\n" + reply.text += f"šŸ“ **Image:** {os.path.basename(image_path) if image_path else 'Unknown'}\n" + if demo_url: + reply.text += f"šŸ”— **Endpoint:** {demo_url}\n\n" + reply.text += f"_Press the **'{tool_config.icon} Run Tool'** button below, or ask about other tools in the chat instead._" + elif demo_url: + # Tool has demo but not registered - use generic demo flow state.pending_demo_tool = top_tool["name"] state.pending_demo_url = demo_url reply.text += f"\nšŸ’” **Would you like me to run the demo for {top_tool['name']}?**\n" reply.text += f"šŸ”— Demo: {demo_url}\n\n" - reply.text += "_Reply 'yes' to run the demo, or continue with another request._" + reply.text += "_Press the **'šŸš€ Run Demo'** button to run the demo, or continue with another request._" else: # No suitable tools reason = result_dict.get("reason", "") diff --git a/src/ai_agent/ui/state.py b/src/ai_agent/ui/state.py index 017aada..73ad653 100644 --- a/src/ai_agent/ui/state.py +++ b/src/ai_agent/ui/state.py @@ -6,6 +6,34 @@ from dataclasses import dataclass, field +def format_stats_markdown(stats: Dict[str, Any]) -> str: + """ + Format performance stats as markdown. + + Args: + stats: Dictionary containing performance metrics + + Returns: + Formatted markdown string with stats, or empty string if no stats + """ + if not stats: + return "" + + parts = ["\n---\n**šŸ“Š Performance Stats:**\n"] + + if "compute_time" in stats: + parts.append(f"ā±ļø Compute time: {stats['compute_time']:.2f}s\n") + + if "total_time" in stats: + parts.append(f"ā±ļø Total time: {stats['total_time']:.2f}s\n") + + if "tokens" in stats: + tok = stats["tokens"] + parts.append(f"šŸŽ« Tokens: {tok.get('total', 0)} (in: {tok.get('input', 0)}, out: {tok.get('output', 0)})\n") + + return "".join(parts) + + @dataclass class ChatState: """Encapsulates all conversation state for the agent.""" @@ -19,6 +47,11 @@ class ChatState: last_files: List[str] = field(default_factory=list) last_image_meta: Optional[str] = None + # Tool approval system + pending_tool_approval: Optional[str] = None # Tool name waiting for approval + pending_tool_params: Dict[str, Any] = field(default_factory=dict) # Tool parameters + agent_result: Optional[Dict[str, Any]] = None # Cached agent result before tool execution + def to_dict(self) -> dict: """Serialize state for Gradio State component.""" return { @@ -31,6 +64,9 @@ def to_dict(self) -> dict: "last_preview_path": self.last_preview_path, "last_files": self.last_files, "last_image_meta": self.last_image_meta, + "pending_tool_approval": self.pending_tool_approval, + "pending_tool_params": self.pending_tool_params, + "agent_result": self.agent_result, } @staticmethod @@ -48,6 +84,9 @@ def from_dict(d: dict) -> 'ChatState': last_preview_path=d.get("last_preview_path"), last_files=d.get("last_files", []), last_image_meta=d.get("last_image_meta"), + pending_tool_approval=d.get("pending_tool_approval"), + pending_tool_params=d.get("pending_tool_params", {}), + agent_result=d.get("agent_result"), ) @@ -60,6 +99,7 @@ class ChatMessage: json_data: Optional[Dict[str, Any]] = None code_blocks: List[Tuple[str, str]] = field(default_factory=list) # (lang, code) tool_traces: List[Dict[str, Any]] = field(default_factory=list) + stats: Optional[Dict[str, Any]] = None # Performance stats (time, tokens, etc.) def to_markdown(self) -> str: """Convert message to markdown with media.""" @@ -68,6 +108,11 @@ def to_markdown(self) -> str: if self.text: parts.append(self.text) + # Render stats if available + stats_md = format_stats_markdown(self.stats) + if stats_md: + parts.append(stats_md) + # Render file links for file_path, label in self.files: if os.path.exists(file_path): diff --git a/src/ai_agent/utils/temp_file_manager.py b/src/ai_agent/utils/temp_file_manager.py new file mode 100644 index 0000000..22f2d16 --- /dev/null +++ b/src/ai_agent/utils/temp_file_manager.py @@ -0,0 +1,69 @@ +""" +Centralized temporary file management with automatic cleanup on shutdown. +""" +from __future__ import annotations + +import os +import logging +import atexit +import threading +from typing import Optional + +log = logging.getLogger("utils.temp_file_manager") + +# Global registry of all temporary files across all tools +_temp_files: list[str] = [] +_cleanup_registered = False +_lock = threading.Lock() + + +def register_temp_file(path: Optional[str]) -> Optional[str]: + """ + Register a temporary file for cleanup on shutdown. + Thread-safe for multi-user Gradio deployments. + + Args: + path: Path to temporary file + + Returns: + The same path (pass-through for convenience) + """ + global _cleanup_registered + + if not path: + return path + + with _lock: + if path not in _temp_files: + _temp_files.append(path) + + # Register cleanup on first use + if not _cleanup_registered: + atexit.register(cleanup_temp_files) + _cleanup_registered = True + + return path + + +def cleanup_temp_files() -> None: + """Clean up all registered temporary files. Thread-safe.""" + with _lock: + if not _temp_files: + return + + log.info(f"Cleaning up {len(_temp_files)} temporary file(s)") + + for path in _temp_files: + try: + if os.path.exists(path): + os.remove(path) + log.debug(f"Cleaned up temporary file: {path}") + except Exception as e: + log.warning(f"Failed to clean up {path}: {e}") + + _temp_files.clear() + + +def get_temp_file_count() -> int: + """Get the number of registered temporary files.""" + return len(_temp_files)