diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f0d5c4c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,89 @@ +[project] +name = "kubeflow-mcp" +version = "0.1.0-dev" +requires-python = ">=3.10" +description = "Model Context Protocol server for AI-assisted development with Kubeflow" +readme = "README.md" +license = {text = "Apache-2.0"} +authors = [{name = "Kubeflow Authors"}] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +dependencies = [ + "fastmcp>=2.0.0", + "pydantic>=2.0", + "pyyaml>=6.0", + "click>=8.0", +] + +[project.optional-dependencies] +trainer = ["kubeflow>=0.4.0"] +optimizer = ["kubeflow>=0.4.0"] +hub = ["kubeflow>=0.4.0", "model-registry>=0.3.6"] +agents = [ + "llama-index-core>=0.12", + "llama-index-llms-ollama>=0.4", + "rich>=13.0", +] +all = ["kubeflow>=0.4.0"] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.23", + "pytest-cov>=4.0", + "pytest-benchmark>=4.0", + "ruff>=0.9", + "mypy>=1.10", +] +docs = [ + "sphinx>=7.0", + "furo>=2024.1.29", + "sphinx-copybutton>=0.5", + "sphinx-design>=0.5", +] + +[project.scripts] +kubeflow-mcp = "kubeflow_mcp.cli:cli" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/kubeflow_mcp"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +addopts = "-v --tb=short" + +[tool.ruff] +target-version = "py310" +line-length = 100 +src = ["src", "tests"] + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "B", # flake8-bugbear + "UP", # pyupgrade +] +ignore = [ + "E501", # line too long (handled by formatter) +] + +[tool.ruff.lint.isort] +known-first-party = ["kubeflow_mcp"] + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_configs = true +ignore_missing_imports = true diff --git a/src/kubeflow_mcp/agents/__init__.py b/src/kubeflow_mcp/agents/__init__.py new file mode 100644 index 0000000..f2e5302 --- /dev/null +++ b/src/kubeflow_mcp/agents/__init__.py @@ -0,0 +1 @@ +"""Sample agents for kubeflow-mcp.""" diff --git a/src/kubeflow_mcp/agents/dynamic_tools.py b/src/kubeflow_mcp/agents/dynamic_tools.py new file mode 100644 index 0000000..979ac74 --- /dev/null +++ b/src/kubeflow_mcp/agents/dynamic_tools.py @@ -0,0 +1,500 @@ +# Copyright 2026 The Kubeflow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dynamic toolsets for token-efficient tool discovery. + +Implements two approaches from https://www.speakeasy.com/blog/100x-token-reduction-dynamic-toolsets: +1. Progressive Search - Hierarchical discovery with prefix-based lookup +2. Semantic Search - Embeddings-based natural language discovery + +Usage: + # Progressive mode (~3K initial tokens) + agent = OllamaAgent(model="qwen2.5:7b", tool_mode="progressive") + + # Semantic mode (~2K initial tokens, requires sentence-transformers) + agent = OllamaAgent(model="qwen2.5:7b", tool_mode="semantic") + + # Full mode (all tools, ~200 tokens with compact descriptions) - default + agent = OllamaAgent(model="qwen2.5:7b", tool_mode="full") +""" + +import inspect +import warnings +from collections.abc import Callable +from typing import Any + +try: + from kubeflow_mcp.trainer import TOOLS +except ImportError: + TOOLS = [] # type: ignore[assignment] # trainer API not available (skeleton branch) + +try: + from kubeflow_mcp.common.constants import TOOL_TO_PHASE +except ImportError: + TOOL_TO_PHASE: dict[str, str] = {} # type: ignore[assignment] + +try: + from kubeflow_mcp.core.server import TOOL_DESCRIPTIONS +except ImportError: + TOOL_DESCRIPTIONS: dict[str, str] = {} # type: ignore[assignment] + +# Build tool registry driven by TOOL_TO_PHASE and TOOL_DESCRIPTIONS from constants/server +# so adding a new tool only requires updating those two central maps. +TOOL_REGISTRY: dict[str, dict[str, Any]] = {} +TOOL_HIERARCHY: dict[str, list[str]] = { + "planning": [], + "training": [], + "discovery": [], + "monitoring": [], + "lifecycle": [], +} + +for _tool_func in TOOLS: + _name = _tool_func.__name__ + _doc = _tool_func.__doc__ or "" + _category = TOOL_TO_PHASE.get(_name, "other") + _short_desc = TOOL_DESCRIPTIONS.get(_name, _doc.split("\n")[0] if _doc else _name) + + TOOL_REGISTRY[_name] = { + "name": _name, + "category": _category, + "description": _short_desc, + "full_doc": _doc, + "func": _tool_func, + } + TOOL_HIERARCHY.setdefault(_category, []).append(_name) + + +# ============================================================================= +# Progressive Search Implementation +# ============================================================================= + + +def list_tools(prefix: str = "") -> dict[str, Any]: + """List available tools by category or prefix. + + Use this to discover what tools are available. Start with no prefix to see + categories, then drill down with specific prefixes. + + Args: + prefix: Filter prefix. Examples: + - "" → List all categories + - "planning" → List planning tools + - "training" → List training tools + - "discovery" → List discovery tools + + Returns: + {categories: [...], tools: [...]} based on prefix + + Example workflow: + 1. list_tools() → See categories: planning, training, discovery, monitoring, lifecycle + 2. list_tools("training") → See: fine_tune, run_custom_training, run_container_training + 3. describe_tools(["fine_tune"]) → Get full schema for fine_tune + 4. execute_tool("fine_tune", {model: "...", dataset: "..."}) + """ + if not prefix: + return { + "categories": list(TOOL_HIERARCHY.keys()), + "category_tools": {cat: len(tools) for cat, tools in TOOL_HIERARCHY.items()}, + "hint": "Use list_tools('category_name') to see tools in a category", + } + + if prefix in TOOL_HIERARCHY: + tools = TOOL_HIERARCHY[prefix] + return { + "category": prefix, + "tools": [{"name": t, "description": TOOL_REGISTRY[t]["description"]} for t in tools], + "hint": "Use describe_tools(['tool_name']) to get full schema", + } + + matching = [ + {"name": name, "description": info["description"]} + for name, info in TOOL_REGISTRY.items() + if name.startswith(prefix) or prefix in name + ] + return { + "prefix": prefix, + "matching_tools": matching, + "hint": "Use describe_tools(['tool_name']) to get full schema", + } + + +def describe_tools(tool_names: list[str]) -> dict[str, Any]: + """Get detailed schema for specific tools. + + Call this after list_tools() to get full parameter information before executing. + + Args: + tool_names: List of tool names to describe (max 5 at a time) + + Returns: + {tools: [{name, description, parameters, returns}]} + """ + if len(tool_names) > 5: + return {"error": "Max 5 tools at a time to conserve tokens"} + + results: list[dict[str, Any]] = [] + for name in tool_names: + if name not in TOOL_REGISTRY: + results.append({"name": name, "error": "Tool not found"}) + continue + + tool = TOOL_REGISTRY[name] + sig = inspect.signature(tool["func"]) + params: dict[str, Any] = {} + for param_name, param in sig.parameters.items(): + param_info: dict[str, Any] = {"type": "any"} + if param.annotation != inspect.Parameter.empty: + param_info["type"] = str(param.annotation) + if param.default != inspect.Parameter.empty: + param_info["default"] = param.default + params[param_name] = param_info + + results.append( + { + "name": name, + "category": tool["category"], + "description": tool["full_doc"], + "parameters": params, + } + ) + + return {"tools": results} + + +def _format_friendly_error(result: dict[str, Any]) -> dict[str, Any]: + """Convert technical errors to user-friendly messages. + + Checks both the top-level error string and the details dict (which + contains the exception cause chain from exception_details()). + """ + if result.get("success") is not False: + return result + + error = result.get("error", "") + error_code = result.get("error_code", "") + # SDK wraps K8s HTTP errors; the cause chain is in details + details = result.get("details") or {} + detail_str = " ".join(str(v) for v in details.values()) + combined = f"{error} {detail_str}" + + if "401" in combined or "Unauthorized" in combined: + result["friendly_error"] = "Not authorized to access the cluster. Check your kubeconfig." + result["hint"] = "Run: kubectl config current-context && kubectl auth can-i list trainjobs" + elif "403" in combined or "Forbidden" in combined: + result["friendly_error"] = "Permission denied. Your account lacks RBAC access." + result["hint"] = "Check RBAC: kubectl auth can-i list trainjobs -n " + elif "404" in combined or "not found" in combined.lower(): + result["friendly_error"] = "Resource not found." + elif "Connection refused" in combined or "connection refused" in combined.lower(): + result["friendly_error"] = "Cannot connect to Kubernetes cluster." + result["hint"] = "Is the cluster running? Check: kubectl cluster-info" + elif "timeout" in combined.lower(): + result["friendly_error"] = "Request timed out. The cluster may be slow or unreachable." + elif error_code == "SDK_ERROR" and "HuggingFace" in combined: + result["friendly_error"] = "Could not fetch model info from HuggingFace." + result["hint"] = "Check the model ID format (e.g., 'meta-llama/Llama-3.2-1B')" + + return result + + +def execute_tool(tool_name: str, arguments: dict[str, Any] | None = None) -> dict[str, Any]: + """Execute a discovered tool. + + Call this after using list_tools() and describe_tools() to run the actual tool. + + Args: + tool_name: Name of the tool to execute + arguments: Tool arguments as key-value pairs + + Returns: + Tool execution result + """ + if tool_name not in TOOL_REGISTRY: + return {"error": f"Tool '{tool_name}' not found", "available": list(TOOL_REGISTRY.keys())} + + func = TOOL_REGISTRY[tool_name]["func"] + args = arguments or {} + + try: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=Warning, module="urllib3") + result = func(**args) + if isinstance(result, dict): + return _format_friendly_error(result) + return {"result": result} + except Exception as e: + return {"error": str(e), "tool": tool_name, "arguments": args} + + +# Progressive search meta-tools (3 tools instead of 16) +PROGRESSIVE_TOOLS = [list_tools, describe_tools, execute_tool] + + +# ============================================================================= +# Semantic Search Implementation +# ============================================================================= + +# Pre-computed tool descriptions for embedding +TOOL_DESCRIPTIONS_FOR_EMBEDDING = { + name: f"{info['description']}. Category: {info['category']}. {info['full_doc'][:200]}" + for name, info in TOOL_REGISTRY.items() +} + + +class _EmbeddingCache: + """Lazy-loaded embedding cache. Holds model + per-tool vectors. + + Centralising state here (vs module globals) makes cache.reset() safe + to call from tests without touching module-level names. + """ + + def __init__(self): + self._embeddings: dict[str, list[float]] | None = None + self._model = None + + def get(self) -> tuple[dict[str, list[float]] | None, Any]: + if self._embeddings is not None: + return self._embeddings, self._model + + try: + from sentence_transformers import SentenceTransformer + + self._model = SentenceTransformer("all-MiniLM-L6-v2") + descriptions = list(TOOL_DESCRIPTIONS_FOR_EMBEDDING.values()) + embeddings = self._model.encode(descriptions) + self._embeddings = { + name: emb.tolist() + for name, emb in zip( + TOOL_DESCRIPTIONS_FOR_EMBEDDING.keys(), embeddings, strict=True + ) + } + return self._embeddings, self._model + except ImportError: + return None, None + + def reset(self) -> None: + self._embeddings = None + self._model = None + + +_embedding_cache = _EmbeddingCache() + + +def find_tools(query: str, top_k: int = 5) -> dict[str, Any]: + """Find relevant tools using semantic search. + + Describe what you want to accomplish in natural language, and this will + return the most relevant tools. + + Args: + query: Natural language description. Examples: + - "all" - LIST ALL 16 AVAILABLE KUBEFLOW TOOLS (use when user asks what tools exist) + - "check GPU availability in the cluster" + - "fine-tune a language model" + - "view logs from a training job" + - "delete a failed job" + top_k: Number of results to return (default 5, ignored when query="all") + + Returns: + {tools: [{name, description, category}], hint: "Use execute_tool(name, args)"} + """ + query_lower = query.strip().lower() + _list_all = { + "*", + "all", + "list", + "list all", + "all tools", + "available tools", + "what tools", + "show tools", + "show all", + "every tool", + "everything", + "available", + "what's available", + "whats available", + } + if query_lower in _list_all or "all tool" in query_lower or "available tool" in query_lower: + return { + "query": query, + "total": len(TOOL_REGISTRY), + "tools": [ + {"name": name, "description": info["description"], "category": info["category"]} + for name, info in TOOL_REGISTRY.items() + ], + "hint": "Use execute_tool(tool_name, {args}) to run a tool", + } + + embeddings, model = _embedding_cache.get() + + if embeddings is None: + return _keyword_search(query, top_k) + + import numpy as np + + query_embedding = model.encode([query])[0] + scores = { + name: float( + np.dot(query_embedding, tool_emb) + / (np.linalg.norm(query_embedding) * np.linalg.norm(tool_emb)) + ) + for name, tool_emb in embeddings.items() + } + sorted_tools = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k] + + return { + "query": query, + "tools": [ + { + "name": name, + "description": TOOL_REGISTRY[name]["description"], + "category": TOOL_REGISTRY[name]["category"], + "relevance": f"{score:.2f}", + } + for name, score in sorted_tools + ], + "hint": "Use execute_tool(tool_name, {args}) to run a tool", + } + + +def _keyword_search(query: str, top_k: int = 5) -> dict[str, Any]: + """Fallback keyword search when embeddings unavailable.""" + query_lower = query.lower() + keywords = query_lower.split() + + scores = {} + for name, info in TOOL_REGISTRY.items(): + text = f"{name} {info['description']} {info['category']}".lower() + score = sum(1 for kw in keywords if kw in text) + if score > 0: + scores[name] = score + + sorted_tools = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k] + return { + "query": query, + "mode": "keyword_fallback", + "tools": [ + { + "name": name, + "description": TOOL_REGISTRY[name]["description"], + "category": TOOL_REGISTRY[name]["category"], + } + for name, _ in sorted_tools + ], + "hint": "Use execute_tool(tool_name, {args}) to run a tool", + } + + +# Semantic search meta-tools (2 tools instead of 16) +SEMANTIC_TOOLS = [find_tools, execute_tool] + + +# ============================================================================= +# Factory Functions +# ============================================================================= + +# Pre-built schema snippet for the 5 most-common tools, injected into the +# progressive system prompt so the agent skips list_tools→describe_tools for +# the happy path (saves 2 LLM round-trips on ~80% of real queries). +_COMMON_TOOL_HINTS = "\n".join( + f" {name}: {TOOL_REGISTRY[name]['description']}" + for name in [ + "get_cluster_resources", + "estimate_resources", + "list_runtimes", + "list_training_jobs", + "fine_tune", + "get_training_logs", + ] + if name in TOOL_REGISTRY +) + +_PROGRESSIVE_SYSTEM_PROMPT = f"""You are a Kubeflow training assistant. Help users manage ML training jobs on Kubernetes. + +When greeted, introduce yourself briefly and offer these options: +- Check cluster resources (GPUs, nodes) +- Fine-tune a model (e.g., Llama, Gemma) +- List training jobs or runtimes +- Monitor a running job + +Common tools you can call directly via execute_tool: +{_COMMON_TOOL_HINTS} + +Categories for discovery: + planning → get_cluster_resources, estimate_resources + training → fine_tune, run_custom_training, run_container_training + discovery → list_runtimes, get_runtime, list_training_jobs, get_training_job + monitoring → get_training_logs, get_training_events, wait_for_training + lifecycle → delete_training_job, suspend_training_job, resume_training_job + +For less-common tasks: list_tools("category") → describe_tools(["tool_name"]) → execute_tool("tool_name", {{args}}) + +When the user asks to train or fine-tune: +1. execute_tool("get_cluster_resources") → check GPUs +2. execute_tool("estimate_resources", {{"model": "google/gemma-2b"}}) → check memory +3. execute_tool("list_runtimes") → check available runtimes +4. execute_tool("fine_tune", {{..., "confirmed": false}}) → show preview +5. Wait for user confirmation, then resubmit with confirmed=true + +Use hf:// prefix for model/dataset URIs. If errors occur, explain them clearly. +""" + +_SEMANTIC_SYSTEM_PROMPT = f"""You are a Kubeflow training assistant. Help users manage ML training jobs on Kubernetes. + +When greeted, introduce yourself briefly and offer these options: +- Check cluster resources (GPUs, nodes) +- Fine-tune a model (e.g., Llama, Gemma) +- List training jobs or runtimes +- Monitor a running job + +Common tools you can call directly via execute_tool: +{_COMMON_TOOL_HINTS} + +IMPORTANT: When user asks "what tools are available", call find_tools("all"). + +Categories: planning (resources), training (fine_tune, custom, container), discovery (list_runtimes, list_training_jobs, get_runtime), monitoring (logs, events), lifecycle (delete, suspend, resume). + +For other tasks, use find_tools("natural language query") to discover tools, then execute_tool(). + +When the user asks to train or fine-tune: +1. execute_tool("get_cluster_resources") → check GPUs +2. execute_tool("estimate_resources", {{"model": "google/gemma-2b"}}) → check memory +3. execute_tool("fine_tune", {{..., "confirmed": false}}) → show preview +4. Wait for user confirmation, then resubmit with confirmed=true + +Use hf:// prefix for model/dataset URIs. If errors occur, explain them clearly. +""" + + +def get_dynamic_tools(mode: str = "progressive") -> list[Callable[..., Any]]: + """Get meta-tools for dynamic discovery. + + Args: + mode: "progressive" or "semantic" + + Returns: + List of meta-tool functions + """ + if mode == "semantic": + return SEMANTIC_TOOLS # type: ignore[return-value] + return PROGRESSIVE_TOOLS # type: ignore[return-value] + + +def get_dynamic_system_prompt(mode: str = "progressive") -> str: + """Get system prompt for dynamic tool mode.""" + if mode == "semantic": + return _SEMANTIC_SYSTEM_PROMPT + return _PROGRESSIVE_SYSTEM_PROMPT diff --git a/src/kubeflow_mcp/agents/ollama.py b/src/kubeflow_mcp/agents/ollama.py new file mode 100644 index 0000000..05d65ce --- /dev/null +++ b/src/kubeflow_mcp/agents/ollama.py @@ -0,0 +1,870 @@ +# Copyright 2026 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Ollama agent using LlamaIndex FunctionAgent with native tool calling. + +Requires optional dependencies: + uv sync --extra agents + pip install kubeflow-mcp[agents] + +Usage: + ollama serve + uv run python -m kubeflow_mcp.agents.ollama + uv run python -m kubeflow_mcp.agents.ollama --model qwen2.5:7b +""" + +import io +import json +import logging +import re +import sys +from contextlib import redirect_stderr +from typing import Any + +# Suppress noisy loggers +logging.getLogger("httpx").setLevel(logging.ERROR) +logging.getLogger("llama_index").setLevel(logging.ERROR) + +# Check if being imported by Sphinx for documentation +_SPHINX_BUILD = "sphinx" in sys.modules + +try: + from llama_index.core.agent.workflow import FunctionAgent + from llama_index.core.memory import ChatMemoryBuffer + from llama_index.core.tools import FunctionTool + from llama_index.llms.ollama import Ollama + from rich.console import Console + from rich.markdown import Markdown + from rich.panel import Panel + from rich.table import Table + from rich.text import Text +except ImportError: + if not _SPHINX_BUILD: + sys.exit("Error: required packages not installed\nRun: uv sync --extra agents") + # Allow import to continue for autodoc even without dependencies + FunctionAgent = None # type: ignore[misc, assignment] + ChatMemoryBuffer = None # type: ignore[misc, assignment] + FunctionTool = None # type: ignore[misc, assignment] + Ollama = None # type: ignore[misc, assignment] + Console = None # type: ignore[misc, assignment] + Markdown = None # type: ignore[misc, assignment] + Panel = None # type: ignore[misc, assignment] + Table = None # type: ignore[misc, assignment] + Text = None # type: ignore[misc, assignment] + +from kubeflow_mcp.agents.dynamic_tools import ( # noqa: E402 + PROGRESSIVE_TOOLS, + SEMANTIC_TOOLS, + get_dynamic_system_prompt, + get_dynamic_tools, +) + +try: + from kubeflow_mcp.core.server import SERVER_INSTRUCTIONS, TOOL_DESCRIPTIONS # noqa: E402 +except ImportError: + SERVER_INSTRUCTIONS = "You are a Kubeflow training assistant." + TOOL_DESCRIPTIONS: dict[str, str] = {} # type: ignore[assignment] +try: + from kubeflow_mcp.trainer import TOOLS # noqa: E402 +except ImportError: + TOOLS = [] # type: ignore[assignment] # trainer API not available (skeleton branch) + +console = Console() + +# Agent configuration defaults +DEFAULT_MODEL = "qwen3:8b" +DEFAULT_URL = "http://localhost:11434" +DEFAULT_REQUEST_TIMEOUT = 180.0 # LLM request timeout in seconds +DEFAULT_MEMORY_TOKEN_LIMIT = 16000 # Chat memory token limit (qwen3:8b has 32K context) + +# Tool modes - counts computed dynamically from actual registries +_NUM_TOOLS = len(TOOLS) +_NUM_PROGRESSIVE = len(PROGRESSIVE_TOOLS) +_NUM_SEMANTIC = len(SEMANTIC_TOOLS) + +# User-facing tool modes +# "full" uses in-process tools (efficient for local agent) +# "progressive" and "semantic" reduce token usage via meta-tools +TOOL_MODES = { + "full": f"All {_NUM_TOOLS} tools loaded", + "progressive": f"{_NUM_PROGRESSIVE} meta-tools with hierarchical discovery", + "semantic": f"{_NUM_SEMANTIC} meta-tools with embedding search", +} + +# Legacy aliases for backward compatibility +_MODE_ALIASES = {"static": "full", "mcp": "full"} + +# Agent-specific additions to server instructions +AGENT_HINTS = """ +AGENT-SPECIFIC: +- When greeted, introduce yourself briefly and offer to help with training tasks +- Model ID formats: estimate_resources() uses "google/gemma-2b", fine_tune() uses "hf://google/gemma-2b" +- Execute planning steps (1-4) together, only pause after showing the preview +- If no GPUs (gpu_total=0), suggest CPU training or inform user +""" + +# System prompt combining server instructions with agent-specific hints +SYSTEM_PROMPT = SERVER_INSTRUCTIONS + AGENT_HINTS + + +def _create_tools(mode: str = "full") -> list[FunctionTool]: + """Create LlamaIndex FunctionTools for the given mode. + + Uses compact TOOL_DESCRIPTIONS from server.py for full mode (~200 tokens) + instead of raw docstrings (~5K tokens). + + Args: + mode: "full" | "progressive" | "semantic" + + Returns: + List of FunctionTool objects + """ + if mode in ("progressive", "semantic"): + tool_funcs = get_dynamic_tools(mode) + else: + tool_funcs = TOOLS # type: ignore[assignment] + + tools = [] + for tool_func in tool_funcs: + doc = tool_func.__doc__ or "" + desc = TOOL_DESCRIPTIONS.get( + tool_func.__name__, doc.split("\n")[0] if doc else tool_func.__name__ + ) + tools.append( + FunctionTool.from_defaults( + fn=tool_func, + name=tool_func.__name__, + description=desc, + ) + ) + return tools + + +def _format_tool_result(result: Any, max_lines: int = 15) -> str: + """Format tool result for display, truncating if needed.""" + if isinstance(result, dict): + formatted = json.dumps(result, indent=2, default=str) + else: + formatted = str(result) + + lines = formatted.split("\n") + if len(lines) > max_lines: + return "\n".join(lines[:max_lines]) + f"\n... ({len(lines) - max_lines} more lines)" + return formatted + + +class OllamaAgent: + """Ollama agent using LlamaIndex FunctionAgent with thinking support. + + Supports multiple tool modes for different context budgets: + - "full": All tools loaded (~200 tokens) - default, best accuracy + - "progressive": 3 meta-tools (~85 tokens) - hierarchical discovery + - "semantic": 2 meta-tools (~69 tokens) - embedding-based discovery + """ + + _agent: "FunctionAgent | None" + _tools: "list[FunctionTool] | None" + _thinking_supported: bool | None + _thinking_notified: bool + memory: "ChatMemoryBuffer | None" + llm: "Ollama | None" + + def __init__( + self, + model: str = DEFAULT_MODEL, + base_url: str = DEFAULT_URL, + tool_mode: str = "full", + ): + self.model = model + self.base_url = base_url + # Resolve legacy aliases (static, mcp -> full) + self.tool_mode = _MODE_ALIASES.get(tool_mode, tool_mode) + self._agent = None + self._tools = None + self._thinking_supported = None # None = unknown, True/False = tested + self._thinking_notified = False + self._use_thinking = True + self._awaiting_confirmation = False # True after agent shows a preview + self.memory = None + self.llm = None + + # Set system prompt based on mode + if tool_mode in ("progressive", "semantic"): + self._system_prompt = get_dynamic_system_prompt(tool_mode) + else: + # For static and mcp modes, use full prompt (mcp may override) + self._system_prompt = SYSTEM_PROMPT + + # Dedicated event loop in background thread (prevents "Event loop is closed" errors) + import asyncio + import threading + + self._loop = asyncio.new_event_loop() + self._loop_thread = threading.Thread(target=self._loop.run_forever, daemon=True) + self._loop_thread.start() + + def _create_llm(self, with_thinking: bool) -> "Ollama": + """Create Ollama LLM with or without thinking mode.""" + return Ollama( + model=self.model, + base_url=self.base_url, + request_timeout=DEFAULT_REQUEST_TIMEOUT, + is_function_calling_model=True, + thinking=with_thinking, + ) + + def _ensure_agent(self, with_thinking: bool | None = None): + """Lazy initialization of agent.""" + if self._agent is not None: + return + + if with_thinking is None: + with_thinking = self._use_thinking + + with redirect_stderr(io.StringIO()): + self._tools = _create_tools(mode=self.tool_mode) + self.llm = self._create_llm(with_thinking) + self.memory = ChatMemoryBuffer.from_defaults(token_limit=DEFAULT_MEMORY_TOKEN_LIMIT) + self._agent = FunctionAgent( + tools=self._tools, + llm=self.llm, + memory=self.memory, + system_prompt=self._system_prompt, + ) + + def set_thinking_mode(self, enabled: bool): + """Toggle thinking mode - recreates LLM but preserves memory.""" + if self._use_thinking == enabled: + return + + self._use_thinking = enabled + + if self._agent is not None: + with redirect_stderr(io.StringIO()): + use_thinking = enabled and (self._thinking_supported is not False) + self.llm = self._create_llm(use_thinking) + self._agent = FunctionAgent( + tools=self._tools, # type: ignore[arg-type] + llm=self.llm, + memory=self.memory, + system_prompt=self._system_prompt, + ) + + def set_mode(self, mode: str) -> int: + """Switch tool mode at runtime. Returns number of tools loaded.""" + # Handle legacy aliases (static, mcp -> full) + resolved_mode = _MODE_ALIASES.get(mode, mode) + + if resolved_mode not in TOOL_MODES: + raise ValueError(f"Unknown mode: {mode}. Choose from: {list(TOOL_MODES.keys())}") + + self.tool_mode = resolved_mode + + # Update system prompt based on mode + if resolved_mode in ("progressive", "semantic"): + self._system_prompt = get_dynamic_system_prompt(resolved_mode) + else: + self._system_prompt = SYSTEM_PROMPT + + # Force agent recreation with new tools + self._agent = None + self._tools = None + self._ensure_agent( + with_thinking=self._use_thinking and self._thinking_supported is not False + ) + + return len(self._tools) if self._tools else 0 + + async def _chat_async( + self, + message: str, + on_thinking=None, + on_tool_call=None, + on_tool_result=None, + ) -> tuple[str, list[dict]]: + """Async chat implementation with thinking support.""" + from llama_index.core.agent.workflow.workflow_events import ( + AgentOutput, + AgentStream, + ToolCallResult, + ) + + # Initialize with thinking if not yet tested + if self._thinking_supported is None: + self._ensure_agent(with_thinking=True) + else: + self._ensure_agent(with_thinking=self._thinking_supported and self._use_thinking) + + tool_calls = [] + seen_tools = set() + + try: + assert self._agent is not None + handler = self._agent.run(user_msg=message, memory=self.memory) + + async for event in handler.stream_events(): + if isinstance(event, AgentStream): + # Stream thinking output (attribute may not exist in all SDK versions) + thinking_delta = getattr(event, "thinking_delta", None) + if thinking_delta and on_thinking: + on_thinking(thinking_delta) + + # Collect tool calls + if event.tool_calls: + for tc in event.tool_calls: + key = f"{tc.tool_name}:{json.dumps(tc.tool_kwargs, sort_keys=True)}" + if key not in seen_tools: + seen_tools.add(key) + tool_info = {"name": tc.tool_name, "args": tc.tool_kwargs} + tool_calls.append(tool_info) + if on_tool_call: + on_tool_call(tool_info) + + elif isinstance(event, ToolCallResult): + if on_tool_result: + result_info = { + "name": event.tool_name, + "result": event.tool_output.content if event.tool_output else None, + } + on_tool_result(result_info) + + result = await handler + if isinstance(result, AgentOutput): + response = result.response.content or "" + else: + response = str(result) + + # qwen3/deepseek sometimes puts the entire intent inside and + # emits nothing after it, leaving response empty even though the model + # reasoned correctly. Strip residual tags so the retry logic fires. + response = re.sub(r".*?", "", response, flags=re.DOTALL).strip() + + # Mark thinking as supported if we got here + if self._thinking_supported is None: + self._thinking_supported = True + + except Exception as e: + error_msg = str(e) + # Handle thinking mode not supported + if "does not support thinking" in error_msg and self._thinking_supported is None: + self._thinking_supported = False + # Recreate agent without thinking + self._agent = None + self._ensure_agent(with_thinking=False) + return await self._chat_async(message, on_thinking, on_tool_call, on_tool_result) + raise + + return response, tool_calls + + def chat( + self, + message: str, + on_thinking=None, + on_tool_call=None, + on_tool_result=None, + ) -> tuple[str, list[dict]]: + """Synchronous chat wrapper using dedicated event loop. + + Uses short polling intervals to allow Ctrl+C to interrupt. + Includes retry logic for empty responses. + """ + import asyncio + + def run_chat(msg: str) -> tuple[str, list[dict]]: + future = asyncio.run_coroutine_threadsafe( + self._chat_async(msg, on_thinking, on_tool_call, on_tool_result), + self._loop, + ) + while True: + try: + return future.result(timeout=0.5) + except TimeoutError: + continue + except KeyboardInterrupt: + future.cancel() + raise + + response, tool_calls = run_chat(message) + + # Track whether agent just showed a preview (confirmed=False tool call) + if tool_calls: + last_tool = tool_calls[-1].get("name", "") + last_args = tool_calls[-1].get("args", {}) + if last_tool in ("fine_tune", "run_custom_training", "run_container_training"): + self._awaiting_confirmation = not last_args.get("confirmed", False) + + if not response.strip() and not tool_calls: + console.print("[dim yellow]⚠ Empty response, retrying...[/dim yellow]") + + # Thinking mode causes qwen3/deepseek to reason but not emit a tool call. + # Disable it first so the model outputs an action on the retry. + if self._use_thinking: + self.set_thinking_mode(False) + + if self._awaiting_confirmation: + retry_msg = "User confirmed. Call the appropriate tool to complete the task." + else: + retry_msg = ( + f"You know what to do. Call execute_tool() now. Original request: {message}" + ) + + response, tool_calls = run_chat(retry_msg) + + if not response.strip() and not tool_calls: + response, tool_calls = run_chat(f"Execute the action now: {message}") + + if not response.strip() and not tool_calls: + response = ( + "I couldn't generate a response. Try:\n" + "- `/think` to toggle thinking mode\n" + "- Be more specific about what you want\n" + "- Use `/mode full` for more reliable responses" + ) + + return response, tool_calls + + def close(self): + """Clean up agent resources.""" + try: + if self._loop and self._loop.is_running(): + self._loop.call_soon_threadsafe(self._loop.stop) + if self._loop_thread and self._loop_thread.is_alive(): + self._loop_thread.join(timeout=2) + except Exception: + pass # Ignore cleanup errors + + +def _check_ollama_model(model: str, url: str) -> tuple[bool, str]: + """Check if model exists on Ollama server.""" + import httpx + + try: + response = httpx.get(f"{url}/api/tags", timeout=10.0) + response.raise_for_status() + available = [m["name"] for m in response.json().get("models", [])] + + if model in available: + return True, "Model ready" + + # Check for similar models + base = model.split(":")[0] + similar = [m for m in available if m.startswith(base)] + if similar: + return False, f"Not found. Try: {', '.join(similar[:3])}" + return False, f"Not found. Pull with: ollama pull {model}" + + except httpx.ConnectError: + return False, f"Cannot connect to {url}" + except Exception as e: + return False, str(e) + + +def run_chat( + model: str = DEFAULT_MODEL, + url: str = DEFAULT_URL, + tool_mode: str = "static", +): + """Run interactive chat loop with rich UI. + + Args: + model: Ollama model name + url: Ollama server URL + tool_mode: Tool loading mode: + - "full": All tools loaded (~200 tokens) - default + - "progressive": 3 meta-tools, hierarchical discovery (~85 tokens) + - "semantic": 2 meta-tools, embedding search (~69 tokens) + """ + # Welcome panel + welcome = Table.grid(padding=(0, 1)) + welcome.add_column(justify="left") + welcome.add_row(Text("Kubeflow AI Agent", style="bold bright_cyan")) + welcome.add_row(Text(f"Model: {model}", style="bright_green")) + welcome.add_row(Text(f"Ollama: {url}", style="bright_white")) + mode_desc = TOOL_MODES.get(tool_mode, tool_mode) + welcome.add_row(Text(f"Tools: {mode_desc}", style="bright_yellow")) + welcome.add_row() + welcome.add_row(Text("Commands:", style="bright_yellow")) + welcome.add_row(Text(" /tools - List available tools", style="white")) + welcome.add_row( + Text(" /mode - Switch tool mode (static/progressive/semantic)", style="white") + ) + welcome.add_row(Text(" /think - Toggle thinking output", style="white")) + welcome.add_row(Text(" /file - Read file and analyze it", style="white")) + welcome.add_row(Text(" /clear - Clear conversation memory", style="white")) + welcome.add_row(Text(" exit - Quit the agent", style="white")) + + console.print() + console.print( + Panel( + welcome, + title="[bold bright_white]🚀 Ollama Agent[/bold bright_white]", + border_style="bright_blue", + padding=(1, 2), + ) + ) + + # Check model availability + console.print("[bright_cyan]Checking model...[/bright_cyan]", end="\r") + model_ok, model_msg = _check_ollama_model(model, url) + if model_ok: + console.print(f"[bright_green]✓ {model_msg}[/bright_green] ") + else: + console.print(f"[bright_red]✗ {model_msg}[/bright_red]") + return + + agent = OllamaAgent(model=model, base_url=url, tool_mode=tool_mode) + + # Pre-load agent + console.print("[bright_cyan]Loading tools...[/bright_cyan]", end="\r") + try: + agent._ensure_agent() + tools_count = len(agent._tools) if agent._tools else 0 + console.print(f"[bright_green]✓ Loaded {tools_count} tools[/bright_green]") + except Exception as e: + console.print(f"[bright_red]✗ Failed to initialize: {e}[/bright_red]") + return + + console.print() + console.print( + "[bright_yellow]💡 Try: 'list training jobs' or 'check cluster resources'[/bright_yellow]" + ) + + # Enable readline for command history (up/down arrow navigation) + try: + import atexit + import os + import readline # noqa: F401 - import enables history for input() + + # Optional: persist history across sessions + history_file = os.path.expanduser("~/.kubeflow_mcp_history") + try: + readline.read_history_file(history_file) + except FileNotFoundError: + pass + atexit.register(readline.write_history_file, history_file) + except ImportError: + pass # readline not available on some platforms + + # State - thinking OFF by default, auto-enables after first message if model supports it + show_thinking = False + thinking_buffer: list[str] = [] + + while True: + try: + console.print() + console.print("[bold bright_blue]You →[/bold bright_blue] ", end="") + user_input = input().strip() # Use raw input() for readline history support + + if not user_input: + continue + + if user_input.lower() in ("exit", "quit", "q"): + agent.close() + console.print("[dim italic]Goodbye![/dim italic]") + break + + if user_input.lower() == "/tools": + tools = agent._tools or [] + console.print(f"\n[bold]Available tools ({len(tools)}):[/bold]") + for t in tools: + console.print(f" [bright_cyan]{t.metadata.name}[/bright_cyan]") + continue + + if user_input.lower() == "/think": + show_thinking = not show_thinking + agent.set_thinking_mode(show_thinking) + status = "ON" if show_thinking else "OFF" + console.print(f"[bright_yellow]Thinking mode: {status}[/bright_yellow]") + if show_thinking: + console.print("[dim]Model reasoning will be shown during responses.[/dim]") + continue + + if user_input.lower() == "/clear": + if agent.memory: + agent.memory.reset() + agent._awaiting_confirmation = False + console.print("[bright_green]✓ Conversation memory cleared[/bright_green]") + console.print("[dim]Context reset - start fresh![/dim]") + else: + console.print("[dim]No memory to clear[/dim]") + if show_thinking: + console.print( + "[dim]Note: Only reasoning models (deepseek-r1, qwq, etc.) show thinking output[/dim]" + ) + continue + + if user_input.lower().startswith("/mode"): + parts = user_input.split() + if len(parts) == 1: + # Show current mode and options + console.print(f"\n[bold]Current mode:[/bold] {agent.tool_mode}") + console.print("\n[bold]Available modes:[/bold]") + for mode_name, mode_desc in TOOL_MODES.items(): + marker = "→" if mode_name == agent.tool_mode else " " + console.print( + f" {marker} [bright_cyan]{mode_name}[/bright_cyan]: {mode_desc}" + ) + console.print("\n[dim]Usage: /mode [/dim]") + else: + new_mode = parts[1].lower() + try: + console.print(f"[bright_cyan]Switching to {new_mode} mode...[/bright_cyan]") + num_tools = agent.set_mode(new_mode) + console.print( + f"[bright_green]✓ Switched to {new_mode} ({num_tools} tools)[/bright_green]" + ) + except ValueError as e: + console.print(f"[bright_red]✗ {e}[/bright_red]") + continue + + # /file command - read local file and include in message + if user_input.lower().startswith("/file"): + # Handle /file without path + if user_input.lower() == "/file" or user_input[5:].strip() == "": + console.print("[bright_yellow]Usage: /file [/bright_yellow]") + console.print("[dim]Example: /file examples/mnist_train.py[/dim]") + console.print("[dim] /file ~/scripts/train.py[/dim]") + continue + + file_path = user_input[5:].strip() + # Remove leading space if present + if file_path.startswith(" "): + file_path = file_path[1:] + + try: + from pathlib import Path + + path = Path(file_path).expanduser() + if not path.exists(): + console.print(f"[bright_red]✗ File not found: {file_path}[/bright_red]") + console.print("[dim]Check the path and try again[/dim]") + continue + + if not path.is_file(): + console.print(f"[bright_red]✗ Not a file: {file_path}[/bright_red]") + continue + + content = path.read_text() + lines = len(content.splitlines()) + console.print( + f"[bright_green]✓ Read {path.name} ({lines} lines)[/bright_green]" + ) + + # Detect file type for syntax highlighting + ext = path.suffix.lower() + lang = { + "py": "python", + "js": "javascript", + "ts": "typescript", + "yaml": "yaml", + "yml": "yaml", + "json": "json", + }.get(ext.lstrip("."), "") + + # Include file content in next message + user_input = f"Here is the contents of `{path.name}`:\n\n```{lang}\n{content}\n```\n\nPlease analyze this file and tell me what it does." + # Fall through to normal processing + except Exception as e: + console.print(f"[bright_red]Error reading file: {e}[/bright_red]") + continue + + # Show user message + console.print() + console.print( + Panel( + Text(user_input, style="white"), + title="[bold bright_blue]You[/bold bright_blue]", + border_style="bright_blue", + padding=(0, 1), + ) + ) + + # Processing indicator + console.print("[bright_cyan]⏳ Thinking...[/bright_cyan]", end="\r") + thinking_buffer.clear() + first_output = [True] + + def on_thinking(delta): + if show_thinking and delta: # noqa: B023 + if first_output[0]: # noqa: B023 + console.print(" " * 20, end="\r") # Clear status + first_output[0] = False # noqa: B023 + thinking_buffer.append(delta) + console.print( + f"[bright_magenta italic]{delta}[/bright_magenta italic]", + end="", + highlight=False, + ) + + def on_tool_call(tool_info): + if first_output[0]: # noqa: B023 + console.print(" " * 20, end="\r") # Clear "Thinking..." + first_output[0] = False # noqa: B023 + if thinking_buffer: + console.print() # Newline after thinking + thinking_buffer.clear() + console.print() + + tool_name = tool_info.get("name", "unknown") + tool_args = tool_info.get("args") or {} + + # Always show tool name + console.print(f" [bright_yellow]🔧 {tool_name}[/bright_yellow]") + + # Show arguments + if tool_args: + args_str = json.dumps(tool_args, indent=2, default=str) + for line in args_str.split("\n"): + console.print(f" [bright_white]{line}[/bright_white]") + else: + console.print(" [dim](no arguments)[/dim]") + + console.print("[bright_cyan] ⏳ Executing...[/bright_cyan]", end="\r") + + def on_tool_result(result_info): + console.print(" " * 30, end="\r") # Clear "Executing..." + if result_info.get("result"): + result_str = _format_tool_result(result_info["result"]) + console.print( + Panel( + Text(result_str, style="white"), + title="[bright_green]Result[/bright_green]", + border_style="green", + padding=(0, 1), + ) + ) + + try: + response, _ = agent.chat( + user_input, + on_thinking=on_thinking if show_thinking else None, + on_tool_call=on_tool_call, + on_tool_result=on_tool_result, + ) + except Exception as e: + console.print() + error_msg = str(e) + console.print( + Panel( + Text(f"{type(e).__name__}: {error_msg}", style="bright_red"), + title="[bright_red bold]❌ Error[/bright_red bold]", + border_style="red", + padding=(0, 1), + ) + ) + # Show helpful hints based on error type + if "does not support tools" in error_msg: + console.print( + "[yellow]💡 This model doesn't support function calling.[/yellow]" + ) + console.print( + "[yellow] Try: qwen2.5:7b, llama3.2, or mistral (with tools, no thinking)[/yellow]" + ) + console.print("[yellow] Or: qwq:32b (has both thinking AND tools)[/yellow]") + elif "connection" in error_msg.lower(): + console.print("[yellow]💡 Check if Ollama is running: ollama serve[/yellow]") + elif "timeout" in error_msg.lower(): + console.print("[yellow]💡 Request timed out. Try a simpler query.[/yellow]") + continue + + # Clear any pending status + console.print(" " * 40, end="\r") + + # Notify user if thinking is available (but don't auto-enable - keeps output clean) + if agent._thinking_supported is True and not agent._thinking_notified: + agent._thinking_notified = True + console.print( + "[dim]💭 Thinking supported. Use /think to see model reasoning.[/dim]" + ) + + # Newline after thinking + if thinking_buffer: + console.print() + + # Only show assistant panel if there's actual response content + if response and response.strip(): + console.print() + console.print( + Panel( + Markdown(response), + title="[bold bright_green]Assistant[/bold bright_green]", + border_style="bright_green", + padding=(0, 2), + ) + ) + + except KeyboardInterrupt: + console.print( + "\n[yellow]Interrupted. Press Ctrl+C again to quit, or continue typing.[/yellow]" + ) + try: + # Wait briefly for another Ctrl+C + import time + + time.sleep(0.5) + except KeyboardInterrupt: + agent.close() + console.print("[dim italic]Goodbye![/dim italic]") + break + continue + except EOFError: + agent.close() + console.print("\n[dim italic]Goodbye![/dim italic]") + break + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Kubeflow MCP Ollama Agent", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Tool modes: + full All tools loaded (~200 tokens) - default, best accuracy + progressive 3 meta-tools (~85 tokens) - hierarchical discovery + semantic 2 meta-tools (~69 tokens) - embedding search + +Examples: + # Default - all tools + python -m kubeflow_mcp.agents.ollama + + # Progressive mode (minimal tokens, hierarchical discovery) + python -m kubeflow_mcp.agents.ollama --mode progressive + + # Semantic mode (requires: pip install sentence-transformers) + python -m kubeflow_mcp.agents.ollama --mode semantic + """, + ) + parser.add_argument("--model", default=DEFAULT_MODEL, help="Ollama model") + parser.add_argument("--url", default=DEFAULT_URL, help="Ollama server URL") + parser.add_argument( + "--mode", + choices=[ + "full", + "progressive", + "semantic", + "static", + "mcp", + ], # static/mcp are legacy aliases + default="full", + help="Tool loading mode: full (all tools), progressive (hierarchical), semantic (embedding search)", + ) + args = parser.parse_args() + + run_chat(model=args.model, url=args.url, tool_mode=args.mode) + + +if __name__ == "__main__": + main()