diff --git a/dream-server/extensions/services/dashboard-api/agent_monitor.py b/dream-server/extensions/services/dashboard-api/agent_monitor.py index e19b5e578..f531e5698 100644 --- a/dream-server/extensions/services/dashboard-api/agent_monitor.py +++ b/dream-server/extensions/services/dashboard-api/agent_monitor.py @@ -26,7 +26,7 @@ def to_dict(self) -> dict: "tokens_per_second": round(self.tokens_per_second, 2), "error_rate_1h": round(self.error_rate_1h, 2), "queue_depth": self.queue_depth, - "last_update": self.last_update.isoformat() + "last_update": self.last_update.isoformat(), } @@ -43,9 +43,11 @@ async def refresh(self): """Query cluster status from smart proxy""" try: proc = await asyncio.create_subprocess_exec( - "curl", "-s", f"http://localhost:{os.environ.get('CLUSTER_PROXY_PORT', '9199')}/status", + "curl", + "-s", + f"http://localhost:{os.environ.get('CLUSTER_PROXY_PORT', '9199')}/status", stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=5) @@ -63,7 +65,7 @@ def to_dict(self) -> dict: "nodes": self.nodes, "total_gpus": self.total_gpus, "active_gpus": self.active_gpus, - "failover_ready": self.failover_ready + "failover_ready": self.failover_ready, } @@ -76,16 +78,14 @@ def __init__(self, history_minutes: int = 15): def add_sample(self, tokens_per_sec: float): """Add a new throughput sample""" - self.data_points.append({ - "timestamp": datetime.now().isoformat(), - "tokens_per_sec": tokens_per_sec - }) + self.data_points.append( + {"timestamp": datetime.now().isoformat(), "tokens_per_sec": tokens_per_sec} + ) # Prune old data cutoff = datetime.now() - timedelta(minutes=self.history_minutes) self.data_points = [ - p for p in self.data_points - if datetime.fromisoformat(p["timestamp"]) > cutoff + p for p in self.data_points if datetime.fromisoformat(p["timestamp"]) > cutoff ] def get_stats(self) -> dict: @@ -98,7 +98,7 @@ def get_stats(self) -> dict: "current": values[-1] if values else 0, "average": sum(values) / len(values), "peak": max(values) if values else 0, - "history": self.data_points[-30:] # Last 30 points + "history": self.data_points[-30:], # Last 30 points } @@ -129,5 +129,5 @@ def get_full_agent_metrics() -> dict: "timestamp": datetime.now().isoformat(), "agent": agent_metrics.to_dict(), "cluster": cluster_status.to_dict(), - "throughput": throughput.get_stats() + "throughput": throughput.get_stats(), } diff --git a/dream-server/extensions/services/dashboard-api/config.py b/dream-server/extensions/services/dashboard-api/config.py index eabc6c8e0..212d51927 100644 --- a/dream-server/extensions/services/dashboard-api/config.py +++ b/dream-server/extensions/services/dashboard-api/config.py @@ -15,10 +15,7 @@ INSTALL_DIR = os.environ.get("DREAM_INSTALL_DIR", os.path.expanduser("~/dream-server")) DATA_DIR = os.environ.get("DREAM_DATA_DIR", os.path.expanduser("~/.dream-server")) EXTENSIONS_DIR = Path( - os.environ.get( - "DREAM_EXTENSIONS_DIR", - str(Path(INSTALL_DIR) / "extensions" / "services") - ) + os.environ.get("DREAM_EXTENSIONS_DIR", str(Path(INSTALL_DIR) / "extensions" / "services")) ) DEFAULT_SERVICE_HOST = os.environ.get("SERVICE_HOST", "host.docker.internal") @@ -39,7 +36,9 @@ def _read_manifest_file(path: Path) -> dict[str, Any]: return data -def load_extension_manifests(manifest_dir: Path, gpu_backend: str) -> tuple[dict[str, dict[str, Any]], list[dict[str, Any]]]: +def load_extension_manifests( + manifest_dir: Path, gpu_backend: str +) -> tuple[dict[str, dict[str, Any]], list[dict[str, Any]]]: """Load service and feature definitions from extension manifests.""" services: dict[str, dict[str, Any]] = {} features: list[dict[str, Any]] = [] @@ -86,7 +85,11 @@ def load_extension_manifests(manifest_dir: Path, gpu_backend: str) -> tuple[dict ext_port_env = service.get("external_port_env") ext_port_default = service.get("external_port_default", service.get("port", 0)) - external_port = int(os.environ.get(ext_port_env, str(ext_port_default))) if ext_port_env else int(ext_port_default) + external_port = ( + int(os.environ.get(ext_port_env, str(ext_port_default))) + if ext_port_env + else int(ext_port_default) + ) services[service_id] = { "host": host, @@ -103,7 +106,11 @@ def load_extension_manifests(manifest_dir: Path, gpu_backend: str) -> tuple[dict if not isinstance(feature, dict): continue supported = feature.get("gpu_backends", ["amd", "nvidia", "apple"]) - if gpu_backend != "apple" and gpu_backend not in supported and "all" not in supported: + if ( + gpu_backend != "apple" + and gpu_backend not in supported + and "all" not in supported + ): continue if feature.get("id") and feature.get("name"): features.append(feature) @@ -112,7 +119,12 @@ def load_extension_manifests(manifest_dir: Path, gpu_backend: str) -> tuple[dict except Exception as e: logger.warning("Failed loading manifest %s: %s", path, e) - logger.info("Loaded %d extension manifests (%d services, %d features)", loaded, len(services), len(features)) + logger.info( + "Loaded %d extension manifests (%d services, %d features)", + loaded, + len(services), + len(features), + ) return services, features @@ -121,7 +133,9 @@ def load_extension_manifests(manifest_dir: Path, gpu_backend: str) -> tuple[dict MANIFEST_SERVICES, MANIFEST_FEATURES = load_extension_manifests(EXTENSIONS_DIR, GPU_BACKEND) SERVICES = MANIFEST_SERVICES if not SERVICES: - logger.error("No services loaded from manifests in %s โ€” dashboard will have no services", EXTENSIONS_DIR) + logger.error( + "No services loaded from manifests in %s โ€” dashboard will have no services", EXTENSIONS_DIR + ) # --- Features --- @@ -147,12 +161,14 @@ def resolve_workflow_dir() -> Path: WORKFLOW_CATALOG_FILE = WORKFLOW_DIR / "catalog.json" DEFAULT_WORKFLOW_CATALOG = {"workflows": [], "categories": {}} + def _default_n8n_url() -> str: cfg = SERVICES.get("n8n", {}) host = cfg.get("host", "n8n") port = cfg.get("port", 5678) return f"http://{host}:{port}" + N8N_URL = os.environ.get("N8N_URL", _default_n8n_url()) N8N_API_KEY = os.environ.get("N8N_API_KEY", "") @@ -164,18 +180,18 @@ def _default_n8n_url() -> str: "general": { "name": "General Helper", "system_prompt": "You are a friendly and helpful AI assistant. You're knowledgeable, patient, and aim to be genuinely useful. Keep responses clear and conversational.", - "icon": "\U0001f4ac" + "icon": "\U0001f4ac", }, "coding": { "name": "Coding Buddy", "system_prompt": "You are a skilled programmer and technical assistant. You write clean, well-documented code and explain technical concepts clearly. You're precise, thorough, and love solving problems.", - "icon": "\U0001f4bb" + "icon": "\U0001f4bb", }, "creative": { "name": "Creative Writer", "system_prompt": "You are an imaginative creative writer and storyteller. You craft vivid descriptions, engaging narratives, and think outside the box. You're expressive and enjoy wordplay.", - "icon": "\U0001f3a8" - } + "icon": "\U0001f3a8", + }, } # --- Sidebar Icons --- diff --git a/dream-server/extensions/services/dashboard-api/gpu.py b/dream-server/extensions/services/dashboard-api/gpu.py index aa0cc3248..346c2251e 100644 --- a/dream-server/extensions/services/dashboard-api/gpu.py +++ b/dream-server/extensions/services/dashboard-api/gpu.py @@ -34,6 +34,7 @@ def _read_sysfs(path: str) -> Optional[str]: def _find_amd_gpu_sysfs() -> Optional[str]: """Find the sysfs base path for an AMD GPU device.""" import glob + for card_dir in sorted(glob.glob("/sys/class/drm/card*/device")): vendor = _read_sysfs(f"{card_dir}/vendor") if vendor == "0x1002": @@ -44,6 +45,7 @@ def _find_amd_gpu_sysfs() -> Optional[str]: def _find_hwmon_dir(device_path: str) -> Optional[str]: """Find the hwmon directory for an AMD GPU device.""" import glob + hwmon_dirs = sorted(glob.glob(f"{device_path}/hwmon/hwmon*")) return hwmon_dirs[0] if hwmon_dirs else None @@ -119,11 +121,13 @@ def get_gpu_info_nvidia() -> Optional[GPUInfo]: Handles multi-GPU systems by summing VRAM across all GPUs and reporting aggregate utilization and peak temperature. """ - success, output = run_command([ - "nvidia-smi", - "--query-gpu=name,memory.used,memory.total,utilization.gpu,temperature.gpu,power.draw", - "--format=csv,noheader,nounits" - ]) + success, output = run_command( + [ + "nvidia-smi", + "--query-gpu=name,memory.used,memory.total,utilization.gpu,temperature.gpu,power.draw", + "--format=csv,noheader,nounits", + ] + ) if not success or not output: return None @@ -140,19 +144,27 @@ def get_gpu_info_nvidia() -> Optional[GPUInfo]: if len(parts) < 5: continue power_w = None - if len(parts) >= 6 and parts[5] not in ("[N/A]", "[Not Supported]", "N/A", "Not Supported", ""): + if len(parts) >= 6 and parts[5] not in ( + "[N/A]", + "[Not Supported]", + "N/A", + "Not Supported", + "", + ): try: power_w = round(float(parts[5]), 1) except (ValueError, TypeError): pass - gpus.append({ - "name": parts[0], - "mem_used": int(parts[1]), - "mem_total": int(parts[2]), - "util": int(parts[3]), - "temp": int(parts[4]), - "power_w": power_w, - }) + gpus.append( + { + "name": parts[0], + "mem_used": int(parts[1]), + "mem_total": int(parts[2]), + "util": int(parts[3]), + "temp": int(parts[4]), + "power_w": power_w, + } + ) if not gpus: return None @@ -229,6 +241,7 @@ def get_gpu_info_apple() -> Optional[GPUInfo]: success, vm_output = run_command(["vm_stat"]) if success: import re + pages = {} for line in vm_output.splitlines(): match = re.match(r"(.+?):\s+(\d+)", line) diff --git a/dream-server/extensions/services/dashboard-api/helpers.py b/dream-server/extensions/services/dashboard-api/helpers.py index f37329580..1a30fbcc0 100644 --- a/dream-server/extensions/services/dashboard-api/helpers.py +++ b/dream-server/extensions/services/dashboard-api/helpers.py @@ -23,7 +23,7 @@ # connections every poll cycle and prevents file-descriptor exhaustion. _aio_session: Optional[aiohttp.ClientSession] = None -_HEALTH_TIMEOUT = aiohttp.ClientTimeout(total=5) # match Docker's own 5 s timeout +_HEALTH_TIMEOUT = aiohttp.ClientTimeout(total=5) # match Docker's own 5 s timeout async def _get_aio_session() -> aiohttp.ClientSession: @@ -72,6 +72,7 @@ def _get_lifetime_tokens() -> int: # --- LLM Metrics --- + async def get_llama_metrics(model_hint: Optional[str] = None) -> dict: """Get inference metrics from llama-server Prometheus /metrics endpoint. @@ -157,12 +158,13 @@ async def get_llama_context_size(model_hint: Optional[str] = None) -> Optional[i # --- Service Health --- + async def check_service_health(service_id: str, config: dict) -> ServiceStatus: """Check if a service is healthy by hitting its health endpoint.""" if config.get("type") == "host-systemd": return await _check_host_service_health(service_id, config) - host = config.get('host', 'localhost') + host = config.get("host", "localhost") url = f"http://{host}:{config['port']}{config['health']}" status = "unknown" response_time = None @@ -187,9 +189,12 @@ async def check_service_health(service_id: str, config: dict) -> ServiceStatus: status = "down" return ServiceStatus( - id=service_id, name=config["name"], port=config["port"], + id=service_id, + name=config["name"], + port=config["port"], external_port=config.get("external_port", config["port"]), - status=status, response_time_ms=round(response_time, 1) if response_time else None + status=status, + response_time_ms=round(response_time, 1) if response_time else None, ) @@ -212,9 +217,12 @@ async def _check_host_service_health(service_id: str, config: dict) -> ServiceSt logger.debug(f"Host health check failed for {service_id} at {url}: {e}") status = "down" return ServiceStatus( - id=service_id, name=config["name"], port=config["port"], + id=service_id, + name=config["name"], + port=config["port"], external_port=config.get("external_port", config["port"]), - status=status, response_time_ms=round(response_time, 1) if response_time else None, + status=status, + response_time_ms=round(response_time, 1) if response_time else None, ) @@ -231,11 +239,16 @@ async def get_all_services() -> list[ServiceStatus]: for (sid, cfg), result in zip(SERVICES.items(), results): if isinstance(result, BaseException): logger.warning("Health check for %s raised %s: %s", sid, type(result).__name__, result) - statuses.append(ServiceStatus( - id=sid, name=cfg["name"], port=cfg["port"], - external_port=cfg.get("external_port", cfg["port"]), - status="down", response_time_ms=None, - )) + statuses.append( + ServiceStatus( + id=sid, + name=cfg["name"], + port=cfg["port"], + external_port=cfg.get("external_port", cfg["port"]), + status="down", + response_time_ms=None, + ) + ) else: statuses.append(result) return statuses @@ -243,11 +256,17 @@ async def get_all_services() -> list[ServiceStatus]: # --- System Metrics --- + def get_disk_usage() -> DiskUsage: """Get disk usage for the Dream Server install directory.""" path = INSTALL_DIR if os.path.exists(INSTALL_DIR) else os.path.expanduser("~") total, used, free = shutil.disk_usage(path) - return DiskUsage(path=path, used_gb=round(used / (1024**3), 2), total_gb=round(total / (1024**3), 2), percent=round(used / total * 100, 1)) + return DiskUsage( + path=path, + used_gb=round(used / (1024**3), 2), + total_gb=round(total / (1024**3), 2), + percent=round(used / total * 100, 1), + ) def get_model_info() -> Optional[ModelInfo]: @@ -258,18 +277,31 @@ def get_model_info() -> Optional[ModelInfo]: with open(env_path) as f: for line in f: if line.startswith("LLM_MODEL="): - model_name = line.split("=", 1)[1].strip().strip('"\'') + model_name = line.split("=", 1)[1].strip().strip("\"'") size_gb, context, quant = 15.0, 32768, None import re as _re + name_lower = model_name.lower() - if _re.search(r'\b7b\b', name_lower): size_gb = 4.0 - elif _re.search(r'\b14b\b', name_lower): size_gb = 8.0 - elif _re.search(r'\b32b\b', name_lower): size_gb = 16.0 - elif _re.search(r'\b70b\b', name_lower): size_gb = 35.0 - if "awq" in name_lower: quant = "AWQ" - elif "gptq" in name_lower: quant = "GPTQ" - elif "gguf" in name_lower: quant = "GGUF" - return ModelInfo(name=model_name, size_gb=size_gb, context_length=context, quantization=quant) + if _re.search(r"\b7b\b", name_lower): + size_gb = 4.0 + elif _re.search(r"\b14b\b", name_lower): + size_gb = 8.0 + elif _re.search(r"\b32b\b", name_lower): + size_gb = 16.0 + elif _re.search(r"\b70b\b", name_lower): + size_gb = 35.0 + if "awq" in name_lower: + quant = "AWQ" + elif "gptq" in name_lower: + quant = "GPTQ" + elif "gguf" in name_lower: + quant = "GGUF" + return ModelInfo( + name=model_name, + size_gb=size_gb, + context_length=context, + quantization=quant, + ) except OSError as e: logger.warning("Failed to read .env for model info: %s", e) return None @@ -295,7 +327,11 @@ def get_bootstrap_status() -> BootstrapStatus: eta_seconds = None if eta_str and eta_str.strip() and eta_str.strip() != "calculating...": try: - parts = [p.strip() for p in eta_str.replace("m", "").replace("s", "").split() if p.strip()] + parts = [ + p.strip() + for p in eta_str.replace("m", "").replace("s", "").split() + if p.strip() + ] if len(parts) == 2: eta_seconds = int(parts[0]) * 60 + int(parts[1]) elif len(parts) == 1: @@ -316,11 +352,13 @@ def get_bootstrap_status() -> BootstrapStatus: pass return BootstrapStatus( - active=True, model_name=data.get("model"), percent=percent, + active=True, + model_name=data.get("model"), + percent=percent, downloaded_gb=bytes_downloaded / (1024**3) if bytes_downloaded else None, total_gb=bytes_total / (1024**3) if bytes_total else None, speed_mbps=speed_bps / (1024**2) if speed_bps else None, - eta_seconds=eta_seconds + eta_seconds=eta_seconds, ) except (json.JSONDecodeError, OSError, KeyError) as e: logger.warning("Failed to parse bootstrap status: %s", e) @@ -336,19 +374,25 @@ def get_uptime() -> int: return int(float(f.read().split()[0])) elif _system == "Darwin": import subprocess + result = subprocess.run( ["sysctl", "-n", "kern.boottime"], - capture_output=True, text=True, timeout=5, + capture_output=True, + text=True, + timeout=5, ) if result.returncode == 0: # Output: "{ sec = 1234567890, usec = 0 } ..." import re + match = re.search(r"sec\s*=\s*(\d+)", result.stdout) if match: import time as _time + return int(_time.time()) - int(match.group(1)) elif _system == "Windows": import ctypes + return ctypes.windll.kernel32.GetTickCount64() // 1000 except Exception as e: logger.debug("get_uptime failed on %s: %s", _system, e) @@ -377,6 +421,7 @@ def _get_cpu_metrics_linux() -> dict: try: import glob + for tz in sorted(glob.glob("/sys/class/thermal/thermal_zone*/type")): with open(tz) as f: zone_type = f.read().strip() @@ -402,12 +447,16 @@ def _get_cpu_metrics_darwin() -> dict: result = {"percent": 0, "temp_c": None} try: import subprocess + out = subprocess.run( ["top", "-l", "1", "-n", "0", "-stats", "cpu"], - capture_output=True, text=True, timeout=5, + capture_output=True, + text=True, + timeout=5, ) if out.returncode == 0: import re + match = re.search(r"CPU usage:\s+([\d.]+)%\s+user.*?([\d.]+)%\s+sys", out.stdout) if match: result["percent"] = round(float(match.group(1)) + float(match.group(2)), 1) @@ -453,20 +502,27 @@ def _get_ram_metrics_sysctl() -> dict: result = {"used_gb": 0, "total_gb": 0, "percent": 0} try: import subprocess + out = subprocess.run( ["sysctl", "-n", "hw.memsize"], - capture_output=True, text=True, timeout=5, + capture_output=True, + text=True, + timeout=5, ) if out.returncode == 0: total_bytes = int(out.stdout.strip()) - total_gb = total_bytes / (1024 ** 3) + total_gb = total_bytes / (1024**3) result["total_gb"] = round(total_gb, 1) # vm_stat for used memory vm = subprocess.run( - ["vm_stat"], capture_output=True, text=True, timeout=5, + ["vm_stat"], + capture_output=True, + text=True, + timeout=5, ) if vm.returncode == 0: import re + pages = {} for line in vm.stdout.splitlines(): match = re.match(r"(.+?):\s+(\d+)", line) @@ -480,7 +536,7 @@ def _get_ram_metrics_sysctl() -> dict: wired = pages.get("Pages wired down", 0) compressed = pages.get("Pages occupied by compressor", 0) used_bytes = (active + wired + compressed) * page_size - result["used_gb"] = round(used_bytes / (1024 ** 3), 1) + result["used_gb"] = round(used_bytes / (1024**3), 1) if total_bytes > 0: result["percent"] = round(used_bytes / total_bytes * 100, 1) except Exception as e: diff --git a/dream-server/extensions/services/dashboard-api/main.py b/dream-server/extensions/services/dashboard-api/main.py index 4588eedf7..1e3973dac 100644 --- a/dream-server/extensions/services/dashboard-api/main.py +++ b/dream-server/extensions/services/dashboard-api/main.py @@ -29,16 +29,27 @@ # --- Local modules --- from config import SERVICES, DATA_DIR, SIDEBAR_ICONS from models import ( - GPUInfo, ServiceStatus, DiskUsage, ModelInfo, BootstrapStatus, - FullStatus, PortCheckRequest, + GPUInfo, + ServiceStatus, + DiskUsage, + ModelInfo, + BootstrapStatus, + FullStatus, + PortCheckRequest, ) from security import verify_api_key from gpu import get_gpu_info from helpers import ( get_all_services, - get_disk_usage, get_model_info, get_bootstrap_status, - get_uptime, get_cpu_metrics, get_ram_metrics, - get_llama_metrics, get_loaded_model, get_llama_context_size, + get_disk_usage, + get_model_info, + get_bootstrap_status, + get_uptime, + get_cpu_metrics, + get_ram_metrics, + get_llama_metrics, + get_loaded_model, + get_llama_context_size, ) from agent_monitor import collect_metrics @@ -52,18 +63,21 @@ app = FastAPI( title="Dream Server Dashboard API", version="2.0.0", - description="System status API for Dream Server Dashboard" + description="System status API for Dream Server Dashboard", ) # --- CORS --- + def get_allowed_origins(): env_origins = os.environ.get("DASHBOARD_ALLOWED_ORIGINS", "") if env_origins: return env_origins.split(",") origins = [ - "http://localhost:3001", "http://127.0.0.1:3001", - "http://localhost:3000", "http://127.0.0.1:3000", + "http://localhost:3001", + "http://127.0.0.1:3001", + "http://localhost:3000", + "http://127.0.0.1:3000", ] try: hostname = socket.gethostname() @@ -76,6 +90,7 @@ def get_allowed_origins(): pass return origins + app.add_middleware( CORSMiddleware, allow_origins=get_allowed_origins(), @@ -98,6 +113,7 @@ def get_allowed_origins(): # Core Endpoints (health, status, preflight, services) # ================================================================ + @app.get("/health") async def health(): """API health check.""" @@ -106,16 +122,22 @@ async def health(): # --- Preflight --- + @app.get("/api/preflight/docker", dependencies=[Depends(verify_api_key)]) async def preflight_docker(): """Check if Docker is available.""" import subprocess + if os.path.exists("/.dockerenv"): return {"available": True, "version": "available (host)"} try: result = subprocess.run(["docker", "--version"], capture_output=True, text=True, timeout=5) if result.returncode == 0: - version = result.stdout.strip().split()[2].rstrip(",") if len(result.stdout.strip().split()) > 2 else "unknown" + version = ( + result.stdout.strip().split()[2].rstrip(",") + if len(result.stdout.strip().split()) > 2 + else "unknown" + ) return {"available": True, "version": version} return {"available": False, "error": "Docker command failed"} except FileNotFoundError: @@ -133,15 +155,27 @@ async def preflight_gpu(): gpu_info = get_gpu_info() if gpu_info: vram_gb = round(gpu_info.memory_total_mb / 1024, 1) - result = {"available": True, "name": gpu_info.name, "vram": vram_gb, "backend": gpu_info.gpu_backend, "memory_type": gpu_info.memory_type} + result = { + "available": True, + "name": gpu_info.name, + "vram": vram_gb, + "backend": gpu_info.gpu_backend, + "memory_type": gpu_info.memory_type, + } if gpu_info.memory_type == "unified": result["memory_label"] = f"{vram_gb} GB Unified" return result gpu_backend = os.environ.get("GPU_BACKEND", "").lower() if gpu_backend == "amd": - return {"available": False, "error": "AMD GPU not detected via sysfs. Check /dev/kfd and /dev/dri access."} - return {"available": False, "error": "No GPU detected. Ensure NVIDIA drivers or AMD amdgpu driver is loaded."} + return { + "available": False, + "error": "AMD GPU not detected via sysfs. Check /dev/kfd and /dev/dri access.", + } + return { + "available": False, + "error": "No GPU detected. Ensure NVIDIA drivers or AMD amdgpu driver is loaded.", + } @app.get("/api/preflight/required-ports") @@ -171,7 +205,9 @@ async def preflight_ports(request: PortCheckRequest): sock.settimeout(1) sock.bind(("0.0.0.0", port)) except socket.error: - conflicts.append({"port": port, "service": port_services.get(port, "Unknown"), "in_use": True}) + conflicts.append( + {"port": port, "service": port_services.get(port, "Unknown"), "in_use": True} + ) return {"conflicts": conflicts, "available": len(conflicts) == 0} @@ -181,7 +217,12 @@ async def preflight_disk(): try: check_path = DATA_DIR if os.path.exists(DATA_DIR) else Path.home() usage = shutil.disk_usage(check_path) - return {"free": usage.free, "total": usage.total, "used": usage.used, "path": str(check_path)} + return { + "free": usage.free, + "total": usage.total, + "used": usage.used, + "path": str(check_path), + } except Exception: logger.exception("Disk preflight check failed") return {"error": "Disk check failed", "free": 0, "total": 0, "used": 0, "path": ""} @@ -189,6 +230,7 @@ async def preflight_disk(): # --- Core Data --- + @app.get("/gpu", response_model=Optional[GPUInfo]) async def gpu(api_key: str = Depends(verify_api_key)): """Get GPU metrics.""" @@ -225,9 +267,12 @@ async def status(api_key: str = Depends(verify_api_key)): service_statuses = await get_all_services() return FullStatus( timestamp=datetime.now(timezone.utc).isoformat(), - gpu=get_gpu_info(), services=service_statuses, - disk=get_disk_usage(), model=get_model_info(), - bootstrap=get_bootstrap_status(), uptime_seconds=get_uptime() + gpu=get_gpu_info(), + services=service_statuses, + disk=get_disk_usage(), + model=get_model_info(), + bootstrap=get_bootstrap_status(), + uptime_seconds=get_uptime(), ) @@ -244,13 +289,21 @@ async def api_status(api_key: str = Depends(verify_api_key)): except Exception: logger.exception("/api/status handler failed โ€” returning safe fallback") return { - "gpu": None, "services": [], "model": None, - "bootstrap": None, "uptime": 0, - "version": app.version, "tier": "Unknown", + "gpu": None, + "services": [], + "model": None, + "bootstrap": None, + "uptime": 0, + "version": app.version, + "tier": "Unknown", "cpu": {"percent": 0, "temp_c": None}, "ram": {"used_gb": 0, "total_gb": 0, "percent": 0}, - "inference": {"tokensPerSecond": 0, "lifetimeTokens": 0, - "loadedModel": None, "contextSize": None}, + "inference": { + "tokensPerSecond": 0, + "lifetimeTokens": 0, + "loadedModel": None, + "contextSize": None, + }, } @@ -288,20 +341,29 @@ async def _build_api_status() -> dict: gpu_data["powerDraw"] = gpu_info.power_w gpu_data["memoryLabel"] = "VRAM Partition" if gpu_info.memory_type == "unified" else "VRAM" - services_data = [{"name": s.name, "status": s.status, "port": s.external_port, "uptime": None} for s in service_statuses] + services_data = [ + {"name": s.name, "status": s.status, "port": s.external_port, "uptime": None} + for s in service_statuses + ] model_data = None if model_info: - model_data = {"name": model_info.name, "tokensPerSecond": None, "contextLength": model_info.context_length} + model_data = { + "name": model_info.name, + "tokensPerSecond": None, + "contextLength": model_info.context_length, + } bootstrap_data = None if bootstrap_info.active: bootstrap_data = { - "active": True, "model": bootstrap_info.model_name or "Full Model", + "active": True, + "model": bootstrap_info.model_name or "Full Model", "percent": bootstrap_info.percent or 0, "bytesDownloaded": int((bootstrap_info.downloaded_gb or 0) * 1024**3), "bytesTotal": int((bootstrap_info.total_gb or 0) * 1024**3), - "eta": bootstrap_info.eta_seconds, "speedMbps": bootstrap_info.speed_mbps + "eta": bootstrap_info.eta_seconds, + "speedMbps": bootstrap_info.speed_mbps, } tier = "Unknown" @@ -309,17 +371,27 @@ async def _build_api_status() -> dict: vram_gb = gpu_info.memory_total_mb / 1024 if gpu_info.memory_type == "unified" and gpu_info.gpu_backend == "amd": tier = "Strix Halo 90+" if vram_gb >= 90 else "Strix Halo Compact" - elif vram_gb >= 80: tier = "Professional" - elif vram_gb >= 24: tier = "Prosumer" - elif vram_gb >= 16: tier = "Standard" - elif vram_gb >= 8: tier = "Entry" - else: tier = "Minimal" + elif vram_gb >= 80: + tier = "Professional" + elif vram_gb >= 24: + tier = "Prosumer" + elif vram_gb >= 16: + tier = "Standard" + elif vram_gb >= 8: + tier = "Entry" + else: + tier = "Minimal" return { - "gpu": gpu_data, "services": services_data, "model": model_data, - "bootstrap": bootstrap_data, "uptime": get_uptime(), - "version": app.version, "tier": tier, - "cpu": get_cpu_metrics(), "ram": get_ram_metrics(), + "gpu": gpu_data, + "services": services_data, + "model": model_data, + "bootstrap": bootstrap_data, + "uptime": get_uptime(), + "version": app.version, + "tier": tier, + "cpu": get_cpu_metrics(), + "ram": get_ram_metrics(), "inference": { "tokensPerSecond": llama_metrics_data.get("tokens_per_second", 0), "lifetimeTokens": llama_metrics_data.get("lifetime_tokens", 0), @@ -331,6 +403,7 @@ async def _build_api_status() -> dict: # --- Settings --- + @app.get("/api/service-tokens", dependencies=[Depends(verify_api_key)]) async def service_tokens(): """Return connection tokens for services that need browser-side auth.""" @@ -363,11 +436,15 @@ async def get_external_links(api_key: str = Depends(verify_api_key)): ext_port = cfg.get("external_port", cfg.get("port", 0)) if not ext_port or sid == "dashboard-api": continue - links.append({ - "id": sid, "label": cfg.get("name", sid), "port": ext_port, - "icon": SIDEBAR_ICONS.get(sid, "ExternalLink"), - "healthNeedles": [sid, cfg.get("name", sid).lower()], - }) + links.append( + { + "id": sid, + "label": cfg.get("name", sid), + "port": ext_port, + "icon": SIDEBAR_ICONS.get(sid, "ExternalLink"), + "healthNeedles": [sid, cfg.get("name", sid).lower()], + } + ) return links @@ -396,15 +473,34 @@ def dir_size_gb(path: Path) -> float: total_data_gb = dir_size_gb(data_dir) return { - "models": {"formatted": f"{models_gb:.1f} GB", "gb": models_gb, "percent": round(models_gb / disk_info.total_gb * 100, 1) if disk_info.total_gb else 0}, - "vector_db": {"formatted": f"{vector_gb:.1f} GB", "gb": vector_gb, "percent": round(vector_gb / disk_info.total_gb * 100, 1) if disk_info.total_gb else 0}, - "total_data": {"formatted": f"{total_data_gb:.1f} GB", "gb": total_data_gb, "percent": round(total_data_gb / disk_info.total_gb * 100, 1) if disk_info.total_gb else 0}, - "disk": {"used_gb": disk_info.used_gb, "total_gb": disk_info.total_gb, "percent": disk_info.percent} + "models": { + "formatted": f"{models_gb:.1f} GB", + "gb": models_gb, + "percent": round(models_gb / disk_info.total_gb * 100, 1) if disk_info.total_gb else 0, + }, + "vector_db": { + "formatted": f"{vector_gb:.1f} GB", + "gb": vector_gb, + "percent": round(vector_gb / disk_info.total_gb * 100, 1) if disk_info.total_gb else 0, + }, + "total_data": { + "formatted": f"{total_data_gb:.1f} GB", + "gb": total_data_gb, + "percent": ( + round(total_data_gb / disk_info.total_gb * 100, 1) if disk_info.total_gb else 0 + ), + }, + "disk": { + "used_gb": disk_info.used_gb, + "total_gb": disk_info.total_gb, + "percent": disk_info.percent, + }, } # --- Startup --- + @app.on_event("startup") async def startup_event(): """Start background metrics collection.""" @@ -413,4 +509,5 @@ async def startup_event(): if __name__ == "__main__": import uvicorn + uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("DASHBOARD_API_PORT", "3002"))) diff --git a/dream-server/extensions/services/dashboard-api/security.py b/dream-server/extensions/services/dashboard-api/security.py index dd1599f7e..805f4685b 100644 --- a/dream-server/extensions/services/dashboard-api/security.py +++ b/dream-server/extensions/services/dashboard-api/security.py @@ -19,7 +19,8 @@ key_file.chmod(0o600) logger.warning( "DASHBOARD_API_KEY not set. Generated temporary key and wrote to %s (mode 0600). " - "Set DASHBOARD_API_KEY in your .env file for production.", key_file + "Set DASHBOARD_API_KEY in your .env file for production.", + key_file, ) security_scheme = HTTPBearer(auto_error=False) @@ -31,7 +32,7 @@ async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(se raise HTTPException( status_code=401, detail="Authentication required. Provide Bearer token in Authorization header.", - headers={"WWW-Authenticate": "Bearer"} + headers={"WWW-Authenticate": "Bearer"}, ) if not secrets.compare_digest(credentials.credentials, DASHBOARD_API_KEY): raise HTTPException(status_code=403, detail="Invalid API key.") diff --git a/dream-server/extensions/services/privacy-shield/pii_scrubber.py b/dream-server/extensions/services/privacy-shield/pii_scrubber.py index dbda92427..ca65a3cac 100644 --- a/dream-server/extensions/services/privacy-shield/pii_scrubber.py +++ b/dream-server/extensions/services/privacy-shield/pii_scrubber.py @@ -28,22 +28,25 @@ class PIIDetector: # Regex patterns for PII detection PATTERNS = { - 'email': re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b'), - 'phone': re.compile(r'\b(?:\+?1[-.\s]?)?\(?[0-9]{3}\)?[-.\s]?[0-9]{3}[-.\s]?[0-9]{4}\b'), - 'ssn': re.compile(r'\b\d{3}[-.\s]?\d{2}[-.\s]?\d{4}\b'), - 'ip_address': re.compile( - r'\b(?:\d{1,3}\.){3}\d{1,3}\b' # IPv4 - r'|' - r'(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}' # Full IPv6 - r'|' - r'(?:[0-9a-fA-F]{1,4}:){1,7}:' # Trailing :: - r'|' - r'::(?:[0-9a-fA-F]{1,4}:){0,6}[0-9a-fA-F]{1,4}' # Leading :: - r'|' - r'(?:[0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}' # Middle :: + "email": re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b"), + "phone": re.compile(r"\b(?:\+?1[-.\s]?)?\(?[0-9]{3}\)?[-.\s]?[0-9]{3}[-.\s]?[0-9]{4}\b"), + "ssn": re.compile(r"\b\d{3}[-.\s]?\d{2}[-.\s]?\d{4}\b"), + "ip_address": re.compile( + r"\b(?:\d{1,3}\.){3}\d{1,3}\b" # IPv4 + r"|" + r"(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}" # Full IPv6 + r"|" + r"(?:[0-9a-fA-F]{1,4}:){1,7}:" # Trailing :: + r"|" + r"::(?:[0-9a-fA-F]{1,4}:){0,6}[0-9a-fA-F]{1,4}" # Leading :: + r"|" + r"(?:[0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}" # Middle :: ), - 'api_key': re.compile(r'\b(?:api[_-]?key|apikey|token)[\s]*[=:]\s*["\']?[a-zA-Z0-9_\-]{16,}["\']?\b', re.IGNORECASE), - 'credit_card': re.compile(r'\b(?:\d{4}[-\s]?){3}\d{4}\b'), + "api_key": re.compile( + r'\b(?:api[_-]?key|apikey|token)[\s]*[=:]\s*["\']?[a-zA-Z0-9_\-]{16,}["\']?\b', + re.IGNORECASE, + ), + "credit_card": re.compile(r"\b(?:\d{4}[-\s]?){3}\d{4}\b"), } def _generate_token(self, pii_type: str, original: str) -> str: @@ -97,10 +100,8 @@ def restore(self, text: str) -> str: def get_stats(self) -> Dict: """Return statistics about detected PII.""" return { - 'unique_pii_count': len(self.pii_map), - 'pii_types': list(set( - token.split('_')[1] for token in self.pii_map.keys() - )) + "unique_pii_count": len(self.pii_map), + "pii_types": list(set(token.split("_")[1] for token in self.pii_map.keys())), } @@ -123,9 +124,9 @@ def process_request(self, prompt: str) -> Tuple[str, Dict]: stats = self.detector.get_stats() metadata = { - 'scrubbed': scrubbed != prompt, - 'pii_count': stats['unique_pii_count'], - 'pii_types': stats['pii_types'] + "scrubbed": scrubbed != prompt, + "pii_count": stats["unique_pii_count"], + "pii_types": stats["pii_types"], } return scrubbed, metadata diff --git a/dream-server/extensions/services/privacy-shield/proxy.py b/dream-server/extensions/services/privacy-shield/proxy.py index a4e3d8d8a..4c12cd78c 100644 --- a/dream-server/extensions/services/privacy-shield/proxy.py +++ b/dream-server/extensions/services/privacy-shield/proxy.py @@ -23,11 +23,14 @@ SHIELD_API_KEY = os.environ.get("SHIELD_API_KEY") if not SHIELD_API_KEY: SHIELD_API_KEY = secrets.token_urlsafe(32) - logging.warning("SHIELD_API_KEY not set. Generated temporary key (not logging for security). " - "Set SHIELD_API_KEY in .env for production.") + logging.warning( + "SHIELD_API_KEY not set. Generated temporary key (not logging for security). " + "Set SHIELD_API_KEY in .env for production." + ) security_scheme = HTTPBearer() + async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security_scheme)): """Verify API key for protected endpoints.""" if not secrets.compare_digest(credentials.credentials, SHIELD_API_KEY): @@ -48,7 +51,7 @@ async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(se # Connection pool for better performance http_client = httpx.AsyncClient( limits=httpx.Limits(max_keepalive_connections=100, max_connections=200), - timeout=httpx.Timeout(60.0, connect=5.0) + timeout=httpx.Timeout(60.0, connect=5.0), ) # Session store (TTL cache with auto-eviction to prevent unbounded growth) @@ -103,22 +106,19 @@ async def health(): "version": "0.2.0", "target_api": TARGET_API_BASE, "cache_enabled": CACHE_ENABLED, - "active_sessions": len(sessions) + "active_sessions": len(sessions), } @app.get("/stats") async def stats(): """Session statistics.""" - total_pii = sum( - s.detector.get_stats()['unique_pii_count'] - for s in sessions.values() - ) + total_pii = sum(s.detector.get_stats()["unique_pii_count"] for s in sessions.values()) return { "active_sessions": len(sessions), "total_pii_scrubbed": total_pii, "cache_enabled": CACHE_ENABLED, - "cache_size": CACHE_SIZE + "cache_size": CACHE_SIZE, } @@ -133,14 +133,16 @@ async def proxy(request: Request, path: str): # Read and process request body body = await request.body() - body_str = body.decode('utf-8') if body else "" + body_str = body.decode("utf-8") if body else "" # Scrub PII from request scrubbed_body, metadata = shield.process_request(body_str) # Forward to target API target_url = f"{TARGET_API_BASE}/{path}" - headers = {k: v for k, v in request.headers.items() if k.lower() not in ('host', 'content-length')} + headers = { + k: v for k, v in request.headers.items() if k.lower() not in ("host", "content-length") + } # Set host header for target host = TARGET_API_BASE.split("//")[-1].split("/")[0] @@ -153,18 +155,13 @@ async def proxy(request: Request, path: str): try: if request.method == "POST": resp = await http_client.post( - target_url, - headers=headers, - content=scrubbed_body.encode('utf-8') + target_url, headers=headers, content=scrubbed_body.encode("utf-8") ) else: - resp = await http_client.get( - target_url, - headers=headers - ) + resp = await http_client.get(target_url, headers=headers) # Read response - response_body = resp.content.decode('utf-8') + response_body = resp.content.decode("utf-8") # Restore PII in response restored_body = shield.process_response(response_body) @@ -175,35 +172,34 @@ async def proxy(request: Request, path: str): # Add privacy headers response_headers = { "X-Privacy-Shield": "active", - "X-PII-Scrubbed": str(metadata.get('pii_count', 0)), + "X-PII-Scrubbed": str(metadata.get("pii_count", 0)), "X-Processing-Time-Ms": f"{overhead_ms:.2f}", - "Content-Type": resp.headers.get("Content-Type", "application/json") + "Content-Type": resp.headers.get("Content-Type", "application/json"), } return Response( - content=restored_body, - status_code=resp.status_code, - headers=response_headers + content=restored_body, status_code=resp.status_code, headers=response_headers ) except httpx.TimeoutException: return JSONResponse( - status_code=504, - content={"error": "Gateway timeout", "shield": "active"} + status_code=504, content={"error": "Gateway timeout", "shield": "active"} ) except Exception as e: import logging import re + logger = logging.getLogger("privacy-shield") # Sanitize error message to prevent PII token leakage error_str = str(e) # Strip PII tokens and their original values - error_str = re.sub(r'', '[REDACTED]', error_str) - error_str = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '[EMAIL]', error_str) + error_str = re.sub(r"", "[REDACTED]", error_str) + error_str = re.sub( + r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", "[EMAIL]", error_str + ) logger.error(f"Privacy shield error: {error_str}") return JSONResponse( - status_code=500, - content={"error": "Privacy check failed", "shield": "active"} + status_code=500, content={"error": "Privacy check failed", "shield": "active"} ) @@ -216,6 +212,8 @@ async def shutdown(): if __name__ == "__main__": print(f"๐Ÿ”’ API Privacy Shield starting on port {PORT}") print(f"๐Ÿ“ก Proxying to: {TARGET_API_BASE}") - print(f"๐Ÿ’พ Cache: {'enabled' if CACHE_ENABLED else 'disabled'} (size={CACHE_SIZE}, ttl={CACHE_TTL}s)") + print( + f"๐Ÿ’พ Cache: {'enabled' if CACHE_ENABLED else 'disabled'} (size={CACHE_SIZE}, ttl={CACHE_TTL}s)" + ) print(f"๐Ÿงช Test with: curl http://localhost:{PORT}/health") uvicorn.run(app, host="0.0.0.0", port=PORT) diff --git a/dream-server/extensions/services/token-spy/db.py b/dream-server/extensions/services/token-spy/db.py index d04b56bc7..ec6987537 100644 --- a/dream-server/extensions/services/token-spy/db.py +++ b/dream-server/extensions/services/token-spy/db.py @@ -83,18 +83,34 @@ def init_db(): def log_usage(entry: dict): conn = _get_conn() cols = [ - "agent", "model", - "request_body_bytes", "message_count", "user_message_count", - "assistant_message_count", "tool_count", + "agent", + "model", + "request_body_bytes", + "message_count", + "user_message_count", + "assistant_message_count", + "tool_count", "system_prompt_total_chars", - "workspace_agents_chars", "workspace_soul_chars", "workspace_tools_chars", - "workspace_identity_chars", "workspace_user_chars", "workspace_heartbeat_chars", + "workspace_agents_chars", + "workspace_soul_chars", + "workspace_tools_chars", + "workspace_identity_chars", + "workspace_user_chars", + "workspace_heartbeat_chars", "workspace_bootstrap_chars", - "skill_injection_chars", "base_prompt_chars", + "skill_injection_chars", + "base_prompt_chars", "conversation_history_chars", - "input_tokens", "output_tokens", "cache_read_tokens", "cache_write_tokens", - "estimated_cost_usd", "duration_ms", "stop_reason", - "filter_chars_saved", "filter_tokens_saved", "filter_tools_removed", + "input_tokens", + "output_tokens", + "cache_read_tokens", + "cache_write_tokens", + "estimated_cost_usd", + "duration_ms", + "stop_reason", + "filter_chars_saved", + "filter_tokens_saved", + "filter_tools_removed", ] values = [entry.get(c) for c in cols] placeholders = ", ".join(["?"] * len(cols)) @@ -120,7 +136,8 @@ def query_usage(agent: str | None = None, hours: int = 24, limit: int = 200) -> def query_summary(hours: int = 24) -> list[dict]: conn = _get_conn() conn.row_factory = sqlite3.Row - rows = conn.execute(""" + rows = conn.execute( + """ SELECT agent, COUNT(*) as turns, @@ -138,7 +155,9 @@ def query_summary(hours: int = 24) -> list[dict]: FROM usage WHERE timestamp > datetime('now', ?) GROUP BY agent - """, [f"-{hours} hours"]).fetchall() + """, + [f"-{hours} hours"], + ).fetchall() return [dict(r) for r in rows] @@ -153,13 +172,16 @@ def query_session_status(agent: str, char_limit: int = 200_000) -> dict: conn.row_factory = sqlite3.Row # Get all recent turns for this agent, ordered chronologically - rows = conn.execute(""" + rows = conn.execute( + """ SELECT conversation_history_chars, cache_read_tokens, cache_write_tokens, estimated_cost_usd, timestamp FROM usage WHERE agent = ? AND timestamp > datetime('now', '-24 hours') ORDER BY timestamp ASC - """, [agent]).fetchall() + """, + [agent], + ).fetchall() if not rows: return { @@ -192,7 +214,9 @@ def query_session_status(agent: str, char_limit: int = 200_000) -> dict: # Last 5 turns for rolling averages last_5 = session_rows[-5:] avg_cost_5 = sum(r["estimated_cost_usd"] or 0 for r in last_5) / max(len(last_5), 1) - total_cache_5 = sum((r["cache_read_tokens"] or 0) + (r["cache_write_tokens"] or 0) for r in last_5) + total_cache_5 = sum( + (r["cache_read_tokens"] or 0) + (r["cache_write_tokens"] or 0) for r in last_5 + ) total_write_5 = sum(r["cache_write_tokens"] or 0 for r in last_5) cache_write_pct = total_write_5 / max(total_cache_5, 1) @@ -239,7 +263,7 @@ def query_recent_events(limit: int = 100, after_id: str = None): ORDER BY timestamp DESC LIMIT ? """, - (after_id, limit) + (after_id, limit), ).fetchall() else: rows = conn.execute( @@ -253,7 +277,7 @@ def query_recent_events(limit: int = 100, after_id: str = None): ORDER BY timestamp DESC LIMIT ? """, - (limit,) + (limit,), ).fetchall() return [dict(r) for r in rows] diff --git a/dream-server/extensions/services/token-spy/db_postgres.py b/dream-server/extensions/services/token-spy/db_postgres.py index 9df5f306a..72dcfbca1 100644 --- a/dream-server/extensions/services/token-spy/db_postgres.py +++ b/dream-server/extensions/services/token-spy/db_postgres.py @@ -71,7 +71,7 @@ def init_db(): # Check if tenant exists cur.execute( "SELECT id FROM tenants WHERE slug = %s AND deleted_at IS NULL", - (SINGLE_TENANT_SLUG,) + (SINGLE_TENANT_SLUG,), ) row = cur.fetchone() @@ -86,7 +86,7 @@ def init_db(): VALUES (%s, %s, 'free') RETURNING id """, - (SINGLE_TENANT_SLUG.replace("-", " ").title(), SINGLE_TENANT_SLUG) + (SINGLE_TENANT_SLUG.replace("-", " ").title(), SINGLE_TENANT_SLUG), ) _tenant_id = cur.fetchone()["id"] logger.info(f"Created tenant: {SINGLE_TENANT_SLUG} ({_tenant_id})") @@ -111,8 +111,7 @@ def _get_or_create_agent(agent_name: str) -> UUID: slug = agent_name.lower().replace(" ", "-") cur.execute( - "SELECT id FROM agents WHERE tenant_id = %s AND slug = %s", - (_tenant_id, slug) + "SELECT id FROM agents WHERE tenant_id = %s AND slug = %s", (_tenant_id, slug) ) row = cur.fetchone() @@ -125,7 +124,7 @@ def _get_or_create_agent(agent_name: str) -> UUID: VALUES (%s, %s, %s) RETURNING id """, - (_tenant_id, agent_name, slug) + (_tenant_id, agent_name, slug), ) agent_id = cur.fetchone()["id"] logger.info(f"Created agent: {agent_name} ({agent_id})") @@ -176,7 +175,9 @@ def log_usage(entry: dict): ) """, ( - uuid4(), _tenant_id, agent_id, + uuid4(), + _tenant_id, + agent_id, _detect_provider(entry.get("model", "")), entry.get("model", "unknown"), entry.get("request_body_bytes", 0), @@ -203,7 +204,7 @@ def log_usage(entry: dict): entry.get("estimated_cost_usd", 0), entry.get("duration_ms", 0), entry.get("stop_reason"), - ) + ), ) conn.commit() except Exception: @@ -315,7 +316,7 @@ def query_summary(hours: int = 24) -> list[dict]: AND r.timestamp > NOW() - INTERVAL '%s hours' GROUP BY a.name """, - (_tenant_id, hours) + (_tenant_id, hours), ) rows = cur.fetchall() @@ -362,7 +363,7 @@ def query_session_status(agent: str, char_limit: int = 200_000) -> dict: AND r.timestamp > NOW() - INTERVAL '24 hours' ORDER BY r.timestamp ASC """, - (_tenant_id, agent) + (_tenant_id, agent), ) rows = cur.fetchall() @@ -396,8 +397,12 @@ def query_session_status(agent: str, char_limit: int = 200_000) -> dict: # Last 5 turns for rolling averages last_5 = session_rows[-5:] - avg_cost_5 = sum(float(r["estimated_cost_usd"] or 0) for r in last_5) / max(len(last_5), 1) - total_cache_5 = sum((r["cache_read_tokens"] or 0) + (r["cache_write_tokens"] or 0) for r in last_5) + avg_cost_5 = sum(float(r["estimated_cost_usd"] or 0) for r in last_5) / max( + len(last_5), 1 + ) + total_cache_5 = sum( + (r["cache_read_tokens"] or 0) + (r["cache_write_tokens"] or 0) for r in last_5 + ) total_write_5 = sum(r["cache_write_tokens"] or 0 for r in last_5) cache_write_pct = total_write_5 / max(total_cache_5, 1) @@ -454,7 +459,7 @@ def query_recent_events(limit: int = 100, after_id: Optional[UUID] = None): ORDER BY r.timestamp DESC LIMIT %s """, - (_tenant_id, after_id, limit) + (_tenant_id, after_id, limit), ) else: cur.execute( @@ -476,7 +481,7 @@ def query_recent_events(limit: int = 100, after_id: Optional[UUID] = None): ORDER BY r.timestamp DESC LIMIT %s """, - (_tenant_id, limit) + (_tenant_id, limit), ) rows = cur.fetchall() # Convert datetime objects to ISO format strings for JSON serialization diff --git a/dream-server/extensions/services/token-spy/filters.py b/dream-server/extensions/services/token-spy/filters.py index 68fdeb0aa..973349269 100644 --- a/dream-server/extensions/services/token-spy/filters.py +++ b/dream-server/extensions/services/token-spy/filters.py @@ -17,6 +17,7 @@ @dataclass class FilterResult: """Metrics captured during filtering.""" + tools_removed: int = 0 tools_kept: int = 0 system_chars_removed: int = 0 @@ -89,8 +90,9 @@ def apply_filters(body: dict, filter_settings: dict) -> tuple[dict, FilterResult # โ”€โ”€ Filter 1: Tool Filtering โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -def _filter_tools(body: dict, cfg: dict, result: FilterResult, - log_details: bool) -> tuple[dict, FilterResult]: +def _filter_tools( + body: dict, cfg: dict, result: FilterResult, log_details: bool +) -> tuple[dict, FilterResult]: """Filter tool schemas by blocklist or allowlist.""" tools = body.get("tools", []) if not tools: @@ -134,8 +136,9 @@ def _filter_tools(body: dict, cfg: dict, result: FilterResult, # โ”€โ”€ Filter 2: System Prompt Trimming โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -def _filter_system_prompt(body: dict, cfg: dict, result: FilterResult, - log_details: bool) -> tuple[dict, FilterResult]: +def _filter_system_prompt( + body: dict, cfg: dict, result: FilterResult, log_details: bool +) -> tuple[dict, FilterResult]: """Trim system/developer role messages.""" messages = body.get("messages", []) mode = cfg.get("mode", "strip_sections") @@ -168,7 +171,11 @@ def _filter_system_prompt(body: dict, cfg: dict, result: FilterResult, if log_details and result.system_chars_removed > 0: log.info( f"[FILTER] System prompt trimmed by {result.system_chars_removed} chars" - + (f" (sections: {result.system_sections_stripped})" if result.system_sections_stripped else "") + + ( + f" (sections: {result.system_sections_stripped})" + if result.system_sections_stripped + else "" + ) ) return body, result @@ -185,7 +192,7 @@ def _strip_markdown_sections(text: str, section_headings: list[str]) -> tuple[st stripped = [] for heading in section_headings: # Determine heading level from the heading string - m = re.match(r'^(#{1,6})\s+', heading) + m = re.match(r"^(#{1,6})\s+", heading) if not m: continue level = len(m.group(1)) @@ -193,12 +200,12 @@ def _strip_markdown_sections(text: str, section_headings: list[str]) -> tuple[st # at the same or higher level (fewer or equal #), or end of string escaped = re.escape(heading) pattern = re.compile( - rf'^{escaped}\s*\n' # the heading line - rf'(.*?)' # content (non-greedy) - rf'(?=^#{{1,{level}}}\s|\Z)', # lookahead: next heading at same/higher level or EOF - re.MULTILINE | re.DOTALL + rf"^{escaped}\s*\n" # the heading line + rf"(.*?)" # content (non-greedy) + rf"(?=^#{{1,{level}}}\s|\Z)", # lookahead: next heading at same/higher level or EOF + re.MULTILINE | re.DOTALL, ) - new_text, count = pattern.subn('', text) + new_text, count = pattern.subn("", text) if count > 0: stripped.append(heading) text = new_text @@ -209,8 +216,9 @@ def _strip_markdown_sections(text: str, section_headings: list[str]) -> tuple[st # โ”€โ”€ Filter 3: Conversation History โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -def _filter_history(body: dict, cfg: dict, result: FilterResult, - log_details: bool) -> tuple[dict, FilterResult]: +def _filter_history( + body: dict, cfg: dict, result: FilterResult, log_details: bool +) -> tuple[dict, FilterResult]: """Manage conversation history size.""" messages = body.get("messages", []) if not messages: @@ -290,7 +298,9 @@ def _filter_history(body: dict, cfg: dict, result: FilterResult, # Check if tail messages are already in filtered_conv # by comparing the last N messages tail_ids = {id(m) for m in tail} - existing_ids = {id(m) for m in filtered_conv[-always_keep_last_n:]} if filtered_conv else set() + existing_ids = ( + {id(m) for m in filtered_conv[-always_keep_last_n:]} if filtered_conv else set() + ) if not tail_ids.issubset(existing_ids): # Ensure tail messages are present โ€” they may have been modified by # truncation but should still be in the list since we keep recent units diff --git a/dream-server/extensions/services/token-spy/main.py b/dream-server/extensions/services/token-spy/main.py index c6ed5321e..c9c501f69 100644 --- a/dream-server/extensions/services/token-spy/main.py +++ b/dream-server/extensions/services/token-spy/main.py @@ -25,9 +25,23 @@ DB_BACKEND = os.environ.get("DB_BACKEND", "sqlite").lower() if DB_BACKEND == "postgres": - from db_postgres import init_db, log_usage, query_session_status, query_summary, query_usage, query_recent_events + from db_postgres import ( + init_db, + log_usage, + query_session_status, + query_summary, + query_usage, + query_recent_events, + ) else: - from db import init_db, log_usage, query_session_status, query_summary, query_usage, query_recent_events + from db import ( + init_db, + log_usage, + query_session_status, + query_summary, + query_usage, + query_recent_events, + ) from filters import apply_filters from providers import ProviderRegistry @@ -110,8 +124,15 @@ "enabled": False, "mode": "allowlist", "allowlist": [ - "exec", "read", "write", "edit", "apply_patch", - "web_fetch", "web_search", "process", "memory_search", + "exec", + "read", + "write", + "edit", + "apply_patch", + "web_fetch", + "web_search", + "process", + "memory_search", "memory_get", ], "blocklist": [], @@ -120,9 +141,13 @@ "enabled": False, "mode": "strip_sections", "strip_sections": [ - "## Heartbeats", "## Silent Replies", "## OpenClaw Self-Update", - "## OpenClaw CLI Quick Reference", "## Reactions", - "## Sandbox", "## Model Aliases", + "## Heartbeats", + "## Silent Replies", + "## OpenClaw Self-Update", + "## OpenClaw CLI Quick Reference", + "## Reactions", + "## Sandbox", + "## Model Aliases", ], "custom_replacement": None, "max_chars": None, @@ -297,6 +322,7 @@ def get_moonshot_client() -> httpx.AsyncClient: _db_available = True + @app.on_event("startup") def on_startup(): global _db_available @@ -305,12 +331,17 @@ def on_startup(): _db_available = True except Exception as e: _db_available = False - log.error(f"Database unavailable -- running in degraded mode (file-based session monitoring only): {e}") + log.error( + f"Database unavailable -- running in degraded mode (file-based session monitoring only): {e}" + ) db_status = "connected" if _db_available else "DEGRADED" - log.info(f"Token monitor started for agent={AGENT_NAME}, provider={API_PROVIDER}, anthropic_upstream={ANTHROPIC_UPSTREAM}, openai_upstream={OPENAI_UPSTREAM}, db={db_status}") + log.info( + f"Token monitor started for agent={AGENT_NAME}, provider={API_PROVIDER}, anthropic_upstream={ANTHROPIC_UPSTREAM}, openai_upstream={OPENAI_UPSTREAM}, db={db_status}" + ) # Start background polling for remote agents (A16 etc.) # Only the first instance (port 9110) runs the poller to avoid duplicates. import asyncio + asyncio.get_event_loop().create_task(_poll_remote_agents()) @@ -330,12 +361,18 @@ async def _poll_remote_agents(): tool_results = status.get("tool_results", 0) needs_reset = chars >= limit or rec == "reset_recommended" if needs_reset: - reason = f"tool loop ({tool_results} calls)" if tool_results >= 480 else f"history {chars:,} >= {limit:,}" + reason = ( + f"tool loop ({tool_results} calls)" + if tool_results >= 480 + else f"history {chars:,} >= {limit:,}" + ) log.warning(f"[REMOTE-POLL] {agent}: auto-reset โ€” {reason}") _kill_session(agent, reason=f"auto-reset ({reason})") _last_auto_reset[agent] = time.time() elif chars > 0: - log.info(f"[REMOTE-POLL] {agent}: {chars:,} / {limit:,} chars ({chars*100//limit}%)") + log.info( + f"[REMOTE-POLL] {agent}: {chars:,} / {limit:,} chars ({chars*100//limit}%)" + ) # Poll local-model agents (file-based, no proxy traffic) for agent in AGENT_SESSION_DIRS: if agent == AGENT_NAME or agent in REMOTE_AGENTS: @@ -351,12 +388,18 @@ async def _poll_remote_agents(): tool_results = status.get("tool_results", 0) needs_reset = chars >= limit or rec == "reset_recommended" if needs_reset: - reason = f"tool loop ({tool_results} calls)" if tool_results >= 480 else f"history {chars:,} >= {limit:,}" + reason = ( + f"tool loop ({tool_results} calls)" + if tool_results >= 480 + else f"history {chars:,} >= {limit:,}" + ) log.warning(f"[LOCAL-POLL] {agent}: auto-reset โ€” {reason}") _kill_session(agent, reason=f"auto-reset ({reason})") _last_auto_reset[agent] = time.time() elif chars > 0: - log.info(f"[LOCAL-POLL] {agent}: {chars:,} / {limit:,} chars ({chars*100//limit}%)") + log.info( + f"[LOCAL-POLL] {agent}: {chars:,} / {limit:,} chars ({chars*100//limit}%)" + ) except Exception as e: log.error(f"[POLL] Error: {e}") await asyncio.sleep(60) @@ -390,10 +433,7 @@ def analyze_system_prompt(system_blocks: list) -> dict: return {"system_prompt_total_chars": 0, "base_prompt_chars": 0} # Combine all system text blocks - text = "\n".join( - b.get("text", "") if isinstance(b, dict) else str(b) - for b in system_blocks - ) + text = "\n".join(b.get("text", "") if isinstance(b, dict) else str(b) for b in system_blocks) result = {"system_prompt_total_chars": len(text)} # Initialize all workspace columns to 0 @@ -407,7 +447,7 @@ def analyze_system_prompt(system_blocks: list) -> dict: # Instead, find each "## KNOWNFILE.md" marker and measure until the next known marker. ctx_match = re.search(r"^# Project Context\b", text, re.MULTILINE) if ctx_match: - after_ctx = text[ctx_match.start():] + after_ctx = text[ctx_match.start() :] # Build list of all known file markers: ## AGENTS.md, ## SOUL.md, etc. # Also include ## Silent Replies, ## Heartbeats, ## Runtime as end markers all_file_names = list(WORKSPACE_FILE_MAP.keys()) @@ -447,7 +487,9 @@ def analyze_system_prompt(system_blocks: list) -> dict: result[col] += len(content) else: result.setdefault("workspace_other_chars", 0) - result["workspace_other_chars"] = result.get("workspace_other_chars", 0) + len(content) + result["workspace_other_chars"] = result.get("workspace_other_chars", 0) + len( + content + ) # Extract skills section (## Skills (mandatory) ... until next ## at same level) skills_match = re.search( @@ -458,8 +500,7 @@ def analyze_system_prompt(system_blocks: list) -> dict: # Base prompt = total minus workspace files and skills accounted = sum( - v for k, v in result.items() - if k.startswith("workspace_") or k == "skill_injection_chars" + v for k, v in result.items() if k.startswith("workspace_") or k == "skill_injection_chars" ) result["base_prompt_chars"] = max(0, result["system_prompt_total_chars"] - accounted) @@ -493,8 +534,14 @@ def analyze_messages(messages: list) -> dict: } -def estimate_cost(model: str, input_tokens: int, output_tokens: int, - cache_read: int, cache_write: int, provider_name: str = "anthropic") -> float: +def estimate_cost( + model: str, + input_tokens: int, + output_tokens: int, + cache_read: int, + cache_write: int, + provider_name: str = "anthropic", +) -> float: """Estimate USD cost based on model and token counts. Uses the provider plugin system for pricing data. Falls back to hardcoded @@ -534,6 +581,7 @@ def estimate_cost(model: str, input_tokens: int, output_tokens: int, # โ”€โ”€ Proxy Endpoint โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + @app.post("/v1/messages", dependencies=[Depends(verify_api_key)]) async def proxy_messages(request: Request): """Transparent proxy for Anthropic /v1/messages with metrics capture.""" @@ -567,15 +615,27 @@ async def proxy_messages(request: Request): # Build upstream headers โ€” forward everything relevant forward_headers = {} - for key in ("x-api-key", "anthropic-version", "content-type", "anthropic-beta", - "anthropic-dangerous-direct-browser-access", "user-agent", "x-app", - "accept", "authorization"): + for key in ( + "x-api-key", + "anthropic-version", + "content-type", + "anthropic-beta", + "anthropic-dangerous-direct-browser-access", + "user-agent", + "x-app", + "accept", + "authorization", + ): val = request.headers.get(key) if val: forward_headers[key] = val # Inject environment API key if not provided in request (for external deployments) - if UPSTREAM_API_KEY and "x-api-key" not in forward_headers and "authorization" not in forward_headers: + if ( + UPSTREAM_API_KEY + and "x-api-key" not in forward_headers + and "authorization" not in forward_headers + ): if API_PROVIDER == "anthropic": forward_headers["x-api-key"] = UPSTREAM_API_KEY else: @@ -585,18 +645,31 @@ async def proxy_messages(request: Request): if is_streaming: return await _handle_streaming( - client, raw_body, forward_headers, model, sys_analysis, msg_analysis, - tools, start, + client, + raw_body, + forward_headers, + model, + sys_analysis, + msg_analysis, + tools, + start, ) else: return await _handle_non_streaming( - client, raw_body, forward_headers, model, sys_analysis, msg_analysis, - tools, start, + client, + raw_body, + forward_headers, + model, + sys_analysis, + msg_analysis, + tools, + start, ) -async def _handle_streaming(client, raw_body, headers, model, sys_analysis, - msg_analysis, tools, start_time): +async def _handle_streaming( + client, raw_body, headers, model, sys_analysis, msg_analysis, tools, start_time +): """Stream SSE response through while capturing token metrics.""" # State for capturing usage from SSE events @@ -613,7 +686,8 @@ async def stream_and_capture(): logged = False try: async with client.stream( - "POST", "/v1/messages", + "POST", + "/v1/messages", content=raw_body, headers=headers, ) as upstream: @@ -635,10 +709,12 @@ async def stream_and_capture(): continue if current_event == "message_start": - msg_usage = (data.get("message", {}).get("usage", {})) + msg_usage = data.get("message", {}).get("usage", {}) usage["input_tokens"] = msg_usage.get("input_tokens", 0) usage["cache_read_tokens"] = msg_usage.get("cache_read_input_tokens", 0) - usage["cache_write_tokens"] = msg_usage.get("cache_creation_input_tokens", 0) + usage["cache_write_tokens"] = msg_usage.get( + "cache_creation_input_tokens", 0 + ) elif current_event == "message_delta": delta_usage = data.get("usage", {}) @@ -651,8 +727,13 @@ async def stream_and_capture(): elif current_event == "message_stop": # Stream complete โ€” log metrics _log_entry( - model, sys_analysis, msg_analysis, tools, - raw_body, usage, start_time, + model, + sys_analysis, + msg_analysis, + tools, + raw_body, + usage, + start_time, provider_name="anthropic", ) logged = True @@ -666,8 +747,13 @@ async def stream_and_capture(): # (which is a BaseException and bypasses 'except Exception') if not logged and usage["input_tokens"] > 0: _log_entry( - model, sys_analysis, msg_analysis, tools, - raw_body, usage, start_time, + model, + sys_analysis, + msg_analysis, + tools, + raw_body, + usage, + start_time, provider_name="anthropic", ) @@ -682,12 +768,14 @@ async def stream_and_capture(): ) -async def _handle_non_streaming(client, raw_body, headers, model, sys_analysis, - msg_analysis, tools, start_time): +async def _handle_non_streaming( + client, raw_body, headers, model, sys_analysis, msg_analysis, tools, start_time +): """Handle non-streaming requests (rare for OpenClaw, but support anyway).""" try: resp = await client.request( - "POST", "/v1/messages", + "POST", + "/v1/messages", content=raw_body, headers=headers, ) @@ -712,7 +800,16 @@ async def _handle_non_streaming(client, raw_body, headers, model, sys_analysis, "stop_reason": data.get("stop_reason"), } - _log_entry(model, sys_analysis, msg_analysis, tools, raw_body, usage, start_time, provider_name="anthropic") + _log_entry( + model, + sys_analysis, + msg_analysis, + tools, + raw_body, + usage, + start_time, + provider_name="anthropic", + ) return Response( content=resp.content, @@ -723,6 +820,7 @@ async def _handle_non_streaming(client, raw_body, headers, model, sys_analysis, # โ”€โ”€ OpenAI-Compatible Proxy (Moonshot/Kimi) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + def _analyze_openai_messages(messages: list) -> dict: """Analyze OpenAI-format messages for metrics.""" if not messages: @@ -821,18 +919,41 @@ async def proxy_chat_completions(request: Request): if is_streaming: return await _handle_openai_streaming( - client, raw_body, forward_headers, model, sys_analysis, msg_analysis, - tools, start, filter_result=filter_result, + client, + raw_body, + forward_headers, + model, + sys_analysis, + msg_analysis, + tools, + start, + filter_result=filter_result, ) else: return await _handle_openai_non_streaming( - client, raw_body, forward_headers, model, sys_analysis, msg_analysis, - tools, start, filter_result=filter_result, + client, + raw_body, + forward_headers, + model, + sys_analysis, + msg_analysis, + tools, + start, + filter_result=filter_result, ) -async def _handle_openai_streaming(client, raw_body, headers, model, sys_analysis, - msg_analysis, tools, start_time, filter_result=None): +async def _handle_openai_streaming( + client, + raw_body, + headers, + model, + sys_analysis, + msg_analysis, + tools, + start_time, + filter_result=None, +): """Stream OpenAI SSE response through while capturing token metrics.""" usage = { "input_tokens": 0, @@ -846,7 +967,8 @@ async def stream_and_capture(): logged = False try: async with client.stream( - "POST", "/v1/chat/completions", + "POST", + "/v1/chat/completions", content=raw_body, headers=headers, ) as upstream: @@ -854,7 +976,9 @@ async def stream_and_capture(): err_body = b"" async for chunk in upstream.aiter_bytes(): err_body += chunk - log.error(f"Upstream {upstream.status_code}: {err_body[:2000].decode(errors='replace')}") + log.error( + f"Upstream {upstream.status_code}: {err_body[:2000].decode(errors='replace')}" + ) yield f"data: {err_body.decode(errors='replace')}\n\n" return async for line in upstream.aiter_lines(): @@ -866,8 +990,13 @@ async def stream_and_capture(): data_str = stripped[5:].strip() if data_str == "[DONE]": _log_entry( - model, sys_analysis, msg_analysis, tools, - raw_body, usage, start_time, + model, + sys_analysis, + msg_analysis, + tools, + raw_body, + usage, + start_time, provider_name="openai", filter_result=filter_result, ) @@ -883,7 +1012,9 @@ async def stream_and_capture(): if chunk_usage: usage["input_tokens"] = chunk_usage.get("prompt_tokens", 0) usage["output_tokens"] = chunk_usage.get("completion_tokens", 0) - usage["cache_read_tokens"] = chunk_usage.get("prompt_tokens_details", {}).get("cached_tokens", 0) + usage["cache_read_tokens"] = chunk_usage.get( + "prompt_tokens_details", {} + ).get("cached_tokens", 0) choices = data.get("choices", []) if choices: @@ -899,7 +1030,17 @@ async def stream_and_capture(): finally: # Guarantee billing metrics are logged even on CancelledError if not logged and usage["input_tokens"] > 0: - _log_entry(model, sys_analysis, msg_analysis, tools, raw_body, usage, start_time, provider_name="openai", filter_result=filter_result) + _log_entry( + model, + sys_analysis, + msg_analysis, + tools, + raw_body, + usage, + start_time, + provider_name="openai", + filter_result=filter_result, + ) return StreamingResponse( stream_and_capture(), @@ -912,12 +1053,22 @@ async def stream_and_capture(): ) -async def _handle_openai_non_streaming(client, raw_body, headers, model, sys_analysis, - msg_analysis, tools, start_time, filter_result=None): +async def _handle_openai_non_streaming( + client, + raw_body, + headers, + model, + sys_analysis, + msg_analysis, + tools, + start_time, + filter_result=None, +): """Handle non-streaming OpenAI-format requests.""" try: resp = await client.request( - "POST", "/v1/chat/completions", + "POST", + "/v1/chat/completions", content=raw_body, headers=headers, ) @@ -939,10 +1090,22 @@ async def _handle_openai_non_streaming(client, raw_body, headers, model, sys_ana "output_tokens": resp_usage.get("completion_tokens", 0), "cache_read_tokens": resp_usage.get("prompt_tokens_details", {}).get("cached_tokens", 0), "cache_write_tokens": 0, - "stop_reason": (data.get("choices", [{}])[0].get("finish_reason") if data.get("choices") else None), + "stop_reason": ( + data.get("choices", [{}])[0].get("finish_reason") if data.get("choices") else None + ), } - _log_entry(model, sys_analysis, msg_analysis, tools, raw_body, usage, start_time, provider_name="openai", filter_result=filter_result) + _log_entry( + model, + sys_analysis, + msg_analysis, + tools, + raw_body, + usage, + start_time, + provider_name="openai", + filter_result=filter_result, + ) return Response( content=resp.content, @@ -985,7 +1148,6 @@ async def _handle_openai_non_streaming(client, raw_body, headers, model, sys_ana _last_auto_reset: dict[str, float] = {} - def _get_local_session_status(agent: str) -> dict: """Get session status for a local agent by reading JSONL files directly. Used for agents whose traffic doesn't pass through the token monitor proxy @@ -995,7 +1157,10 @@ def _get_local_session_status(agent: str) -> dict: return None import glob - files = sorted(glob.glob(os.path.join(sessions_dir, "*.jsonl")), key=os.path.getmtime, reverse=True) + + files = sorted( + glob.glob(os.path.join(sessions_dir, "*.jsonl")), key=os.path.getmtime, reverse=True + ) if not files: return None @@ -1080,6 +1245,7 @@ def _get_local_accumulated_turns(agent: str) -> int: # Use user turns if available; fall back to assistant turns for agents # whose OpenClaw gateway doesn't log user messages in the JSONL. import glob + files = glob.glob(os.path.join(sessions_dir, "*.jsonl")) user_turns = 0 assistant_turns = 0 @@ -1117,7 +1283,7 @@ def _get_local_accumulated_turns(agent: str) -> int: if current_file_turns >= last_file_turns: # Normal growth or no change โ€” add the delta - total += (current_file_turns - last_file_turns) + total += current_file_turns - last_file_turns else: # Session files were purged (current < last) โ€” add what's on disk now total += current_file_turns @@ -1136,18 +1302,27 @@ def _get_local_accumulated_turns(agent: str) -> int: def _get_remote_session_status(agent: str) -> dict: """Get session status for a remote agent via SSH.""" import subprocess + remote = REMOTE_AGENTS.get(agent) if not remote: - return {"agent": agent, "recommendation": "no_data", "current_session_turns": 0, - "current_history_chars": 0, "last_turn_cost": 0, "avg_cost_last_5": 0, - "cache_write_pct_last_5": 0, "cost_since_last_reset": 0, "turns_since_last_reset": 0} + return { + "agent": agent, + "recommendation": "no_data", + "current_session_turns": 0, + "current_history_chars": 0, + "last_turn_cost": 0, + "avg_cost_last_5": 0, + "cache_write_pct_last_5": 0, + "cost_since_last_reset": 0, + "turns_since_last_reset": 0, + } ssh_target = f"{remote['user']}@{remote['host']}" sessions_dir = remote["sessions_dir"] script = ( "import json, os, glob\n" - f"sdir = \"{sessions_dir}\"\n" + f'sdir = "{sessions_dir}"\n' "files = sorted(glob.glob(os.path.join(sdir, '*.jsonl')), key=os.path.getmtime, reverse=True)\n" "if not files:\n" " print(json.dumps({'turns': 0, 'chars': 0, 'files': 0}))\n" @@ -1178,15 +1353,34 @@ def _get_remote_session_status(agent: str) -> dict: ) try: result = subprocess.run( - ["ssh", "-o", "ConnectTimeout=3", "-o", "StrictHostKeyChecking=accept-new", - ssh_target, "python3", "-"], - input=script, capture_output=True, text=True, timeout=10, + [ + "ssh", + "-o", + "ConnectTimeout=3", + "-o", + "StrictHostKeyChecking=accept-new", + ssh_target, + "python3", + "-", + ], + input=script, + capture_output=True, + text=True, + timeout=10, ) if result.returncode != 0: log.warning(f"[REMOTE] SSH to {agent} failed: {result.stderr[:200]}") - return {"agent": agent, "recommendation": "no_data", "current_session_turns": 0, - "current_history_chars": 0, "last_turn_cost": 0, "avg_cost_last_5": 0, - "cache_write_pct_last_5": 0, "cost_since_last_reset": 0, "turns_since_last_reset": 0} + return { + "agent": agent, + "recommendation": "no_data", + "current_session_turns": 0, + "current_history_chars": 0, + "last_turn_cost": 0, + "avg_cost_last_5": 0, + "cache_write_pct_last_5": 0, + "cost_since_last_reset": 0, + "turns_since_last_reset": 0, + } data = json.loads(result.stdout.strip()) history_chars = data.get("chars", 0) @@ -1196,7 +1390,9 @@ def _get_remote_session_status(agent: str) -> dict: limit = get_agent_setting(agent, "session_char_limit") or AUTO_RESET_HISTORY_CHARS if tool_results >= 480: rec = "reset_recommended" - log.warning(f"[REMOTE] {agent}: tool loop detected ({tool_results} tool results in session)") + log.warning( + f"[REMOTE] {agent}: tool loop detected ({tool_results} tool results in session)" + ) elif history_chars > limit: rec = "reset_recommended" elif history_chars > limit * 0.8: @@ -1221,14 +1417,23 @@ def _get_remote_session_status(agent: str) -> dict: } except Exception as e: log.warning(f"[REMOTE] Failed to get session status for {agent}: {e}") - return {"agent": agent, "recommendation": "no_data", "current_session_turns": 0, - "current_history_chars": 0, "last_turn_cost": 0, "avg_cost_last_5": 0, - "cache_write_pct_last_5": 0, "cost_since_last_reset": 0, "turns_since_last_reset": 0} + return { + "agent": agent, + "recommendation": "no_data", + "current_session_turns": 0, + "current_history_chars": 0, + "last_turn_cost": 0, + "avg_cost_last_5": 0, + "cache_write_pct_last_5": 0, + "cost_since_last_reset": 0, + "turns_since_last_reset": 0, + } def _kill_remote_session(agent: str, reason: str = "dashboard") -> dict: """Kill the largest session for a remote agent via SSH.""" import subprocess + remote = REMOTE_AGENTS.get(agent) if not remote: return {"agent": agent, "action": "none", "reason": f"unknown remote agent: {agent}"} @@ -1238,7 +1443,7 @@ def _kill_remote_session(agent: str, reason: str = "dashboard") -> dict: script = ( "import os, glob, json\n" - f"sdir = \"{sessions_dir}\"\n" + f'sdir = "{sessions_dir}"\n' "files = sorted(glob.glob(os.path.join(sdir, '*.jsonl')), key=os.path.getsize, reverse=True)\n" "if not files:\n" " print(json.dumps({'action': 'none', 'reason': 'no sessions'}))\n" @@ -1258,24 +1463,43 @@ def _kill_remote_session(agent: str, reason: str = "dashboard") -> dict: ) try: result = subprocess.run( - ["ssh", "-o", "ConnectTimeout=3", "-o", "StrictHostKeyChecking=accept-new", - ssh_target, "python3", "-"], - input=script, capture_output=True, text=True, timeout=10, + [ + "ssh", + "-o", + "ConnectTimeout=3", + "-o", + "StrictHostKeyChecking=accept-new", + ssh_target, + "python3", + "-", + ], + input=script, + capture_output=True, + text=True, + timeout=10, ) if result.returncode != 0: - return {"agent": agent, "action": "none", "reason": f"SSH failed: {result.stderr[:100]}"} + return { + "agent": agent, + "action": "none", + "reason": f"SSH failed: {result.stderr[:100]}", + } data = json.loads(result.stdout.strip()) data["agent"] = agent if data.get("action") == "killed": - log.warning(f"[RESET] Remote killed session {data.get('session_id')} for {agent} ({data.get('size_bytes')} bytes) โ€” {reason}") + log.warning( + f"[RESET] Remote killed session {data.get('session_id')} for {agent} ({data.get('size_bytes')} bytes) โ€” {reason}" + ) return data except Exception as e: log.error(f"Remote session check failed for {agent}: {e}") return {"agent": agent, "action": "none", "reason": "Remote check failed"} + def _kill_session(agent: str, reason: str = "manual") -> dict: """Kill the largest active session for an agent. Returns result dict.""" import subprocess + if agent in REMOTE_AGENTS: return _kill_remote_session(agent, reason) @@ -1285,7 +1509,8 @@ def _kill_session(agent: str, reason: str = "manual") -> dict: result = subprocess.run( ["ls", "-S", f"{sessions_dir}/"], - capture_output=True, text=True, + capture_output=True, + text=True, ) largest = None for line in result.stdout.strip().split("\n"): @@ -1309,7 +1534,9 @@ def _kill_session(agent: str, reason: str = "manual") -> dict: try: with open(sessions_json, "r") as f: data = json.load(f) - to_remove = [k for k, v in data.items() if isinstance(v, dict) and v.get("sessionId") == largest] + to_remove = [ + k for k, v in data.items() if isinstance(v, dict) and v.get("sessionId") == largest + ] for k in to_remove: del data[k] with open(sessions_json, "w") as f: @@ -1349,8 +1576,17 @@ def _auto_reset_check(agent: str, history_chars: int): log.warning(f"[AUTO-RESET] {agent} session killed: {result.get('session_id')}") -def _log_entry(model, sys_analysis, msg_analysis, tools, raw_body, usage, start_time, - provider_name: str = None, filter_result=None): +def _log_entry( + model, + sys_analysis, + msg_analysis, + tools, + raw_body, + usage, + start_time, + provider_name: str = None, + filter_result=None, +): """Write a usage entry to SQLite. Args: @@ -1418,6 +1654,7 @@ def _log_entry(model, sys_analysis, msg_analysis, tools, raw_body, usage, start_ # โ”€โ”€ Health โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + @app.get("/health") def health(): uptime = int(time.time() - START_TIME) @@ -1459,7 +1696,9 @@ def api_filter_stats(): "history": { "enabled": f_settings.get("history", {}).get("enabled", False), "max_pairs": f_settings.get("history", {}).get("max_pairs"), - "truncate_tool_results_chars": f_settings.get("history", {}).get("truncate_tool_results_chars"), + "truncate_tool_results_chars": f_settings.get("history", {}).get( + "truncate_tool_results_chars" + ), }, } @@ -1469,8 +1708,12 @@ def api_get_settings(): """Current settings. Per-agent values of null inherit the global default.""" settings = load_settings() for agent_name, agent_cfg in settings.get("agents", {}).items(): - agent_cfg["_effective_session_char_limit"] = get_agent_setting(agent_name, "session_char_limit") - agent_cfg["_effective_poll_interval_minutes"] = get_agent_setting(agent_name, "poll_interval_minutes") + agent_cfg["_effective_session_char_limit"] = get_agent_setting( + agent_name, "session_char_limit" + ) + agent_cfg["_effective_poll_interval_minutes"] = get_agent_setting( + agent_name, "poll_interval_minutes" + ) return settings @@ -1491,7 +1734,9 @@ async def api_update_settings(request: Request): if val is not None: val = int(val) if val < 10000: - return JSONResponse({"error": "session_char_limit must be >= 10000"}, status_code=400) + return JSONResponse( + {"error": "session_char_limit must be >= 10000"}, status_code=400 + ) settings["session_char_limit"] = val if "poll_interval_minutes" in body: @@ -1499,7 +1744,9 @@ async def api_update_settings(request: Request): if val is not None: val = int(val) if val < 1 or val > 60: - return JSONResponse({"error": "poll_interval_minutes must be 1-60"}, status_code=400) + return JSONResponse( + {"error": "poll_interval_minutes must be 1-60"}, status_code=400 + ) settings["poll_interval_minutes"] = val # Deep-merge filter settings (hot-reloadable) @@ -1539,11 +1786,15 @@ async def api_update_settings(request: Request): def _update_timer_interval(minutes: int): """Best-effort update of the systemd timer interval.""" import subprocess - timer_path = os.environ.get("SESSION_TIMER_PATH", "/etc/systemd/system/openclaw-session-cleanup.timer") + + timer_path = os.environ.get( + "SESSION_TIMER_PATH", "/etc/systemd/system/openclaw-session-cleanup.timer" + ) try: with open(timer_path, "r") as f: timer_content = f.read() import re as _re + new_content = _re.sub( r"OnUnitActiveSec=\d+min", f"OnUnitActiveSec={minutes}min", @@ -1553,11 +1804,14 @@ def _update_timer_interval(minutes: int): with open(timer_path, "w") as f: f.write(new_content) subprocess.run(["systemctl", "daemon-reload"], capture_output=True) - subprocess.run(["systemctl", "restart", "openclaw-session-cleanup.timer"], capture_output=True) + subprocess.run( + ["systemctl", "restart", "openclaw-session-cleanup.timer"], capture_output=True + ) log.info(f"[SETTINGS] Timer updated to {minutes}min") except Exception as e: log.warning(f"[SETTINGS] Could not update timer: {e} (may need sudo)") + @app.get("/api/usage", dependencies=[Depends(verify_api_key)]) def api_usage(agent: str | None = None, hours: int = 24, limit: int = 200): return query_usage(agent=agent, hours=hours, limit=limit) @@ -1591,17 +1845,19 @@ def api_summary(hours: int = 24): if accumulated_turns > 0 or (local and local.get("current_session_turns", 0) > 0): current_chars = local.get("current_history_chars", 0) if local else 0 is_local = agent_name in LOCAL_MODEL_AGENTS - result.append({ - "agent": agent_name, - "turns": accumulated_turns, - "total_input_tokens": current_chars // 4, - "total_output_tokens": 0, - "total_cost": 0, - "total_cache_read": 0, - "total_cache_write": 0, - "avg_input_tokens": (current_chars // 4) // max(accumulated_turns, 1), - "is_local_model": is_local, - }) + result.append( + { + "agent": agent_name, + "turns": accumulated_turns, + "total_input_tokens": current_chars // 4, + "total_output_tokens": 0, + "total_cost": 0, + "total_cache_read": 0, + "total_cache_write": 0, + "avg_input_tokens": (current_chars // 4) // max(accumulated_turns, 1), + "is_local_model": is_local, + } + ) return result @@ -1640,8 +1896,10 @@ def api_reset_session(agent: str): """Kill the largest active session for an agent (safety valve trigger).""" if not AGENT_SESSION_DIRS.get(agent) and agent not in REMOTE_AGENTS: return JSONResponse( - {"error": f"Session reset not configured for agent: {agent}. Set AGENT_SESSION_DIRS env var."}, - status_code=400 + { + "error": f"Session reset not configured for agent: {agent}. Set AGENT_SESSION_DIRS env var." + }, + status_code=400, ) return _kill_session(agent, reason="dashboard") @@ -1896,7 +2154,7 @@ def api_reset_session(agent: str): const showReset = ['reset_recommended', 'compact_soon', 'monitor'].includes(rec); const isLocal = s.is_local_model; const cardClass = 'session-card' + (isLocal ? ' local-model' : ''); - const agentLabel = s.agent + (isLocal ? '\u26A1 Self-Hosted' : ''); + const agentLabel = s.agent + (isLocal ? '\u26a1 Self-Hosted' : ''); const limit = s.session_char_limit || 200000; const pct = limit > 0 ? Math.round((s.current_history_chars / limit) * 100) : 0; const barColor = pct > 80 ? '#da3633' : pct > 60 ? '#9e6a03' : '#238636'; @@ -1908,7 +2166,7 @@ def api_reset_session(agent: str): '
~' + fmt(Math.round(s.current_history_chars / 4)) + ' / ' + fmt(Math.round(limit / 4)) + ' tokens
' + '
' + (isLocal ? - '
Inference\u26A1 Local GPU \u2014 $0.00/token
' + '
Inference\u26a1 Local GPU \u2014 $0.00/token
' : '
Last turn cost' + fmtCost(s.last_turn_cost) + '
' + '
Avg cost (last 5)' + fmtCost(s.avg_cost_last_5) + '
' + @@ -1945,7 +2203,7 @@ def api_reset_session(agent: str): '

