Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 17 additions & 2 deletions server/backend/mlx.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .mlx_runner import MLXRunner
from ..cache_utils import get_model_path
from fastapi import HTTPException
from ..schemas import ChatMessage, ChatCompletionRequest, downloadRequest
from ..schemas import ChatMessage, ChatCompletionRequest, downloadRequest, GenerationMetrics
from ..hf_downloader import pull_model

import logging
Expand Down Expand Up @@ -113,6 +113,7 @@ async def generate_chat_stream(
yield f"data: {json.dumps(initial_response)}\n\n"

# Stream tokens
metrics = None
try:
for token in runner.generate_streaming(
prompt=prompt,
Expand All @@ -125,6 +126,11 @@ async def generate_chat_stream(
use_chat_template=False, # Already applied in _format_conversation
use_chat_stop_tokens=False, # Server mode shouldn't stop on chat markers
):
# Check if this is metrics object (last item yielded)
if isinstance(token, GenerationMetrics):
metrics = token
continue

chunk_response = {
"id": completion_id,
"object": "chat.completion.chunk",
Expand Down Expand Up @@ -156,7 +162,7 @@ async def generate_chat_stream(
}
yield f"data: {json.dumps(error_response)}\n\n"

# Final response
# Final response with metrics
final_response = {
"id": completion_id,
"object": "chat.completion.chunk",
Expand All @@ -165,6 +171,15 @@ async def generate_chat_stream(
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}

# Include benchmarking metrics if available
if metrics:
final_response["metrics"] = {
"ttft_ms": metrics.ttft_ms,
"total_tokens": metrics.total_tokens,
"tokens_per_second": metrics.tokens_per_second,
"total_latency_s": metrics.total_latency_s,
}

yield f"data: {json.dumps(final_response)}\n\n"
yield "data: [DONE]\n\n"

Expand Down
44 changes: 44 additions & 0 deletions server/backend/mlx_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from mlx_lm.sample_utils import make_repetition_penalty, make_sampler

from ..reasoning_utils import ReasoningExtractor, StreamingReasoningParser
from ..schemas import GenerationMetrics


def get_model_context_length(model_path: str) -> int:
Expand Down Expand Up @@ -475,6 +476,7 @@ def generate_streaming(
# Track generation metrics
start_time = time.time()
tokens_generated = 0
ttft = None # Time to first token

# Create sampler with our parameters
sampler = make_sampler(temp=temperature, top_p=top_p)
Expand Down Expand Up @@ -567,6 +569,19 @@ def generate_streaming(
yield formatted_token
else:
yield new_part_before_stop

# Yield metrics before returning
if reasoning_parser:
yield from reasoning_parser.finalize()
total_latency = time.time() - start_time
tokens_per_second = tokens_generated / total_latency if total_latency > 0 else 0
ttft_ms = (ttft * 1000) if ttft is not None else 0
yield GenerationMetrics(
ttft_ms=ttft_ms,
total_tokens=tokens_generated,
tokens_per_second=tokens_per_second,
total_latency_s=total_latency
)
Comment thread
madclaws marked this conversation as resolved.
return # Stop generation without yielding stop token

# Only check chat stop tokens if no native stop token found (fallback)
Expand Down Expand Up @@ -597,9 +612,26 @@ def generate_streaming(
yield formatted_token
else:
yield new_part_before_stop

# Yield metrics before returning
if reasoning_parser:
yield from reasoning_parser.finalize()
Comment on lines +617 to +618
Copy link
Copy Markdown
Member

@madclaws madclaws Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this check?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The StreamingReasoningParser buffers and formats this content as it streams. When generation stops (either naturally or via stop token), we need to:

  1. Flush any buffered content - The parser may have partial reasoning text that hasn't been yielded yet
  2. Finalize formatting - Add closing markers or format the final output properly, If we skip finalize() when a stop token triggers early return:
    Buffered reasoning content would be lost
    The response might be incomplete or malformed
    User would see truncated output
    The condition just ensures we only call it when the model is a reasoning model (the parser is only created for those models).

total_latency = time.time() - start_time
tokens_per_second = tokens_generated / total_latency if total_latency > 0 else 0
ttft_ms = (ttft * 1000) if ttft is not None else 0
yield GenerationMetrics(
ttft_ms=ttft_ms,
total_tokens=tokens_generated,
tokens_per_second=tokens_per_second,
total_latency_s=total_latency
)
Comment on lines +615 to +627
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Same token-count issue for chat stop tokens.

This branch has the same token-counting issue as the native stop token branch (lines 572-584). See the earlier comment for details and fix.

return # Stop generation without yielding stop token

# No stop token found, process the new text
# Capture time to first token
if ttft is None:
ttft = time.time() - start_time
Comment thread
madclaws marked this conversation as resolved.

if reasoning_parser:
# Process through reasoning parser for formatting
for formatted_token in reasoning_parser.process_token(new_text):
Expand All @@ -617,6 +649,18 @@ def generate_streaming(
if reasoning_parser:
yield from reasoning_parser.finalize()

# Yield metrics at the end
total_latency = time.time() - start_time
tokens_per_second = tokens_generated / total_latency if total_latency > 0 else 0
ttft_ms = (ttft * 1000) if ttft is not None else 0
metrics = GenerationMetrics(
ttft_ms=ttft_ms,
total_tokens=tokens_generated,
tokens_per_second=tokens_per_second,
total_latency_s=total_latency
)
yield metrics

# Print generation statistics if verbose
if self.verbose:
generation_time = time.time() - start_time
Expand Down
10 changes: 10 additions & 0 deletions server/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pydantic import BaseModel, Field
from typing import Any, Dict, List, Optional, Union
from dataclasses import dataclass

class CompletionRequest(BaseModel):
model: str
Expand Down Expand Up @@ -63,3 +64,12 @@ class StartRequest(BaseModel):

class downloadRequest(BaseModel):
model: str


@dataclass
class GenerationMetrics:
"""Benchmarking metrics for token generation."""
ttft_ms: float # Time to first token in milliseconds
total_tokens: int # Total tokens generated
tokens_per_second: float # Throughput
total_latency_s: float # End-to-end latency in seconds
1 change: 1 addition & 0 deletions tiles/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ tokio = { version = "1" , features = ["macros", "rt-multi-thread"]}
owo-colors = "4"
futures-util = "0.3"
hf-hub = {version = "0.4", features = ["tokio"]}
chrono = "0.4"
4 changes: 4 additions & 0 deletions tiles/src/commands/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,7 @@ pub async fn start_server(runtime: &Runtime) {
pub async fn stop_server(runtime: &Runtime) {
let _ = runtime.stop_server_daemon().await;
}

pub async fn bench(runtime: &Runtime, run_args: RunArgs) {
runtime.bench(run_args).await;
}
13 changes: 13 additions & 0 deletions tiles/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ enum Commands {
flags: RunFlags,
},

/// Runs a benchmark and saves results to log file
Bench {
/// Path to the Modelfile (uses default model if not provided)
modelfile_path: Option<String>,
},

/// Checks the status of dependencies
Health,

Expand Down Expand Up @@ -70,6 +76,13 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
};
commands::run(&runtime, run_args).await;
}
Commands::Bench { modelfile_path } => {
let run_args = RunArgs {
modelfile_path,
relay_count: 0, // unused by bench
};
commands::bench(&runtime, run_args).await;
}
Commands::Health => {
commands::check_health();
}
Expand Down
4 changes: 4 additions & 0 deletions tiles/src/runtime/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,8 @@ impl CPURuntime {
pub async fn stop_server_daemon(&self) -> Result<()> {
unimplemented!()
}

pub async fn bench(&self, _run_args: super::RunArgs) {
unimplemented!()
}
}
Loading