-
Notifications
You must be signed in to change notification settings - Fork 6
Added basic message support for /v1/responses api #82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,14 @@ | ||
| from .mlx_runner import MLXRunner | ||
| from ..cache_utils import get_model_path | ||
| from fastapi import HTTPException | ||
| from ..schemas import ChatMessage, ChatCompletionRequest, downloadRequest, GenerationMetrics | ||
| from ..schemas import ( | ||
| ChatMessage, | ||
| ChatCompletionRequest, | ||
| ResponsesResponse, | ||
| downloadRequest, | ||
| GenerationMetrics, | ||
| ResponsesRequest, | ||
| ) | ||
| from ..hf_downloader import pull_model | ||
|
|
||
| import logging | ||
|
|
@@ -17,6 +24,8 @@ | |
| _model_cache: Dict[str, MLXRunner] = {} | ||
| _default_max_tokens: Optional[int] = None # Use dynamic model-aware limits by default | ||
| _current_model_path: Optional[str] = None | ||
| # Store generated responses for follow-up support (previous_response_id) | ||
| _responses: Dict[str, ResponsesResponse] = {} | ||
|
Comment on lines
+27
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In-memory response cache grows unbounded. The 💾 Consider adding cache size limit+from collections import OrderedDict
+
+_MAX_RESPONSES_CACHE = 1000 # Configurable limit
+
# Store generated responses for follow-up support (previous_response_id)
-_responses: Dict[str, ResponsesResponse] = {}
+_responses: OrderedDict[str, ResponsesResponse] = OrderedDict()
+
+def _cache_response(response_id: str, response: ResponsesResponse) -> None:
+ """Add response to cache with LRU eviction."""
+ _responses[response_id] = response
+ while len(_responses) > _MAX_RESPONSES_CACHE:
+ _responses.popitem(last=False) # Remove oldest🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| def download_model(model_name: str): | ||
|
|
@@ -78,6 +87,7 @@ def get_or_load_model(model_spec: str, verbose: bool = False) -> MLXRunner: | |
|
|
||
| return _model_cache[model_path_str] | ||
|
|
||
|
|
||
| async def generate_chat_stream( | ||
| messages: List[ChatMessage], request: ChatCompletionRequest | ||
| ) -> AsyncGenerator[str, None]: | ||
|
|
@@ -181,6 +191,7 @@ async def generate_chat_stream( | |
| yield f"data: {json.dumps(final_response)}\n\n" | ||
| yield "data: [DONE]\n\n" | ||
|
|
||
|
|
||
| def format_chat_messages_for_runner( | ||
| messages: List[ChatMessage], | ||
| ) -> List[Dict[str, str]]: | ||
|
|
@@ -191,7 +202,292 @@ def format_chat_messages_for_runner( | |
| return [{"role": msg.role, "content": msg.content} for msg in messages] | ||
|
|
||
|
|
||
| def _prepend_previous_response(user_input: str, prev_id: Optional[str]) -> str: | ||
| """If prev_id points to a stored response, prepend its output text as context.""" | ||
| if not prev_id: | ||
| return user_input | ||
| prev = _responses.get(prev_id) | ||
| if not prev or not getattr(prev, "output", None): | ||
| return user_input | ||
| prev_text_parts: List[str] = [] | ||
| for out in prev.output: | ||
| for c in out.get("content", []): | ||
| if c.get("type") == "output_text": | ||
| prev_text_parts.append(c.get("text", "")) | ||
| if prev_text_parts: | ||
| return "\n".join(prev_text_parts) + "\n\n" + user_input | ||
| return user_input | ||
|
|
||
|
|
||
| def _calc_usage(runner: MLXRunner, input_text: str, generated_text: str) -> Dict[str, int]: | ||
| """Calculate token usage using the runner tokenizer; fall back to zeros on error.""" | ||
| try: | ||
| input_tokens = len(runner.tokenizer.encode(input_text)) | ||
| output_tokens = len(runner.tokenizer.encode(generated_text)) | ||
| return {"input_tokens": input_tokens, "output_tokens": output_tokens} | ||
| except Exception: | ||
| return {"input_tokens": 0, "output_tokens": 0} | ||
|
|
||
|
|
||
| def _store_response( | ||
| response_id: str, | ||
| created: int, | ||
| completed_at: Optional[int], | ||
| model: str, | ||
| status: str, | ||
| output: List[Dict[str, Any]], | ||
| usage: Dict[str, int], | ||
| metrics: Optional[Dict[str, Any]] = None, | ||
| error: Optional[Dict[str, Any]] = None, | ||
| ) -> ResponsesResponse: | ||
| """Create a ResponsesResponse, attach metrics to metadata and store it in `_responses`.""" | ||
| resp = ResponsesResponse( | ||
| id=response_id, | ||
| created_at=created, | ||
| completed_at=completed_at, | ||
| model=model, | ||
| status=status, | ||
| object="response", | ||
| error=error, | ||
| output=output, | ||
| usage=usage, | ||
| ) | ||
| if metrics: | ||
| try: | ||
| resp.metadata["metrics"] = metrics | ||
| except Exception: | ||
| pass | ||
| try: | ||
| _responses[response_id] = resp | ||
| except Exception: | ||
| pass | ||
| return resp | ||
|
Comment on lines
+255
to
+264
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Silent exception handling hides potential bugs. The 🔧 Add logging for debugging if metrics:
try:
resp.metadata["metrics"] = metrics
- except Exception:
- pass
+ except Exception as e:
+ logger.warning(f"Failed to attach metrics to response {response_id}: {e}")
try:
_responses[response_id] = resp
- except Exception:
- pass
+ except Exception as e:
+ logger.warning(f"Failed to store response {response_id}: {e}")
return resp🧰 Tools🪛 Ruff (0.14.14)[error] 258-259: (S110) [warning] 258-258: Do not catch blind exception: (BLE001) [error] 262-263: (S110) [warning] 262-262: Do not catch blind exception: (BLE001) 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| def count_tokens(text: str) -> int: | ||
| """Rough token count estimation.""" | ||
| return int(len(text.split()) * 1.3) # Approximation, convert to int | ||
|
|
||
|
|
||
| async def generate_response_chat_stream( | ||
| request: ResponsesRequest | ||
| ) -> AsyncGenerator[str, None]: | ||
| """Generate streaming chat responses for Responses API.""" | ||
|
|
||
| model = request.model or "mlx-community/gpt-oss-20b-MXFP4-Q4" | ||
| user_input = request.input or "" | ||
| response_id = f"resp-{uuid.uuid4()}" | ||
| msg_id = f"msg_{uuid.uuid4()}" | ||
| created = int(time.time()) | ||
| runner = get_or_load_model(model) | ||
| metrics = None | ||
| # If a previous_response_id is provided, prepend its text to the prompt | ||
| prev_id = getattr(request, "previous_response_id", None) | ||
| user_input = _prepend_previous_response(user_input, prev_id) | ||
|
|
||
| # Calculate input tokens once | ||
| input_tokens = len(runner.tokenizer.encode(user_input)) | ||
|
|
||
| # Initial chunk | ||
| initial_chunk = { | ||
| "id": response_id, | ||
| "object": "response.chunk", | ||
| "created_at": created, | ||
| "model": model, | ||
| "status": "in_progress", | ||
| "output": [ | ||
| { | ||
| "type": "message", | ||
| "id": msg_id, | ||
| "status": "in_progress", | ||
| "role": "assistant", | ||
| "content": [], | ||
| } | ||
| ], | ||
| "usage": {"input_tokens": input_tokens, "output_tokens": 0}, | ||
| } | ||
| yield f"data: {json.dumps(initial_chunk)}\n\n" | ||
|
|
||
| # Stream tokens | ||
| accumulated_text = "" | ||
| output_tokens = 0 | ||
| try: | ||
| for token in runner.generate_streaming( | ||
| prompt=user_input, | ||
| max_tokens=runner.get_effective_max_tokens(request.max_output_tokens), | ||
| temperature=request.temperature or 1, | ||
| top_p=request.top_p or 1, | ||
| use_chat_template=True, | ||
| ): | ||
| if isinstance(token, GenerationMetrics): | ||
| metrics = token | ||
| continue | ||
|
|
||
| accumulated_text += token | ||
| output_tokens += 1 # Each yield is one token | ||
|
|
||
| chunk = { | ||
| "id": response_id, | ||
| "object": "response.chunk", | ||
| "created_at": created, | ||
| "model": model, | ||
| "status": "in_progress", | ||
| "output": [ | ||
| { | ||
| "type": "message", | ||
| "id": msg_id, | ||
| "status": "in_progress", | ||
| "role": "assistant", | ||
| "content": [ | ||
| { | ||
| "type": "output_text", | ||
| "text": token, | ||
| "annotations": [], | ||
| } | ||
| ], | ||
| } | ||
| ], | ||
| "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens}, | ||
| } | ||
| yield f"data: {json.dumps(chunk)}\n\n" | ||
|
|
||
| except Exception as e: | ||
| error_chunk = { | ||
| "id": response_id, | ||
| "object": "response.chunk", | ||
| "created_at": created, | ||
| "model": model, | ||
| "status": "failed", | ||
| "error": {"message": str(e), "type": "internal_error"}, | ||
| "output": [], | ||
| "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens}, | ||
| } | ||
| yield f"data: {json.dumps(error_chunk)}\n\n" | ||
| return | ||
|
|
||
| # Final chunk | ||
| completed_at = int(time.time()) | ||
| # Build final chunk with accumulated text and store response for follow-ups | ||
| final_chunk = { | ||
| "id": response_id, | ||
| "object": "response.chunk", | ||
| "created_at": created, | ||
| "completed_at": completed_at, | ||
| "model": model, | ||
| "status": "completed", | ||
| "output": [ | ||
| { | ||
| "type": "message", | ||
| "id": msg_id, | ||
| "status": "completed", | ||
| "role": "assistant", | ||
| "content": [ | ||
| { | ||
| "type": "output_text", | ||
| "text": "", | ||
| "annotations": [], | ||
| } | ||
| ], | ||
| } | ||
| ], | ||
| "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens}, | ||
| } | ||
| # Store and return a typed ResponsesResponse for follow-ups | ||
| metrics_obj = None | ||
| if metrics: | ||
| metrics_obj = { | ||
| "ttft_ms": metrics.ttft_ms, | ||
| "total_tokens": metrics.total_tokens, | ||
| "tokens_per_second": metrics.tokens_per_second, | ||
| "total_latency_s": metrics.total_latency_s, | ||
| } | ||
| final_chunk["metrics"] = metrics_obj | ||
|
|
||
| _store_response( | ||
| response_id=response_id, | ||
| created=created, | ||
| completed_at=completed_at, | ||
| model=model, | ||
| status="completed", | ||
| output=final_chunk["output"], | ||
| usage={"input_tokens": input_tokens, "output_tokens": output_tokens}, | ||
| metrics=metrics_obj, | ||
| ) | ||
| yield f"data: {json.dumps(final_chunk)}\n\n" | ||
| yield "data: [DONE]\n\n" | ||
|
|
||
|
|
||
| async def generate_response_chat(request: ResponsesRequest): | ||
| """Generate chat responses""" | ||
|
|
||
| model = request.model or "mlx-community/gpt-oss-20b-MXFP4-Q4" | ||
| user_input = request.input or "" | ||
| response_id = f"resp-{uuid.uuid4()}" | ||
| msg_id = f"msg_{uuid.uuid4()}" | ||
| created = int(time.time()) | ||
| runner = get_or_load_model(model) | ||
|
|
||
| # If a previous_response_id is provided, prepend its text to the prompt | ||
| prev_id = getattr(request, "previous_response_id", None) | ||
| user_input = _prepend_previous_response(user_input, prev_id) | ||
|
|
||
| metrics_obj = None | ||
| try: | ||
| start_time = time.time() | ||
| generated_text = runner.generate_batch( | ||
| prompt=user_input, | ||
| max_tokens=runner.get_effective_max_tokens(request.max_output_tokens), | ||
| temperature=request.temperature or 1, | ||
| top_p=request.top_p or 1, | ||
| use_chat_template=True, | ||
| ) | ||
|
|
||
| # Metrics for batch generation (approximate) | ||
| generation_time = time.time() - start_time | ||
|
|
||
| completed_at = int(time.time()) | ||
| status = "completed" | ||
| error = None | ||
|
|
||
| # Calculate token usage | ||
| usage = _calc_usage(runner, user_input, generated_text) | ||
| output_tokens = usage.get("output_tokens", 0) | ||
| metrics_obj = { | ||
| "ttft_ms": generation_time * 1000.0, | ||
| "total_tokens": output_tokens, | ||
| "tokens_per_second": (output_tokens / generation_time) if generation_time > 0 else 0.0, | ||
| "total_latency_s": generation_time, | ||
| } | ||
|
|
||
| except Exception as e: | ||
| completed_at = None | ||
| status = "failed" | ||
| error = {"message": str(e), "type": "internal_error"} | ||
| generated_text = "" | ||
| usage = {"input_tokens": 0, "output_tokens": 0} | ||
|
|
||
| output_block = [ | ||
| { | ||
| "type": "message", | ||
| "id": msg_id, | ||
| "status": "completed" if status == "completed" else "failed", | ||
| "role": "assistant", | ||
| "content": [ | ||
| {"type": "output_text", "text": generated_text, "annotations": []} | ||
| ], | ||
| } | ||
| ] if status == "completed" else [] | ||
|
|
||
| resp = _store_response( | ||
| response_id=response_id, | ||
| created=created, | ||
| completed_at=completed_at, | ||
| model=model, | ||
| status=status, | ||
| output=output_block, | ||
| usage=usage, | ||
| metrics=(metrics_obj if status == "completed" else None), | ||
| error=error, | ||
| ) | ||
|
|
||
| return resp | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,3 @@ | ||
| backend = None | ||
| from typing import Any | ||
|
|
||
| backend: Any = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling for
/v1/responsesto avoid broken streams.Unlike
/v1/chat/completions, exceptions here will bubble up and may abruptly terminate streaming connections without a structured error.🔧 Suggested fix (parity with chat completions)
`@app.post`("/v1/responses") async def create_chat_response(request: ResponsesRequest): """ Create a response with openResponse format """ global _messages - if request.stream: - # Streaming response - return StreamingResponse( - runtime.backend.generate_response_chat_stream(request), - media_type="text/plain", - headers={"Cache-Control": "no-cache"}, - ) - else: - return await runtime.backend.generate_response_chat(request) + try: + if request.stream: + # Streaming response + return StreamingResponse( + runtime.backend.generate_response_chat_stream(request), + media_type="text/plain", + headers={"Cache-Control": "no-cache"}, + ) + return await runtime.backend.generate_response_chat(request) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e))🤖 Prompt for AI Agents