Skip to content

Commit fec3cf2

Browse files
author
赵明俊
committed
feat(memory): 新增 /api/memory REST 端点并注册路由
提供 remember/recall/context/vector/stats/clear 能力,便于外部管理与检索记忆。 Made-with: Cursor
1 parent c4d9513 commit fec3cf2

2 files changed

Lines changed: 168 additions & 4 deletions

File tree

router/main.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from router.network import router as network_router
1919
from router.database import router as database_router
2020
from router.tools import router as tools_router
21+
from router.memory import router as memory_router
2122
from router.dependencies import get_db_manager
23+
from utils.error_mapper import map_exception_to_client, redact_sensitive_text
2224
from utils.logger import logger
2325
from utils.log_context import log_context
2426

@@ -29,7 +31,7 @@ def create_app() -> FastAPI:
2931
application = FastAPI(
3032
title="Hackbot API",
3133
description="Hackbot AI 安全测试机器人 — REST + SSE 接口",
32-
version="1.0.0",
34+
version="2.0.0b1",
3335
docs_url="/docs",
3436
redoc_url="/redoc",
3537
)
@@ -60,15 +62,17 @@ async def request_logging_middleware(request: Request, call_next):
6062
).info(f"{request.method} {request.url.path} -> {response.status_code}")
6163
response.headers["X-Request-Id"] = request_id
6264
return response
63-
except Exception:
65+
except Exception as exc:
6466
duration_ms = int((time.perf_counter() - start) * 1000)
6567
logger.bind(
6668
event="system_error", duration_ms=duration_ms
6769
).exception(f"{request.method} {request.url.path} unhandled exception")
70+
mapped = map_exception_to_client(exc)
6871
return JSONResponse(
69-
status_code=500,
72+
status_code=mapped.status_code,
7073
content={
71-
"detail": "内部服务异常,请查看日志",
74+
"detail": mapped.message,
75+
"code": mapped.code.value,
7276
"request_id": request_id,
7377
},
7478
headers={"X-Request-Id": request_id},
@@ -85,6 +89,7 @@ async def request_logging_middleware(request: Request, call_next):
8589
application.include_router(network_router)
8690
application.include_router(database_router)
8791
application.include_router(tools_router)
92+
application.include_router(memory_router)
8893

8994
# ------------------------------------------------------------------
9095
# 启动时初始化数据库(确保 secbot.db 与表在首次请求前就存在)

router/memory.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""
2+
Memory REST API — 与 npm-release 的 MemoryController 对齐
3+
暴露 /api/memory 系列端点,对接 MemoryManager 和 VectorStoreManager
4+
"""
5+
6+
from typing import Any, Dict, List, Optional
7+
8+
from fastapi import APIRouter, HTTPException
9+
from pydantic import BaseModel
10+
11+
from router.dependencies import get_memory_manager, get_vector_store_manager
12+
from secbot_agent.core.context_assembler import text_to_vector, VECTOR_DIMENSION
13+
14+
router = APIRouter(prefix="/api/memory", tags=["Memory"])
15+
16+
17+
class RememberRequest(BaseModel):
18+
content: str
19+
memory_type: str = "short_term"
20+
importance: float = 0.5
21+
metadata: Dict[str, Any] = {}
22+
23+
24+
class RecallRequest(BaseModel):
25+
query: str = ""
26+
memory_type: Optional[str] = None
27+
limit: int = 10
28+
29+
30+
class VectorAddRequest(BaseModel):
31+
content: str
32+
collection: str = "episodic"
33+
metadata: Dict[str, Any] = {}
34+
35+
36+
class VectorSearchRequest(BaseModel):
37+
query: str
38+
collection: str = "episodic"
39+
limit: int = 10
40+
41+
42+
@router.post("/remember", summary="添加记忆")
43+
async def remember(body: RememberRequest):
44+
mgr = get_memory_manager()
45+
await mgr.remember(
46+
body.content, body.memory_type, body.importance, **body.metadata
47+
)
48+
return {"ok": True}
49+
50+
51+
@router.post("/recall", summary="召回记忆")
52+
async def recall(body: RecallRequest):
53+
mgr = get_memory_manager()
54+
items = await mgr.recall(body.query, body.memory_type, body.limit)
55+
return {
56+
"items": [
57+
{
58+
"id": item.id,
59+
"content": item.content,
60+
"type": item.type,
61+
"importance": item.importance,
62+
"created_at": item.created_at,
63+
"metadata": item.metadata,
64+
}
65+
for item in items
66+
]
67+
}
68+
69+
70+
@router.get("/context", summary="获取 agent 上下文记忆")
71+
async def context(query: str = ""):
72+
mgr = get_memory_manager()
73+
ctx = await mgr.get_context_for_agent(query)
74+
return {"context": ctx}
75+
76+
77+
@router.get("/stats", summary="获取记忆统计")
78+
async def stats():
79+
mgr = get_memory_manager()
80+
mem_stats = mgr.get_stats()
81+
vsm = get_vector_store_manager()
82+
vec_stats = vsm.get_stats()
83+
return {"memory": mem_stats, "vector": vec_stats}
84+
85+
86+
@router.post("/clear", summary="清空所有记忆")
87+
async def clear():
88+
mgr = get_memory_manager()
89+
await mgr.clear_all()
90+
return {"ok": True}
91+
92+
93+
@router.get("/list", summary="列出记忆")
94+
async def list_memories(
95+
memory_type: Optional[str] = None,
96+
limit: int = 20,
97+
):
98+
mgr = get_memory_manager()
99+
if memory_type == "short_term":
100+
items = await mgr.short_term.get(limit)
101+
elif memory_type == "episodic":
102+
items = await mgr.episodic.get(limit)
103+
elif memory_type == "long_term":
104+
items = await mgr.long_term.get(limit)
105+
else:
106+
st = await mgr.short_term.get(limit)
107+
ep = await mgr.episodic.get(limit)
108+
lt = await mgr.long_term.get(limit)
109+
items = st + ep + lt
110+
return {
111+
"items": [
112+
{
113+
"id": item.id,
114+
"content": item.content,
115+
"type": item.type,
116+
"importance": item.importance,
117+
"created_at": item.created_at,
118+
}
119+
for item in items[:limit]
120+
]
121+
}
122+
123+
124+
@router.post("/vector/add", summary="添加向量记忆")
125+
async def vector_add(body: VectorAddRequest):
126+
vsm = get_vector_store_manager()
127+
vec = text_to_vector(body.content)
128+
item_id = await vsm.add_memory(
129+
content=body.content,
130+
vector=vec,
131+
memory_type=body.collection,
132+
metadata=body.metadata,
133+
)
134+
return {"id": item_id}
135+
136+
137+
@router.post("/vector/search", summary="向量搜索")
138+
async def vector_search(body: VectorSearchRequest):
139+
vsm = get_vector_store_manager()
140+
vec = text_to_vector(body.query)
141+
store = vsm.get_store(body.collection, VECTOR_DIMENSION)
142+
results = store.search(vec, limit=body.limit, collection=body.collection, threshold=0.3)
143+
return {
144+
"results": [
145+
{
146+
"id": item.id,
147+
"content": item.content,
148+
"similarity": round(sim, 4),
149+
"metadata": item.metadata,
150+
}
151+
for item, sim in results
152+
]
153+
}
154+
155+
156+
@router.get("/vector/stats", summary="向量存储统计")
157+
async def vector_stats():
158+
vsm = get_vector_store_manager()
159+
return vsm.get_stats()

0 commit comments

Comments
 (0)