Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion server/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import FastAPI, HTTPException

from .schemas import ChatMessage, ChatCompletionRequest, StartRequest, downloadRequest
from .schemas import ChatMessage, ChatCompletionRequest, StartRequest, downloadRequest, ResponseRequest
from .config import SYSTEM_PROMPT
import logging
import sys
Expand Down Expand Up @@ -79,3 +79,14 @@ async def create_chat_completion(request: ChatCompletionRequest):
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@app.post("/v1/responses")
async def create_response(request: ResponseRequest):
"""Create a non-streaming completion response."""
try:
response = await runtime.backend.generate_response(request)
return response
except Exception as e:
logger.exception("Error in generate_response")
raise HTTPException(status_code=500, detail=str(e)) from e
86 changes: 85 additions & 1 deletion server/backend/mlx.py → server/backend/mlx_backend.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .mlx_runner import MLXRunner
from ..cache_utils import get_model_path
from fastapi import HTTPException
from ..schemas import ChatMessage, ChatCompletionRequest, downloadRequest
from ..schemas import ChatMessage, ChatCompletionRequest, downloadRequest, ResponseRequest
from ..hf_downloader import pull_model

import logging
import asyncio
import json
import time
import uuid
Expand Down Expand Up @@ -114,6 +115,15 @@ async def generate_chat_stream(

# Stream tokens
try:
json_schema = None
if request.response_format:
if request.response_format.get("type") == "json_schema":
schema_info = request.response_format.get("json_schema", {})
json_schema = json.dumps(schema_info.get("schema", {}))
elif request.response_format.get("type") == "json_object":
# Fallback for json_object type
json_schema = "{}"

for token in runner.generate_streaming(
prompt=prompt,
max_tokens=runner.get_effective_max_tokens(
Expand All @@ -124,6 +134,7 @@ async def generate_chat_stream(
repetition_penalty=request.repetition_penalty,
use_chat_template=False, # Already applied in _format_conversation
use_chat_stop_tokens=False, # Server mode shouldn't stop on chat markers
json_schema=json_schema,
):
chunk_response = {
"id": completion_id,
Expand Down Expand Up @@ -168,6 +179,79 @@ async def generate_chat_stream(
yield f"data: {json.dumps(final_response)}\n\n"
yield "data: [DONE]\n\n"


async def generate_response(request: ResponseRequest) -> Dict[str, Any]:
"""Generate complete non-streaming chat completion response."""
Comment thread
coderabbitai[bot] marked this conversation as resolved.
completion_id = f"chatcmpl-{uuid.uuid4()}"
created = int(time.time())
runner = get_or_load_model(request.model)

# Convert messages to dict format for runner
message_dicts = format_chat_messages_for_runner(request.messages)

# Let the runner format with chat templates
prompt = runner._format_conversation(message_dicts, use_chat_template=True)

json_schema = None
if request.response_format:
if request.response_format.get("type") == "json_schema":
schema_info = request.response_format.get("json_schema", {})
json_schema = json.dumps(schema_info.get("schema", {}))
elif request.response_format.get("type") == "json_object":
# Fallback for json_object type
json_schema = "{}"

response_text = await asyncio.to_thread(
runner.generate_batch,
prompt=prompt,
max_tokens=runner.get_effective_max_tokens(
request.max_tokens or _default_max_tokens, interactive=False
),
temperature=request.temperature,
top_p=request.top_p,
repetition_penalty=request.repetition_penalty,
use_chat_template=False,
json_schema=json_schema,
)

# Handle stop sequences if provided
if request.stop:
stop_sequences = (
request.stop if isinstance(request.stop, list) else [request.stop]
)
min_index = len(response_text)
found_stop = False
for stop in stop_sequences:
index = response_text.find(stop)
if index != -1:
min_index = min(min_index, index)
found_stop = True

if found_stop:
response_text = response_text[:min_index]

prompt_tokens = count_tokens(prompt)
completion_tokens = count_tokens(response_text)

return {
"id": completion_id,
"object": "chat.completion",
"created": created,
"model": request.model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": response_text},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}

def format_chat_messages_for_runner(
messages: List[ChatMessage],
) -> List[Dict[str, str]]:
Expand Down
52 changes: 52 additions & 0 deletions server/backend/mlx_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
`mlx_engine` is LM Studio's LLM inferencing engine for Apple MLX
"""

__all__ = [
"load_model",
"load_draft_model",
"is_draft_model_compatible",
"unload_draft_model",
"create_generator",
"tokenize",
]

from pathlib import Path
import os

from .utils.disable_hf_download import patch_huggingface_hub
from .utils.register_models import register_models
from .utils.logger import setup_logging


from .generate import (
load_model,
load_draft_model,
is_draft_model_compatible,
unload_draft_model,
create_generator,
tokenize,
)

patch_huggingface_hub()
register_models()
setup_logging()


def _set_outlines_cache_dir(cache_dir: Path | str):
"""
Set the cache dir for Outlines.

Outlines reads the OUTLINES_CACHE_DIR environment variable to
determine where to read/write its cache files
"""
if "OUTLINES_CACHE_DIR" in os.environ:
return

cache_dir = Path(cache_dir).expanduser().resolve()
os.environ["OUTLINES_CACHE_DIR"] = str(cache_dir)


_set_outlines_cache_dir(
os.getenv("TILES_OUTLINES_CACHE", "~/.cache/tiles/.internal/outlines")
)
Loading