11"""memU Server - FastAPI application entry point."""
22
3+ import asyncio
34import json
45import logging
6+ import re
57import uuid
68from collections .abc import AsyncIterator
79from contextlib import asynccontextmanager
810from pathlib import Path
9- from typing import Any
11+ from typing import Any , cast
1012
1113from fastapi import FastAPI , HTTPException , Request
1214from fastapi .responses import JSONResponse
15+ from temporalio .client import Client
16+ from temporalio .service import RPCError , RPCStatusCode
1317
1418from app .schemas .memory import (
1519 CategoryObject ,
1620 ClearMemoriesRequest ,
1721 ClearMemoriesResponse ,
1822 ListCategoriesRequest ,
1923 ListCategoriesResponse ,
24+ MemorizeRequest ,
25+ MemorizeResponse ,
26+ TaskStatusResponse ,
2027)
2128from app .services .memu import create_memory_service
29+ from app .workers .memorize_workflow import MemorizeWorkflow
30+ from app .workers .worker import TASK_QUEUE
2231from config .settings import Settings
2332
2433logger = logging .getLogger (__name__ )
3847storage_dir = Path (settings .STORAGE_PATH )
3948
4049
50+ async def _get_temporal_client (app : FastAPI ) -> Client :
51+ """Return the cached Temporal client, connecting lazily on first call."""
52+ # Treat any non-None value as the cached client to support mocking/DI.
53+ client = getattr (app .state , "temporal" , None )
54+ if client is not None :
55+ return cast (Client , client )
56+ # Create the lock lazily on app.state so it's bound to the running event loop
57+ # (module-level asyncio.Lock() can raise RuntimeError in Python 3.13+).
58+ lock : asyncio .Lock = getattr (app .state , "_temporal_lock" , None ) or asyncio .Lock ()
59+ app .state ._temporal_lock = lock
60+ async with lock :
61+ # Double-check after acquiring the lock
62+ client = getattr (app .state , "temporal" , None )
63+ if client is not None :
64+ return cast (Client , client )
65+ client = await Client .connect (
66+ settings .temporal_url ,
67+ namespace = settings .TEMPORAL_NAMESPACE ,
68+ )
69+ app .state .temporal = client
70+ logger .info ("Connected to Temporal at %s" , settings .temporal_url )
71+ return client
72+
73+
4174@asynccontextmanager
4275async def lifespan (_app : FastAPI ) -> AsyncIterator [None ]:
43- """Initialise MemoryService on startup (defers DB connection until the app runs) ."""
76+ """Initialise MemoryService on startup. Temporal connects lazily on first use ."""
4477 try :
4578 storage_dir .mkdir (parents = True , exist_ok = True )
4679 _app .state .service = create_memory_service (settings )
4780 except Exception as exc :
48- # Log full traceback for operators and wrap in a clearer startup error
4981 msg = "Failed to initialize MemoryService during application startup"
5082 logger .exception (msg )
5183 raise RuntimeError (msg ) from exc
@@ -56,27 +88,121 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
5688
5789
5890@app .post ("/memorize" )
59- async def memorize (request : Request , payload : dict [str , Any ]):
91+ async def memorize (request : Request , body : MemorizeRequest ):
92+ """Submit an async memorization task via Temporal workflow."""
93+ file_path : Path | None = None
94+ workflow_started = False
6095 try :
61- service = request .app .state .service
62- file_path = storage_dir / f"conversation-{ uuid .uuid4 ().hex } .json"
63- with file_path .open ("w" , encoding = "utf-8" ) as f :
64- json .dump (payload , f , ensure_ascii = False )
96+ # 1. Save conversation to local storage (offload sync I/O to threadpool)
97+ task_id = uuid .uuid4 ().hex
98+ file_path = storage_dir / f"conversation-{ task_id } .json"
99+ data = json .dumps (body .conversation , ensure_ascii = False )
100+ await asyncio .to_thread (file_path .write_text , data , "utf-8" )
65101
66- result = await service .memorize (resource_url = str (file_path ), modality = "conversation" )
67- return JSONResponse (content = {"status" : "success" , "result" : result })
102+ # 2. Build workflow spec
103+ # Pass the filename only; the worker reconstructs the full path
104+ # from its own STORAGE_PATH, so it works across containers/hosts.
105+ spec = {
106+ "task_id" : task_id ,
107+ "resource_url" : file_path .name ,
108+ "user_id" : body .user_id ,
109+ "agent_id" : body .agent_id ,
110+ "override_config" : body .override_config ,
111+ }
112+
113+ # 3. Start Temporal workflow
114+ temporal = await _get_temporal_client (request .app )
115+ workflow_id = f"memorize-{ task_id } "
116+
117+ await temporal .start_workflow (
118+ MemorizeWorkflow .run ,
119+ spec ,
120+ id = workflow_id ,
121+ task_queue = TASK_QUEUE ,
122+ )
123+ workflow_started = True
124+
125+ logger .info ("Memorize workflow started: %s" , workflow_id )
126+
127+ result = MemorizeResponse (
128+ task_id = workflow_id ,
129+ status = "PENDING" ,
130+ message = f"Memorization task submitted for user { body .user_id } " ,
131+ )
132+ return JSONResponse (content = {"status" : "success" , "result" : result .model_dump ()})
133+ except Exception as exc :
134+ # Only clean up the conversation file if the workflow has NOT started,
135+ # because a running workflow still needs its input file.
136+ if not workflow_started and file_path is not None and file_path .exists ():
137+ try :
138+ file_path .unlink (missing_ok = True )
139+ except Exception :
140+ logger .warning (
141+ "Failed to clean up conversation file %s during error handling" ,
142+ file_path ,
143+ exc_info = True ,
144+ )
145+ logger .exception ("Failed to submit memorize task" )
146+ raise HTTPException (status_code = 500 , detail = "Failed to submit memorization task" ) from exc
147+
148+
149+ # Regex for valid memorize workflow IDs: memorize-<32 hex chars>
150+ _MEMORIZE_WORKFLOW_ID_RE = re .compile (r"^memorize-[0-9a-f]{32}$" )
151+
152+
153+ @app .get ("/memorize/status/{task_id}" )
154+ async def get_memorize_status (request : Request , task_id : str ):
155+ """Get the status of a memorization task."""
156+ if not _MEMORIZE_WORKFLOW_ID_RE .match (task_id ):
157+ raise HTTPException (
158+ status_code = 422 ,
159+ detail = "task_id must match the format 'memorize-<uuid4hex>' (e.g. memorize-abc123def456...)" ,
160+ )
161+ try :
162+ temporal = await _get_temporal_client (request .app )
163+ handle = temporal .get_workflow_handle (task_id )
164+
165+ describe = await handle .describe ()
166+ status = describe .status .name if describe .status else "UNKNOWN"
167+
168+ detail = None
169+ if status == "COMPLETED" :
170+ result = await handle .result ()
171+ if isinstance (result , dict ):
172+ detail = result .get ("status" , "SUCCESS" )
173+ elif result is not None :
174+ detail = str (result )
175+ else :
176+ detail = "SUCCESS"
177+ elif status == "FAILED" :
178+ detail = "Task execution failed"
179+
180+ task_status = TaskStatusResponse (
181+ task_id = task_id ,
182+ status = status ,
183+ detail = detail ,
184+ )
185+ return JSONResponse (content = {"status" : "success" , "result" : task_status .model_dump ()})
186+ except RPCError as exc :
187+ if exc .status == RPCStatusCode .NOT_FOUND :
188+ raise HTTPException (status_code = 404 , detail = f"Task { task_id } not found" ) from exc
189+ logger .exception ("Temporal RPC error for task %s" , task_id )
190+ raise HTTPException (status_code = 500 , detail = "Internal server error" ) from exc
68191 except Exception as exc :
69- logger .exception ("Memorize request failed" )
192+ logger .exception ("Failed to get task status for %s" , task_id )
70193 raise HTTPException (status_code = 500 , detail = "Internal server error" ) from exc
71194
72195
73196@app .post ("/retrieve" )
74197async def retrieve (request : Request , payload : dict [str , Any ]):
75198 if "query" not in payload :
76199 raise HTTPException (status_code = 400 , detail = "Missing 'query' in request body" )
200+ query = payload ["query" ]
201+ if not isinstance (query , str ) or not query .strip ():
202+ raise HTTPException (status_code = 400 , detail = "'query' must be a non-empty string" )
77203 try :
78204 service = request .app .state .service
79- result = await service .retrieve ([payload [ " query" ] ])
205+ result = await service .retrieve ([query . strip () ])
80206 return JSONResponse (content = {"status" : "success" , "result" : result })
81207 except Exception as exc :
82208 logger .exception ("Retrieve request failed" )
0 commit comments