diff --git a/.env.example b/.env.example index 78a3b72c07..30f49777e5 100644 --- a/.env.example +++ b/.env.example @@ -4,13 +4,54 @@ LLM_API_KEY=your_api_key_here LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 LLM_MODEL_NAME=qwen-plus +LLM_CONTEXT_WINDOW=8192 +LLM_MAX_CONCURRENCY=1 +LLM_JSON_MAX_RETRIES=2 +ONTOLOGY_MAX_OUTPUT_TOKENS=2048 +ONTOLOGY_PROMPT_MARGIN_TOKENS=512 +ONTOLOGY_MAX_CHUNKS=8 +LOCAL_ZEP_EXTRACT_MAX_OUTPUT_TOKENS=2048 +LOCAL_ZEP_EXTRACT_MAX_RETRIES=2 -# ===== ZEP记忆图谱配置 ===== -# 每月免费额度即可支撑简单使用:https://app.getzep.com/ -ZEP_API_KEY=your_zep_api_key_here +# ===== Local Zep / Embeddings 配置 ===== +# 使用 OpenAI-compatible embeddings 接口,vLLM 可作为首选后端 +EMBEDDING_API_KEY=local-embedding-key +EMBEDDING_BASE_URL=http://localhost:8001/v1 +EMBEDDING_MODEL_NAME=your_embedding_model_here + +# 可选:cross_encoder reranker 接口;不配置时本地图谱会自动回退到 RRF +RERANKER_API_KEY=local-reranker-key +RERANKER_BASE_URL=http://localhost:8002/v1 +RERANKER_MODEL_NAME=your_reranker_model_here +LOCAL_ZEP_RERANK_TOP_K=50 + +# 本地图谱 SQLite 路径(可选) +LOCAL_ZEP_DB_PATH=backend/data/local_zep.sqlite3 + +# Optional: Python executable for the original OASIS/CAMEL simulation runner. +# Use this when the backend runs on Python 3.13 but OASIS is installed in a Python 3.11 venv. +# OASIS_PYTHON=/absolute/path/to/oasis-venv/bin/python + +# ===== Tailscale / Remote Access(可选)===== +# 前端开发模式默认会通过 /api 代理访问后端,适合直接用 Tailscale 访问 http://:3000 +VITE_API_BASE_URL= +VITE_DEV_HOST=0.0.0.0 +VITE_DEV_PORT=3000 +VITE_DEV_PROXY_TARGET=http://127.0.0.1:5001 +VITE_ALLOWED_HOSTS=localhost,127.0.0.1,.ts.net,.beta.tailscale.net + +# 如果通过 MagicDNS 访问时 HMR 需要显式指定,可取消注释 +# VITE_HMR_HOST=your-machine.your-tailnet.ts.net +# VITE_HMR_CLIENT_PORT=3000 + +FLASK_HOST=0.0.0.0 +FLASK_PORT=5001 +ENABLE_PROXY_FIX=false +PUBLIC_BASE_URL= +TAILSCALE_URL= # ===== 加速 LLM 配置(可选)===== # 注意如果不使用加速配置,env文件中就不要出现下面的配置项 LLM_BOOST_API_KEY=your_api_key_here LLM_BOOST_BASE_URL=your_base_url_here -LLM_BOOST_MODEL_NAME=your_model_name_here \ No newline at end of file +LLM_BOOST_MODEL_NAME=your_model_name_here diff --git a/.gitignore b/.gitignore index 55d3ef197b..d39439e912 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ Thumbs.db .env.development .env.test .env.production +start.sh # Python __pycache__/ @@ -18,6 +19,7 @@ __pycache__/ .Python .venv/ venv/ +venv11/ ENV/ .eggs/ *.egg-info/ @@ -57,4 +59,4 @@ backend/logs/ backend/uploads/ # Docker 数据 -data/ \ No newline at end of file +data/ diff --git a/README-ZH.md b/README-ZH.md index 0b20424d3a..5a428acad8 100644 --- a/README-ZH.md +++ b/README-ZH.md @@ -100,8 +100,8 @@ MiroFish 致力于打造映射现实的群体智能镜像,通过捕捉个体 | 工具 | 版本要求 | 说明 | 安装检查 | |------|---------|------|---------| | **Node.js** | 18+ | 前端运行环境,包含 npm | `node -v` | -| **Python** | ≥3.11, ≤3.12 | 后端运行环境 | `python --version` | -| **uv** | 最新版 | Python 包管理器 | `uv --version` | +| **Python** | ≥3.11 | 后端运行环境。API 与本地图谱后端支持 Python 3.13。 | `python --version` | +| **Python 3.11** | 3.11.x | 可选但推荐给 OASIS/CAMEL 使用,因为 `camel-oasis==0.2.5` 声明 Python `<3.12`。 | `python3.11 --version` | #### 1. 配置环境变量 @@ -121,10 +121,31 @@ cp .env.example .env LLM_API_KEY=your_api_key LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 LLM_MODEL_NAME=qwen-plus - -# Zep Cloud 配置 -# 每月免费额度即可支撑简单使用:https://app.getzep.com/ -ZEP_API_KEY=your_zep_api_key +LLM_CONTEXT_WINDOW=8192 +LLM_MAX_CONCURRENCY=1 +LLM_JSON_MAX_RETRIES=2 +ONTOLOGY_MAX_OUTPUT_TOKENS=2048 +ONTOLOGY_PROMPT_MARGIN_TOKENS=512 +LOCAL_ZEP_EXTRACT_MAX_OUTPUT_TOKENS=2048 +LOCAL_ZEP_EXTRACT_MAX_RETRIES=2 + +# 本地图谱 / Embeddings 配置 +# 使用 OpenAI-compatible embeddings 接口,推荐优先接 vLLM +EMBEDDING_API_KEY=local-embedding-key +EMBEDDING_BASE_URL=http://localhost:8001/v1 +EMBEDDING_MODEL_NAME=your_embedding_model + +# 可选 cross-encoder reranker 接口;不配置时本地图谱搜索自动回退到 RRF +RERANKER_API_KEY=local-reranker-key +RERANKER_BASE_URL=http://localhost:8002/v1 +RERANKER_MODEL_NAME=your_reranker_model +LOCAL_ZEP_RERANK_TOP_K=50 + +# 可选:本地图谱 SQLite 存储路径 +LOCAL_ZEP_DB_PATH=backend/data/local_zep.sqlite3 + +# 可选:为原版 OASIS/CAMEL 脚本指定独立 Python 3.11 venv +OASIS_PYTHON=/absolute/path/to/venv11/bin/python ``` #### 2. 安装依赖 @@ -144,6 +165,23 @@ npm run setup npm run setup:backend ``` +如果系统默认 `python3` 低于 3.11,可以手动用目标解释器创建后端 venv: + +```bash +python3.13 -m venv venv +./venv/bin/python -m pip install --upgrade pip +./venv/bin/python -m pip install -r backend/requirements.txt +``` + +如果需要运行原版 OASIS/CAMEL 模拟引擎,建议单独创建 Python 3.11 venv,并在 `.env` 中设置 `OASIS_PYTHON`: + +```bash +python3.11 -m venv venv11 +./venv11/bin/python -m pip install --upgrade pip +./venv11/bin/python -m pip install -r backend/requirements.txt +echo "OASIS_PYTHON=$(pwd)/venv11/bin/python" >> .env +``` + #### 3. 启动服务 ```bash @@ -162,8 +200,130 @@ npm run backend # 仅启动后端 npm run frontend # 仅启动前端 ``` +### 本地模型运行与参数 + +MiroFish 不再依赖 Zep Cloud。本地图谱模块会把图数据存入 SQLite,并使用你提供的 OpenAI-compatible embedding 接口。Reranker 是可选项,只在图谱搜索请求 `cross_encoder` 时用于重排候选结果;未配置时会自动回退到本地 hybrid ranking / RRF。 + +本地端点示例: + +```bash +# 主 LLM,llama.cpp 示例 +llama-server \ + -m /path/to/model.gguf \ + --alias Qwen3.5-9B-VL \ + --host 127.0.0.1 \ + --port 8000 \ + -c 16384 \ + --jinja + +# Embeddings,vLLM 示例 +vllm serve /path/to/Qwen3-Embedding-0.6B \ + --host 127.0.0.1 \ + --port 8001 \ + --runner pooling \ + --max-model-len 8192 + +# 可选 reranker,vLLM 示例 +vllm serve /path/to/Qwen3-Reranker-0.6B \ + --host 127.0.0.1 \ + --port 8002 \ + --runner pooling \ + --max-model-len 8192 +``` + +已验证可用于 24GB 显存 GPU(RTX 3090 / RTX 4090 级别)的配置。路径已做匿名化处理: + +```bash +# 主 llama.cpp LLM 端点 +./llama-server \ + --mmproj /path/to/models/Qwen3.5-9B-gguf/mmproj-Qwen3.5-9B-BF16.gguf \ + --alias Qwen3.5-9B-VL \ + --host 127.0.0.1 \ + --port 8000 \ + -c 280000 \ + -ngl auto \ + --temp 0.7 \ + --top-p 0.8 \ + --top-k 20 \ + --min-p 0.0 \ + --jinja \ + -m /path/to/models/Qwen3.5-9B-gguf/Qwen3.5-9B-Q6_K.gguf + +# Embedding 端点,实测峰值约 2-3GB 显存 +vllm serve /path/to/models/embedding/Qwen3-Embedding-0.6B \ + --host 127.0.0.1 \ + --port 8001 \ + --runner pooling \ + --gpu-memory-utilization 0.08 \ + --max-model-len 8192 \ + --quantization fp8 \ + --kv-cache-dtype fp8 + +# Reranker 端点,实测峰值约 2-3GB 显存 +vllm serve /path/to/models/embedding/Qwen3-Reranker-0.6B \ + --host 127.0.0.1 \ + --port 8002 \ + --runner pooling \ + --gpu-memory-utilization 0.09 \ + --max-model-len 8192 \ + --quantization fp8 \ + --kv-cache-dtype fp8 +``` + +对应 `.env`: + +```env +LLM_BASE_URL=http://127.0.0.1:8000/v1 +LLM_MODEL_NAME=Qwen3.5-9B-VL +LLM_CONTEXT_WINDOW=280000 +LLM_MAX_CONCURRENCY=1 +ONTOLOGY_MAX_OUTPUT_TOKENS=8192 +LOCAL_ZEP_EXTRACT_MAX_OUTPUT_TOKENS=8192 + +EMBEDDING_BASE_URL=http://127.0.0.1:8001/v1 +EMBEDDING_MODEL_NAME=/path/to/Qwen3-Embedding-0.6B + +RERANKER_BASE_URL=http://127.0.0.1:8002/v1 +RERANKER_MODEL_NAME=/path/to/Qwen3-Reranker-0.6B +LOCAL_ZEP_RERANK_TOP_K=50 +``` + +参数规则: +- `LLM_CONTEXT_WINDOW` 必须和主 LLM 服务暴露的上下文长度一致,确保 `prompt tokens + max output tokens <= LLM_CONTEXT_WINDOW`。 +- 使用上面的 24GB 配置时,`LLM_CONTEXT_WINDOW=280000` 应与 `-c 280000` 保持一致。16k 上下文模型建议 `ONTOLOGY_MAX_OUTPUT_TOKENS=8192`,大约保留一半窗口给输出;8k 端点建议使用 `2048` 到 `4096`。 +- 单个 llama.cpp/vLLM 主端点建议 `LLM_MAX_CONCURRENCY=1`,除非确认服务端可以稳定并发处理 chat completions。 +- Embedding 维度不需要手动配置。本地图谱会记录模型返回的向量;同一个图谱数据库应保持使用同一个 embedding 模型。 +- Reranker 与 embedding 是不同模型。Embedding 负责图谱索引和语义检索;reranker 只负责对候选搜索结果重排。 +- `.env`、`start.sh`、`venv/`、`venv11/`、上传的模拟数据和本地图谱数据库都会被 git 忽略。 + +### Tailscale 访问 + +发布说明:Tailscale 支持主要面向源码/开发部署,即后端和 Vite 代理运行在同一台主机上。它不会自动公开本地 LLM、embedding 或 reranker 端点;除非你明确需要暴露,否则建议这些模型端点继续绑定在 `127.0.0.1`。 + +前端现在默认走同源 `/api`,不再默认写死 `http://localhost:5001`。因此你从另一台 Tailscale 设备访问 Vite 开发服务器时,前端请求会继续命中宿主机后端,而不会错误地访问调用方自己的 localhost。 + +建议在项目根目录 `.env` 中加入: + +```env +VITE_DEV_HOST=0.0.0.0 +VITE_DEV_PORT=3000 +VITE_DEV_PROXY_TARGET=http://127.0.0.1:5001 +VITE_ALLOWED_HOSTS=localhost,127.0.0.1,.ts.net,.beta.tailscale.net +FLASK_HOST=0.0.0.0 +FLASK_PORT=5001 +``` + +然后从另一台 Tailscale 设备访问: + +- 前端:`http://<你的 tailnet 主机名>:3000` +- 后端:`http://<你的 tailnet 主机名>:5001` + +如果你通过 `tailscale serve` 或 `tailscale funnel` 暴露后端,建议设置 `ENABLE_PROXY_FIX=true`。如果通过 MagicDNS 访问时 HMR 连接不稳定,再额外设置 `VITE_HMR_HOST` 为你的 Tailscale 主机名。 + ### 二、Docker 部署 +发布限制:Docker 部署属于旧路径,目前未针对本地 Zep 替代、外部 vLLM/llama.cpp 端点、以及拆分的 Python 3.11 OASIS 运行时做完整验证。需要本地优先工作流时,请使用源码部署。 + ```bash # 1. 配置环境变量(同源码部署) cp .env.example .env diff --git a/README.md b/README.md index 4b8369f4cf..00feb9f04b 100644 --- a/README.md +++ b/README.md @@ -100,8 +100,8 @@ Click the image to watch MiroFish's deep prediction of the lost ending based on | Tool | Version | Description | Check Installation | |------|---------|-------------|-------------------| | **Node.js** | 18+ | Frontend runtime, includes npm | `node -v` | -| **Python** | ≥3.11, ≤3.12 | Backend runtime | `python --version` | -| **uv** | Latest | Python package manager | `uv --version` | +| **Python** | ≥3.11 | Backend runtime. Python 3.13 is supported for the API and local graph backend. | `python --version` | +| **Python 3.11** | 3.11.x | Optional but recommended for OASIS/CAMEL, because `camel-oasis==0.2.5` declares Python `<3.12`. | `python3.11 --version` | #### 1. Configure Environment Variables @@ -121,10 +121,31 @@ cp .env.example .env LLM_API_KEY=your_api_key LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 LLM_MODEL_NAME=qwen-plus - -# Zep Cloud Configuration -# Free monthly quota is sufficient for simple usage: https://app.getzep.com/ -ZEP_API_KEY=your_zep_api_key +LLM_CONTEXT_WINDOW=8192 +LLM_MAX_CONCURRENCY=1 +LLM_JSON_MAX_RETRIES=2 +ONTOLOGY_MAX_OUTPUT_TOKENS=2048 +ONTOLOGY_PROMPT_MARGIN_TOKENS=512 +LOCAL_ZEP_EXTRACT_MAX_OUTPUT_TOKENS=2048 +LOCAL_ZEP_EXTRACT_MAX_RETRIES=2 + +# Local Graph / Embeddings Configuration +# Use an OpenAI-compatible embeddings endpoint. vLLM is a good first option. +EMBEDDING_API_KEY=local-embedding-key +EMBEDDING_BASE_URL=http://localhost:8001/v1 +EMBEDDING_MODEL_NAME=your_embedding_model + +# Optional cross-encoder reranker endpoint. If omitted, local graph search falls back to RRF. +RERANKER_API_KEY=local-reranker-key +RERANKER_BASE_URL=http://localhost:8002/v1 +RERANKER_MODEL_NAME=your_reranker_model +LOCAL_ZEP_RERANK_TOP_K=50 + +# Optional: choose where the local graph SQLite file is stored +LOCAL_ZEP_DB_PATH=backend/data/local_zep.sqlite3 + +# Optional: use a separate Python 3.11 venv for original OASIS/CAMEL scripts +OASIS_PYTHON=/absolute/path/to/venv11/bin/python ``` #### 2. Install Dependencies @@ -144,6 +165,23 @@ npm run setup npm run setup:backend ``` +If your default `python3` is older than 3.11, create the backend venv manually with the desired interpreter before running the backend: + +```bash +python3.13 -m venv venv +./venv/bin/python -m pip install --upgrade pip +./venv/bin/python -m pip install -r backend/requirements.txt +``` + +For the original OASIS/CAMEL simulation engine, use a Python 3.11 venv and point `OASIS_PYTHON` at it: + +```bash +python3.11 -m venv venv11 +./venv11/bin/python -m pip install --upgrade pip +./venv11/bin/python -m pip install -r backend/requirements.txt +echo "OASIS_PYTHON=$(pwd)/venv11/bin/python" >> .env +``` + #### 3. Start Services ```bash @@ -162,8 +200,130 @@ npm run backend # Start backend only npm run frontend # Start frontend only ``` +### Local Model Runtime and Parameters + +MiroFish no longer requires Zep Cloud. The local graph module stores graph data in SQLite and uses your OpenAI-compatible embedding endpoint. A reranker endpoint is optional and is used only when graph search requests `cross_encoder`; otherwise search falls back to local hybrid ranking/RRF. + +Example local endpoints: + +```bash +# Main LLM endpoint, llama.cpp example +llama-server \ + -m /path/to/model.gguf \ + --alias Qwen3.5-9B-VL \ + --host 127.0.0.1 \ + --port 8000 \ + -c 16384 \ + --jinja + +# Embeddings endpoint, vLLM example +vllm serve /path/to/Qwen3-Embedding-0.6B \ + --host 127.0.0.1 \ + --port 8001 \ + --runner pooling \ + --max-model-len 8192 + +# Optional reranker endpoint, vLLM example +vllm serve /path/to/Qwen3-Reranker-0.6B \ + --host 127.0.0.1 \ + --port 8002 \ + --runner pooling \ + --max-model-len 8192 +``` + +Tested 24GB GPU profile (RTX 3090 / RTX 4090 class). Paths are intentionally anonymized: + +```bash +# Main llama.cpp LLM endpoint +./llama-server \ + --mmproj /path/to/models/Qwen3.5-9B-gguf/mmproj-Qwen3.5-9B-BF16.gguf \ + --alias Qwen3.5-9B-VL \ + --host 127.0.0.1 \ + --port 8000 \ + -c 280000 \ + -ngl auto \ + --temp 0.7 \ + --top-p 0.8 \ + --top-k 20 \ + --min-p 0.0 \ + --jinja \ + -m /path/to/models/Qwen3.5-9B-gguf/Qwen3.5-9B-Q6_K.gguf + +# Embedding endpoint. Observed peak is about 2-3GB VRAM. +vllm serve /path/to/models/embedding/Qwen3-Embedding-0.6B \ + --host 127.0.0.1 \ + --port 8001 \ + --runner pooling \ + --gpu-memory-utilization 0.08 \ + --max-model-len 8192 \ + --quantization fp8 \ + --kv-cache-dtype fp8 + +# Reranker endpoint. Observed peak is about 2-3GB VRAM. +vllm serve /path/to/models/embedding/Qwen3-Reranker-0.6B \ + --host 127.0.0.1 \ + --port 8002 \ + --runner pooling \ + --gpu-memory-utilization 0.09 \ + --max-model-len 8192 \ + --quantization fp8 \ + --kv-cache-dtype fp8 +``` + +Matching `.env` settings: + +```env +LLM_BASE_URL=http://127.0.0.1:8000/v1 +LLM_MODEL_NAME=Qwen3.5-9B-VL +LLM_CONTEXT_WINDOW=280000 +LLM_MAX_CONCURRENCY=1 +ONTOLOGY_MAX_OUTPUT_TOKENS=8192 +LOCAL_ZEP_EXTRACT_MAX_OUTPUT_TOKENS=8192 + +EMBEDDING_BASE_URL=http://127.0.0.1:8001/v1 +EMBEDDING_MODEL_NAME=/path/to/Qwen3-Embedding-0.6B + +RERANKER_BASE_URL=http://127.0.0.1:8002/v1 +RERANKER_MODEL_NAME=/path/to/Qwen3-Reranker-0.6B +LOCAL_ZEP_RERANK_TOP_K=50 +``` + +Parameter rules: +- `LLM_CONTEXT_WINDOW` must match the context length exposed by the main LLM server. Keep `prompt tokens + max output tokens <= LLM_CONTEXT_WINDOW`. +- For the 24GB profile above, set `LLM_CONTEXT_WINDOW=280000` to match `-c 280000`. For a `16k` context model, `ONTOLOGY_MAX_OUTPUT_TOKENS=8192` leaves roughly half the window for output. For an `8k` endpoint, use `2048` to `4096`. +- Set `LLM_MAX_CONCURRENCY=1` for single llama.cpp/vLLM main endpoints unless you know the server can handle parallel chat completions. +- Embedding dimension is not configured manually. The local graph records vectors from the embedding model; keep the same embedding model for one graph database. +- The reranker is different from the embedding model. Embeddings are required for graph indexing and semantic search. The reranker is optional and only reorders candidate search results. +- `.env`, `start.sh`, `venv/`, `venv11/`, uploaded simulations, and local graph databases are intentionally ignored by git. + +### Tailscale Access + +Release note: Tailscale support is intended for source-code/dev deployments where the backend and Vite proxy run on the same host. It does not make the local LLM, embedding, or reranker endpoints public by itself; keep those endpoints bound to `127.0.0.1` unless you intentionally expose them. + +The frontend now defaults to same-origin `/api` instead of `http://localhost:5001`, so opening the Vite dev server from another device over Tailscale will still reach the backend running on the host machine. + +Recommended root `.env` settings: + +```env +VITE_DEV_HOST=0.0.0.0 +VITE_DEV_PORT=3000 +VITE_DEV_PROXY_TARGET=http://127.0.0.1:5001 +VITE_ALLOWED_HOSTS=localhost,127.0.0.1,.ts.net,.beta.tailscale.net +FLASK_HOST=0.0.0.0 +FLASK_PORT=5001 +``` + +Then access from another Tailscale device at: + +- Frontend: `http://:3000` +- Backend: `http://:5001` + +If you expose the backend through `tailscale serve` or `tailscale funnel`, set `ENABLE_PROXY_FIX=true`. If HMR does not reconnect correctly over MagicDNS, set `VITE_HMR_HOST` to your Tailscale hostname. + ### Option 2: Docker Deployment +Release constraint: Docker deployment is legacy/not validated for the local Zep replacement, external vLLM/llama.cpp endpoints, and split Python 3.11 OASIS runtime. Use source deployment for the tested local-first workflow. + ```bash # 1. Configure environment variables (same as source deployment) cp .env.example .env @@ -200,4 +360,4 @@ MiroFish's simulation engine is powered by **[OASIS (Open Agent Social Interacti Star History Chart - \ No newline at end of file + diff --git a/backend/app/__init__.py b/backend/app/__init__.py index aba624bba9..50fad46e2c 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -20,6 +20,10 @@ def create_app(config_class=Config): """Flask应用工厂函数""" app = Flask(__name__) app.config.from_object(config_class) + + if app.config.get('ENABLE_PROXY_FIX'): + from werkzeug.middleware.proxy_fix import ProxyFix + app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_port=1) # 设置JSON编码:确保中文直接显示(而不是 \uXXXX 格式) # Flask >= 2.3 使用 app.json.ensure_ascii,旧版本使用 JSON_AS_ASCII 配置 @@ -38,6 +42,8 @@ def create_app(config_class=Config): logger.info("=" * 50) logger.info("MiroFish Backend 启动中...") logger.info("=" * 50) + if app.config.get('ENABLE_PROXY_FIX'): + logger.info("已启用 ProxyFix(适用于 Tailscale Serve/Funnel 等反向代理)") # 启用CORS CORS(app, resources={r"/api/*": {"origins": "*"}}) @@ -77,4 +83,3 @@ def health(): logger.info("MiroFish Backend 启动完成") return app - diff --git a/backend/app/api/graph.py b/backend/app/api/graph.py index 759ff48b0e..2f80668d27 100644 --- a/backend/app/api/graph.py +++ b/backend/app/api/graph.py @@ -283,17 +283,6 @@ def build_graph(): try: logger.info("=== 开始构建图谱 ===") - # 检查配置 - errors = [] - if not Config.ZEP_API_KEY: - errors.append(t('api.zepApiKeyMissing')) - if errors: - logger.error(f"配置错误: {errors}") - return jsonify({ - "success": False, - "error": t('api.configError', details="; ".join(errors)) - }), 500 - # 解析请求 data = request.get_json() or {} project_id = data.get('project_id') @@ -440,7 +429,7 @@ def add_progress_callback(msg, progress_ratio): episode_uuids = builder.add_text_batches( graph_id, chunks, - batch_size=3, + batch_size=1, progress_callback=add_progress_callback ) @@ -572,12 +561,6 @@ def get_graph_data(graph_id: str): 获取图谱数据(节点和边) """ try: - if not Config.ZEP_API_KEY: - return jsonify({ - "success": False, - "error": t('api.zepApiKeyMissing') - }), 500 - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) graph_data = builder.get_graph_data(graph_id) @@ -600,12 +583,6 @@ def delete_graph(graph_id: str): 删除Zep图谱 """ try: - if not Config.ZEP_API_KEY: - return jsonify({ - "success": False, - "error": t('api.zepApiKeyMissing') - }), 500 - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) builder.delete_graph(graph_id) diff --git a/backend/app/api/simulation.py b/backend/app/api/simulation.py index 3a8e1e3fc8..1c5665f4bc 100644 --- a/backend/app/api/simulation.py +++ b/backend/app/api/simulation.py @@ -57,12 +57,6 @@ def get_graph_entities(graph_id: str): enrich: 是否获取相关边信息(默认true) """ try: - if not Config.ZEP_API_KEY: - return jsonify({ - "success": False, - "error": t('api.zepApiKeyMissing') - }), 500 - entity_types_str = request.args.get('entity_types', '') entity_types = [t.strip() for t in entity_types_str.split(',') if t.strip()] if entity_types_str else None enrich = request.args.get('enrich', 'true').lower() == 'true' @@ -94,12 +88,6 @@ def get_graph_entities(graph_id: str): def get_entity_detail(graph_id: str, entity_uuid: str): """获取单个实体的详细信息""" try: - if not Config.ZEP_API_KEY: - return jsonify({ - "success": False, - "error": t('api.zepApiKeyMissing') - }), 500 - reader = ZepEntityReader() entity = reader.get_entity_with_context(graph_id, entity_uuid) @@ -127,12 +115,6 @@ def get_entity_detail(graph_id: str, entity_uuid: str): def get_entities_by_type(graph_id: str, entity_type: str): """获取指定类型的所有实体""" try: - if not Config.ZEP_API_KEY: - return jsonify({ - "success": False, - "error": t('api.zepApiKeyMissing') - }), 500 - enrich = request.args.get('enrich', 'true').lower() == 'true' reader = ZepEntityReader() @@ -466,7 +448,7 @@ def prepare_simulation(): entity_types_list = data.get('entity_types') use_llm_for_profiles = data.get('use_llm_for_profiles', True) - parallel_profile_count = data.get('parallel_profile_count', 5) + parallel_profile_count = data.get('parallel_profile_count', 1) # ========== 同步获取实体数量(在后台任务启动前) ========== # 这样前端在调用prepare后立即就能获取到预期Agent总数 diff --git a/backend/app/config.py b/backend/app/config.py index 953dfa50a2..d8964588a0 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -31,9 +31,31 @@ class Config: LLM_API_KEY = os.environ.get('LLM_API_KEY') LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1') LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini') + LLM_CONTEXT_WINDOW = int(os.environ.get('LLM_CONTEXT_WINDOW', '8192')) + LLM_MAX_CONCURRENCY = int(os.environ.get('LLM_MAX_CONCURRENCY', '1')) + LLM_JSON_MAX_RETRIES = int(os.environ.get('LLM_JSON_MAX_RETRIES', '2')) + ONTOLOGY_MAX_OUTPUT_TOKENS = int(os.environ.get('ONTOLOGY_MAX_OUTPUT_TOKENS', '2048')) + ONTOLOGY_PROMPT_MARGIN_TOKENS = int(os.environ.get('ONTOLOGY_PROMPT_MARGIN_TOKENS', '512')) + ONTOLOGY_MAX_CHUNKS = int(os.environ.get('ONTOLOGY_MAX_CHUNKS', '8')) + LOCAL_ZEP_EXTRACT_MAX_RETRIES = int(os.environ.get('LOCAL_ZEP_EXTRACT_MAX_RETRIES', '2')) + LOCAL_ZEP_EXTRACT_MAX_OUTPUT_TOKENS = int(os.environ.get('LOCAL_ZEP_EXTRACT_MAX_OUTPUT_TOKENS', '2048')) - # Zep配置 - ZEP_API_KEY = os.environ.get('ZEP_API_KEY') + # Local graph / embeddings 配置 + ZEP_API_KEY = os.environ.get('ZEP_API_KEY') # deprecated, ignored by the local graph backend + EMBEDDING_API_KEY = os.environ.get('EMBEDDING_API_KEY', 'local-embedding-key') + EMBEDDING_BASE_URL = os.environ.get('EMBEDDING_BASE_URL') + EMBEDDING_MODEL_NAME = os.environ.get('EMBEDDING_MODEL_NAME') + RERANKER_API_KEY = os.environ.get('RERANKER_API_KEY', 'local-reranker-key') + RERANKER_BASE_URL = os.environ.get('RERANKER_BASE_URL') + RERANKER_MODEL_NAME = os.environ.get('RERANKER_MODEL_NAME') + LOCAL_ZEP_RERANK_TOP_K = int(os.environ.get('LOCAL_ZEP_RERANK_TOP_K', '50')) + LOCAL_ZEP_DB_PATH = os.environ.get( + 'LOCAL_ZEP_DB_PATH', + os.path.join(os.path.dirname(__file__), '../data/local_zep.sqlite3') + ) + PUBLIC_BASE_URL = os.environ.get('PUBLIC_BASE_URL') + TAILSCALE_URL = os.environ.get('TAILSCALE_URL') + ENABLE_PROXY_FIX = os.environ.get('ENABLE_PROXY_FIX', 'False').lower() == 'true' # 文件上传配置 MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB @@ -47,6 +69,7 @@ class Config: # OASIS模拟配置 OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10')) OASIS_SIMULATION_DATA_DIR = os.path.join(os.path.dirname(__file__), '../uploads/simulations') + OASIS_PYTHON = os.environ.get('OASIS_PYTHON') # OASIS平台可用动作配置 OASIS_TWITTER_ACTIONS = [ @@ -69,7 +92,8 @@ def validate(cls): errors = [] if not cls.LLM_API_KEY: errors.append("LLM_API_KEY 未配置") - if not cls.ZEP_API_KEY: - errors.append("ZEP_API_KEY 未配置") + if not cls.EMBEDDING_BASE_URL: + errors.append("EMBEDDING_BASE_URL 未配置") + if not cls.EMBEDDING_MODEL_NAME: + errors.append("EMBEDDING_MODEL_NAME 未配置") return errors - diff --git a/backend/app/services/graph_builder.py b/backend/app/services/graph_builder.py index 37c9969c79..50b118b7ab 100644 --- a/backend/app/services/graph_builder.py +++ b/backend/app/services/graph_builder.py @@ -45,9 +45,6 @@ class GraphBuilderService: def __init__(self, api_key: Optional[str] = None): self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") - self.client = Zep(api_key=self.api_key) self.task_manager = TaskManager() @@ -503,4 +500,3 @@ def get_graph_data(self, graph_id: str) -> Dict[str, Any]: def delete_graph(self, graph_id: str): """删除图谱""" self.client.graph.delete(graph_id=graph_id) - diff --git a/backend/app/services/oasis_profile_generator.py b/backend/app/services/oasis_profile_generator.py index 7704a627eb..89b69963ff 100644 --- a/backend/app/services/oasis_profile_generator.py +++ b/backend/app/services/oasis_profile_generator.py @@ -20,6 +20,7 @@ from ..config import Config from ..utils.logger import get_logger +from ..utils.llm_gate import main_llm_slot from ..utils.locale import get_language_instruction, get_locale, set_locale, t from .zep_entity_reader import EntityNode, ZepEntityReader @@ -203,11 +204,10 @@ def __init__( self.zep_client = None self.graph_id = graph_id - if self.zep_api_key: - try: - self.zep_client = Zep(api_key=self.zep_api_key) - except Exception as e: - logger.warning(f"Zep客户端初始化失败: {e}") + try: + self.zep_client = Zep(api_key=self.zep_api_key) + except Exception as e: + logger.warning(f"Zep客户端初始化失败: {e}") def generate_profile_from_entity( self, @@ -527,16 +527,17 @@ def _generate_profile_with_llm( for attempt in range(max_attempts): try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=[ - {"role": "system", "content": self._get_system_prompt(is_individual)}, - {"role": "user", "content": prompt} - ], - response_format={"type": "json_object"}, - temperature=0.7 - (attempt * 0.1) # 每次重试降低温度 - # 不设置max_tokens,让LLM自由发挥 - ) + with main_llm_slot(): + response = self.client.chat.completions.create( + model=self.model_name, + messages=[ + {"role": "system", "content": self._get_system_prompt(is_individual)}, + {"role": "user", "content": prompt} + ], + response_format={"type": "json_object"}, + temperature=0.7 - (attempt * 0.1) # 每次重试降低温度 + # 不设置max_tokens,让LLM自由发挥 + ) content = response.choices[0].message.content @@ -950,6 +951,7 @@ def generate_single_profile(idx: int, entity: EntityNode) -> tuple: ) return idx, fallback_profile, str(e) + parallel_count = max(1, min(parallel_count, Config.LLM_MAX_CONCURRENCY)) logger.info(f"开始并行生成 {total} 个Agent人设(并行数: {parallel_count})...") print(f"\n{'='*60}") print(f"开始生成Agent人设 - 共 {total} 个实体,并行数: {parallel_count}") @@ -1202,4 +1204,3 @@ def save_profiles_to_json( """[已废弃] 请使用 save_profiles() 方法""" logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法") self.save_profiles(profiles, file_path, platform) - diff --git a/backend/app/services/ontology_generator.py b/backend/app/services/ontology_generator.py index 01a3d799a5..7042ea3138 100644 --- a/backend/app/services/ontology_generator.py +++ b/backend/app/services/ontology_generator.py @@ -4,13 +4,24 @@ """ import json -import logging +import math import re +import traceback from typing import Dict, Any, List, Optional from ..utils.llm_client import LLMClient from ..utils.locale import get_language_instruction +from ..utils.logger import get_logger +from ..config import Config -logger = logging.getLogger(__name__) +logger = get_logger('mirofish.ontology') + + +def _estimate_tokens(text: str) -> int: + """Conservative token estimate for local vLLM context budgeting.""" + text = text or "" + cjk_chars = len(re.findall(r'[\u3400-\u9fff\uf900-\ufaff]', text)) + non_cjk_chars = len(text) - cjk_chars + return cjk_chars + math.ceil(non_cjk_chars / 4) def _to_pascal_case(name: str) -> str: @@ -199,51 +210,499 @@ def generate( Returns: 本体定义(entity_types, edge_types等) """ - # 构建用户消息 - user_message = self._build_user_message( - document_texts, - simulation_requirement, - additional_context - ) - lang_instruction = get_language_instruction() system_prompt = f"{ONTOLOGY_SYSTEM_PROMPT}\n\n{lang_instruction}\nIMPORTANT: Entity type names MUST be in English PascalCase (e.g., 'PersonEntity', 'MediaOrganization'). Relationship type names MUST be in English UPPER_SNAKE_CASE (e.g., 'WORKS_FOR'). Attribute names MUST be in English snake_case. Only description fields and analysis_summary should use the specified language above." - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_message} - ] - - # 调用LLM - result = self.llm_client.chat_json( - messages=messages, - temperature=0.3, - max_tokens=4096 + + chunks = self._build_document_chunks( + document_texts=document_texts, + simulation_requirement=simulation_requirement, + additional_context=additional_context, + system_prompt=system_prompt, ) + fallback_ontology = self._document_aware_fallback(document_texts, simulation_requirement) + logger.info("Ontology generation split into %s LLM chunk(s)", len(chunks)) + + partial_results = [] + for index, chunk in enumerate(chunks, start=1): + user_message = self._build_user_message( + [chunk], + simulation_requirement, + additional_context, + chunk_index=index, + chunk_count=len(chunks), + ) + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_message} + ] + + try: + raw_result = self.llm_client.chat_json( + messages=messages, + temperature=0.3, + max_tokens=Config.ONTOLOGY_MAX_OUTPUT_TOKENS, + max_retries=Config.LLM_JSON_MAX_RETRIES, + ) + result = self._coerce_ontology_result(raw_result) + if self._has_usable_ontology(result): + partial_results.append(result) + else: + logger.warning( + "Ontology LLM chunk %s/%s returned JSON without usable entity_types/edge_types: %s", + index, + len(chunks), + str(raw_result)[:1000], + ) + except Exception as exc: + logger.error("Ontology LLM chunk %s/%s failed: %s", index, len(chunks), exc) + logger.debug(traceback.format_exc()) + + if partial_results: + result = self._merge_ontologies(partial_results, simulation_requirement) + else: + logger.error("All ontology LLM chunks failed, using document-aware fallback ontology") + result = fallback_ontology # 验证和后处理 - result = self._validate_and_process(result) + result = self._validate_and_process(result, fill_ontology=fallback_ontology) return result + + def _coerce_ontology_result(self, result: Dict[str, Any]) -> Dict[str, Any]: + """Accept common local-LLM schema variants and normalize to our shape.""" + if not isinstance(result, dict): + return {} + + for wrapper_key in ("ontology", "data", "result"): + wrapped = result.get(wrapper_key) + if isinstance(wrapped, dict): + result = wrapped + break + + normalized = dict(result) + if "entity_types" not in normalized: + for key in ("entities", "entityTypes", "node_types", "nodeTypes", "nodes"): + if isinstance(normalized.get(key), list): + normalized["entity_types"] = normalized[key] + break + + if "edge_types" not in normalized: + for key in ("relationships", "relations", "edges", "edgeTypes", "relation_types", "relationship_types"): + if isinstance(normalized.get(key), list): + normalized["edge_types"] = normalized[key] + break + + normalized["entity_types"] = [ + self._coerce_entity_type(item) + for item in normalized.get("entity_types", []) + if isinstance(item, dict) + ] + normalized["edge_types"] = [ + self._coerce_edge_type(item) + for item in normalized.get("edge_types", []) + if isinstance(item, dict) + ] + return normalized + + def _coerce_entity_type(self, item: Dict[str, Any]) -> Dict[str, Any]: + name = item.get("name") or item.get("type") or item.get("label") + attributes = item.get("attributes") if isinstance(item.get("attributes"), list) else [] + return { + "name": name, + "description": item.get("description") or item.get("summary") or f"{name} entity.", + "attributes": attributes, + "examples": item.get("examples") if isinstance(item.get("examples"), list) else [], + } + + def _coerce_edge_type(self, item: Dict[str, Any]) -> Dict[str, Any]: + name = item.get("name") or item.get("type") or item.get("relation") + source_targets = item.get("source_targets") + if not isinstance(source_targets, list): + source = item.get("source") or item.get("source_type") + target = item.get("target") or item.get("target_type") + source_targets = [{"source": source or "Person", "target": target or "Organization"}] + return { + "name": name, + "description": item.get("description") or item.get("summary") or f"{name} relation.", + "source_targets": source_targets, + "attributes": item.get("attributes") if isinstance(item.get("attributes"), list) else [], + } + + def _has_usable_ontology(self, result: Dict[str, Any]) -> bool: + return bool(result.get("entity_types")) and bool(result.get("edge_types")) + + def _document_aware_fallback( + self, + document_texts: List[str], + simulation_requirement: str, + ) -> Dict[str, Any]: + corpus = "\n".join(document_texts or [])[:50000] + lower = corpus.lower() + fate_markers = [ + "fate/grand order", + "lostbelt", + "阿瓦隆", + "勒", + "妖精", + "摩根", + "迦勒底", + "不列顛", + "奧伯龍", + "科爾努諾斯", + ] + if any(marker in lower or marker in corpus for marker in fate_markers): + return { + "entity_types": [ + { + "name": "FictionalCharacter", + "description": "Named story character or role in the Lostbelt narrative.", + "attributes": [{"name": "role", "type": "text", "description": "Narrative role"}], + "examples": ["Morgan", "Artoria Caster"], + }, + { + "name": "Faction", + "description": "Political, military, or social group in the setting.", + "attributes": [{"name": "alignment", "type": "text", "description": "Faction alignment"}], + "examples": ["Chaldea", "Round Table"], + }, + { + "name": "FairyClan", + "description": "Fairy clan or species group in Britain.", + "attributes": [{"name": "clan_role", "type": "text", "description": "Clan role"}], + "examples": ["Wind clan", "Fang clan"], + }, + { + "name": "Kingdom", + "description": "Realm, court, or governing power.", + "attributes": [{"name": "ruler", "type": "text", "description": "Known ruler"}], + "examples": ["Camelot", "Fairy Britain"], + }, + { + "name": "Deity", + "description": "Godlike or mythic entity affecting events.", + "attributes": [{"name": "domain", "type": "text", "description": "Mythic domain"}], + "examples": ["Cernunnos"], + }, + { + "name": "Location", + "description": "Named place or region in the chronology.", + "attributes": [{"name": "region_type", "type": "text", "description": "Place category"}], + "examples": ["Avalon", "Britain"], + }, + { + "name": "NarrativeEvent", + "description": "Major battle, calamity, or turning point.", + "attributes": [{"name": "era", "type": "text", "description": "Era or time marker"}], + "examples": ["Great Calamity", "Queen Morgan battle"], + }, + { + "name": "SourceMaterial", + "description": "Official or community source cited by the report.", + "attributes": [{"name": "source_type", "type": "text", "description": "Source category"}], + "examples": ["Road to 7", "official soundtrack"], + }, + { + "name": "Person", + "description": "Any individual not fitting other specific person types.", + "attributes": [{"name": "full_name", "type": "text", "description": "Full name"}], + "examples": ["writer", "commentator"], + }, + { + "name": "Organization", + "description": "Any organization not fitting other specific organization types.", + "attributes": [{"name": "org_name", "type": "text", "description": "Organization name"}], + "examples": ["publisher", "studio"], + }, + ], + "edge_types": [ + { + "name": "APPEARS_IN", + "description": "Entity appears in a source, era, or event.", + "source_targets": [{"source": "FictionalCharacter", "target": "NarrativeEvent"}], + "attributes": [], + }, + { + "name": "RULES", + "description": "Character or power rules a realm or group.", + "source_targets": [{"source": "FictionalCharacter", "target": "Kingdom"}], + "attributes": [], + }, + { + "name": "ALLIED_WITH", + "description": "Entity is allied or cooperating with another.", + "source_targets": [{"source": "Faction", "target": "Faction"}, {"source": "FictionalCharacter", "target": "Faction"}], + "attributes": [], + }, + { + "name": "OPPOSES", + "description": "Entity opposes another entity or faction.", + "source_targets": [{"source": "FictionalCharacter", "target": "FictionalCharacter"}, {"source": "Faction", "target": "Faction"}], + "attributes": [], + }, + { + "name": "LOCATED_IN", + "description": "Entity or event is located in a place.", + "source_targets": [{"source": "NarrativeEvent", "target": "Location"}, {"source": "Kingdom", "target": "Location"}], + "attributes": [], + }, + { + "name": "CAUSES", + "description": "Entity or event causes another event.", + "source_targets": [{"source": "NarrativeEvent", "target": "NarrativeEvent"}, {"source": "Deity", "target": "NarrativeEvent"}], + "attributes": [], + }, + { + "name": "DOCUMENTS", + "description": "Source material documents an entity or event.", + "source_targets": [{"source": "SourceMaterial", "target": "NarrativeEvent"}, {"source": "SourceMaterial", "target": "FictionalCharacter"}], + "attributes": [], + }, + { + "name": "TRANSFORMS_INTO", + "description": "Entity changes form, role, or state.", + "source_targets": [{"source": "FictionalCharacter", "target": "FictionalCharacter"}, {"source": "NarrativeEvent", "target": "NarrativeEvent"}], + "attributes": [], + }, + ], + "analysis_summary": f"Document-aware fallback ontology generated for Fate/Lostbelt content: {simulation_requirement[:200]}", + } + + return self._fallback_ontology(simulation_requirement) + + def _fallback_ontology(self, simulation_requirement: str) -> Dict[str, Any]: + """Deterministic ontology used when a local LLM fails JSON generation.""" + return { + "entity_types": [ + { + "name": "Journalist", + "description": "Reporter or editor participating in public discourse.", + "attributes": [{"name": "role", "type": "text", "description": "Media role"}], + "examples": ["reporter", "editor"], + }, + { + "name": "MediaOutlet", + "description": "Media organization publishing news or commentary.", + "attributes": [{"name": "org_name", "type": "text", "description": "Outlet name"}], + "examples": ["newspaper", "online media"], + }, + { + "name": "Company", + "description": "Business organization involved in the issue.", + "attributes": [{"name": "industry", "type": "text", "description": "Industry"}], + "examples": ["company", "platform"], + }, + { + "name": "GovernmentAgency", + "description": "Government or regulator relevant to the event.", + "attributes": [{"name": "jurisdiction", "type": "text", "description": "Jurisdiction"}], + "examples": ["regulator", "department"], + }, + { + "name": "Official", + "description": "Public official or authority figure.", + "attributes": [{"name": "title", "type": "text", "description": "Official title"}], + "examples": ["mayor", "spokesperson"], + }, + { + "name": "Expert", + "description": "Analyst, scholar, or professional commentator.", + "attributes": [{"name": "specialty", "type": "text", "description": "Expertise"}], + "examples": ["researcher", "lawyer"], + }, + { + "name": "CommunityGroup", + "description": "Grassroots group or collective actor.", + "attributes": [{"name": "focus", "type": "text", "description": "Group focus"}], + "examples": ["local group", "advocacy group"], + }, + { + "name": "Influencer", + "description": "Online personality with audience influence.", + "attributes": [{"name": "platform", "type": "text", "description": "Main platform"}], + "examples": ["blogger", "creator"], + }, + { + "name": "Person", + "description": "Any individual person not fitting specific person types.", + "attributes": [{"name": "full_name", "type": "text", "description": "Full name"}], + "examples": ["ordinary citizen", "witness"], + }, + { + "name": "Organization", + "description": "Any organization not fitting specific organization types.", + "attributes": [{"name": "org_name", "type": "text", "description": "Organization name"}], + "examples": ["association", "small organization"], + }, + ], + "edge_types": [ + { + "name": "WORKS_FOR", + "description": "Employment or affiliation relationship.", + "source_targets": [{"source": "Person", "target": "Organization"}, {"source": "Journalist", "target": "MediaOutlet"}], + "attributes": [], + }, + { + "name": "REPORTS_ON", + "description": "Publishes or reports about an actor.", + "source_targets": [{"source": "MediaOutlet", "target": "Organization"}, {"source": "Journalist", "target": "Person"}], + "attributes": [], + }, + { + "name": "RESPONDS_TO", + "description": "Publicly responds to another actor.", + "source_targets": [{"source": "Person", "target": "Person"}, {"source": "Organization", "target": "Organization"}], + "attributes": [], + }, + { + "name": "SUPPORTS", + "description": "Expresses support for another actor.", + "source_targets": [{"source": "Person", "target": "Organization"}, {"source": "Organization", "target": "Person"}], + "attributes": [], + }, + { + "name": "OPPOSES", + "description": "Expresses opposition to another actor.", + "source_targets": [{"source": "Person", "target": "Organization"}, {"source": "Organization", "target": "Person"}], + "attributes": [], + }, + { + "name": "COLLABORATES_WITH", + "description": "Cooperates with another actor.", + "source_targets": [{"source": "Organization", "target": "Organization"}, {"source": "Person", "target": "Person"}], + "attributes": [], + }, + { + "name": "INFLUENCES", + "description": "Influences opinions or decisions.", + "source_targets": [{"source": "Influencer", "target": "Person"}, {"source": "MediaOutlet", "target": "Person"}], + "attributes": [], + }, + { + "name": "REGULATES", + "description": "Regulatory or oversight relation.", + "source_targets": [{"source": "GovernmentAgency", "target": "Company"}, {"source": "Official", "target": "Organization"}], + "attributes": [], + }, + ], + "analysis_summary": f"Fallback ontology generated for: {simulation_requirement[:200]}", + } - # 传给 LLM 的文本最大长度(5万字) - MAX_TEXT_LENGTH_FOR_LLM = 50000 + def _context_input_budget( + self, + system_prompt: str, + simulation_requirement: str, + additional_context: Optional[str], + ) -> int: + empty_user_message = self._build_user_message( + [""], + simulation_requirement, + additional_context, + chunk_index=1, + chunk_count=1, + ) + reserved = ( + _estimate_tokens(system_prompt) + + _estimate_tokens(empty_user_message) + + Config.ONTOLOGY_MAX_OUTPUT_TOKENS + + Config.ONTOLOGY_PROMPT_MARGIN_TOKENS + ) + return max(512, Config.LLM_CONTEXT_WINDOW - reserved) + + def _build_document_chunks( + self, + document_texts: List[str], + simulation_requirement: str, + additional_context: Optional[str], + system_prompt: str, + ) -> List[str]: + budget = self._context_input_budget(system_prompt, simulation_requirement, additional_context) + chunks: List[str] = [] + + for text in document_texts: + normalized = text or "" + current_parts: List[str] = [] + current_tokens = 0 + for part in self._iter_text_parts(normalized, budget): + part_tokens = _estimate_tokens(part) + if current_parts and current_tokens + part_tokens > budget: + chunks.append("\n\n".join(current_parts)) + current_parts = [] + current_tokens = 0 + current_parts.append(part) + current_tokens += part_tokens + if current_parts: + chunks.append("\n\n".join(current_parts)) + + if not chunks: + return [""] + + max_chunks = max(1, Config.ONTOLOGY_MAX_CHUNKS) + if len(chunks) > max_chunks: + logger.warning( + "Ontology input produced %s chunks; keeping first %s to avoid endpoint overload", + len(chunks), + max_chunks, + ) + chunks = chunks[:max_chunks] + + return chunks + + def _iter_text_parts(self, text: str, budget: int): + paragraphs = [part.strip() for part in re.split(r'\n{2,}', text) if part.strip()] + if not paragraphs: + paragraphs = [text.strip()] if text.strip() else [""] + + for paragraph in paragraphs: + if _estimate_tokens(paragraph) <= budget: + yield paragraph + continue + + sentences = [part.strip() for part in re.split(r'(?<=[。!?.!?])\s*', paragraph) if part.strip()] + if not sentences: + sentences = [paragraph] + + current = "" + for sentence in sentences: + if _estimate_tokens(sentence) > budget: + if current: + yield current + current = "" + yield from self._split_oversize_text(sentence, budget) + continue + candidate = f"{current}\n{sentence}".strip() if current else sentence + if current and _estimate_tokens(candidate) > budget: + yield current + current = sentence + else: + current = candidate + if current: + yield current + + def _split_oversize_text(self, text: str, budget: int): + # Conservative char window: CJK can be close to one token per char. + window = max(800, min(len(text), budget * 2)) + start = 0 + while start < len(text): + end = min(len(text), start + window) + chunk = text[start:end] + while _estimate_tokens(chunk) > budget and len(chunk) > 500: + end = start + max(500, (end - start) // 2) + chunk = text[start:end] + yield chunk + start = end def _build_user_message( self, document_texts: List[str], simulation_requirement: str, - additional_context: Optional[str] + additional_context: Optional[str], + chunk_index: int = 1, + chunk_count: int = 1, ) -> str: """构建用户消息""" # 合并文本 combined_text = "\n\n---\n\n".join(document_texts) - original_length = len(combined_text) - - # 如果文本超过5万字,截断(仅影响传给LLM的内容,不影响图谱构建) - if len(combined_text) > self.MAX_TEXT_LENGTH_FOR_LLM: - combined_text = combined_text[:self.MAX_TEXT_LENGTH_FOR_LLM] - combined_text += f"\n\n...(原文共{original_length}字,已截取前{self.MAX_TEXT_LENGTH_FOR_LLM}字用于本体分析)..." message = f"""## 模拟需求 @@ -271,10 +730,55 @@ def _build_user_message( 4. 所有实体类型必须是现实中可以发声的主体,不能是抽象概念 5. 属性名不能使用 name、uuid、group_id 等保留字,用 full_name、org_name 等替代 """ + if chunk_count > 1: + message += f"\n当前是第 {chunk_index}/{chunk_count} 个文本分片。请只基于当前分片生成候选本体,后续系统会合并去重。\n" return message + + def _merge_ontologies(self, results: List[Dict[str, Any]], simulation_requirement: str) -> Dict[str, Any]: + merged = { + "entity_types": [], + "edge_types": [], + "analysis_summary": "", + } + seen_entities = set() + seen_edges = set() + summaries = [] + + for result in results: + for entity in result.get("entity_types", []): + name = _to_pascal_case(str(entity.get("name", ""))) + if not name or name in seen_entities: + continue + seen_entities.add(name) + entity = dict(entity) + entity["name"] = name + merged["entity_types"].append(entity) + + for edge in result.get("edge_types", []): + name = str(edge.get("name", "")).upper() + if not name or name in seen_edges: + continue + seen_edges.add(name) + edge = dict(edge) + edge["name"] = name + merged["edge_types"].append(edge) + + summary = str(result.get("analysis_summary", "")).strip() + if summary: + summaries.append(summary) + + if not merged["entity_types"] or not merged["edge_types"]: + return self._fallback_ontology(simulation_requirement) + + merged["analysis_summary"] = " ".join(summaries[:3]) or f"Ontology generated for: {simulation_requirement[:200]}" + return merged - def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]: + def _validate_and_process( + self, + result: Dict[str, Any], + fill_ontology: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: """验证和后处理结果""" # 确保必要字段存在 @@ -387,6 +891,21 @@ def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]: # 添加兜底类型 result["entity_types"].extend(fallbacks_to_add) + + # Local LLMs sometimes return too few types. Fill from deterministic + # social-simulation defaults so downstream graph setup always has a + # usable ontology instead of failing later. + fill_ontology = fill_ontology or self._fallback_ontology("") + + if len(result["entity_types"]) < MAX_ENTITY_TYPES: + entity_names = {e["name"] for e in result["entity_types"]} + for fallback_entity in fill_ontology.get("entity_types", []): + if len(result["entity_types"]) >= MAX_ENTITY_TYPES: + break + if fallback_entity["name"] in entity_names: + continue + result["entity_types"].append(fallback_entity) + entity_names.add(fallback_entity["name"]) # 最终确保不超过限制(防御性编程) if len(result["entity_types"]) > MAX_ENTITY_TYPES: @@ -394,6 +913,16 @@ def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]: if len(result["edge_types"]) > MAX_EDGE_TYPES: result["edge_types"] = result["edge_types"][:MAX_EDGE_TYPES] + + if len(result["edge_types"]) < 6: + edge_names = {edge.get("name") for edge in result["edge_types"]} + for fallback_edge in fill_ontology.get("edge_types", []): + if len(result["edge_types"]) >= 6: + break + if fallback_edge["name"] in edge_names: + continue + result["edge_types"].append(fallback_edge) + edge_names.add(fallback_edge["name"]) return result @@ -503,4 +1032,3 @@ def generate_python_code(self, ontology: Dict[str, Any]) -> str: code_lines.append('}') return '\n'.join(code_lines) - diff --git a/backend/app/services/simulation_config_generator.py b/backend/app/services/simulation_config_generator.py index cb77f6b6cd..5f5fdb111c 100644 --- a/backend/app/services/simulation_config_generator.py +++ b/backend/app/services/simulation_config_generator.py @@ -20,6 +20,7 @@ from ..config import Config from ..utils.logger import get_logger +from ..utils.llm_gate import main_llm_slot from ..utils.locale import get_language_instruction, t from .zep_entity_reader import EntityNode, ZepEntityReader @@ -440,16 +441,17 @@ def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any for attempt in range(max_attempts): try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt} - ], - response_format={"type": "json_object"}, - temperature=0.7 - (attempt * 0.1) # 每次重试降低温度 - # 不设置max_tokens,让LLM自由发挥 - ) + with main_llm_slot(): + response = self.client.chat.completions.create( + model=self.model_name, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt} + ], + response_format={"type": "json_object"}, + temperature=0.7 - (attempt * 0.1) # 每次重试降低温度 + # 不设置max_tokens,让LLM自由发挥 + ) content = response.choices[0].message.content finish_reason = response.choices[0].finish_reason @@ -988,4 +990,3 @@ def _generate_agent_config_by_rule(self, entity: EntityNode) -> Dict[str, Any]: "influence_weight": 1.0 } - diff --git a/backend/app/services/simulation_manager.py b/backend/app/services/simulation_manager.py index 0d161a9095..73fde73453 100644 --- a/backend/app/services/simulation_manager.py +++ b/backend/app/services/simulation_manager.py @@ -235,7 +235,7 @@ def prepare_simulation( defined_entity_types: Optional[List[str]] = None, use_llm_for_profiles: bool = True, progress_callback: Optional[callable] = None, - parallel_profile_count: int = 3 + parallel_profile_count: int = 1 ) -> SimulationState: """ 准备模拟环境(全程自动化) diff --git a/backend/app/services/simulation_runner.py b/backend/app/services/simulation_runner.py index e86021f808..f9eeb18426 100644 --- a/backend/app/services/simulation_runner.py +++ b/backend/app/services/simulation_runner.py @@ -414,7 +414,7 @@ def start_simulation( # simulation.log - 主进程日志 cmd = [ - sys.executable, # Python解释器 + Config.OASIS_PYTHON or sys.executable, # Python解释器 script_path, "--config", config_path, # 使用完整配置文件路径 ] @@ -1765,4 +1765,3 @@ def get_interview_history( results = results[:limit] return results - diff --git a/backend/app/services/zep_entity_reader.py b/backend/app/services/zep_entity_reader.py index 71661be499..25172f4110 100644 --- a/backend/app/services/zep_entity_reader.py +++ b/backend/app/services/zep_entity_reader.py @@ -80,9 +80,6 @@ class ZepEntityReader: def __init__(self, api_key: Optional[str] = None): self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") - self.client = Zep(api_key=self.api_key) def _call_with_retry( @@ -434,4 +431,3 @@ def get_entities_by_type( ) return result.entities - diff --git a/backend/app/services/zep_graph_memory_updater.py b/backend/app/services/zep_graph_memory_updater.py index e034fee2b2..5fe4a001c4 100644 --- a/backend/app/services/zep_graph_memory_updater.py +++ b/backend/app/services/zep_graph_memory_updater.py @@ -239,10 +239,6 @@ def __init__(self, graph_id: str, api_key: Optional[str] = None): """ self.graph_id = graph_id self.api_key = api_key or Config.ZEP_API_KEY - - if not self.api_key: - raise ValueError("ZEP_API_KEY未配置") - self.client = Zep(api_key=self.api_key) # 活动队列 diff --git a/backend/app/services/zep_tools.py b/backend/app/services/zep_tools.py index 3bc8a57abb..78278d8083 100644 --- a/backend/app/services/zep_tools.py +++ b/backend/app/services/zep_tools.py @@ -424,9 +424,6 @@ class ZepToolsService: def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient] = None): self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") - self.client = Zep(api_key=self.api_key) # LLM客户端用于InsightForge生成子问题 self._llm_client = llm_client diff --git a/backend/app/utils/llm_client.py b/backend/app/utils/llm_client.py index 6c1a81f49b..05ec6ef573 100644 --- a/backend/app/utils/llm_client.py +++ b/backend/app/utils/llm_client.py @@ -5,10 +5,15 @@ import json import re +import time from typing import Optional, Dict, Any, List from openai import OpenAI from ..config import Config +from .logger import get_logger +from .llm_gate import main_llm_slot + +logger = get_logger('mirofish.llm') class LLMClient: @@ -61,17 +66,20 @@ def chat( if response_format: kwargs["response_format"] = response_format - response = self.client.chat.completions.create(**kwargs) + with main_llm_slot(): + response = self.client.chat.completions.create(**kwargs) content = response.choices[0].message.content # 部分模型(如MiniMax M2.5)会在content中包含思考内容,需要移除 content = re.sub(r'[\s\S]*?', '', content).strip() + content = re.sub(r'^Thinking Process:[\s\S]*?(?=\{|\[)', '', content).strip() return content def chat_json( self, messages: List[Dict[str, str]], temperature: float = 0.3, - max_tokens: int = 4096 + max_tokens: int = 4096, + max_retries: Optional[int] = None ) -> Dict[str, Any]: """ 发送聊天请求并返回JSON @@ -84,20 +92,132 @@ def chat_json( Returns: 解析后的JSON对象 """ - response = self.chat( - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - response_format={"type": "json_object"} - ) - # 清理markdown代码块标记 - cleaned_response = response.strip() - cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE) - cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response) - cleaned_response = cleaned_response.strip() + retries = Config.LLM_JSON_MAX_RETRIES if max_retries is None else max_retries + last_error: Exception | None = None + last_response = "" + + for attempt in range(retries + 1): + attempt_messages = list(messages) + if attempt > 0: + attempt_messages.append({ + "role": "user", + "content": ( + "The previous answer was not valid JSON. Return exactly one valid JSON object. " + "Do not include markdown fences, explanation, comments, or thinking text." + ) + }) + + try: + response = self.chat( + messages=attempt_messages, + temperature=temperature, + max_tokens=max_tokens, + response_format={"type": "json_object"} + ) + last_response = response + parsed = self._parse_json_lenient(response) + if parsed is not None: + if attempt > 0: + logger.info("LLM JSON recovered after retry %s/%s", attempt, retries) + return parsed + + last_error = ValueError("LLM returned unparsable JSON") + logger.warning( + "LLM returned invalid JSON on attempt %s/%s: %s", + attempt + 1, + retries + 1, + self._cleanup_response(response)[:1000], + ) + except Exception as exc: + last_error = exc + logger.warning( + "LLM JSON call failed on attempt %s/%s: %s", + attempt + 1, + retries + 1, + str(exc)[:1000], + ) + + if attempt < retries: + time.sleep(min(2.0, 0.5 * (attempt + 1))) + + cleaned = self._cleanup_response(last_response) + if cleaned: + logger.error("LLM returned invalid JSON after retries: %s", cleaned[:2000]) + raise ValueError(f"LLM返回的JSON格式无效: {cleaned[:1000]}") + raise last_error or ValueError("LLM JSON call failed") + + def _parse_json_lenient(self, text: str) -> Dict[str, Any] | None: + cleaned = self._cleanup_response(text) + candidates = [cleaned] + + for block in re.findall(r"```(?:json)?\s*([\s\S]*?)```", text or "", flags=re.IGNORECASE): + candidates.append(self._cleanup_response(block)) + + candidates.extend(self._balanced_json_objects(cleaned)) + + start = cleaned.find('{') + end = cleaned.rfind('}') + if start >= 0 and end > start: + candidates.append(cleaned[start:end + 1]) + + for candidate in candidates: + candidate = candidate.strip() + if not candidate: + continue + repaired = self._repair_json(candidate) + try: + parsed = json.loads(repaired) + return parsed if isinstance(parsed, dict) else None + except json.JSONDecodeError: + pass + + return None + + def _cleanup_response(self, text: str) -> str: + cleaned = text or "" + cleaned = cleaned.replace("\ufeff", "").replace("\u200b", "") + cleaned = re.sub(r"[\s\S]*?", "", cleaned).strip() + cleaned = re.sub(r"^Thinking Process:[\s\S]*?(?=\{|\[)", "", cleaned).strip() + cleaned = re.sub(r"^```(?:json)?\s*\n?", "", cleaned, flags=re.IGNORECASE) + cleaned = re.sub(r"\n?```\s*$", "", cleaned) + return cleaned.strip() + + def _repair_json(self, text: str) -> str: + repaired = text.strip() + repaired = re.sub(r",(\s*[}\]])", r"\1", repaired) + repaired = re.sub(r"\bNone\b", "null", repaired) + repaired = re.sub(r"\bTrue\b", "true", repaired) + repaired = re.sub(r"\bFalse\b", "false", repaired) + return repaired + + def _balanced_json_objects(self, text: str) -> List[str]: + objects: List[str] = [] + start = None + depth = 0 + in_string = False + escape = False + + for index, char in enumerate(text or ""): + if in_string: + if escape: + escape = False + elif char == "\\": + escape = True + elif char == '"': + in_string = False + continue - try: - return json.loads(cleaned_response) - except json.JSONDecodeError: - raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}") + if char == '"': + in_string = True + elif char == "{": + if depth == 0: + start = index + depth += 1 + elif char == "}" and depth: + depth -= 1 + if depth == 0 and start is not None: + objects.append(text[start:index + 1]) + start = None + # Prefer larger objects first; noisy LLM text may contain small examples. + return sorted(objects, key=len, reverse=True) diff --git a/backend/app/utils/llm_gate.py b/backend/app/utils/llm_gate.py new file mode 100644 index 0000000000..29bd367781 --- /dev/null +++ b/backend/app/utils/llm_gate.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from contextlib import contextmanager +from threading import BoundedSemaphore + +from ..config import Config + + +_MAIN_LLM_GATE = BoundedSemaphore(max(1, Config.LLM_MAX_CONCURRENCY)) + + +@contextmanager +def main_llm_slot(): + """Limit concurrent calls to the local main LLM endpoint.""" + _MAIN_LLM_GATE.acquire() + try: + yield + finally: + _MAIN_LLM_GATE.release() diff --git a/backend/local_zep/__init__.py b/backend/local_zep/__init__.py new file mode 100644 index 0000000000..93e8d83415 --- /dev/null +++ b/backend/local_zep/__init__.py @@ -0,0 +1,9 @@ +from .client import Zep +from .models import EpisodeData, EntityEdgeSourceTarget, InternalServerError + +__all__ = [ + "EntityEdgeSourceTarget", + "EpisodeData", + "InternalServerError", + "Zep", +] diff --git a/backend/local_zep/client.py b/backend/local_zep/client.py new file mode 100644 index 0000000000..01eebec936 --- /dev/null +++ b/backend/local_zep/client.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +from typing import Any + +from .models import EntityEdgeSourceTarget, EpisodeData, GraphSearchResults +from .store import LocalZepStore + + +_DEFAULT_STORE: LocalZepStore | None = None + + +def get_default_store() -> LocalZepStore: + global _DEFAULT_STORE + if _DEFAULT_STORE is None: + _DEFAULT_STORE = LocalZepStore() + return _DEFAULT_STORE + + +def _model_fields_to_attributes(model_cls: Any) -> list[dict[str, str]]: + fields = getattr(model_cls, "model_fields", {}) or {} + result = [] + for field_name, field_info in fields.items(): + result.append( + { + "name": field_name, + "description": getattr(field_info, "description", None) or field_name, + } + ) + return result + + +def _compat_ontology_from_models( + entities: dict[str, Any] | None, + edges: dict[str, Any] | None, +) -> dict[str, Any]: + ontology = {"entity_types": [], "edge_types": []} + + for entity_name, entity_cls in (entities or {}).items(): + ontology["entity_types"].append( + { + "name": entity_name, + "description": getattr(entity_cls, "__doc__", "") or f"A {entity_name} entity.", + "attributes": _model_fields_to_attributes(entity_cls), + } + ) + + for edge_name, edge_value in (edges or {}).items(): + edge_cls, source_targets = edge_value + formatted_targets = [] + for pair in source_targets or []: + if isinstance(pair, EntityEdgeSourceTarget): + formatted_targets.append({"source": pair.source, "target": pair.target}) + else: + formatted_targets.append( + { + "source": getattr(pair, "source", "Entity"), + "target": getattr(pair, "target", "Entity"), + } + ) + ontology["edge_types"].append( + { + "name": edge_name, + "description": getattr(edge_cls, "__doc__", "") or f"A {edge_name} relationship.", + "attributes": _model_fields_to_attributes(edge_cls), + "source_targets": formatted_targets, + } + ) + + return ontology + + +class EpisodeManager: + def __init__(self, store: LocalZepStore) -> None: + self.store = store + + def get(self, uuid_: str): + return self.store.get_episode(uuid_) + + def get_by_graph_id(self, graph_id: str, lastn: int | None = None): + return self.store.get_episodes_by_graph_id(graph_id=graph_id, lastn=lastn) + + +class NodeManager: + def __init__(self, store: LocalZepStore) -> None: + self.store = store + + def get_by_graph_id(self, graph_id: str, limit: int = 100, uuid_cursor: str | None = None): + return self.store.get_nodes_page(graph_id=graph_id, limit=limit, uuid_cursor=uuid_cursor) + + def get_by_user_id(self, user_id: str, limit: int = 100, uuid_cursor: str | None = None): + return self.get_by_graph_id(graph_id=user_id, limit=limit, uuid_cursor=uuid_cursor) + + def get(self, uuid_: str): + return self.store.get_node(uuid_) + + def get_entity_edges(self, node_uuid: str): + return self.store.get_entity_edges(node_uuid) + + def get_edges(self, node_uuid: str): + return self.get_entity_edges(node_uuid) + + +class EdgeManager: + def __init__(self, store: LocalZepStore) -> None: + self.store = store + + def get_by_graph_id(self, graph_id: str, limit: int = 100, uuid_cursor: str | None = None): + return self.store.get_edges_page(graph_id=graph_id, limit=limit, uuid_cursor=uuid_cursor) + + def get_by_user_id(self, user_id: str, limit: int = 100, uuid_cursor: str | None = None): + return self.get_by_graph_id(graph_id=user_id, limit=limit, uuid_cursor=uuid_cursor) + + def get(self, uuid_: str): + return self.store.get_edge(uuid_) + + +class GraphManager: + def __init__(self, store: LocalZepStore) -> None: + self.store = store + self.node = NodeManager(store) + self.edge = EdgeManager(store) + self.episode = EpisodeManager(store) + + def create(self, graph_id: str, name: str = "", description: str = ""): + return self.store.create_graph(graph_id=graph_id, name=name, description=description) + + def get(self, graph_id: str): + return self.store.get_graph(graph_id=graph_id) + + def delete(self, graph_id: str): + self.store.delete_graph(graph_id=graph_id) + + def set_ontology( + self, + graph_ids: list[str] | None = None, + user_ids: list[str] | None = None, + entities: dict[str, Any] | None = None, + edges: dict[str, Any] | None = None, + ontology: dict[str, Any] | None = None, + **_: Any, + ): + graph_ids = graph_ids or user_ids or [] + parsed = ontology or _compat_ontology_from_models(entities, edges) + for graph_id in graph_ids: + self.store.set_ontology(graph_id=graph_id, ontology=parsed) + + def add( + self, + graph_id: str | None = None, + data: str = "", + type: str = "text", + user_id: str | None = None, + created_at: str | None = None, + metadata: dict[str, Any] | None = None, + source_description: str | None = None, + **extra: Any, + ): + graph_id = graph_id or user_id or extra.get("graphId") or extra.get("userId") + if not graph_id: + raise ValueError("graph_id or user_id is required") + return self.store.add( + graph_id=graph_id, + data=data, + type_=type, + created_at=created_at or extra.get("createdAt"), + metadata=metadata, + source_description=source_description or extra.get("sourceDescription"), + ) + + def add_batch(self, graph_id: str | None = None, episodes: list[EpisodeData] | None = None, user_id: str | None = None, **extra: Any): + graph_id = graph_id or user_id or extra.get("graphId") or extra.get("userId") + if not graph_id: + raise ValueError("graph_id or user_id is required") + return self.store.add_batch(graph_id=graph_id, episodes=episodes or []) + + def search( + self, + graph_id: str | None = None, + query: str = "", + limit: int = 10, + scope: str = "edges", + user_id: str | None = None, + reranker: str = "rrf", + mmr_lambda: float | None = None, + center_node_uuid: str | None = None, + search_filters: Any = None, + bfs_origin_node_uuids: list[str] | None = None, + **extra: Any, + ) -> GraphSearchResults: + graph_id = graph_id or user_id or extra.get("graphId") or extra.get("userId") + if not graph_id: + return GraphSearchResults() + if mmr_lambda is None: + mmr_lambda = extra.get("mmrLambda") + if center_node_uuid is None: + center_node_uuid = extra.get("centerNodeUuid") + if search_filters is None: + search_filters = extra.get("searchFilters") + if bfs_origin_node_uuids is None: + bfs_origin_node_uuids = extra.get("bfsOriginNodeUuids") + return self.store.search( + graph_id=graph_id, + query=query, + limit=limit, + scope=scope, + reranker=reranker, + mmr_lambda=mmr_lambda, + center_node_uuid=center_node_uuid, + search_filters=search_filters, + bfs_origin_node_uuids=bfs_origin_node_uuids, + ) + + +class Zep: + def __init__(self, api_key: str | None = None, **_: Any) -> None: + del api_key + self._store = get_default_store() + self.graph = GraphManager(self._store) diff --git a/backend/local_zep/embeddings.py b/backend/local_zep/embeddings.py new file mode 100644 index 0000000000..054f8c59f7 --- /dev/null +++ b/backend/local_zep/embeddings.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import math +from typing import Iterable + +from .settings import settings + + +def cosine_similarity(left: list[float], right: list[float]) -> float: + if not left or not right or len(left) != len(right): + return 0.0 + + left_norm = math.sqrt(sum(value * value for value in left)) + right_norm = math.sqrt(sum(value * value for value in right)) + if left_norm == 0 or right_norm == 0: + return 0.0 + + dot = sum(a * b for a, b in zip(left, right)) + return dot / (left_norm * right_norm) + + +class EmbeddingClient: + """OpenAI-compatible embeddings client with vLLM-friendly defaults.""" + + def __init__( + self, + api_key: str | None = None, + base_url: str | None = None, + model_name: str | None = None, + ) -> None: + self.api_key = api_key or settings.embedding_api_key or "local-embedding-key" + self.base_url = (base_url or settings.embedding_base_url or "").strip() + self.model_name = (model_name or settings.embedding_model_name or "").strip() + + if not self.base_url: + raise ValueError("EMBEDDING_BASE_URL 未配置") + if not self.model_name: + raise ValueError("EMBEDDING_MODEL_NAME 未配置") + + from openai import OpenAI + + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url, + ) + + def embed_texts(self, texts: Iterable[str], batch_size: int = 32) -> list[list[float]]: + normalized = [text.strip() if text else "" for text in texts] + if not normalized: + return [] + + results: list[list[float]] = [] + for start in range(0, len(normalized), batch_size): + batch = normalized[start:start + batch_size] + response = self.client.embeddings.create( + model=self.model_name, + input=batch, + ) + ordered = sorted(response.data, key=lambda item: item.index) + results.extend([list(item.embedding) for item in ordered]) + return results + + def embed_text(self, text: str) -> list[float]: + embeddings = self.embed_texts([text]) + return embeddings[0] if embeddings else [] diff --git a/backend/local_zep/extraction.py b/backend/local_zep/extraction.py new file mode 100644 index 0000000000..e09f7fba6f --- /dev/null +++ b/backend/local_zep/extraction.py @@ -0,0 +1,339 @@ +from __future__ import annotations + +import json +import re +import time +from typing import Any + +from .settings import settings + +try: + from app.utils.logger import get_logger + from app.utils.llm_gate import main_llm_slot + logger = get_logger("mirofish.local_zep.extraction") +except Exception: # pragma: no cover - fallback for direct package use + import logging + from contextlib import nullcontext + main_llm_slot = nullcontext + logger = logging.getLogger(__name__) + + +def _clean_string(value: Any) -> str: + if value is None: + return "" + return str(value).strip() + + +def _coerce_mapping(value: Any) -> dict[str, str]: + if isinstance(value, dict): + result = {} + for key, item in value.items(): + cleaned = _clean_string(item) + if cleaned: + result[str(key)] = cleaned + return result + return {} + + +def _normalize_type_name(value: str) -> str: + cleaned = _clean_string(value) + if not cleaned: + return "Entity" + return re.sub(r"\s+", "", cleaned) + + +def _clean_timestamp(value: Any) -> str | None: + cleaned = _clean_string(value) + return cleaned or None + + +def _cleanup_model_json(text: str) -> str: + cleaned = text or "" + cleaned = cleaned.replace("\ufeff", "").replace("\u200b", "") + cleaned = re.sub(r"[\s\S]*?", "", cleaned).strip() + cleaned = re.sub(r"^Thinking Process:[\s\S]*?(?=\{|\[)", "", cleaned).strip() + cleaned = re.sub(r"^```(?:json)?\s*\n?", "", cleaned, flags=re.IGNORECASE) + cleaned = re.sub(r"\n?```\s*$", "", cleaned) + return cleaned.strip() + + +def _repair_json(text: str) -> str: + repaired = text.strip() + repaired = re.sub(r",(\s*[}\]])", r"\1", repaired) + repaired = re.sub(r"\bNone\b", "null", repaired) + repaired = re.sub(r"\bTrue\b", "true", repaired) + repaired = re.sub(r"\bFalse\b", "false", repaired) + return repaired + + +def _balanced_json_objects(text: str) -> list[str]: + objects: list[str] = [] + start = None + depth = 0 + in_string = False + escape = False + + for index, char in enumerate(text or ""): + if in_string: + if escape: + escape = False + elif char == "\\": + escape = True + elif char == '"': + in_string = False + continue + + if char == '"': + in_string = True + elif char == "{": + if depth == 0: + start = index + depth += 1 + elif char == "}" and depth: + depth -= 1 + if depth == 0 and start is not None: + objects.append(text[start:index + 1]) + start = None + + return sorted(objects, key=len, reverse=True) + + +def _extract_balanced_array_after_key(text: str, key: str) -> list[Any] | None: + match = re.search(rf'"{re.escape(key)}"\s*:\s*\[', text) + if not match: + return None + + start = match.end() - 1 + depth = 0 + in_string = False + escape = False + for index in range(start, len(text)): + char = text[index] + if in_string: + if escape: + escape = False + elif char == "\\": + escape = True + elif char == '"': + in_string = False + continue + + if char == '"': + in_string = True + elif char == "[": + depth += 1 + elif char == "]" and depth: + depth -= 1 + if depth == 0: + try: + parsed = json.loads(_repair_json(text[start:index + 1])) + return parsed if isinstance(parsed, list) else None + except json.JSONDecodeError: + return None + + return None + + +def _parse_payload_lenient(text: str) -> dict[str, Any] | None: + cleaned = _cleanup_model_json(text) + candidates = [cleaned] + candidates.extend(_cleanup_model_json(block) for block in re.findall(r"```(?:json)?\s*([\s\S]*?)```", text or "", flags=re.IGNORECASE)) + candidates.extend(_balanced_json_objects(cleaned)) + + start = cleaned.find("{") + end = cleaned.rfind("}") + if start >= 0 and end > start: + candidates.append(cleaned[start:end + 1]) + + for candidate in candidates: + candidate = candidate.strip() + if not candidate: + continue + try: + parsed = json.loads(_repair_json(candidate)) + if isinstance(parsed, dict): + return parsed + except json.JSONDecodeError: + pass + + loose: dict[str, Any] = {} + entities = _extract_balanced_array_after_key(cleaned, "entities") + edges = _extract_balanced_array_after_key(cleaned, "edges") + if entities is not None: + loose["entities"] = entities + if edges is not None: + loose["edges"] = edges + return loose or None + + +class GraphExtractor: + """LLM-backed entity and relation extraction constrained by ontology.""" + + def __init__( + self, + api_key: str | None = None, + base_url: str | None = None, + model_name: str | None = None, + ) -> None: + self.api_key = api_key or settings.llm_api_key + self.base_url = base_url or settings.llm_base_url + self.model_name = model_name or settings.llm_model_name + + if not self.api_key: + raise ValueError("LLM_API_KEY 未配置") + + from openai import OpenAI + + self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) + + def extract(self, text: str, ontology: dict[str, Any] | None) -> dict[str, list[dict[str, Any]]]: + if not text or not text.strip(): + return {"entities": [], "edges": []} + + ontology = ontology or {"entity_types": [], "edge_types": []} + entity_types = ontology.get("entity_types", []) + edge_types = ontology.get("edge_types", []) + + entity_type_names = [item.get("name", "Entity") for item in entity_types] + edge_type_names = [item.get("name", "RELATED_TO") for item in edge_types] + + messages = [ + { + "role": "system", + "content": ( + "You are an information extraction engine for a local temporal knowledge graph. " + "Return strict JSON with keys entities and edges. " + "Only use ontology entity types and edge types provided by the user. " + "Do not invent unsupported types. Keep summaries concise." + ), + }, + { + "role": "user", + "content": ( + "Extract graph updates from the text below.\n\n" + "Ontology:\n" + f"{ontology}\n\n" + "JSON schema:\n" + "{\n" + ' "entities": [\n' + ' {"name": "entity name", "type": "one of ontology entity types", ' + '"summary": "short summary", "attributes": {"attr": "value"}}\n' + " ],\n" + ' "edges": [\n' + ' {"name": "one of ontology edge types", "source": "entity name", ' + '"target": "entity name", "fact": "atomic factual sentence", ' + '"attributes": {"attr": "value"}, "valid_at": "optional RFC3339 time"}\n' + " ]\n" + "}\n\n" + "Rules:\n" + f"- Allowed entity types: {entity_type_names or ['Entity']}\n" + f"- Allowed edge types: {edge_type_names or ['RELATED_TO']}\n" + "- Use exact entity names from the text when possible.\n" + "- Omit any item you are not reasonably confident about.\n" + "- Every edge source and target must reference an entity name.\n\n" + "- If a fact includes an explicit real-world start time, put it in valid_at. " + "Otherwise omit valid_at and the episode created_at will be used.\n\n" + "Text:\n" + f"{text}" + ), + }, + ] + + payload: dict[str, Any] | None = None + last_content = "" + max_retries = max(0, settings.local_zep_extract_max_retries) + for attempt in range(max_retries + 1): + attempt_messages = list(messages) + if attempt > 0: + attempt_messages.append({ + "role": "user", + "content": ( + "Retry: return only compact valid JSON with top-level keys entities and edges. " + "No markdown, no comments, no thinking text. Keep at most 20 entities and 20 edges." + ), + }) + + try: + with main_llm_slot(): + response = self.client.chat.completions.create( + model=self.model_name, + messages=attempt_messages, + temperature=0.1, + max_tokens=settings.local_zep_extract_max_output_tokens, + response_format={"type": "json_object"}, + ) + last_content = response.choices[0].message.content or "{}" + payload = _parse_payload_lenient(last_content) + if payload is not None: + if attempt > 0: + logger.info("Graph extraction JSON recovered after retry %s/%s", attempt, max_retries) + break + logger.warning( + "Graph extraction returned invalid JSON on attempt %s/%s: %s", + attempt + 1, + max_retries + 1, + _cleanup_model_json(last_content)[:1000], + ) + except Exception as exc: + logger.warning( + "Graph extraction LLM call failed on attempt %s/%s: %s", + attempt + 1, + max_retries + 1, + str(exc)[:1000], + ) + + if attempt < max_retries: + time.sleep(min(2.0, 0.5 * (attempt + 1))) + + if payload is None: + logger.error( + "Graph extraction failed after retries; skipping text chunk. Last response: %s", + _cleanup_model_json(last_content)[:2000], + ) + return {"entities": [], "edges": []} + + entities = [] + for raw_entity in payload.get("entities", []): + if not isinstance(raw_entity, dict): + continue + name = _clean_string(raw_entity.get("name")) + entity_type = _normalize_type_name(raw_entity.get("type")) + if not name: + continue + if entity_types and entity_type not in entity_type_names: + continue + entities.append( + { + "name": name, + "type": entity_type, + "summary": _clean_string(raw_entity.get("summary")), + "attributes": _coerce_mapping(raw_entity.get("attributes")), + } + ) + + edges = [] + for raw_edge in payload.get("edges", []): + if not isinstance(raw_edge, dict): + continue + name = _normalize_type_name(raw_edge.get("name")) + source = _clean_string(raw_edge.get("source")) + target = _clean_string(raw_edge.get("target")) + fact = _clean_string(raw_edge.get("fact")) + if not name or not source or not target: + continue + if edge_types and name not in edge_type_names: + continue + if not fact: + fact = f"{source} {name} {target}" + edges.append( + { + "name": name, + "source": source, + "target": target, + "fact": fact, + "attributes": _coerce_mapping(raw_edge.get("attributes")), + "valid_at": _clean_timestamp(raw_edge.get("valid_at")), + } + ) + + return {"entities": entities, "edges": edges} diff --git a/backend/local_zep/models.py b/backend/local_zep/models.py new file mode 100644 index 0000000000..d4d4045548 --- /dev/null +++ b/backend/local_zep/models.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +class InternalServerError(RuntimeError): + """Compatibility error used by paging helpers.""" + + +@dataclass +class EpisodeData: + data: str + type: str = "text" + created_at: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + source_description: str | None = None + + +@dataclass +class EntityEdgeSourceTarget: + source: str = "Entity" + target: str = "Entity" + + +@dataclass +class GraphRecord: + graph_id: str + name: str + description: str = "" + created_at: str | None = None + + +@dataclass +class GraphNode: + uuid_: str + graph_id: str + name: str + labels: list[str] = field(default_factory=list) + summary: str = "" + attributes: dict[str, Any] = field(default_factory=dict) + created_at: str | None = None + score: float | None = None + relevance: float | None = None + + @property + def uuid(self) -> str: + return self.uuid_ + + +@dataclass +class GraphEdge: + uuid_: str + graph_id: str + name: str + fact: str + source_node_uuid: str + target_node_uuid: str + attributes: dict[str, Any] = field(default_factory=dict) + created_at: str | None = None + valid_at: str | None = None + invalid_at: str | None = None + expired_at: str | None = None + episodes: list[str] = field(default_factory=list) + score: float | None = None + relevance: float | None = None + + @property + def uuid(self) -> str: + return self.uuid_ + + @property + def episode_ids(self) -> list[str]: + return self.episodes + + +@dataclass +class GraphEpisode: + uuid_: str + graph_id: str + data: str + type: str = "text" + processed: bool = False + created_at: str | None = None + error: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + source_description: str | None = None + role: str | None = None + role_type: str | None = None + thread_id: str | None = None + task_id: str | None = None + score: float | None = None + relevance: float | None = None + + @property + def uuid(self) -> str: + return self.uuid_ + + @property + def content(self) -> str: + return self.data + + @property + def source(self) -> str: + return self.type + + +@dataclass +class GraphSearchResults: + nodes: list[GraphNode] = field(default_factory=list) + edges: list[GraphEdge] = field(default_factory=list) + episodes: list[GraphEpisode] = field(default_factory=list) diff --git a/backend/local_zep/reranker.py b/backend/local_zep/reranker.py new file mode 100644 index 0000000000..d93a5458ef --- /dev/null +++ b/backend/local_zep/reranker.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import json +import logging +import urllib.error +import urllib.request +from typing import Any + +from .settings import settings + + +logger = logging.getLogger("mirofish.local_zep") + + +class RerankerClient: + """Small OpenAI/vLLM-friendly reranker client. + + The local graph uses this only for Zep's ``cross_encoder`` reranker. If no + reranker endpoint is configured, graph search falls back to local hybrid + ranking rather than failing the application request. + """ + + def __init__( + self, + api_key: str | None = None, + base_url: str | None = None, + model_name: str | None = None, + ) -> None: + self.api_key = api_key or settings.reranker_api_key or "local-reranker-key" + self.base_url = (base_url or settings.reranker_base_url or "").rstrip("/") + self.model_name = (model_name or settings.reranker_model_name or "").strip() + + @property + def is_configured(self) -> bool: + return bool(self.base_url and self.model_name) + + def rerank(self, query: str, documents: list[str]) -> list[float] | None: + if not self.is_configured or not query or not documents: + return None + + payload = { + "model": self.model_name, + "query": query, + "documents": documents, + "top_n": len(documents), + "return_documents": False, + } + + for url in self._candidate_urls(): + try: + response = self._post_json(url, payload) + scores = self._extract_scores(response, len(documents)) + if scores is not None: + return scores + except (OSError, urllib.error.URLError, json.JSONDecodeError, ValueError) as exc: + logger.warning("Reranker request failed for %s: %s", url, exc) + + return None + + def _candidate_urls(self) -> list[str]: + if self.base_url.endswith("/v1"): + return [f"{self.base_url}/rerank"] + return [f"{self.base_url}/v1/rerank", f"{self.base_url}/rerank"] + + def _post_json(self, url: str, payload: dict[str, Any]) -> dict[str, Any]: + body = json.dumps(payload).encode("utf-8") + request = urllib.request.Request( + url, + data=body, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + }, + method="POST", + ) + with urllib.request.urlopen(request, timeout=30) as response: + return json.loads(response.read().decode("utf-8")) + + def _extract_scores(self, response: dict[str, Any], expected_count: int) -> list[float] | None: + if isinstance(response.get("scores"), list): + scores = [float(value) for value in response["scores"]] + return scores[:expected_count] + [0.0] * max(0, expected_count - len(scores)) + + rows = response.get("results") + if not isinstance(rows, list): + rows = response.get("data") + if not isinstance(rows, list): + return None + + scores = [0.0] * expected_count + found = False + for rank, item in enumerate(rows): + if not isinstance(item, dict): + continue + index = item.get("index", item.get("document_index", item.get("documentIndex", rank))) + try: + index = int(index) + except (TypeError, ValueError): + continue + if index < 0 or index >= expected_count: + continue + score = item.get("relevance_score", item.get("relevanceScore", item.get("score", item.get("relevance")))) + try: + scores[index] = float(score) + except (TypeError, ValueError): + continue + found = True + + return scores if found else None diff --git a/backend/local_zep/settings.py b/backend/local_zep/settings.py new file mode 100644 index 0000000000..44b6b93cfc --- /dev/null +++ b/backend/local_zep/settings.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path + +_BACKEND_ROOT = Path(__file__).resolve().parents[1] +_PROJECT_ROOT = _BACKEND_ROOT.parent +_ENV_PATH = _PROJECT_ROOT / ".env" + +try: + from dotenv import load_dotenv +except ImportError: + load_dotenv = None + +if load_dotenv is not None: + if _ENV_PATH.exists(): + load_dotenv(_ENV_PATH, override=True) + else: + load_dotenv(override=True) + + +def _project_path(value: str) -> str: + path = Path(value) + if path.is_absolute(): + return str(path) + return str((_PROJECT_ROOT / path).resolve()) + + +@dataclass(frozen=True) +class LocalZepSettings: + llm_api_key: str | None = os.environ.get("LLM_API_KEY") + llm_base_url: str = os.environ.get("LLM_BASE_URL", "https://api.openai.com/v1") + llm_model_name: str = os.environ.get("LLM_MODEL_NAME", "gpt-4o-mini") + embedding_api_key: str = os.environ.get("EMBEDDING_API_KEY", "local-embedding-key") + embedding_base_url: str | None = os.environ.get("EMBEDDING_BASE_URL") + embedding_model_name: str | None = os.environ.get("EMBEDDING_MODEL_NAME") + reranker_api_key: str = os.environ.get("RERANKER_API_KEY", "local-reranker-key") + reranker_base_url: str | None = os.environ.get("RERANKER_BASE_URL") + reranker_model_name: str | None = os.environ.get("RERANKER_MODEL_NAME") + local_zep_rerank_top_k: int = int(os.environ.get("LOCAL_ZEP_RERANK_TOP_K", "50")) + local_zep_extract_max_retries: int = int(os.environ.get("LOCAL_ZEP_EXTRACT_MAX_RETRIES", "2")) + local_zep_extract_max_output_tokens: int = int(os.environ.get("LOCAL_ZEP_EXTRACT_MAX_OUTPUT_TOKENS", "2048")) + local_zep_db_path: str = _project_path( + os.environ.get( + "LOCAL_ZEP_DB_PATH", + str(_BACKEND_ROOT / "data" / "local_zep.sqlite3"), + ) + ) + + +settings = LocalZepSettings() diff --git a/backend/local_zep/store.py b/backend/local_zep/store.py new file mode 100644 index 0000000000..78b883a419 --- /dev/null +++ b/backend/local_zep/store.py @@ -0,0 +1,1642 @@ +from __future__ import annotations + +import json +import logging +import math +import os +import re +import sqlite3 +import threading +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any + +from .embeddings import EmbeddingClient, cosine_similarity +from .extraction import GraphExtractor +from .models import GraphEdge, GraphEpisode, GraphNode, GraphRecord, GraphSearchResults +from .reranker import RerankerClient +from .settings import settings + +logger = logging.getLogger("mirofish.local_zep") +_TOKEN_RE = re.compile(r"[\w\u4e00-\u9fff]+", re.UNICODE) +_CONFLICTING_EDGE_NAMES = { + "SUPPORTS": {"OPPOSES"}, + "OPPOSES": {"SUPPORTS"}, + "APPROVES": {"REJECTS", "OPPOSES"}, + "REJECTS": {"APPROVES", "SUPPORTS"}, + "LIKES": {"DISLIKES"}, + "DISLIKES": {"LIKES"}, +} + + +@dataclass +class _SearchCandidate: + kind: str + uuid: str + text: str + item: GraphNode | GraphEdge | GraphEpisode + embedding: list[float] = field(default_factory=list) + semantic_score: float = 0.0 + lexical_score: float = 0.0 + score: float = 0.0 + relevance: float | None = None + episode_count: int = 0 + distance: int | None = None + + +def _now_iso() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z") + + +def _coerce_iso(value: str | None) -> str: + value = (value or "").strip() + return value or _now_iso() + + +def _normalize_name(value: str) -> str: + return " ".join((value or "").strip().lower().split()) + + +def _normalize_fact(value: str) -> str: + return re.sub(r"\s+", " ", (value or "").strip().lower().rstrip(".")) + + +def _primary_label(labels: list[str]) -> str: + for label in labels: + if label not in {"Entity", "Node"}: + return label + return "Entity" + + +def _json_dumps(value: Any) -> str: + return json.dumps(value or {}, ensure_ascii=False, sort_keys=True) + + +def _json_loads(value: str | None, default: Any) -> Any: + if not value: + return default + try: + return json.loads(value) + except json.JSONDecodeError: + return default + + +def _tokenize(text: str) -> list[str]: + return _TOKEN_RE.findall((text or "").lower()) + + +def _camel_case(value: str) -> str: + parts = value.split("_") + return parts[0] + "".join(part[:1].upper() + part[1:] for part in parts[1:]) + + +def _get_value(source: Any, key: str, default: Any = None) -> Any: + if source is None: + return default + + keys = [key, _camel_case(key)] + for candidate in keys: + if isinstance(source, dict) and candidate in source: + return source[candidate] + if hasattr(source, candidate): + return getattr(source, candidate) + + return default + + +def _as_list(value: Any) -> list[Any]: + if value is None: + return [] + if isinstance(value, list): + return value + if isinstance(value, tuple) or isinstance(value, set): + return list(value) + return [value] + + +def _bm25_scores(query: str, documents: list[str]) -> list[float]: + query_terms = _tokenize(query[:400]) + if not query_terms or not documents: + return [0.0] * len(documents) + + tokenized_docs = [_tokenize(document) for document in documents] + doc_count = len(tokenized_docs) + avg_len = sum(len(tokens) for tokens in tokenized_docs) / max(doc_count, 1) + if avg_len <= 0: + avg_len = 1.0 + + document_frequency: dict[str, int] = {} + for tokens in tokenized_docs: + for token in set(tokens): + document_frequency[token] = document_frequency.get(token, 0) + 1 + + scores: list[float] = [] + k1 = 1.5 + b = 0.75 + for tokens in tokenized_docs: + term_counts: dict[str, int] = {} + for token in tokens: + term_counts[token] = term_counts.get(token, 0) + 1 + + score = 0.0 + doc_len = max(len(tokens), 1) + for token in query_terms: + tf = term_counts.get(token, 0) + if tf <= 0: + continue + df = document_frequency.get(token, 0) + idf = math.log(1.0 + (doc_count - df + 0.5) / (df + 0.5)) + denominator = tf + k1 * (1.0 - b + b * doc_len / avg_len) + score += idf * (tf * (k1 + 1.0)) / denominator + + if query.lower().strip() and query.lower().strip() in (documents[len(scores)] or "").lower(): + score += 1.5 + scores.append(score) + + return scores + + +def _rank_positions(candidates: list[_SearchCandidate], attr: str) -> dict[str, int]: + ranked = sorted(candidates, key=lambda candidate: getattr(candidate, attr), reverse=True) + return { + candidate.uuid: rank + for rank, candidate in enumerate(ranked, start=1) + if getattr(candidate, attr) > 0 + } + + +def _matches_labels(labels: list[str], include: list[str], exclude: list[str]) -> bool: + label_set = set(labels) + if include and not label_set.intersection(include): + return False + if exclude and label_set.intersection(exclude): + return False + return True + + +def _compare_value(value: Any, operator: str, expected: Any = None) -> bool: + operator = (operator or "=").upper() + if operator == "IS NULL": + return value is None + if operator == "IS NOT NULL": + return value is not None + + if value is None: + return False + + left = value + right = expected + try: + left = float(value) + right = float(expected) + except (TypeError, ValueError): + left = str(value) + right = str(expected) + + if operator == "=": + return left == right + if operator == "<>": + return left != right + if operator == ">": + return left > right + if operator == "<": + return left < right + if operator == ">=": + return left >= right + if operator == "<=": + return left <= right + return True + + +class LocalZepStore: + def __init__(self, db_path: str | None = None) -> None: + self.db_path = db_path or settings.local_zep_db_path + self._lock = threading.RLock() + self._embedding_client: EmbeddingClient | None = None + self._extractor: GraphExtractor | None = None + self._reranker_client: RerankerClient | None = None + + db_dir = os.path.dirname(self.db_path) + if db_dir: + os.makedirs(db_dir, exist_ok=True) + + self._ensure_schema() + + def _connect(self) -> sqlite3.Connection: + conn = sqlite3.connect(self.db_path, timeout=30, check_same_thread=False) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA foreign_keys = ON") + return conn + + def _ensure_schema(self) -> None: + with self._connect() as conn: + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS graphs ( + graph_id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT DEFAULT '', + ontology_json TEXT DEFAULT '{}', + created_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS episodes ( + uuid TEXT PRIMARY KEY, + graph_id TEXT NOT NULL, + data TEXT NOT NULL, + type TEXT NOT NULL, + processed INTEGER NOT NULL DEFAULT 0, + error TEXT, + metadata_json TEXT DEFAULT '{}', + source_description TEXT, + role TEXT, + role_type TEXT, + thread_id TEXT, + task_id TEXT, + created_at TEXT NOT NULL, + FOREIGN KEY(graph_id) REFERENCES graphs(graph_id) ON DELETE CASCADE + ); + + CREATE TABLE IF NOT EXISTS nodes ( + uuid TEXT PRIMARY KEY, + graph_id TEXT NOT NULL, + name TEXT NOT NULL, + normalized_name TEXT NOT NULL, + primary_label TEXT NOT NULL, + labels_json TEXT NOT NULL, + summary TEXT DEFAULT '', + attributes_json TEXT DEFAULT '{}', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY(graph_id) REFERENCES graphs(graph_id) ON DELETE CASCADE + ); + + CREATE UNIQUE INDEX IF NOT EXISTS idx_nodes_identity + ON nodes(graph_id, normalized_name, primary_label); + + CREATE TABLE IF NOT EXISTS edges ( + uuid TEXT PRIMARY KEY, + graph_id TEXT NOT NULL, + name TEXT NOT NULL, + fact TEXT NOT NULL, + source_node_uuid TEXT NOT NULL, + target_node_uuid TEXT NOT NULL, + attributes_json TEXT DEFAULT '{}', + created_at TEXT NOT NULL, + valid_at TEXT, + invalid_at TEXT, + expired_at TEXT, + FOREIGN KEY(graph_id) REFERENCES graphs(graph_id) ON DELETE CASCADE, + FOREIGN KEY(source_node_uuid) REFERENCES nodes(uuid) ON DELETE CASCADE, + FOREIGN KEY(target_node_uuid) REFERENCES nodes(uuid) ON DELETE CASCADE + ); + + CREATE UNIQUE INDEX IF NOT EXISTS idx_edges_identity + ON edges(graph_id, source_node_uuid, target_node_uuid, name, fact); + + CREATE TABLE IF NOT EXISTS edge_episodes ( + edge_uuid TEXT NOT NULL, + episode_uuid TEXT NOT NULL, + PRIMARY KEY(edge_uuid, episode_uuid), + FOREIGN KEY(edge_uuid) REFERENCES edges(uuid) ON DELETE CASCADE, + FOREIGN KEY(episode_uuid) REFERENCES episodes(uuid) ON DELETE CASCADE + ); + + CREATE TABLE IF NOT EXISTS node_embeddings ( + node_uuid TEXT PRIMARY KEY, + embedding_json TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY(node_uuid) REFERENCES nodes(uuid) ON DELETE CASCADE + ); + + CREATE TABLE IF NOT EXISTS edge_embeddings ( + edge_uuid TEXT PRIMARY KEY, + embedding_json TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY(edge_uuid) REFERENCES edges(uuid) ON DELETE CASCADE + ); + + CREATE TABLE IF NOT EXISTS episode_embeddings ( + episode_uuid TEXT PRIMARY KEY, + embedding_json TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY(episode_uuid) REFERENCES episodes(uuid) ON DELETE CASCADE + ); + """ + ) + self._ensure_column(conn, "episodes", "metadata_json", "TEXT DEFAULT '{}'") + self._ensure_column(conn, "episodes", "source_description", "TEXT") + self._ensure_column(conn, "episodes", "role", "TEXT") + self._ensure_column(conn, "episodes", "role_type", "TEXT") + self._ensure_column(conn, "episodes", "thread_id", "TEXT") + self._ensure_column(conn, "episodes", "task_id", "TEXT") + + def _ensure_column(self, conn: sqlite3.Connection, table: str, column: str, definition: str) -> None: + rows = conn.execute(f"PRAGMA table_info({table})").fetchall() + existing = {row["name"] for row in rows} + if column not in existing: + conn.execute(f"ALTER TABLE {table} ADD COLUMN {column} {definition}") + + def _get_embedding_client(self) -> EmbeddingClient: + if self._embedding_client is None: + self._embedding_client = EmbeddingClient() + return self._embedding_client + + def _get_extractor(self) -> GraphExtractor: + if self._extractor is None: + self._extractor = GraphExtractor() + return self._extractor + + def _get_reranker_client(self) -> RerankerClient: + if self._reranker_client is None: + self._reranker_client = RerankerClient() + return self._reranker_client + + def create_graph(self, graph_id: str, name: str, description: str = "") -> GraphRecord: + created_at = _now_iso() + with self._lock, self._connect() as conn: + conn.execute( + """ + INSERT INTO graphs(graph_id, name, description, created_at) + VALUES (?, ?, ?, ?) + ON CONFLICT(graph_id) DO UPDATE SET + name = excluded.name, + description = excluded.description + """, + (graph_id, name, description, created_at), + ) + return GraphRecord(graph_id=graph_id, name=name, description=description, created_at=created_at) + + def delete_graph(self, graph_id: str) -> None: + with self._lock, self._connect() as conn: + conn.execute("DELETE FROM graphs WHERE graph_id = ?", (graph_id,)) + + def set_ontology(self, graph_id: str, ontology: dict[str, Any]) -> None: + with self._lock, self._connect() as conn: + conn.execute( + "UPDATE graphs SET ontology_json = ? WHERE graph_id = ?", + (_json_dumps(ontology or {}), graph_id), + ) + + def get_ontology(self, graph_id: str) -> dict[str, Any]: + with self._connect() as conn: + row = conn.execute( + "SELECT ontology_json FROM graphs WHERE graph_id = ?", + (graph_id,), + ).fetchone() + return _json_loads(row["ontology_json"], {}) if row else {} + + def get_graph(self, graph_id: str) -> GraphRecord | None: + with self._connect() as conn: + row = conn.execute("SELECT * FROM graphs WHERE graph_id = ?", (graph_id,)).fetchone() + if not row: + return None + return GraphRecord( + graph_id=row["graph_id"], + name=row["name"], + description=row["description"] or "", + created_at=row["created_at"], + ) + + def add( + self, + graph_id: str, + data: str, + type_: str = "text", + created_at: str | None = None, + metadata: dict[str, Any] | None = None, + source_description: str | None = None, + ) -> GraphEpisode: + episode_created_at = _coerce_iso(created_at) + episode = GraphEpisode( + uuid_=uuid.uuid4().hex, + graph_id=graph_id, + data=data, + type=type_, + processed=False, + created_at=episode_created_at, + metadata=metadata or {}, + source_description=source_description, + ) + + with self._lock, self._connect() as conn: + conn.execute( + """ + INSERT INTO episodes( + uuid, graph_id, data, type, processed, error, metadata_json, + source_description, role, role_type, thread_id, task_id, created_at + ) + VALUES (?, ?, ?, ?, 0, NULL, ?, ?, NULL, NULL, NULL, NULL, ?) + """, + ( + episode.uuid_, + graph_id, + data, + type_, + _json_dumps(metadata or {}), + source_description, + episode.created_at, + ), + ) + + try: + ontology = self.get_ontology(graph_id) + extracted = self._get_extractor().extract(data, ontology) + touched_nodes, touched_edges = self._apply_extraction( + graph_id, + episode.uuid_, + extracted, + ontology, + episode.created_at or _now_iso(), + ) + episode.processed = True + with self._lock, self._connect() as conn: + conn.execute( + "UPDATE episodes SET processed = 1, error = NULL WHERE uuid = ?", + (episode.uuid_,), + ) + self._refresh_node_embeddings(graph_id, touched_nodes) + self._refresh_edge_embeddings(graph_id, touched_edges) + self._refresh_episode_embeddings(graph_id, {episode.uuid_}) + except Exception as exc: + logger.exception("Local graph episode processing failed: %s", exc) + with self._lock, self._connect() as conn: + conn.execute( + "UPDATE episodes SET processed = 0, error = ? WHERE uuid = ?", + (str(exc), episode.uuid_), + ) + episode.error = str(exc) + raise + + return self.get_episode(episode.uuid_) or episode + + def add_batch(self, graph_id: str, episodes: list[Any]) -> list[GraphEpisode]: + results = [] + for episode in episodes: + data = getattr(episode, "data", "") if episode is not None else "" + type_ = getattr(episode, "type", "text") if episode is not None else "text" + created_at = getattr(episode, "created_at", None) if episode is not None else None + metadata = getattr(episode, "metadata", None) if episode is not None else None + source_description = getattr(episode, "source_description", None) if episode is not None else None + results.append( + self.add( + graph_id=graph_id, + data=data, + type_=type_, + created_at=created_at, + metadata=metadata, + source_description=source_description, + ) + ) + return results + + def get_episode(self, uuid_: str) -> GraphEpisode | None: + with self._connect() as conn: + row = conn.execute( + "SELECT * FROM episodes WHERE uuid = ?", + (uuid_,), + ).fetchone() + return self._row_to_episode(row) if row else None + + def get_episodes_by_graph_id(self, graph_id: str, lastn: int | None = None): + query = "SELECT * FROM episodes WHERE graph_id = ? ORDER BY created_at DESC, uuid DESC" + params: list[Any] = [graph_id] + if lastn: + query += " LIMIT ?" + params.append(lastn) + with self._connect() as conn: + rows = conn.execute(query, params).fetchall() + return type("EpisodeList", (), {"episodes": [self._row_to_episode(row) for row in rows]})() + + def get_nodes_page(self, graph_id: str, limit: int = 100, uuid_cursor: str | None = None) -> list[GraphNode]: + query = "SELECT * FROM nodes WHERE graph_id = ?" + params: list[Any] = [graph_id] + if uuid_cursor: + query += " AND uuid > ?" + params.append(uuid_cursor) + query += " ORDER BY uuid LIMIT ?" + params.append(limit) + + with self._connect() as conn: + rows = conn.execute(query, params).fetchall() + return [self._row_to_node(row) for row in rows] + + def get_edges_page(self, graph_id: str, limit: int = 100, uuid_cursor: str | None = None) -> list[GraphEdge]: + query = "SELECT * FROM edges WHERE graph_id = ?" + params: list[Any] = [graph_id] + if uuid_cursor: + query += " AND uuid > ?" + params.append(uuid_cursor) + query += " ORDER BY uuid LIMIT ?" + params.append(limit) + + with self._connect() as conn: + rows = conn.execute(query, params).fetchall() + edge_ids = [row["uuid"] for row in rows] + episode_map = self._load_edge_episode_map(conn, edge_ids) + return [self._row_to_edge(row, episode_map.get(row["uuid"], [])) for row in rows] + + def get_node(self, uuid_: str) -> GraphNode | None: + with self._connect() as conn: + row = conn.execute("SELECT * FROM nodes WHERE uuid = ?", (uuid_,)).fetchone() + return self._row_to_node(row) if row else None + + def get_edge(self, uuid_: str) -> GraphEdge | None: + with self._connect() as conn: + row = conn.execute("SELECT * FROM edges WHERE uuid = ?", (uuid_,)).fetchone() + if not row: + return None + episode_map = self._load_edge_episode_map(conn, [uuid_]) + return self._row_to_edge(row, episode_map.get(uuid_, [])) + + def get_entity_edges(self, node_uuid: str) -> list[GraphEdge]: + with self._connect() as conn: + rows = conn.execute( + """ + SELECT * FROM edges + WHERE source_node_uuid = ? OR target_node_uuid = ? + ORDER BY created_at DESC, uuid + """, + (node_uuid, node_uuid), + ).fetchall() + edge_ids = [row["uuid"] for row in rows] + episode_map = self._load_edge_episode_map(conn, edge_ids) + return [self._row_to_edge(row, episode_map.get(row["uuid"], [])) for row in rows] + + def search( + self, + graph_id: str, + query: str, + limit: int = 10, + scope: str = "edges", + reranker: str = "rrf", + mmr_lambda: float | None = None, + center_node_uuid: str | None = None, + search_filters: Any = None, + bfs_origin_node_uuids: list[str] | None = None, + ) -> GraphSearchResults: + results = GraphSearchResults() + query = (query or "").strip()[:400] + if not query: + return results + + query_embedding: list[float] = [] + try: + query_embedding = self._get_embedding_client().embed_text(query) + except Exception as exc: + logger.warning("Embedding lookup failed, falling back to lexical search: %s", exc) + + with self._connect() as conn: + candidates = self._build_search_candidates( + conn=conn, + graph_id=graph_id, + query=query, + query_embedding=query_embedding, + scope=(scope or "edges").lower(), + search_filters=search_filters, + ) + if not candidates: + return results + + if bfs_origin_node_uuids: + distances = self._graph_distances(conn, graph_id, bfs_origin_node_uuids) + self._apply_distances(conn, candidates, distances) + + self._rank_candidates( + conn=conn, + graph_id=graph_id, + query=query, + query_embedding=query_embedding, + candidates=candidates, + reranker=reranker or "rrf", + mmr_lambda=mmr_lambda, + center_node_uuid=center_node_uuid, + ) + + ranked = sorted(candidates, key=lambda candidate: candidate.score, reverse=True)[: max(limit, 0)] + results.edges = [ + self._scored_item(candidate) + for candidate in ranked + if candidate.kind == "edge" + ] + results.nodes = [ + self._scored_item(candidate) + for candidate in ranked + if candidate.kind == "node" + ] + results.episodes = [ + self._scored_item(candidate) + for candidate in ranked + if candidate.kind == "episode" + ] + + return results + + def _build_search_candidates( + self, + conn: sqlite3.Connection, + graph_id: str, + query: str, + query_embedding: list[float], + scope: str, + search_filters: Any, + ) -> list[_SearchCandidate]: + candidates: list[_SearchCandidate] = [] + + if scope in {"edges", "both"}: + rows = conn.execute( + """ + SELECT + e.*, + ee.embedding_json, + src.name AS source_name, + src.labels_json AS source_labels_json, + dst.name AS target_name, + dst.labels_json AS target_labels_json + FROM edges e + JOIN nodes src ON src.uuid = e.source_node_uuid + JOIN nodes dst ON dst.uuid = e.target_node_uuid + LEFT JOIN edge_embeddings ee ON ee.edge_uuid = e.uuid + WHERE e.graph_id = ? + """, + (graph_id,), + ).fetchall() + edge_ids = [row["uuid"] for row in rows] + episode_map = self._load_edge_episode_map(conn, edge_ids) + for row in rows: + if not self._edge_matches_filters(row, search_filters): + continue + edge = self._row_to_edge(row, episode_map.get(row["uuid"], [])) + if not self._episode_metadata_matches_any(conn, edge.episodes, search_filters): + continue + text = " ".join(filter(None, [row["name"], row["fact"], row["source_name"], row["target_name"]])) + candidates.append( + _SearchCandidate( + kind="edge", + uuid=row["uuid"], + text=text, + item=edge, + embedding=_json_loads(row["embedding_json"], []), + episode_count=len(edge.episodes), + ) + ) + + if scope in {"nodes", "both"}: + rows = conn.execute( + """ + SELECT n.*, ne.embedding_json + FROM nodes n + LEFT JOIN node_embeddings ne ON ne.node_uuid = n.uuid + WHERE n.graph_id = ? + """, + (graph_id,), + ).fetchall() + episode_counts = self._node_episode_counts(conn, graph_id) + node_episode_ids = self._node_episode_ids(conn, graph_id) + for row in rows: + if not self._node_matches_filters(row, search_filters): + continue + if not self._episode_metadata_matches_any(conn, node_episode_ids.get(row["uuid"], []), search_filters): + continue + labels = _json_loads(row["labels_json"], []) + attributes = _json_loads(row["attributes_json"], {}) + text = " ".join( + filter( + None, + [ + row["name"], + row["summary"], + " ".join(labels), + json.dumps(attributes, ensure_ascii=False), + ], + ) + ) + candidates.append( + _SearchCandidate( + kind="node", + uuid=row["uuid"], + text=text, + item=self._row_to_node(row), + embedding=_json_loads(row["embedding_json"], []), + episode_count=episode_counts.get(row["uuid"], 0), + ) + ) + + if scope == "episodes": + rows = conn.execute( + """ + SELECT ep.*, ee.embedding_json + FROM episodes ep + LEFT JOIN episode_embeddings ee ON ee.episode_uuid = ep.uuid + WHERE ep.graph_id = ? + """, + (graph_id,), + ).fetchall() + for row in rows: + if not self._episode_metadata_matches(_json_loads(row["metadata_json"], {}), search_filters): + continue + candidates.append( + _SearchCandidate( + kind="episode", + uuid=row["uuid"], + text=row["data"] or "", + item=self._row_to_episode(row), + embedding=_json_loads(row["embedding_json"], []), + episode_count=1, + ) + ) + + lexical_scores = _bm25_scores(query, [candidate.text for candidate in candidates]) + for candidate, lexical_score in zip(candidates, lexical_scores): + candidate.lexical_score = lexical_score + if query_embedding and candidate.embedding: + candidate.semantic_score = cosine_similarity(query_embedding, candidate.embedding) + + return candidates + + def _rank_candidates( + self, + conn: sqlite3.Connection, + graph_id: str, + query: str, + query_embedding: list[float], + candidates: list[_SearchCandidate], + reranker: str, + mmr_lambda: float | None, + center_node_uuid: str | None, + ) -> None: + reranker = (reranker or "rrf").lower() + if reranker == "cross_encoder": + self._rank_rrf(candidates) + pool = sorted(candidates, key=lambda candidate: candidate.score, reverse=True)[: settings.local_zep_rerank_top_k] + scores = self._get_reranker_client().rerank(query, [candidate.text for candidate in pool]) + if scores is not None: + for candidate, score in zip(pool, scores): + candidate.score = float(score) + self._distance_boost(candidate) + candidate.relevance = max(0.0, min(1.0, float(score))) + pool_ids = {candidate.uuid for candidate in pool} + for candidate in candidates: + if candidate.uuid not in pool_ids: + candidate.score *= 0.01 + return + + logger.info("Cross-encoder reranker is not configured; using local RRF fallback") + return + + if reranker == "mmr": + self._rank_mmr(candidates, query_embedding, mmr_lambda if mmr_lambda is not None else 0.5) + return + + if reranker == "episode_mentions": + self._rank_rrf(candidates) + for candidate in candidates: + candidate.score += math.log1p(candidate.episode_count) * 0.1 + return + + if reranker == "node_distance" and center_node_uuid: + distances = self._graph_distances(conn, graph_id, [center_node_uuid]) + self._apply_distances(conn, candidates, distances) + self._rank_node_distance(candidates) + return + + self._rank_rrf(candidates) + + def _rank_rrf(self, candidates: list[_SearchCandidate]) -> None: + semantic_ranks = _rank_positions(candidates, "semantic_score") + lexical_ranks = _rank_positions(candidates, "lexical_score") + for candidate in candidates: + score = 0.0 + if candidate.uuid in semantic_ranks: + score += 1.0 / (60.0 + semantic_ranks[candidate.uuid]) + if candidate.uuid in lexical_ranks: + score += 1.0 / (60.0 + lexical_ranks[candidate.uuid]) + candidate.score = score + self._distance_boost(candidate) + + def _rank_mmr(self, candidates: list[_SearchCandidate], query_embedding: list[float], lambda_value: float) -> None: + lambda_value = max(0.0, min(1.0, lambda_value)) + remaining = candidates[:] + selected: list[_SearchCandidate] = [] + + while remaining: + best: _SearchCandidate | None = None + best_score = -float("inf") + for candidate in remaining: + relevance = candidate.semantic_score + (candidate.lexical_score * 0.05) + diversity_penalty = 0.0 + if query_embedding and candidate.embedding and selected: + similarities = [ + cosine_similarity(candidate.embedding, selected_candidate.embedding) + for selected_candidate in selected + if selected_candidate.embedding + ] + diversity_penalty = max(similarities) if similarities else 0.0 + mmr_score = lambda_value * relevance - (1.0 - lambda_value) * diversity_penalty + if mmr_score > best_score: + best_score = mmr_score + best = candidate + + if best is None: + break + remaining.remove(best) + selected.append(best) + best.score = best_score + self._distance_boost(best) + + rank_count = len(selected) + for rank, candidate in enumerate(selected): + candidate.score += (rank_count - rank) * 1e-6 + + def _rank_node_distance(self, candidates: list[_SearchCandidate]) -> None: + self._rank_rrf(candidates) + for candidate in candidates: + if candidate.distance is None: + candidate.score *= 0.01 + else: + candidate.score += 1.0 / (1.0 + candidate.distance) + + def _distance_boost(self, candidate: _SearchCandidate) -> float: + if candidate.distance is None: + return 0.0 + return 0.15 / (1.0 + candidate.distance) + + def _scored_item(self, candidate: _SearchCandidate): + candidate.item.score = candidate.score + candidate.item.relevance = candidate.relevance + return candidate.item + + def _node_matches_filters(self, row: sqlite3.Row, search_filters: Any) -> bool: + if not search_filters: + return True + + labels = _json_loads(row["labels_json"], []) + include_labels = [str(value) for value in _as_list(_get_value(search_filters, "node_labels"))] + exclude_labels = [str(value) for value in _as_list(_get_value(search_filters, "exclude_node_labels"))] + if not _matches_labels(labels, include_labels, exclude_labels): + return False + + attributes = _json_loads(row["attributes_json"], {}) + return self._properties_match(attributes, search_filters) + + def _edge_matches_filters(self, row: sqlite3.Row, search_filters: Any) -> bool: + if not search_filters: + return True + + include_edge_types = [str(value) for value in _as_list(_get_value(search_filters, "edge_types"))] + exclude_edge_types = [str(value) for value in _as_list(_get_value(search_filters, "exclude_edge_types"))] + if include_edge_types and row["name"] not in include_edge_types: + return False + if exclude_edge_types and row["name"] in exclude_edge_types: + return False + + source_labels = _json_loads(row["source_labels_json"], []) + target_labels = _json_loads(row["target_labels_json"], []) + labels = sorted({*source_labels, *target_labels}) + include_labels = [str(value) for value in _as_list(_get_value(search_filters, "node_labels"))] + exclude_labels = [str(value) for value in _as_list(_get_value(search_filters, "exclude_node_labels"))] + if not _matches_labels(labels, include_labels, exclude_labels): + return False + + attributes = _json_loads(row["attributes_json"], {}) + if not self._properties_match(attributes, search_filters): + return False + + for field_name in ("created_at", "valid_at", "invalid_at", "expired_at"): + if not self._date_filters_match(row[field_name], _get_value(search_filters, field_name)): + return False + + return True + + def _properties_match(self, attributes: dict[str, Any], search_filters: Any) -> bool: + for prop_filter in _as_list(_get_value(search_filters, "property_filters")): + property_name = _get_value(prop_filter, "property_name") + if not property_name: + continue + operator = _get_value(prop_filter, "comparison_operator", "=") + expected = _get_value(prop_filter, "property_value") + if not _compare_value(attributes.get(str(property_name)), str(operator), expected): + return False + return True + + def _episode_metadata_matches_any( + self, + conn: sqlite3.Connection, + episode_ids: list[str], + search_filters: Any, + ) -> bool: + metadata_filter = _get_value(search_filters, "episode_metadata_filters") + if not metadata_filter: + return True + if not episode_ids: + return False + + rows = conn.execute( + f""" + SELECT metadata_json + FROM episodes + WHERE uuid IN ({",".join("?" for _ in episode_ids)}) + """, + episode_ids, + ).fetchall() + return any( + self._episode_metadata_matches(_json_loads(row["metadata_json"], {}), search_filters) + for row in rows + ) + + def _episode_metadata_matches(self, metadata: dict[str, Any], search_filters: Any) -> bool: + metadata_filter = _get_value(search_filters, "episode_metadata_filters") + if not metadata_filter: + return True + return self._metadata_group_matches(metadata, metadata_filter) + + def _metadata_group_matches(self, metadata: dict[str, Any], group: Any) -> bool: + group_type = str(_get_value(group, "type", "and")).lower() + checks: list[bool] = [] + + for metadata_filter in _as_list(_get_value(group, "filters")): + property_name = _get_value(metadata_filter, "property_name") + if not property_name: + continue + operator = str(_get_value(metadata_filter, "comparison_operator", "=")) + expected = _get_value(metadata_filter, "property_value") + checks.append(_compare_value(metadata.get(str(property_name)), operator, expected)) + + for nested_group in _as_list(_get_value(group, "groups")): + checks.append(self._metadata_group_matches(metadata, nested_group)) + + if not checks: + return True + if group_type == "or": + return any(checks) + return all(checks) + + def _date_filters_match(self, value: str | None, filter_groups: Any) -> bool: + if not filter_groups: + return True + + groups = _as_list(filter_groups) + if groups and not isinstance(groups[0], (list, tuple, set)): + groups = [groups] + + for group in groups: + predicates = _as_list(group) + if all( + _compare_value( + value, + str(_get_value(predicate, "comparison_operator", "=")), + _get_value(predicate, "date"), + ) + for predicate in predicates + ): + return True + return False + + def _node_episode_counts(self, conn: sqlite3.Connection, graph_id: str) -> dict[str, int]: + episode_ids_by_node = self._node_episode_ids(conn, graph_id) + return {node_uuid: len(set(episode_ids)) for node_uuid, episode_ids in episode_ids_by_node.items()} + + def _node_episode_ids(self, conn: sqlite3.Connection, graph_id: str) -> dict[str, list[str]]: + rows = conn.execute( + """ + SELECT e.source_node_uuid AS node_uuid, ee.episode_uuid + FROM edges e + JOIN edge_episodes ee ON ee.edge_uuid = e.uuid + WHERE e.graph_id = ? + UNION ALL + SELECT e.target_node_uuid AS node_uuid, ee.episode_uuid + FROM edges e + JOIN edge_episodes ee ON ee.edge_uuid = e.uuid + WHERE e.graph_id = ? + """, + (graph_id, graph_id), + ).fetchall() + episodes_by_node: dict[str, list[str]] = {} + for row in rows: + episodes_by_node.setdefault(row["node_uuid"], []).append(row["episode_uuid"]) + return episodes_by_node + + def _graph_distances( + self, + conn: sqlite3.Connection, + graph_id: str, + origin_node_uuids: list[str] | None, + ) -> dict[str, int]: + origins = [origin for origin in (origin_node_uuids or []) if origin] + if not origins: + return {} + + rows = conn.execute( + "SELECT uuid, source_node_uuid, target_node_uuid FROM edges WHERE graph_id = ?", + (graph_id,), + ).fetchall() + adjacency: dict[str, set[str]] = {} + for row in rows: + adjacency.setdefault(row["source_node_uuid"], set()).add(row["target_node_uuid"]) + adjacency.setdefault(row["target_node_uuid"], set()).add(row["source_node_uuid"]) + + placeholders = ",".join("?" for _ in origins) + seed_rows = conn.execute( + f"SELECT uuid FROM nodes WHERE graph_id = ? AND uuid IN ({placeholders})", + [graph_id, *origins], + ).fetchall() + seed_nodes = {row["uuid"] for row in seed_rows} + + episode_rows = conn.execute( + f""" + SELECT e.source_node_uuid, e.target_node_uuid + FROM edge_episodes ee + JOIN edges e ON e.uuid = ee.edge_uuid + WHERE e.graph_id = ? AND ee.episode_uuid IN ({placeholders}) + """, + [graph_id, *origins], + ).fetchall() + for row in episode_rows: + seed_nodes.add(row["source_node_uuid"]) + seed_nodes.add(row["target_node_uuid"]) + + distances: dict[str, int] = {node_uuid: 0 for node_uuid in seed_nodes} + queue = list(seed_nodes) + cursor = 0 + while cursor < len(queue): + node_uuid = queue[cursor] + cursor += 1 + for neighbor in adjacency.get(node_uuid, set()): + if neighbor in distances: + continue + distances[neighbor] = distances[node_uuid] + 1 + queue.append(neighbor) + return distances + + def _apply_distances( + self, + conn: sqlite3.Connection, + candidates: list[_SearchCandidate], + distances: dict[str, int], + ) -> None: + if not distances: + return + + for candidate in candidates: + if candidate.kind == "node": + candidate.distance = distances.get(candidate.uuid) + continue + + if candidate.kind == "edge": + edge = candidate.item + candidate.distance = min( + ( + distance + for distance in [ + distances.get(edge.source_node_uuid), + distances.get(edge.target_node_uuid), + ] + if distance is not None + ), + default=None, + ) + continue + + rows = conn.execute( + """ + SELECT e.source_node_uuid, e.target_node_uuid + FROM edge_episodes ee + JOIN edges e ON e.uuid = ee.edge_uuid + WHERE ee.episode_uuid = ? + """, + (candidate.uuid,), + ).fetchall() + episode_distances = [ + distance + for row in rows + for distance in (distances.get(row["source_node_uuid"]), distances.get(row["target_node_uuid"])) + if distance is not None + ] + candidate.distance = min(episode_distances) if episode_distances else None + + def _apply_extraction( + self, + graph_id: str, + episode_uuid: str, + extracted: dict[str, list[dict[str, Any]]], + ontology: dict[str, Any], + episode_created_at: str, + ) -> tuple[set[str], set[str]]: + touched_nodes: set[str] = set() + touched_edges: set[str] = set() + entity_lookup: dict[tuple[str, str], GraphNode] = {} + + with self._lock, self._connect() as conn: + for entity in extracted.get("entities", []): + node = self._upsert_node(conn, graph_id, entity) + entity_lookup[(_normalize_name(node.name), _primary_label(node.labels))] = node + entity_lookup[(_normalize_name(node.name), "Entity")] = node + touched_nodes.add(node.uuid_) + + for edge in extracted.get("edges", []): + source_node = self._resolve_edge_node(conn, graph_id, edge.get("source", ""), ontology, edge.get("name", ""), True, entity_lookup) + target_node = self._resolve_edge_node(conn, graph_id, edge.get("target", ""), ontology, edge.get("name", ""), False, entity_lookup) + touched_nodes.update({source_node.uuid_, target_node.uuid_}) + stored_edge = self._upsert_edge( + conn, + graph_id, + episode_uuid, + edge, + source_node, + target_node, + episode_created_at, + ) + touched_edges.add(stored_edge.uuid_) + + return touched_nodes, touched_edges + + def _upsert_node(self, conn: sqlite3.Connection, graph_id: str, entity: dict[str, Any]) -> GraphNode: + name = (entity.get("name") or "").strip() + entity_type = (entity.get("type") or "Entity").strip() or "Entity" + summary = (entity.get("summary") or "").strip() + attributes = entity.get("attributes") or {} + labels = ["Entity"] if entity_type == "Entity" else ["Entity", entity_type] + normalized_name = _normalize_name(name) + label = _primary_label(labels) + row = conn.execute( + """ + SELECT * FROM nodes + WHERE graph_id = ? AND normalized_name = ? AND primary_label = ? + """, + (graph_id, normalized_name, label), + ).fetchone() + timestamp = _now_iso() + + if row: + existing_labels = _json_loads(row["labels_json"], []) + merged_labels = sorted({*existing_labels, *labels}) + existing_attributes = _json_loads(row["attributes_json"], {}) + merged_attributes = {**existing_attributes, **attributes} + merged_summary = self._merge_summary(row["summary"], summary, merged_attributes) + conn.execute( + """ + UPDATE nodes + SET labels_json = ?, summary = ?, attributes_json = ?, updated_at = ? + WHERE uuid = ? + """, + (_json_dumps(merged_labels), merged_summary, _json_dumps(merged_attributes), timestamp, row["uuid"]), + ) + updated_row = conn.execute("SELECT * FROM nodes WHERE uuid = ?", (row["uuid"],)).fetchone() + return self._row_to_node(updated_row) + + node_uuid = uuid.uuid4().hex + summary = summary or self._fallback_summary(name, attributes) + conn.execute( + """ + INSERT INTO nodes( + uuid, graph_id, name, normalized_name, primary_label, labels_json, + summary, attributes_json, created_at, updated_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + node_uuid, + graph_id, + name, + normalized_name, + label, + _json_dumps(labels), + summary, + _json_dumps(attributes), + timestamp, + timestamp, + ), + ) + return GraphNode( + uuid_=node_uuid, + graph_id=graph_id, + name=name, + labels=labels, + summary=summary, + attributes=attributes, + created_at=timestamp, + ) + + def _resolve_edge_node( + self, + conn: sqlite3.Connection, + graph_id: str, + node_name: str, + ontology: dict[str, Any], + edge_name: str, + is_source: bool, + entity_lookup: dict[tuple[str, str], GraphNode], + ) -> GraphNode: + normalized_name = _normalize_name(node_name) + preferred_labels = self._allowed_labels_for_edge(ontology, edge_name, is_source) + + for preferred_label in preferred_labels + ["Entity"]: + existing = entity_lookup.get((normalized_name, preferred_label)) + if existing: + return existing + + row = None + if preferred_labels: + placeholders = ",".join("?" for _ in preferred_labels) + row = conn.execute( + f""" + SELECT * FROM nodes + WHERE graph_id = ? AND normalized_name = ? AND primary_label IN ({placeholders}) + ORDER BY updated_at DESC + LIMIT 1 + """, + [graph_id, normalized_name, *preferred_labels], + ).fetchone() + if row is None: + row = conn.execute( + """ + SELECT * FROM nodes + WHERE graph_id = ? AND normalized_name = ? + ORDER BY updated_at DESC + LIMIT 1 + """, + (graph_id, normalized_name), + ).fetchone() + if row: + node = self._row_to_node(row) + entity_lookup[(normalized_name, _primary_label(node.labels))] = node + entity_lookup[(normalized_name, "Entity")] = node + return node + + fallback_type = preferred_labels[0] if preferred_labels else "Entity" + node = self._upsert_node( + conn, + graph_id, + { + "name": node_name, + "type": fallback_type, + "summary": node_name, + "attributes": {}, + }, + ) + entity_lookup[(normalized_name, _primary_label(node.labels))] = node + entity_lookup[(normalized_name, "Entity")] = node + return node + + def _allowed_labels_for_edge(self, ontology: dict[str, Any], edge_name: str, is_source: bool) -> list[str]: + for edge in ontology.get("edge_types", []): + if edge.get("name") != edge_name: + continue + labels = [] + for pair in edge.get("source_targets", []): + label = pair.get("source") if is_source else pair.get("target") + if label and label != "Entity" and label not in labels: + labels.append(label) + return labels + return [] + + def _upsert_edge( + self, + conn: sqlite3.Connection, + graph_id: str, + episode_uuid: str, + edge: dict[str, Any], + source_node: GraphNode, + target_node: GraphNode, + episode_created_at: str, + ) -> GraphEdge: + name = (edge.get("name") or "RELATED_TO").strip() or "RELATED_TO" + fact = (edge.get("fact") or f"{source_node.name} {name} {target_node.name}").strip() + attributes = edge.get("attributes") or {} + learned_at = _now_iso() + valid_at = _coerce_iso(edge.get("valid_at") or episode_created_at) + row = conn.execute( + """ + SELECT * FROM edges + WHERE graph_id = ? AND source_node_uuid = ? AND target_node_uuid = ? AND name = ? AND fact = ? + """, + (graph_id, source_node.uuid_, target_node.uuid_, name, fact), + ).fetchone() + + if row: + existing_attributes = _json_loads(row["attributes_json"], {}) + merged_attributes = {**existing_attributes, **attributes} + conn.execute( + """ + UPDATE edges + SET attributes_json = ?, valid_at = COALESCE(valid_at, ?) + WHERE uuid = ? + """, + (_json_dumps(merged_attributes), valid_at, row["uuid"]), + ) + edge_uuid = row["uuid"] + else: + self._invalidate_superseded_edges( + conn=conn, + graph_id=graph_id, + source_node_uuid=source_node.uuid_, + target_node_uuid=target_node.uuid_, + edge_name=name, + new_fact=fact, + invalid_at=valid_at, + expired_at=learned_at, + ) + edge_uuid = uuid.uuid4().hex + conn.execute( + """ + INSERT INTO edges( + uuid, graph_id, name, fact, source_node_uuid, target_node_uuid, + attributes_json, created_at, valid_at, invalid_at, expired_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, NULL, NULL) + """, + ( + edge_uuid, + graph_id, + name, + fact, + source_node.uuid_, + target_node.uuid_, + _json_dumps(attributes), + learned_at, + valid_at, + ), + ) + + conn.execute( + """ + INSERT OR IGNORE INTO edge_episodes(edge_uuid, episode_uuid) + VALUES (?, ?) + """, + (edge_uuid, episode_uuid), + ) + + row = conn.execute("SELECT * FROM edges WHERE uuid = ?", (edge_uuid,)).fetchone() + return self._row_to_edge(row, [episode_uuid]) + + def _invalidate_superseded_edges( + self, + conn: sqlite3.Connection, + graph_id: str, + source_node_uuid: str, + target_node_uuid: str, + edge_name: str, + new_fact: str, + invalid_at: str, + expired_at: str, + ) -> None: + """Approximate Zep/Graphiti temporal fact invalidation. + + If a new fact uses the same source/target and either the same relation + name or an explicitly conflicting relation name, treat the old fact as + superseded unless it is the same normalized fact. This preserves history + while keeping active facts current for typical single-user workflows. + """ + conflicting_names = self._conflicting_names(edge_name) + rows = conn.execute( + f""" + SELECT uuid, fact + FROM edges + WHERE graph_id = ? + AND source_node_uuid = ? + AND target_node_uuid = ? + AND name IN ({",".join("?" for _ in conflicting_names)}) + AND invalid_at IS NULL + AND expired_at IS NULL + """, + [graph_id, source_node_uuid, target_node_uuid, *conflicting_names], + ).fetchall() + normalized_new_fact = _normalize_fact(new_fact) + superseded_ids = [ + row["uuid"] + for row in rows + if _normalize_fact(row["fact"]) != normalized_new_fact + ] + if not superseded_ids: + return + conn.execute( + f""" + UPDATE edges + SET invalid_at = ?, expired_at = ? + WHERE uuid IN ({",".join("?" for _ in superseded_ids)}) + """, + [invalid_at, expired_at, *superseded_ids], + ) + + def _conflicting_names(self, edge_name: str) -> list[str]: + names = {edge_name} + names.update(_CONFLICTING_EDGE_NAMES.get(edge_name.upper(), set())) + return sorted(names) + + def _refresh_node_embeddings(self, graph_id: str, node_ids: set[str]) -> None: + if not node_ids: + return + with self._connect() as conn: + rows = conn.execute( + f""" + SELECT * FROM nodes + WHERE graph_id = ? AND uuid IN ({",".join("?" for _ in node_ids)}) + """, + [graph_id, *node_ids], + ).fetchall() + if not rows: + return + + texts = [] + ids = [] + for row in rows: + text = " ".join( + filter( + None, + [ + row["name"], + row["summary"], + " ".join(_json_loads(row["labels_json"], [])), + json.dumps(_json_loads(row["attributes_json"], {}), ensure_ascii=False), + ], + ) + ) + ids.append(row["uuid"]) + texts.append(text) + + try: + embeddings = self._get_embedding_client().embed_texts(texts) + except Exception as exc: + logger.warning("Failed to refresh node embeddings: %s", exc) + return + + now = _now_iso() + with self._lock, self._connect() as conn: + for node_id, embedding in zip(ids, embeddings): + conn.execute( + """ + INSERT INTO node_embeddings(node_uuid, embedding_json, updated_at) + VALUES (?, ?, ?) + ON CONFLICT(node_uuid) DO UPDATE SET + embedding_json = excluded.embedding_json, + updated_at = excluded.updated_at + """, + (node_id, _json_dumps(embedding), now), + ) + + def _refresh_edge_embeddings(self, graph_id: str, edge_ids: set[str]) -> None: + if not edge_ids: + return + with self._connect() as conn: + rows = conn.execute( + f""" + SELECT e.*, src.name AS source_name, dst.name AS target_name + FROM edges e + JOIN nodes src ON src.uuid = e.source_node_uuid + JOIN nodes dst ON dst.uuid = e.target_node_uuid + WHERE e.graph_id = ? AND e.uuid IN ({",".join("?" for _ in edge_ids)}) + """, + [graph_id, *edge_ids], + ).fetchall() + if not rows: + return + + ids = [] + texts = [] + for row in rows: + ids.append(row["uuid"]) + texts.append(" ".join(filter(None, [row["name"], row["fact"], row["source_name"], row["target_name"]]))) + + try: + embeddings = self._get_embedding_client().embed_texts(texts) + except Exception as exc: + logger.warning("Failed to refresh edge embeddings: %s", exc) + return + + now = _now_iso() + with self._lock, self._connect() as conn: + for edge_id, embedding in zip(ids, embeddings): + conn.execute( + """ + INSERT INTO edge_embeddings(edge_uuid, embedding_json, updated_at) + VALUES (?, ?, ?) + ON CONFLICT(edge_uuid) DO UPDATE SET + embedding_json = excluded.embedding_json, + updated_at = excluded.updated_at + """, + (edge_id, _json_dumps(embedding), now), + ) + + def _refresh_episode_embeddings(self, graph_id: str, episode_ids: set[str]) -> None: + if not episode_ids: + return + with self._connect() as conn: + rows = conn.execute( + f""" + SELECT * FROM episodes + WHERE graph_id = ? AND uuid IN ({",".join("?" for _ in episode_ids)}) + """, + [graph_id, *episode_ids], + ).fetchall() + if not rows: + return + + ids = [row["uuid"] for row in rows] + texts = [row["data"] or "" for row in rows] + + try: + embeddings = self._get_embedding_client().embed_texts(texts) + except Exception as exc: + logger.warning("Failed to refresh episode embeddings: %s", exc) + return + + now = _now_iso() + with self._lock, self._connect() as conn: + for episode_id, embedding in zip(ids, embeddings): + conn.execute( + """ + INSERT INTO episode_embeddings(episode_uuid, embedding_json, updated_at) + VALUES (?, ?, ?) + ON CONFLICT(episode_uuid) DO UPDATE SET + embedding_json = excluded.embedding_json, + updated_at = excluded.updated_at + """, + (episode_id, _json_dumps(embedding), now), + ) + + def _load_edge_episode_map(self, conn: sqlite3.Connection, edge_ids: list[str]) -> dict[str, list[str]]: + if not edge_ids: + return {} + rows = conn.execute( + f""" + SELECT edge_uuid, episode_uuid + FROM edge_episodes + WHERE edge_uuid IN ({",".join("?" for _ in edge_ids)}) + ORDER BY episode_uuid + """, + edge_ids, + ).fetchall() + episode_map: dict[str, list[str]] = {} + for row in rows: + episode_map.setdefault(row["edge_uuid"], []).append(row["episode_uuid"]) + return episode_map + + def _load_edge_endpoint_names(self, conn: sqlite3.Connection, row: sqlite3.Row) -> tuple[str, str]: + source = conn.execute("SELECT name FROM nodes WHERE uuid = ?", (row["source_node_uuid"],)).fetchone() + target = conn.execute("SELECT name FROM nodes WHERE uuid = ?", (row["target_node_uuid"],)).fetchone() + return (source["name"] if source else "", target["name"] if target else "") + + def _row_to_node(self, row: sqlite3.Row) -> GraphNode: + return GraphNode( + uuid_=row["uuid"], + graph_id=row["graph_id"], + name=row["name"], + labels=_json_loads(row["labels_json"], []), + summary=row["summary"] or "", + attributes=_json_loads(row["attributes_json"], {}), + created_at=row["created_at"], + ) + + def _row_to_edge(self, row: sqlite3.Row, episodes: list[str]) -> GraphEdge: + return GraphEdge( + uuid_=row["uuid"], + graph_id=row["graph_id"], + name=row["name"], + fact=row["fact"], + source_node_uuid=row["source_node_uuid"], + target_node_uuid=row["target_node_uuid"], + attributes=_json_loads(row["attributes_json"], {}), + created_at=row["created_at"], + valid_at=row["valid_at"], + invalid_at=row["invalid_at"], + expired_at=row["expired_at"], + episodes=episodes, + ) + + def _row_to_episode(self, row: sqlite3.Row) -> GraphEpisode: + return GraphEpisode( + uuid_=row["uuid"], + graph_id=row["graph_id"], + data=row["data"], + type=row["type"], + processed=bool(row["processed"]), + created_at=row["created_at"], + error=row["error"], + metadata=_json_loads(row["metadata_json"], {}), + source_description=row["source_description"], + role=row["role"], + role_type=row["role_type"], + thread_id=row["thread_id"], + task_id=row["task_id"], + ) + + def _merge_summary(self, existing: str, new_value: str, attributes: dict[str, Any]) -> str: + existing = (existing or "").strip() + new_value = (new_value or "").strip() + if existing and new_value: + if new_value in existing: + return existing + if existing in new_value: + return new_value + return f"{existing} {new_value}".strip() + if new_value: + return new_value + if existing: + return existing + return self._fallback_summary("", attributes) + + def _fallback_summary(self, name: str, attributes: dict[str, Any]) -> str: + if attributes: + pairs = [f"{key}: {value}" for key, value in attributes.items() if value] + if pairs: + prefix = f"{name} - " if name else "" + return prefix + ", ".join(pairs[:4]) + return name or "" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 4f5361d537..6a58838b47 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -16,12 +16,10 @@ dependencies = [ # LLM 相关 "openai>=1.0.0", - # Zep Cloud - "zep-cloud==3.13.0", - - # OASIS 社交媒体模拟 - "camel-oasis==0.2.5", - "camel-ai==0.2.78", + # OASIS 社交媒体模拟. + # These packages currently publish metadata for Python <3.12 only. + "camel-oasis==0.2.5; python_version < '3.12'", + "camel-ai==0.2.78; python_version < '3.12'", # 文件处理 "PyMuPDF>=1.24.0", @@ -35,6 +33,10 @@ dependencies = [ ] [project.optional-dependencies] +oasis = [ + "camel-oasis==0.2.5; python_version < '3.12'", + "camel-ai==0.2.78; python_version < '3.12'", +] dev = [ "pytest>=8.0.0", "pytest-asyncio>=0.23.0", @@ -52,4 +54,4 @@ dev = [ ] [tool.hatch.build.targets.wheel] -packages = ["app"] +packages = ["app", "local_zep", "zep_cloud"] diff --git a/backend/requirements-oasis.txt b/backend/requirements-oasis.txt new file mode 100644 index 0000000000..cef69da561 --- /dev/null +++ b/backend/requirements-oasis.txt @@ -0,0 +1,7 @@ +# Optional OASIS runtime. +# +# These third-party packages currently declare Python <3.12 support. +# Use a Python 3.11 venv if you need the original OASIS simulation scripts. + +camel-oasis==0.2.5; python_version < "3.12" +camel-ai==0.2.78; python_version < "3.12" diff --git a/backend/requirements.txt b/backend/requirements.txt index 4f146296ba..4261ab5ebe 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,7 +1,7 @@ # =========================================== # MiroFish Backend Dependencies # =========================================== -# Python 3.11+ required +# Python 3.11+ required. Python 3.13 is supported for the API/local-Zep backend. # Install: pip install -r requirements.txt # =========================================== @@ -13,13 +13,11 @@ flask-cors>=6.0.0 # OpenAI SDK(统一使用 OpenAI 格式调用 LLM) openai>=1.0.0 -# ============= Zep Cloud ============= -zep-cloud==3.13.0 - # ============= OASIS 社交媒体模拟 ============= -# OASIS 社交模拟框架 -camel-oasis==0.2.5 -camel-ai==0.2.78 +# camel-oasis 0.2.5 declares Python <3.12, so keep it optional on newer Python. +# The backend server and local-Zep graph work on Python 3.13 without these packages. +camel-oasis==0.2.5; python_version < "3.12" +camel-ai==0.2.78; python_version < "3.12" # ============= 文件处理 ============= PyMuPDF>=1.24.0 diff --git a/backend/run.py b/backend/run.py index 4e3b04fa96..74937ec105 100644 --- a/backend/run.py +++ b/backend/run.py @@ -40,6 +40,11 @@ def main(): host = os.environ.get('FLASK_HOST', '0.0.0.0') port = int(os.environ.get('FLASK_PORT', 5001)) debug = Config.DEBUG + + if Config.PUBLIC_BASE_URL: + print(f"Public backend URL: {Config.PUBLIC_BASE_URL}") + if Config.TAILSCALE_URL: + print(f"Tailscale URL: {Config.TAILSCALE_URL}") # 启动服务 app.run(host=host, port=port, debug=debug, threaded=True) @@ -47,4 +52,3 @@ def main(): if __name__ == '__main__': main() - diff --git a/backend/scripts/fake_openai_server.py b/backend/scripts/fake_openai_server.py new file mode 100644 index 0000000000..35b2a2277e --- /dev/null +++ b/backend/scripts/fake_openai_server.py @@ -0,0 +1,364 @@ +""" +Minimal OpenAI-compatible fake server for local graph E2E tests. + +Implements: +- POST /v1/chat/completions +- POST /v1/embeddings +- POST /v1/rerank +""" + +from __future__ import annotations + +import hashlib +import json +import threading +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer + + +FIXED_ONTOLOGY = { + "entity_types": [ + { + "name": "Journalist", + "description": "A news reporter or editor active in public discourse.", + "attributes": [{"name": "role", "type": "text", "description": "Journalism role"}], + "examples": ["Alice"], + }, + { + "name": "MediaOutlet", + "description": "A media organization publishing reports and commentary.", + "attributes": [{"name": "org_name", "type": "text", "description": "Media brand"}], + "examples": ["DailyNews"], + }, + { + "name": "Ngo", + "description": "A civil society or advocacy organization.", + "attributes": [{"name": "mission", "type": "text", "description": "Primary mission"}], + "examples": ["GreenFuture"], + }, + { + "name": "Company", + "description": "A commercial company participating in the issue.", + "attributes": [{"name": "industry", "type": "text", "description": "Industry sector"}], + "examples": ["PolluteCorp"], + }, + { + "name": "Official", + "description": "A government or public official making responses.", + "attributes": [{"name": "title", "type": "text", "description": "Official title"}], + "examples": ["Mayor Lee"], + }, + { + "name": "Citizen", + "description": "A member of the public commenting on the issue.", + "attributes": [{"name": "location", "type": "text", "description": "Location"}], + "examples": ["Concerned resident"], + }, + { + "name": "Expert", + "description": "An expert or scholar adding analysis.", + "attributes": [{"name": "specialty", "type": "text", "description": "Area of expertise"}], + "examples": ["Policy researcher"], + }, + { + "name": "CommunityGroup", + "description": "A community or activist group involved in the issue.", + "attributes": [{"name": "focus", "type": "text", "description": "Group focus"}], + "examples": ["Local advocates"], + }, + { + "name": "Person", + "description": "Any individual person not fitting other specific person types.", + "attributes": [{"name": "full_name", "type": "text", "description": "Full name"}], + "examples": ["ordinary citizen"], + }, + { + "name": "Organization", + "description": "Any organization not fitting other specific organization types.", + "attributes": [{"name": "org_name", "type": "text", "description": "Organization name"}], + "examples": ["generic organization"], + }, + ], + "edge_types": [ + { + "name": "WORKS_FOR", + "description": "Employment or affiliation relation.", + "source_targets": [{"source": "Journalist", "target": "MediaOutlet"}], + "attributes": [], + }, + { + "name": "REPORTS_ON", + "description": "Coverage relation for public reporting.", + "source_targets": [ + {"source": "MediaOutlet", "target": "Ngo"}, + {"source": "MediaOutlet", "target": "Company"}, + {"source": "Journalist", "target": "Ngo"}, + ], + "attributes": [], + }, + { + "name": "SUPPORTS", + "description": "Supportive stance relation.", + "source_targets": [ + {"source": "Journalist", "target": "Ngo"}, + {"source": "Citizen", "target": "Ngo"}, + {"source": "Official", "target": "Organization"}, + ], + "attributes": [], + }, + { + "name": "OPPOSES", + "description": "Opposing stance relation.", + "source_targets": [ + {"source": "Company", "target": "Ngo"}, + {"source": "Organization", "target": "Organization"}, + ], + "attributes": [], + }, + { + "name": "RESPONDS_TO", + "description": "A public response to another actor.", + "source_targets": [ + {"source": "Official", "target": "Journalist"}, + {"source": "Organization", "target": "MediaOutlet"}, + ], + "attributes": [], + }, + { + "name": "COLLABORATES_WITH", + "description": "Cooperation relation between actors.", + "source_targets": [ + {"source": "Organization", "target": "Organization"}, + {"source": "Ngo", "target": "CommunityGroup"}, + ], + "attributes": [], + }, + ], + "analysis_summary": "Fake ontology for local graph integration testing.", +} + + +def _embedding_for_text(text: str) -> list[float]: + digest = hashlib.sha256((text or "").encode("utf-8")).digest() + values = [] + for index in range(8): + chunk = digest[index * 4:(index + 1) * 4] + raw = int.from_bytes(chunk, "big") + values.append((raw % 1000) / 1000.0) + return values + + +def _rerank_score(query: str, document: str) -> float: + query_tokens = set((query or "").lower().replace(".", " ").replace(",", " ").split()) + document_tokens = set((document or "").lower().replace(".", " ").replace(",", " ").split()) + if not query_tokens or not document_tokens: + return 0.0 + overlap = len(query_tokens.intersection(document_tokens)) / len(query_tokens) + exact_bonus = 0.4 if (query or "").lower() in (document or "").lower() else 0.0 + return min(1.0, overlap + exact_bonus) + + +def _build_extraction_payload(text: str) -> dict: + normalized = (text or "").lower() + entities = [] + edges = [] + + if "alice" in normalized: + entities.append({ + "name": "Alice", + "type": "Journalist", + "summary": "Journalist covering the environmental dispute.", + "attributes": {"role": "journalist"}, + }) + if "dailynews" in normalized: + entities.append({ + "name": "DailyNews", + "type": "MediaOutlet", + "summary": "Media outlet reporting on the issue.", + "attributes": {"org_name": "DailyNews"}, + }) + if "greenfuture" in normalized: + entities.append({ + "name": "GreenFuture", + "type": "Ngo", + "summary": "Environmental organization active in the conflict.", + "attributes": {"mission": "environmental advocacy"}, + }) + if "pollutecorp" in normalized: + entities.append({ + "name": "PolluteCorp", + "type": "Company", + "summary": "Company opposing GreenFuture's campaign.", + "attributes": {"industry": "manufacturing"}, + }) + if "mayor lee" in normalized: + entities.append({ + "name": "Mayor Lee", + "type": "Official", + "summary": "Public official responding to media coverage.", + "attributes": {"title": "Mayor"}, + }) + + if {"alice", "dailynews"} <= set(normalized.replace(".", " ").replace(",", " ").split()): + edges.append({ + "name": "WORKS_FOR", + "source": "Alice", + "target": "DailyNews", + "fact": "Alice works for DailyNews.", + "attributes": {}, + }) + if "dailynews" in normalized and "greenfuture" in normalized: + edges.append({ + "name": "REPORTS_ON", + "source": "DailyNews", + "target": "GreenFuture", + "fact": "DailyNews reports on GreenFuture.", + "attributes": {}, + }) + if "alice supports greenfuture" in normalized or ("alice" in normalized and "supports greenfuture" in normalized): + edges.append({ + "name": "SUPPORTS", + "source": "Alice", + "target": "GreenFuture", + "fact": "Alice supports GreenFuture.", + "attributes": {}, + }) + if "alice opposes greenfuture" in normalized: + edges.append({ + "name": "OPPOSES", + "source": "Alice", + "target": "GreenFuture", + "fact": "Alice opposes GreenFuture.", + "attributes": {}, + }) + if "pollutecorp opposes greenfuture" in normalized or ("pollutecorp" in normalized and "opposes greenfuture" in normalized): + edges.append({ + "name": "OPPOSES", + "source": "PolluteCorp", + "target": "GreenFuture", + "fact": "PolluteCorp opposes GreenFuture.", + "attributes": {}, + }) + if "mayor lee responds to alice" in normalized or ("mayor lee" in normalized and "responds to alice" in normalized): + edges.append({ + "name": "RESPONDS_TO", + "source": "Mayor Lee", + "target": "Alice", + "fact": "Mayor Lee responds to Alice.", + "attributes": {}, + }) + + return {"entities": entities, "edges": edges} + + +def _chat_response(messages: list[dict]) -> str: + system = "\n".join(message.get("content", "") for message in messages if message.get("role") == "system") + user = "\n".join(message.get("content", "") for message in messages if message.get("role") == "user") + + if "本体设计专家" in system or "正好输出10个实体类型" in user: + return json.dumps(FIXED_ONTOLOGY, ensure_ascii=False) + + if "Extract graph updates from the text below" in user: + return json.dumps(_build_extraction_payload(user), ensure_ascii=False) + + return json.dumps({"message": "fake-server-default-response"}, ensure_ascii=False) + + +class FakeOpenAIHandler(BaseHTTPRequestHandler): + server_version = "FakeOpenAI/0.1" + + def do_POST(self): # noqa: N802 + length = int(self.headers.get("Content-Length", "0")) + payload = json.loads(self.rfile.read(length).decode("utf-8") or "{}") + + if self.path == "/v1/chat/completions": + response = { + "id": "chatcmpl-fake", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": _chat_response(payload.get("messages", [])), + }, + "finish_reason": "stop", + } + ], + } + self._write_json(200, response) + return + + if self.path == "/v1/embeddings": + inputs = payload.get("input", []) + if not isinstance(inputs, list): + inputs = [inputs] + response = { + "object": "list", + "data": [ + { + "object": "embedding", + "index": index, + "embedding": _embedding_for_text(text), + } + for index, text in enumerate(inputs) + ], + "model": payload.get("model", "fake-embedding-model"), + } + self._write_json(200, response) + return + + if self.path == "/v1/rerank": + documents = payload.get("documents", []) + if not isinstance(documents, list): + documents = [str(documents)] + query = payload.get("query", "") + scored = [ + { + "index": index, + "relevance_score": _rerank_score(query, document), + } + for index, document in enumerate(documents) + ] + scored.sort(key=lambda item: item["relevance_score"], reverse=True) + response = { + "model": payload.get("model", "fake-reranker-model"), + "results": scored[: int(payload.get("top_n") or len(scored))], + } + self._write_json(200, response) + return + + self._write_json(404, {"error": f"Unknown path: {self.path}"}) + + def log_message(self, format, *args): # noqa: A003 + return + + def _write_json(self, status: int, payload: dict) -> None: + body = json.dumps(payload, ensure_ascii=False).encode("utf-8") + self.send_response(status) + self.send_header("Content-Type", "application/json; charset=utf-8") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + +def start_server(host: str = "127.0.0.1", port: int = 18080) -> tuple[ThreadingHTTPServer, threading.Thread]: + server = ThreadingHTTPServer((host, port), FakeOpenAIHandler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + return server, thread + + +def stop_server(server: ThreadingHTTPServer, thread: threading.Thread | None = None) -> None: + server.shutdown() + server.server_close() + if thread is not None: + thread.join(timeout=2) + + +if __name__ == "__main__": + httpd, worker = start_server() + print("fake_openai_server listening on http://127.0.0.1:18080/v1") + try: + worker.join() + except KeyboardInterrupt: + stop_server(httpd, worker) diff --git a/backend/scripts/test_local_zep_e2e.py b/backend/scripts/test_local_zep_e2e.py new file mode 100644 index 0000000000..35c7e78ab1 --- /dev/null +++ b/backend/scripts/test_local_zep_e2e.py @@ -0,0 +1,419 @@ +""" +Zero-dependency end-to-end test for the local graph replacement. + +This test: +1. Starts a fake OpenAI-compatible server +2. Injects minimal runtime shims for missing third-party packages +3. Runs ontology generation -> graph build -> entity read -> graph search +""" + +from __future__ import annotations + +import importlib.util +import json +import os +import shutil +import sys +import tempfile +import time +import types +import urllib.request +from pathlib import Path +from types import SimpleNamespace + +from fake_openai_server import start_server, stop_server + + +ROOT = Path(__file__).resolve().parents[2] +BACKEND_ROOT = ROOT / "backend" +APP_ROOT = BACKEND_ROOT / "app" + +SAMPLE_TEXT = ( + "Alice is a journalist at DailyNews. " + "DailyNews reports on GreenFuture, an environmental organization. " + "Alice supports GreenFuture. " + "PolluteCorp opposes GreenFuture. " + "Mayor Lee responds to Alice." +) + + +def _to_namespace(value): + if isinstance(value, dict): + return SimpleNamespace(**{key: _to_namespace(item) for key, item in value.items()}) + if isinstance(value, list): + return [_to_namespace(item) for item in value] + return value + + +def install_runtime_shims() -> None: + flask_module = types.ModuleType("flask") + flask_module.request = SimpleNamespace(headers={}) + flask_module.has_request_context = lambda: False + flask_module.Flask = type("Flask", (), {}) + sys.modules["flask"] = flask_module + + flask_cors_module = types.ModuleType("flask_cors") + flask_cors_module.CORS = lambda *args, **kwargs: None + sys.modules["flask_cors"] = flask_cors_module + + dotenv_module = types.ModuleType("dotenv") + dotenv_module.load_dotenv = lambda *args, **kwargs: False + sys.modules["dotenv"] = dotenv_module + + pydantic_module = types.ModuleType("pydantic") + + class FieldInfo: + def __init__(self, default=None, description=None): + self.default = default + self.description = description + + def Field(*, default=None, description=None): + return FieldInfo(default=default, description=description) + + class ConfigDict(dict): + pass + + class BaseModel: + model_fields = {} + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + annotations = getattr(cls, "__annotations__", {}) + model_fields = {} + for name in annotations: + value = getattr(cls, name, None) + if isinstance(value, FieldInfo): + model_fields[name] = value + cls.model_fields = model_fields + + pydantic_module.Field = Field + pydantic_module.BaseModel = BaseModel + pydantic_module.ConfigDict = ConfigDict + sys.modules["pydantic"] = pydantic_module + + openai_module = types.ModuleType("openai") + + class _ChatCompletions: + def __init__(self, client): + self.client = client + + def create(self, **kwargs): + return self.client._post("/chat/completions", kwargs) + + class _Embeddings: + def __init__(self, client): + self.client = client + + def create(self, **kwargs): + return self.client._post("/embeddings", kwargs) + + class OpenAI: + def __init__(self, api_key=None, base_url=None): + self.api_key = api_key or "" + self.base_url = (base_url or "").rstrip("/") + self.chat = SimpleNamespace(completions=_ChatCompletions(self)) + self.embeddings = _Embeddings(self) + + def _post(self, path: str, payload: dict): + body = json.dumps(payload).encode("utf-8") + request = urllib.request.Request( + self.base_url + path, + data=body, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + }, + method="POST", + ) + with urllib.request.urlopen(request, timeout=10) as response: + data = json.loads(response.read().decode("utf-8")) + return _to_namespace(data) + + openai_module.OpenAI = OpenAI + sys.modules["openai"] = openai_module + + +def ensure_package(name: str, path: Path) -> None: + module = types.ModuleType(name) + module.__path__ = [str(path)] + sys.modules[name] = module + + +def load_module(name: str, file_path: Path): + spec = importlib.util.spec_from_file_location(name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def load_backend_modules(temp_upload_root: Path): + sys.path.insert(0, str(BACKEND_ROOT)) + + ensure_package("app", APP_ROOT) + ensure_package("app.utils", APP_ROOT / "utils") + ensure_package("app.models", APP_ROOT / "models") + ensure_package("app.services", APP_ROOT / "services") + + config_module = load_module("app.config", APP_ROOT / "config.py") + config_module.Config.UPLOAD_FOLDER = str(temp_upload_root) + config_module.Config.OASIS_SIMULATION_DATA_DIR = str(temp_upload_root / "simulations") + + load_module("app.utils.logger", APP_ROOT / "utils" / "logger.py") + load_module("app.utils.locale", APP_ROOT / "utils" / "locale.py") + load_module("app.utils.file_parser", APP_ROOT / "utils" / "file_parser.py") + load_module("app.utils.llm_client", APP_ROOT / "utils" / "llm_client.py") + load_module("app.utils.zep_paging", APP_ROOT / "utils" / "zep_paging.py") + load_module("app.models.task", APP_ROOT / "models" / "task.py") + load_module("app.models.project", APP_ROOT / "models" / "project.py") + load_module("app.services.text_processor", APP_ROOT / "services" / "text_processor.py") + load_module("app.services.ontology_generator", APP_ROOT / "services" / "ontology_generator.py") + load_module("app.services.graph_builder", APP_ROOT / "services" / "graph_builder.py") + load_module("app.services.zep_entity_reader", APP_ROOT / "services" / "zep_entity_reader.py") + load_module("app.services.zep_tools", APP_ROOT / "services" / "zep_tools.py") + + return { + "config": sys.modules["app.config"], + "file_parser": sys.modules["app.utils.file_parser"], + "project": sys.modules["app.models.project"], + "task": sys.modules["app.models.task"], + "text_processor": sys.modules["app.services.text_processor"], + "ontology": sys.modules["app.services.ontology_generator"], + "graph_builder": sys.modules["app.services.graph_builder"], + "entity_reader": sys.modules["app.services.zep_entity_reader"], + "zep_tools": sys.modules["app.services.zep_tools"], + } + + +def wait_for_task(task_manager, task_id: str, timeout: int = 20): + deadline = time.time() + timeout + while time.time() < deadline: + task = task_manager.get_task(task_id) + if task and task.status.value in {"completed", "failed"}: + return task + time.sleep(0.2) + raise TimeoutError(f"Timed out waiting for task {task_id}") + + +class FakeFileStorage: + def __init__(self, source_path: Path): + self.source_path = source_path + + def save(self, destination: str) -> None: + shutil.copyfile(self.source_path, destination) + + +def run_e2e() -> dict: + with tempfile.TemporaryDirectory(prefix="mirofish-local-zep-") as temp_dir: + temp_root = Path(temp_dir) + upload_root = temp_root / "uploads" + sample_path = temp_root / "seed.txt" + db_path = temp_root / "local_zep.sqlite3" + sample_path.write_text(SAMPLE_TEXT, encoding="utf-8") + + os.environ["LLM_API_KEY"] = "fake-key" + os.environ["LLM_BASE_URL"] = "http://127.0.0.1:18080/v1" + os.environ["LLM_MODEL_NAME"] = "fake-chat-model" + os.environ["EMBEDDING_API_KEY"] = "fake-key" + os.environ["EMBEDDING_BASE_URL"] = "http://127.0.0.1:18080/v1" + os.environ["EMBEDDING_MODEL_NAME"] = "fake-embedding-model" + os.environ["RERANKER_API_KEY"] = "fake-key" + os.environ["RERANKER_BASE_URL"] = "http://127.0.0.1:18080/v1" + os.environ["RERANKER_MODEL_NAME"] = "fake-reranker-model" + os.environ["LOCAL_ZEP_DB_PATH"] = str(db_path) + + install_runtime_shims() + modules = load_backend_modules(upload_root) + + ProjectManager = modules["project"].ProjectManager + OntologyGenerator = modules["ontology"].OntologyGenerator + GraphBuilderService = modules["graph_builder"].GraphBuilderService + ZepEntityReader = modules["entity_reader"].ZepEntityReader + ZepToolsService = modules["zep_tools"].ZepToolsService + TaskManager = modules["task"].TaskManager + FileParser = modules["file_parser"].FileParser + TextProcessor = modules["text_processor"].TextProcessor + + project = ProjectManager.create_project(name="Local Zep E2E") + project.simulation_requirement = "Simulate public sentiment around the GreenFuture environmental campaign." + + file_info = ProjectManager.save_file_to_project( + project.project_id, + FakeFileStorage(sample_path), + "seed.txt", + ) + project.files.append({"filename": file_info["original_filename"], "size": file_info["size"]}) + + extracted_text = TextProcessor.preprocess_text(FileParser.extract_text(file_info["path"])) + ProjectManager.save_extracted_text(project.project_id, extracted_text) + project.total_text_length = len(extracted_text) + ProjectManager.save_project(project) + + ontology_generator = OntologyGenerator() + ontology = ontology_generator.generate( + document_texts=[extracted_text], + simulation_requirement=project.simulation_requirement, + additional_context=None, + ) + assert len(ontology["entity_types"]) == 10, ontology["entity_types"] + assert any(item["name"] == "Person" for item in ontology["entity_types"]) + assert any(item["name"] == "Organization" for item in ontology["entity_types"]) + + project.ontology = { + "entity_types": ontology["entity_types"], + "edge_types": ontology["edge_types"], + } + ProjectManager.save_project(project) + + builder = GraphBuilderService(api_key="ignored-for-local") + task_id = builder.build_graph_async( + text=extracted_text, + ontology=project.ontology, + graph_name="Local Zep E2E Graph", + chunk_size=400, + chunk_overlap=40, + batch_size=2, + ) + task = wait_for_task(TaskManager(), task_id) + assert task.status.value == "completed", task.error + + graph_id = task.result["graph_id"] + graph_data = builder.get_graph_data(graph_id) + assert graph_data["node_count"] >= 5, graph_data + assert graph_data["edge_count"] >= 5, graph_data + + reader = ZepEntityReader() + filtered = reader.filter_defined_entities(graph_id=graph_id, enrich_with_edges=True) + assert filtered.filtered_count >= 5, filtered.to_dict() + + tools = ZepToolsService(api_key="ignored-for-local") + search = tools.search_graph(graph_id=graph_id, query="Alice supports GreenFuture", limit=5, scope="edges") + assert any("Alice supports GreenFuture" in fact for fact in search.facts), search.to_dict() + cross_encoder_results = tools.client.graph.search( + graph_id=graph_id, + query="Alice supports GreenFuture", + limit=3, + scope="edges", + reranker="cross_encoder", + ) + assert cross_encoder_results.edges, cross_encoder_results + assert cross_encoder_results.edges[0].score is not None + assert cross_encoder_results.edges[0].relevance is not None + + rrf_results = tools.client.graph.search( + graph_id=graph_id, + query="GreenFuture", + limit=3, + scope="both", + reranker="rrf", + ) + assert rrf_results.edges or rrf_results.nodes, rrf_results + + mmr_results = tools.client.graph.search( + graph_id=graph_id, + query="GreenFuture", + limit=3, + scope="edges", + reranker="mmr", + mmr_lambda=0.5, + ) + assert mmr_results.edges, mmr_results + + episode_results = tools.client.graph.search( + graph_id=graph_id, + query="Mayor Lee", + limit=2, + scope="episodes", + reranker="rrf", + ) + assert episode_results.episodes, episode_results + + temporal_episode = tools.client.graph.add( + graph_id=graph_id, + data="Alice opposes GreenFuture.", + type="text", + created_at="2025-01-01T00:00:00Z", + metadata={"source": "temporal_test"}, + source_description="Temporal update test", + ) + assert temporal_episode.metadata["source"] == "temporal_test" + + temporal_edges = tools.client.graph.edge.get_by_graph_id(graph_id=graph_id, limit=20) + old_supports = [ + edge + for edge in temporal_edges + if edge.name == "SUPPORTS" and edge.source_node_uuid and "Alice supports GreenFuture" in edge.fact + ] + new_opposes = [ + edge + for edge in temporal_edges + if edge.name == "OPPOSES" and "Alice opposes GreenFuture" in edge.fact + ] + assert old_supports and old_supports[0].invalid_at is not None, [edge.fact for edge in temporal_edges] + assert new_opposes and new_opposes[0].valid_at == "2025-01-01T00:00:00Z", new_opposes + + active_results = tools.client.graph.search( + graph_id=graph_id, + query="Alice GreenFuture", + limit=10, + scope="edges", + reranker="rrf", + search_filters={"invalid_at": [[{"comparison_operator": "IS NULL"}]]}, + ) + assert any("Alice opposes GreenFuture" in edge.fact for edge in active_results.edges) + assert not any("Alice supports GreenFuture" in edge.fact for edge in active_results.edges) + + metadata_results = tools.client.graph.search( + graph_id=graph_id, + query="Alice opposes GreenFuture", + limit=5, + scope="episodes", + search_filters={ + "episode_metadata_filters": { + "type": "and", + "filters": [ + { + "property_name": "source", + "comparison_operator": "=", + "property_value": "temporal_test", + } + ], + } + }, + ) + assert metadata_results.episodes and metadata_results.episodes[0].uuid_ == temporal_episode.uuid_ + + stats = tools.get_graph_statistics(graph_id) + assert stats["total_nodes"] >= 5 + assert stats["total_edges"] >= 5 + + summary = tools.get_entity_summary(graph_id=graph_id, entity_name="Alice") + assert summary["entity_info"] is not None, summary + assert summary["total_relations"] >= 2, summary + + return { + "project_id": project.project_id, + "graph_id": graph_id, + "node_count": graph_data["node_count"], + "edge_count": graph_data["edge_count"], + "entity_count": filtered.filtered_count, + "search_facts": search.facts[:3], + "cross_encoder_relevance": cross_encoder_results.edges[0].relevance, + "episode_hits": len(episode_results.episodes), + "alice_relations": summary["total_relations"], + } + + +def main() -> int: + server, thread = start_server() + try: + result = run_e2e() + print(json.dumps({"status": "ok", "result": result}, ensure_ascii=False, indent=2)) + return 0 + finally: + stop_server(server, thread) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/backend/zep_cloud/__init__.py b/backend/zep_cloud/__init__.py new file mode 100644 index 0000000000..64a13ee43b --- /dev/null +++ b/backend/zep_cloud/__init__.py @@ -0,0 +1,7 @@ +from local_zep import EntityEdgeSourceTarget, EpisodeData, InternalServerError + +__all__ = [ + "EntityEdgeSourceTarget", + "EpisodeData", + "InternalServerError", +] diff --git a/backend/zep_cloud/client.py b/backend/zep_cloud/client.py new file mode 100644 index 0000000000..46e5154535 --- /dev/null +++ b/backend/zep_cloud/client.py @@ -0,0 +1,3 @@ +from local_zep.client import Zep + +__all__ = ["Zep"] diff --git a/backend/zep_cloud/external_clients/__init__.py b/backend/zep_cloud/external_clients/__init__.py new file mode 100644 index 0000000000..ee8e0a0362 --- /dev/null +++ b/backend/zep_cloud/external_clients/__init__.py @@ -0,0 +1,3 @@ +from .ontology import EdgeModel, EntityModel, EntityText + +__all__ = ["EdgeModel", "EntityModel", "EntityText"] diff --git a/backend/zep_cloud/external_clients/ontology.py b/backend/zep_cloud/external_clients/ontology.py new file mode 100644 index 0000000000..b92e1c9e9e --- /dev/null +++ b/backend/zep_cloud/external_clients/ontology.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel, ConfigDict + +EntityText = str + + +class EntityModel(BaseModel): + model_config = ConfigDict(extra="allow") + + +class EdgeModel(BaseModel): + model_config = ConfigDict(extra="allow") diff --git a/frontend/src/api/index.js b/frontend/src/api/index.js index e840e1166a..14ebb66ef2 100644 --- a/frontend/src/api/index.js +++ b/frontend/src/api/index.js @@ -1,9 +1,20 @@ import axios from 'axios' import i18n from '../i18n' +function resolveApiBaseURL() { + const configured = (import.meta.env.VITE_API_BASE_URL || '').trim() + if (configured) { + return configured + } + + // API modules already call /api/... paths. Keeping baseURL empty makes + // remote browsers use the Vite same-origin proxy instead of their localhost. + return '' +} + // 创建axios实例 const service = axios.create({ - baseURL: import.meta.env.VITE_API_BASE_URL || 'http://localhost:5001', + baseURL: resolveApiBaseURL(), timeout: 300000, // 5分钟超时(本体生成可能需要较长时间) headers: { 'Content-Type': 'application/json' diff --git a/frontend/src/views/Process.vue b/frontend/src/views/Process.vue index 2d2d3cc1ac..068ae89479 100644 --- a/frontend/src/views/Process.vue +++ b/frontend/src/views/Process.vue @@ -317,7 +317,7 @@
接口说明
- 基于生成的本体,将文档分块后调用 Zep API 构建知识图谱,提取实体和关系 + 基于生成的本体,将文档分块后调用本地图谱引擎构建知识图谱,提取实体和关系
@@ -2065,4 +2065,4 @@ onUnmounted(() => { display: none; } } - \ No newline at end of file + diff --git a/frontend/vite.config.js b/frontend/vite.config.js index 8f1e4c11b5..7ab864214c 100644 --- a/frontend/vite.config.js +++ b/frontend/vite.config.js @@ -1,24 +1,58 @@ -import { defineConfig } from 'vite' +import { defineConfig, loadEnv } from 'vite' import vue from '@vitejs/plugin-vue' import path from 'path' +function parseAllowedHosts(value) { + if (!value) { + return ['localhost', '127.0.0.1', '.ts.net', '.beta.tailscale.net'] + } + + return value + .split(',') + .map(item => item.trim()) + .filter(Boolean) +} + +function buildHmrConfig(env) { + const hmr = {} + + if (env.VITE_HMR_HOST) hmr.host = env.VITE_HMR_HOST + if (env.VITE_HMR_PROTOCOL) hmr.protocol = env.VITE_HMR_PROTOCOL + if (env.VITE_HMR_PORT) hmr.port = Number(env.VITE_HMR_PORT) + if (env.VITE_HMR_CLIENT_PORT) hmr.clientPort = Number(env.VITE_HMR_CLIENT_PORT) + if (env.VITE_HMR_PATH) hmr.path = env.VITE_HMR_PATH + + return Object.keys(hmr).length ? hmr : undefined +} + // https://vite.dev/config/ -export default defineConfig({ - plugins: [vue()], - resolve: { - alias: { - '@': path.resolve(__dirname, 'src'), - '@locales': path.resolve(__dirname, '../locales') - } - }, - server: { - port: 3000, - open: true, - proxy: { - '/api': { - target: 'http://localhost:5001', - changeOrigin: true, - secure: false +export default defineConfig(({ mode }) => { + const repoRoot = path.resolve(__dirname, '..') + const rootEnv = loadEnv(mode, repoRoot, '') + const frontendEnv = loadEnv(mode, __dirname, '') + const env = { ...rootEnv, ...frontendEnv, ...process.env } + + return { + plugins: [vue()], + resolve: { + alias: { + '@': path.resolve(__dirname, 'src'), + '@locales': path.resolve(__dirname, '../locales') + } + }, + server: { + host: env.VITE_DEV_HOST || '0.0.0.0', + port: Number(env.VITE_DEV_PORT || 3000), + strictPort: true, + open: env.VITE_OPEN_BROWSER ? env.VITE_OPEN_BROWSER === 'true' : true, + allowedHosts: parseAllowedHosts(env.VITE_ALLOWED_HOSTS), + hmr: buildHmrConfig(env), + proxy: { + '/api': { + target: env.VITE_DEV_PROXY_TARGET || 'http://127.0.0.1:5001', + changeOrigin: true, + secure: false + } } } } diff --git a/locales/en.json b/locales/en.json index 544c68b1f6..419b1204e1 100644 --- a/locales/en.json +++ b/locales/en.json @@ -85,7 +85,7 @@ "ontologyDesc": "LLM analyzes document content and simulation requirements, extracts reality seeds, and auto-generates a suitable ontology structure", "analyzingDocs": "Analyzing documents...", "graphRagBuild": "GraphRAG Build", - "graphRagDesc": "Based on the generated ontology, documents are auto-chunked and sent to Zep to build a knowledge graph, extracting entities and relations, forming temporal memory and community summaries", + "graphRagDesc": "Based on the generated ontology, documents are auto-chunked and sent to the local graph engine to build a knowledge graph, extracting entities and relations, forming temporal memory and community summaries", "entityNodes": "Entity Nodes", "relationEdges": "Relation Edges", "schemaTypes": "Schema Types", @@ -328,7 +328,7 @@ "noDocProcessed": "No documents were processed successfully. Please check file formats.", "requireProjectId": "Please provide project_id", "configError": "Configuration error: {details}", - "zepApiKeyMissing": "ZEP_API_KEY not configured", + "zepApiKeyMissing": "Local graph embeddings are not configured", "ontologyNotGenerated": "Ontology not yet generated. Please call /ontology/generate first.", "graphBuilding": "Graph build in progress. Do not resubmit. To force rebuild, add force: true.", "textNotFound": "Extracted text content not found", @@ -393,10 +393,10 @@ "progress": { "initGraphService": "Initializing graph build service...", "textChunking": "Chunking text...", - "creatingZepGraph": "Creating Zep graph...", + "creatingZepGraph": "Creating local graph...", "settingOntology": "Setting ontology definition...", "addingChunks": "Adding {count} text chunks...", - "waitingZepProcess": "Waiting for Zep to process data...", + "waitingZepProcess": "Waiting for the local graph engine to process data...", "fetchingGraphData": "Fetching graph data...", "graphBuildComplete": "Graph build complete", "buildFailed": "Build failed: {error}", @@ -410,12 +410,12 @@ "noEpisodesWait": "No episodes to wait for", "waitingEpisodes": "Waiting for {count} text chunks to process...", "episodesTimeout": "Some chunks timed out, {completed}/{total} completed", - "zepProcessing": "Zep processing... {completed}/{total} done, {pending} pending ({elapsed}s)", + "zepProcessing": "Local graph processing... {completed}/{total} done, {pending} pending ({elapsed}s)", "processingComplete": "Processing complete: {completed}/{total}", "taskComplete": "Task complete", "taskFailed": "Task failed", "startPreparingEnv": "Preparing simulation environment...", - "connectingZepGraph": "Connecting to Zep graph...", + "connectingZepGraph": "Connecting to local graph...", "readingNodeData": "Reading node data...", "readingComplete": "Done, {count} entities found", "startGenerating": "Starting generation...", @@ -492,7 +492,7 @@ "detectedExistingPrep": "Detected existing preparation, using it directly", "prepareTaskStarted": "Preparation task started", "prepareTaskId": " └─ Task ID: {taskId}", - "zepEntitiesFound": "Found {count} entities from Zep graph", + "zepEntitiesFound": "Found {count} entities from the local graph", "entityTypes": " └─ Entity types: {types}", "startPollingProgress": "Polling preparation progress...", "prepareFailed": "Preparation failed: {error}", @@ -610,13 +610,13 @@ "redirectToInsightForge": "get_simulation_context redirected to insight_forge" }, "console": { - "zepToolsInitialized": "ZepToolsService initialized", - "zepRetryAttempt": "Zep {operation} attempt {attempt} failed: {error}, retrying in {delay}s...", - "zepAllRetriesFailed": "Zep {operation} failed after {retries} attempts: {error}", + "zepToolsInitialized": "Local graph tools initialized", + "zepRetryAttempt": "Local graph {operation} attempt {attempt} failed: {error}, retrying in {delay}s...", + "zepAllRetriesFailed": "Local graph {operation} failed after {retries} attempts: {error}", "graphSearch": "Graph search: graph_id={graphId}, query={query}...", "graphSearchOp": "Graph search (graph={graphId})", "searchComplete": "Search complete: found {count} relevant facts", - "zepSearchApiFallback": "Zep Search API failed, falling back to local search: {error}", + "zepSearchApiFallback": "Semantic search failed, falling back to local keyword search: {error}", "usingLocalSearch": "Using local search: query={query}...", "localSearchComplete": "Local search complete: found {count} relevant facts", "localSearchFailed": "Local search failed: {error}", diff --git a/locales/zh.json b/locales/zh.json index cd747e2fa7..d99d478413 100644 --- a/locales/zh.json +++ b/locales/zh.json @@ -85,7 +85,7 @@ "ontologyDesc": "LLM分析文档内容与模拟需求,提取出现实种子,自动生成合适的本体结构", "analyzingDocs": "正在分析文档...", "graphRagBuild": "GraphRAG构建", - "graphRagDesc": "基于生成的本体,将文档自动分块后调用 Zep 构建知识图谱,提取实体和关系,并形成时序记忆与社区摘要", + "graphRagDesc": "基于生成的本体,将文档自动分块后送入本地图谱引擎构建知识图谱,提取实体和关系,并形成时序记忆与社区摘要", "entityNodes": "实体节点", "relationEdges": "关系边", "schemaTypes": "SCHEMA类型", @@ -328,7 +328,7 @@ "noDocProcessed": "没有成功处理任何文档,请检查文件格式", "requireProjectId": "请提供 project_id", "configError": "配置错误: {details}", - "zepApiKeyMissing": "ZEP_API_KEY未配置", + "zepApiKeyMissing": "本地图谱 Embeddings 配置未完成", "ontologyNotGenerated": "项目尚未生成本体,请先调用 /ontology/generate", "graphBuilding": "图谱正在构建中,请勿重复提交。如需强制重建,请添加 force: true", "textNotFound": "未找到提取的文本内容", @@ -393,10 +393,10 @@ "progress": { "initGraphService": "初始化图谱构建服务...", "textChunking": "文本分块中...", - "creatingZepGraph": "创建Zep图谱...", + "creatingZepGraph": "创建本地图谱...", "settingOntology": "设置本体定义...", "addingChunks": "开始添加 {count} 个文本块...", - "waitingZepProcess": "等待Zep处理数据...", + "waitingZepProcess": "等待本地图谱引擎处理数据...", "fetchingGraphData": "获取图谱数据...", "graphBuildComplete": "图谱构建完成", "buildFailed": "构建失败: {error}", @@ -410,12 +410,12 @@ "noEpisodesWait": "无需等待(没有 episode)", "waitingEpisodes": "开始等待 {count} 个文本块处理...", "episodesTimeout": "部分文本块超时,已完成 {completed}/{total}", - "zepProcessing": "Zep处理中... {completed}/{total} 完成, {pending} 待处理 ({elapsed}秒)", + "zepProcessing": "本地图谱处理中... {completed}/{total} 完成, {pending} 待处理 ({elapsed}秒)", "processingComplete": "处理完成: {completed}/{total}", "taskComplete": "任务完成", "taskFailed": "任务失败", "startPreparingEnv": "开始准备模拟环境...", - "connectingZepGraph": "正在连接Zep图谱...", + "connectingZepGraph": "正在连接本地图谱...", "readingNodeData": "正在读取节点数据...", "readingComplete": "完成,共 {count} 个实体", "startGenerating": "开始生成...", @@ -492,7 +492,7 @@ "detectedExistingPrep": "检测到已有完成的准备工作,直接使用", "prepareTaskStarted": "准备任务已启动", "prepareTaskId": " └─ Task ID: {taskId}", - "zepEntitiesFound": "从Zep图谱读取到 {count} 个实体", + "zepEntitiesFound": "从本地图谱读取到 {count} 个实体", "entityTypes": " └─ 实体类型: {types}", "startPollingProgress": "开始轮询准备进度...", "prepareFailed": "准备失败: {error}", @@ -610,13 +610,13 @@ "redirectToInsightForge": "get_simulation_context 已重定向到 insight_forge" }, "console": { - "zepToolsInitialized": "ZepToolsService 初始化完成", - "zepRetryAttempt": "Zep {operation} 第 {attempt} 次尝试失败: {error}, {delay}秒后重试...", - "zepAllRetriesFailed": "Zep {operation} 在 {retries} 次尝试后仍失败: {error}", + "zepToolsInitialized": "本地图谱工具初始化完成", + "zepRetryAttempt": "本地图谱 {operation} 第 {attempt} 次尝试失败: {error}, {delay}秒后重试...", + "zepAllRetriesFailed": "本地图谱 {operation} 在 {retries} 次尝试后仍失败: {error}", "graphSearch": "图谱搜索: graph_id={graphId}, query={query}...", "graphSearchOp": "图谱搜索(graph={graphId})", "searchComplete": "搜索完成: 找到 {count} 条相关事实", - "zepSearchApiFallback": "Zep Search API失败,降级为本地搜索: {error}", + "zepSearchApiFallback": "语义检索失败,降级为本地关键词搜索: {error}", "usingLocalSearch": "使用本地搜索: query={query}...", "localSearchComplete": "本地搜索完成: 找到 {count} 条相关事实", "localSearchFailed": "本地搜索失败: {error}", diff --git a/package.json b/package.json index 63ace21a99..489c731336 100644 --- a/package.json +++ b/package.json @@ -4,10 +4,10 @@ "description": "MiroFish - 简洁通用的群体智能引擎,预测万物", "scripts": { "setup": "npm install && cd frontend && npm install", - "setup:backend": "cd backend && uv sync", + "setup:backend": "python3 -m venv venv && ./venv/bin/python -m pip install --upgrade pip && ./venv/bin/python -m pip install -r backend/requirements.txt", "setup:all": "npm run setup && npm run setup:backend", "dev": "concurrently --kill-others -n \"backend,frontend\" -c \"green,cyan\" \"npm run backend\" \"npm run frontend\"", - "backend": "cd backend && uv run python run.py", + "backend": "./venv/bin/python backend/run.py", "frontend": "cd frontend && npm run dev", "build": "cd frontend && npm run build" },