Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion server/backend/mlx.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .mlx_runner import MLXRunner
from ..cache_utils import get_model_path
from fastapi import HTTPException
from ..schemas import ChatMessage, ChatCompletionRequest, downloadRequest
from ..schemas import ChatMessage, ChatCompletionRequest, downloadRequest, GenerationMetrics
from ..hf_downloader import pull_model

import logging
Expand Down Expand Up @@ -113,6 +113,7 @@ async def generate_chat_stream(
yield f"data: {json.dumps(initial_response)}\n\n"

# Stream tokens
metrics = None
try:
for token in runner.generate_streaming(
prompt=prompt,
Expand All @@ -125,6 +126,10 @@ async def generate_chat_stream(
use_chat_template=False, # Already applied in _format_conversation
use_chat_stop_tokens=False, # Server mode shouldn't stop on chat markers
):
if isinstance(token, GenerationMetrics):
metrics = token
continue

chunk_response = {
"id": completion_id,
"object": "chat.completion.chunk",
Expand Down Expand Up @@ -165,6 +170,14 @@ async def generate_chat_stream(
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}

# Include benchmarking metrics if available
if metrics:
final_response["metrics"] = {
"ttft_ms": metrics.ttft_ms,
"total_tokens": metrics.total_tokens,
"tokens_per_second": metrics.tokens_per_second,
"total_latency_s": metrics.total_latency_s,
}
yield f"data: {json.dumps(final_response)}\n\n"
yield "data: [DONE]\n\n"

Expand Down
41 changes: 39 additions & 2 deletions server/backend/mlx_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from mlx_lm import load
from mlx_lm.generate import generate_step
from mlx_lm.sample_utils import make_repetition_penalty, make_sampler

from ..schemas import GenerationMetrics
from ..reasoning_utils import ReasoningExtractor, StreamingReasoningParser


Expand Down Expand Up @@ -474,7 +474,7 @@ def generate_streaming(
# Track generation metrics
start_time = time.time()
tokens_generated = 0

ttft = None
# Create sampler with our parameters
sampler = make_sampler(temp=temperature, top_p=top_p)

Expand Down Expand Up @@ -566,6 +566,17 @@ def generate_streaming(
yield formatted_token
else:
yield new_part_before_stop
if reasoning_parser:
yield from reasoning_parser.finalize()
total_latency = time.time() - start_time
tokens_per_second = tokens_generated / total_latency if total_latency > 0 else 0
ttft_ms = (ttft * 1000) if ttft is not None else 0
yield GenerationMetrics(
ttft_ms=ttft_ms,
total_tokens=tokens_generated,
tokens_per_second=tokens_per_second,
total_latency_s=total_latency
)
return # Stop generation without yielding stop token

# Only check chat stop tokens if no native stop token found (fallback)
Expand Down Expand Up @@ -596,8 +607,22 @@ def generate_streaming(
yield formatted_token
else:
yield new_part_before_stop
if reasoning_parser:
yield from reasoning_parser.finalize()
total_latency = time.time() - start_time
tokens_per_second = tokens_generated / total_latency if total_latency > 0 else 0
ttft_ms = (ttft * 1000) if ttft is not None else 0
yield GenerationMetrics(
ttft_ms=ttft_ms,
total_tokens=tokens_generated,
tokens_per_second=tokens_per_second,
total_latency_s=total_latency
)
return # Stop generation without yielding stop token

if ttft is None:
ttft = time.time() - start_time

# No stop token found, process the new text
if reasoning_parser:
# Process through reasoning parser for formatting
Expand All @@ -616,6 +641,18 @@ def generate_streaming(
if reasoning_parser:
yield from reasoning_parser.finalize()

# Yield metrics at the end
total_latency = time.time() - start_time
tokens_per_second = tokens_generated / total_latency if total_latency > 0 else 0
ttft_ms = (ttft * 1000) if ttft is not None else 0
metrics = GenerationMetrics(
ttft_ms=ttft_ms,
total_tokens=tokens_generated,
tokens_per_second=tokens_per_second,
total_latency_s=total_latency
)
yield metrics

# Print generation statistics if verbose
if self.verbose:
generation_time = time.time() - start_time
Expand Down
9 changes: 9 additions & 0 deletions server/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pydantic import BaseModel, Field
from typing import Any, Dict, List, Optional, Union
from dataclasses import dataclass

class CompletionRequest(BaseModel):
model: str
Expand Down Expand Up @@ -63,3 +64,11 @@ class StartRequest(BaseModel):

class downloadRequest(BaseModel):
model: str

@dataclass
class GenerationMetrics:
"""Benchmarking metrics for token generation."""
ttft_ms: float # Time to first token in milliseconds
total_tokens: int # Total tokens generated
tokens_per_second: float # Throughput
total_latency_s: float # End-to-end latency in seconds
67 changes: 65 additions & 2 deletions tiles/src/runtime/mlx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use rustyline::hint::Hinter;
use rustyline::history::DefaultHistory;
use rustyline::validate::Validator;
use rustyline::{Config, Editor, Helper};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::fs;
use std::fs::File;
Expand All @@ -23,13 +24,34 @@ use std::time::Duration;
use std::{io, process::Command};
use tilekit::modelfile::Modelfile;
use tokio::time::sleep;

#[derive(Debug, Deserialize, Serialize)]
pub struct BenchmarkMetrics {
ttft_ms: f64,
total_tokens: i32,
tokens_per_second: f64,
total_latency_s: f64,
}

impl BenchmarkMetrics {
fn update(&mut self, metrics: BenchmarkMetrics) -> &Self {
if self.ttft_ms == 0.0 {
self.ttft_ms += metrics.ttft_ms;
}
self.total_tokens += metrics.total_tokens;
self.tokens_per_second += metrics.tokens_per_second;
self.total_latency_s += metrics.total_latency_s;
self
}
}
pub struct MLXRuntime {}

impl MLXRuntime {}
pub struct ChatResponse {
// think: String,
reply: String,
code: String,
metrics: Option<BenchmarkMetrics>,
}

impl Default for MLXRuntime {
Expand Down Expand Up @@ -388,6 +410,12 @@ async fn start_repl(mlx_runtime: &MLXRuntime, modelname: &str, run_args: &RunArg
}
let mut remaining_count = run_args.relay_count;
let mut python_code: String = "".to_owned();
let mut bench_metrics: BenchmarkMetrics = BenchmarkMetrics {
ttft_ms: 0.0,
total_tokens: 0,
tokens_per_second: 0.0,
total_latency_s: 0.0,
};
loop {
if remaining_count > 0 {
let chat_start = remaining_count == run_args.relay_count;
Expand All @@ -405,6 +433,9 @@ async fn start_repl(mlx_runtime: &MLXRuntime, modelname: &str, run_args: &RunArg
if !response.code.is_empty() {
python_code = response.code;
}
if let Some(metrics) = response.metrics {
bench_metrics.update(metrics);
}
remaining_count -= 1;
} else {
g_reply = response.reply.clone();
Expand All @@ -413,6 +444,23 @@ async fn start_repl(mlx_runtime: &MLXRuntime, modelname: &str, run_args: &RunArg
} else {
println!("\n");
}
// Display benchmark metrics if available
if let Some(metrics) = response.metrics {
bench_metrics.update(metrics);
println!(
"{}",
format!(
"\n{} {:.1} tok/s | {} tokens | {:.0}ms TTFT",
"💡".yellow(),
bench_metrics.total_tokens as f64
/ bench_metrics.total_latency_s,
bench_metrics.total_tokens,
bench_metrics.ttft_ms
)
.dimmed()
);
}

break;
}
} else {
Expand Down Expand Up @@ -505,6 +553,7 @@ async fn chat(
let mut stream = res.bytes_stream();
let mut accumulated = String::new();
println!();
let mut metrics: Option<BenchmarkMetrics> = None;
let mut is_answer_start = false;
while let Some(chunk) = stream.next().await {
let chunk = chunk.unwrap();
Expand All @@ -517,10 +566,19 @@ async fn chat(
let data = line.trim_start_matches("data: ");

if data == "[DONE]" {
return Ok(convert_to_chat_response(&accumulated, run_args.memory));
return Ok(convert_to_chat_response(
&accumulated,
run_args.memory,
metrics,
));
}

// Parse JSON
let v: Value = serde_json::from_str(data).unwrap();
// Check for metrics in the response
if let Some(metrics_obj) = v.get("metrics") {
metrics = serde_json::from_value(metrics_obj.clone()).ok();
}
if let Some(delta) = v["choices"][0]["delta"]["content"].as_str() {
accumulated.push_str(delta);
if !run_args.memory && delta.contains("**[Answer]**") {
Expand All @@ -539,10 +597,15 @@ async fn chat(
Err(String::from("request failed"))
}

fn convert_to_chat_response(content: &str, memory_mode: bool) -> ChatResponse {
fn convert_to_chat_response(
content: &str,
memory_mode: bool,
metrics: Option<BenchmarkMetrics>,
) -> ChatResponse {
ChatResponse {
reply: extract_reply(content, memory_mode),
code: extract_python(content),
metrics,
}
}

Expand Down