Skip to content

Commit d325329

Browse files
authored
feat: async memorize via Temporal + task status endpoint (#23)
1 parent 0002c21 commit d325329

6 files changed

Lines changed: 799 additions & 30 deletions

File tree

app/main.py

Lines changed: 138 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,33 @@
11
"""memU Server - FastAPI application entry point."""
22

3+
import asyncio
34
import json
45
import logging
6+
import re
57
import uuid
68
from collections.abc import AsyncIterator
79
from contextlib import asynccontextmanager
810
from pathlib import Path
9-
from typing import Any
11+
from typing import Any, cast
1012

1113
from fastapi import FastAPI, HTTPException, Request
1214
from fastapi.responses import JSONResponse
15+
from temporalio.client import Client
16+
from temporalio.service import RPCError, RPCStatusCode
1317

1418
from app.schemas.memory import (
1519
CategoryObject,
1620
ClearMemoriesRequest,
1721
ClearMemoriesResponse,
1822
ListCategoriesRequest,
1923
ListCategoriesResponse,
24+
MemorizeRequest,
25+
MemorizeResponse,
26+
TaskStatusResponse,
2027
)
2128
from app.services.memu import create_memory_service
29+
from app.workers.memorize_workflow import MemorizeWorkflow
30+
from app.workers.worker import TASK_QUEUE
2231
from config.settings import Settings
2332

2433
logger = logging.getLogger(__name__)
@@ -38,14 +47,37 @@
3847
storage_dir = Path(settings.STORAGE_PATH)
3948

4049

50+
async def _get_temporal_client(app: FastAPI) -> Client:
51+
"""Return the cached Temporal client, connecting lazily on first call."""
52+
# Treat any non-None value as the cached client to support mocking/DI.
53+
client = getattr(app.state, "temporal", None)
54+
if client is not None:
55+
return cast(Client, client)
56+
# Create the lock lazily on app.state so it's bound to the running event loop
57+
# (module-level asyncio.Lock() can raise RuntimeError in Python 3.13+).
58+
lock: asyncio.Lock = getattr(app.state, "_temporal_lock", None) or asyncio.Lock()
59+
app.state._temporal_lock = lock
60+
async with lock:
61+
# Double-check after acquiring the lock
62+
client = getattr(app.state, "temporal", None)
63+
if client is not None:
64+
return cast(Client, client)
65+
client = await Client.connect(
66+
settings.temporal_url,
67+
namespace=settings.TEMPORAL_NAMESPACE,
68+
)
69+
app.state.temporal = client
70+
logger.info("Connected to Temporal at %s", settings.temporal_url)
71+
return client
72+
73+
4174
@asynccontextmanager
4275
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
43-
"""Initialise MemoryService on startup (defers DB connection until the app runs)."""
76+
"""Initialise MemoryService on startup. Temporal connects lazily on first use."""
4477
try:
4578
storage_dir.mkdir(parents=True, exist_ok=True)
4679
_app.state.service = create_memory_service(settings)
4780
except Exception as exc:
48-
# Log full traceback for operators and wrap in a clearer startup error
4981
msg = "Failed to initialize MemoryService during application startup"
5082
logger.exception(msg)
5183
raise RuntimeError(msg) from exc
@@ -56,27 +88,121 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
5688

5789

5890
@app.post("/memorize")
59-
async def memorize(request: Request, payload: dict[str, Any]):
91+
async def memorize(request: Request, body: MemorizeRequest):
92+
"""Submit an async memorization task via Temporal workflow."""
93+
file_path: Path | None = None
94+
workflow_started = False
6095
try:
61-
service = request.app.state.service
62-
file_path = storage_dir / f"conversation-{uuid.uuid4().hex}.json"
63-
with file_path.open("w", encoding="utf-8") as f:
64-
json.dump(payload, f, ensure_ascii=False)
96+
# 1. Save conversation to local storage (offload sync I/O to threadpool)
97+
task_id = uuid.uuid4().hex
98+
file_path = storage_dir / f"conversation-{task_id}.json"
99+
data = json.dumps(body.conversation, ensure_ascii=False)
100+
await asyncio.to_thread(file_path.write_text, data, "utf-8")
65101

66-
result = await service.memorize(resource_url=str(file_path), modality="conversation")
67-
return JSONResponse(content={"status": "success", "result": result})
102+
# 2. Build workflow spec
103+
# Pass the filename only; the worker reconstructs the full path
104+
# from its own STORAGE_PATH, so it works across containers/hosts.
105+
spec = {
106+
"task_id": task_id,
107+
"resource_url": file_path.name,
108+
"user_id": body.user_id,
109+
"agent_id": body.agent_id,
110+
"override_config": body.override_config,
111+
}
112+
113+
# 3. Start Temporal workflow
114+
temporal = await _get_temporal_client(request.app)
115+
workflow_id = f"memorize-{task_id}"
116+
117+
await temporal.start_workflow(
118+
MemorizeWorkflow.run,
119+
spec,
120+
id=workflow_id,
121+
task_queue=TASK_QUEUE,
122+
)
123+
workflow_started = True
124+
125+
logger.info("Memorize workflow started: %s", workflow_id)
126+
127+
result = MemorizeResponse(
128+
task_id=workflow_id,
129+
status="PENDING",
130+
message=f"Memorization task submitted for user {body.user_id}",
131+
)
132+
return JSONResponse(content={"status": "success", "result": result.model_dump()})
133+
except Exception as exc:
134+
# Only clean up the conversation file if the workflow has NOT started,
135+
# because a running workflow still needs its input file.
136+
if not workflow_started and file_path is not None and file_path.exists():
137+
try:
138+
file_path.unlink(missing_ok=True)
139+
except Exception:
140+
logger.warning(
141+
"Failed to clean up conversation file %s during error handling",
142+
file_path,
143+
exc_info=True,
144+
)
145+
logger.exception("Failed to submit memorize task")
146+
raise HTTPException(status_code=500, detail="Failed to submit memorization task") from exc
147+
148+
149+
# Regex for valid memorize workflow IDs: memorize-<32 hex chars>
150+
_MEMORIZE_WORKFLOW_ID_RE = re.compile(r"^memorize-[0-9a-f]{32}$")
151+
152+
153+
@app.get("/memorize/status/{task_id}")
154+
async def get_memorize_status(request: Request, task_id: str):
155+
"""Get the status of a memorization task."""
156+
if not _MEMORIZE_WORKFLOW_ID_RE.match(task_id):
157+
raise HTTPException(
158+
status_code=422,
159+
detail="task_id must match the format 'memorize-<uuid4hex>' (e.g. memorize-abc123def456...)",
160+
)
161+
try:
162+
temporal = await _get_temporal_client(request.app)
163+
handle = temporal.get_workflow_handle(task_id)
164+
165+
describe = await handle.describe()
166+
status = describe.status.name if describe.status else "UNKNOWN"
167+
168+
detail = None
169+
if status == "COMPLETED":
170+
result = await handle.result()
171+
if isinstance(result, dict):
172+
detail = result.get("status", "SUCCESS")
173+
elif result is not None:
174+
detail = str(result)
175+
else:
176+
detail = "SUCCESS"
177+
elif status == "FAILED":
178+
detail = "Task execution failed"
179+
180+
task_status = TaskStatusResponse(
181+
task_id=task_id,
182+
status=status,
183+
detail=detail,
184+
)
185+
return JSONResponse(content={"status": "success", "result": task_status.model_dump()})
186+
except RPCError as exc:
187+
if exc.status == RPCStatusCode.NOT_FOUND:
188+
raise HTTPException(status_code=404, detail=f"Task {task_id} not found") from exc
189+
logger.exception("Temporal RPC error for task %s", task_id)
190+
raise HTTPException(status_code=500, detail="Internal server error") from exc
68191
except Exception as exc:
69-
logger.exception("Memorize request failed")
192+
logger.exception("Failed to get task status for %s", task_id)
70193
raise HTTPException(status_code=500, detail="Internal server error") from exc
71194

72195

73196
@app.post("/retrieve")
74197
async def retrieve(request: Request, payload: dict[str, Any]):
75198
if "query" not in payload:
76199
raise HTTPException(status_code=400, detail="Missing 'query' in request body")
200+
query = payload["query"]
201+
if not isinstance(query, str) or not query.strip():
202+
raise HTTPException(status_code=400, detail="'query' must be a non-empty string")
77203
try:
78204
service = request.app.state.service
79-
result = await service.retrieve([payload["query"]])
205+
result = await service.retrieve([query.strip()])
80206
return JSONResponse(content={"status": "success", "result": result})
81207
except Exception as exc:
82208
logger.exception("Retrieve request failed")

app/schemas/memory.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,47 @@
11
"""Request/response schemas for memory endpoints."""
22

3-
from pydantic import BaseModel, Field, model_validator
3+
from pydantic import BaseModel, Field, field_validator, model_validator
44

5-
# ── Clear ──
5+
6+
# ── Memorize (async) ──
7+
class MemorizeRequest(BaseModel):
8+
"""Request to memorize a conversation."""
9+
10+
conversation: dict | list = Field(..., description="Conversation data to memorize")
11+
user_id: str = Field(..., min_length=1, description="User ID (non-empty)")
12+
agent_id: str = Field(default="", description="Agent ID")
13+
override_config: dict | None = Field(default=None, description="Override MemU config")
14+
15+
@field_validator("user_id", mode="before")
16+
@classmethod
17+
def strip_user_id(cls, v: str) -> str:
18+
"""Strip whitespace and reject blank user_id."""
19+
if isinstance(v, str):
20+
return v.strip()
21+
return v
22+
23+
24+
class MemorizeResponse(BaseModel):
25+
"""Response after submitting an async memorize task."""
26+
27+
task_id: str = Field(..., description="Task ID for tracking (Temporal workflow ID)")
28+
status: str = Field(default="PENDING", description="Initial task status")
29+
message: str = Field(default="Memorization task submitted", description="Response message")
30+
31+
32+
# ── Task Status ──
33+
class TaskStatusResponse(BaseModel):
34+
"""Response for task status query."""
35+
36+
task_id: str = Field(..., description="Task ID")
37+
status: str = Field(
38+
...,
39+
description=(
40+
"Task status from Temporal: RUNNING, COMPLETED, FAILED, UNKNOWN, CANCELED, TERMINATED. "
41+
"PENDING is returned only by the initial POST /memorize response before Temporal picks up the task."
42+
),
43+
)
44+
detail: str | None = Field(default=None, description="Status detail or error message")
645

746

847
class ClearMemoriesRequest(BaseModel):
@@ -11,6 +50,15 @@ class ClearMemoriesRequest(BaseModel):
1150
user_id: str | None = Field(default=None, description="User ID")
1251
agent_id: str | None = Field(default=None, description="Agent ID")
1352

53+
@field_validator("user_id", "agent_id", mode="before")
54+
@classmethod
55+
def strip_whitespace(cls, v: str | None) -> str | None:
56+
"""Strip whitespace; treat blank strings as None."""
57+
if isinstance(v, str):
58+
v = v.strip()
59+
return v if v else None
60+
return v
61+
1462
@model_validator(mode="after")
1563
def check_user_or_agent(self) -> "ClearMemoriesRequest":
1664
if self.user_id is None and self.agent_id is None:
@@ -33,9 +81,17 @@ class ClearMemoriesResponse(BaseModel):
3381
class ListCategoriesRequest(BaseModel):
3482
"""Request to list memory categories."""
3583

36-
user_id: str = Field(..., description="User ID")
84+
user_id: str = Field(..., min_length=1, description="User ID (non-empty)")
3785
agent_id: str | None = Field(default=None, description="Agent ID")
3886

87+
@field_validator("user_id", mode="before")
88+
@classmethod
89+
def strip_user_id(cls, v: str) -> str:
90+
"""Strip whitespace and reject blank user_id."""
91+
if isinstance(v, str):
92+
return v.strip()
93+
return v
94+
3995

4096
class CategoryObject(BaseModel):
4197
"""A single memory category."""

app/workers/memorize_activity.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import logging
55
from datetime import UTC, datetime
6+
from pathlib import Path
67
from typing import Any
78

89
from temporalio import activity
@@ -60,6 +61,19 @@ async def task_memorize(spec: dict) -> dict[str, Any]:
6061
msg = f"override_config must be a dict, got {type(override_config).__name__}"
6162
raise ApplicationError(msg, non_retryable=True)
6263

64+
# Validate and resolve resource_url BEFORE building the service so
65+
# invalid specs fail fast without opening DB connections or other resources.
66+
raw_url = spec["resource_url"]
67+
candidate = Path(raw_url)
68+
# Reject absolute paths, path traversal, and any directory components.
69+
# candidate.name != raw_url catches inputs like "subdir/file.json".
70+
if candidate.is_absolute() or ".." in candidate.parts or candidate.name != raw_url:
71+
raise ApplicationError(
72+
"Invalid resource_url: must be a bare filename without path separators",
73+
non_retryable=True,
74+
)
75+
resource_url = str(Path(settings.STORAGE_PATH).resolve() / candidate.name)
76+
6377
# Build MemoryService with optional config override
6478
if override_config:
6579
service = create_memory_service(
@@ -71,7 +85,7 @@ async def task_memorize(spec: dict) -> dict[str, Any]:
7185

7286
# Execute memorization
7387
result = await service.memorize(
74-
resource_url=spec["resource_url"],
88+
resource_url=resource_url,
7589
modality="conversation",
7690
user={
7791
"user_id": spec["user_id"],

0 commit comments

Comments
 (0)