Cache Efficiency

' + cacheReadPct + '%
' + fmt(totalCacheRead) + ' reads / ' + fmt(totalCacheWrite) + ' writes
'; data.forEach(d => { if (d.is_local_model) { - html += '

' + d.agent.toUpperCase() + ' \u26A1 SELF-HOSTED

' + d.turns + ' turns
$0.00 \u2014 local GPU | ~' + fmt(d.avg_input_tokens) + ' tokens/turn
'; + html += '

' + d.agent.toUpperCase() + ' \u26a1 SELF-HOSTED

' + d.turns + ' turns
$0.00 \u2014 local GPU | ~' + fmt(d.avg_input_tokens) + ' tokens/turn
'; } else { html += '

' + d.agent.toUpperCase() + '

' + d.turns + ' turns
' + fmtCost(d.total_cost) + ' | avg ' + fmt(d.avg_input_tokens) + ' in/turn
'; } @@ -2291,9 +2549,11 @@ def dashboard(): # โ”€โ”€ SSE Token Events Stream โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + @app.get("/token_events", dependencies=[Depends(verify_api_key)]) async def token_events(request: Request): """Stream token usage events as Server-Sent Events.""" + async def event_stream(): last_id = None while True: @@ -2313,7 +2573,7 @@ async def event_stream(): "total_tokens": event.get("total_tokens", 0), "cost_usd": float(event.get("cost_usd", 0) or 0), "timestamp": event.get("timestamp", ""), - "agent_name": event.get("agent_name", AGENT_NAME) + "agent_name": event.get("agent_name", AGENT_NAME), } yield f"data: {json.dumps(event_data)}\n\n" @@ -2343,6 +2603,7 @@ async def event_stream(): # โ”€โ”€ Catch-all for other endpoints โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) async def proxy_other(request: Request, path: str): """Forward any other requests to upstream transparently.""" @@ -2352,8 +2613,15 @@ async def proxy_other(request: Request, path: str): else: client = get_http_client() headers = {} - for key in ("x-api-key", "anthropic-version", "content-type", "anthropic-beta", - "authorization", "accept", "user-agent"): + for key in ( + "x-api-key", + "anthropic-version", + "content-type", + "anthropic-beta", + "authorization", + "accept", + "user-agent", + ): val = request.headers.get(key) if val: headers[key] = val diff --git a/dream-server/scripts/healthcheck.py b/dream-server/scripts/healthcheck.py index 9e2eaca7b..890ab3715 100644 --- a/dream-server/scripts/healthcheck.py +++ b/dream-server/scripts/healthcheck.py @@ -9,15 +9,17 @@ import urllib.error import socket + def check_http(url, timeout=5): """Check HTTP endpoint returns 200.""" try: - req = urllib.request.Request(url, method='HEAD') + req = urllib.request.Request(url, method="HEAD") with urllib.request.urlopen(req, timeout=timeout) as resp: return resp.status == 200 except (urllib.error.HTTPError, urllib.error.URLError, socket.timeout): return False + def check_tcp(host, port, timeout=5): """Check TCP port is open.""" try: @@ -26,6 +28,7 @@ def check_tcp(host, port, timeout=5): except (socket.timeout, ConnectionRefusedError, OSError): return False + if __name__ == "__main__": if len(sys.argv) < 2: print("Usage: healthcheck.py ") @@ -33,10 +36,10 @@ def check_tcp(host, port, timeout=5): target = sys.argv[1] - if target.startswith('http://') or target.startswith('https://'): + if target.startswith("http://") or target.startswith("https://"): ok = check_http(target) - elif ':' in target: - host, port = target.rsplit(':', 1) + elif ":" in target: + host, port = target.rsplit(":", 1) ok = check_tcp(host, int(port)) else: print(f"Invalid target: {target}") diff --git a/dream-server/scripts/validate-models.py b/dream-server/scripts/validate-models.py index b27caf948..3b7a49676 100644 --- a/dream-server/scripts/validate-models.py +++ b/dream-server/scripts/validate-models.py @@ -31,6 +31,7 @@ }, } + def check_model(service, config): """Check if a model exists and has reasonable size.""" # Resolve base path relative to script location (scripts/ -> parent -> dream-server root) @@ -45,7 +46,7 @@ def check_model(service, config): size_gb = model_path.stat().st_size / (1024**3) else: # Directory - sum all files - size_gb = sum(f.stat().st_size for f in model_path.rglob('*') if f.is_file()) / (1024**3) + size_gb = sum(f.stat().st_size for f in model_path.rglob("*") if f.is_file()) / (1024**3) min_size = config["size_gb"] * 0.5 # At least 50% of expected size if size_gb < min_size: @@ -53,6 +54,7 @@ def check_model(service, config): return True, f"OK: {size_gb:.2f}GB" + def main(): """Validate all required models are present.""" print("=" * 60) @@ -82,5 +84,6 @@ def main(): print(" ./scripts/download-models.sh") return 1 + if __name__ == "__main__": sys.exit(main())