Skip to content

Commit 621a7e0

Browse files
committed
feat: add benchmarking harness with TTFT, throughput, and latency metrics
- Add `tiles bench` command for running benchmarks - Track TTFT, tokens/sec, total tokens, and latency - Display metrics after each REPL response - Save benchmark results to ~/.config/tiles/benchmark_log.jsonl - Add GenerationMetrics dataclass in Python server
1 parent fe14a70 commit 621a7e0

10 files changed

Lines changed: 331 additions & 20 deletions

File tree

Cargo.lock

Lines changed: 97 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

server/backend/mlx.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .mlx_runner import MLXRunner
22
from ..cache_utils import get_model_path
33
from fastapi import HTTPException
4-
from ..schemas import ChatMessage, ChatCompletionRequest, downloadRequest
4+
from ..schemas import ChatMessage, ChatCompletionRequest, downloadRequest, GenerationMetrics
55
from ..hf_downloader import pull_model
66

77
import logging
@@ -113,6 +113,7 @@ async def generate_chat_stream(
113113
yield f"data: {json.dumps(initial_response)}\n\n"
114114

115115
# Stream tokens
116+
metrics = None
116117
try:
117118
for token in runner.generate_streaming(
118119
prompt=prompt,
@@ -125,6 +126,11 @@ async def generate_chat_stream(
125126
use_chat_template=False, # Already applied in _format_conversation
126127
use_chat_stop_tokens=False, # Server mode shouldn't stop on chat markers
127128
):
129+
# Check if this is metrics object (last item yielded)
130+
if isinstance(token, GenerationMetrics):
131+
metrics = token
132+
continue
133+
128134
chunk_response = {
129135
"id": completion_id,
130136
"object": "chat.completion.chunk",
@@ -156,7 +162,7 @@ async def generate_chat_stream(
156162
}
157163
yield f"data: {json.dumps(error_response)}\n\n"
158164

159-
# Final response
165+
# Final response with metrics
160166
final_response = {
161167
"id": completion_id,
162168
"object": "chat.completion.chunk",
@@ -165,6 +171,15 @@ async def generate_chat_stream(
165171
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
166172
}
167173

174+
# Include benchmarking metrics if available
175+
if metrics:
176+
final_response["metrics"] = {
177+
"ttft_ms": metrics.ttft_ms,
178+
"total_tokens": metrics.total_tokens,
179+
"tokens_per_second": metrics.tokens_per_second,
180+
"total_latency_s": metrics.total_latency_s,
181+
}
182+
168183
yield f"data: {json.dumps(final_response)}\n\n"
169184
yield "data: [DONE]\n\n"
170185

server/backend/mlx_runner.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from mlx_lm.sample_utils import make_repetition_penalty, make_sampler
2121

2222
from ..reasoning_utils import ReasoningExtractor, StreamingReasoningParser
23+
from ..schemas import GenerationMetrics
2324

2425

2526
def get_model_context_length(model_path: str) -> int:
@@ -475,6 +476,7 @@ def generate_streaming(
475476
# Track generation metrics
476477
start_time = time.time()
477478
tokens_generated = 0
479+
ttft = None # Time to first token
478480

479481
# Create sampler with our parameters
480482
sampler = make_sampler(temp=temperature, top_p=top_p)
@@ -567,6 +569,19 @@ def generate_streaming(
567569
yield formatted_token
568570
else:
569571
yield new_part_before_stop
572+
573+
# Yield metrics before returning
574+
if reasoning_parser:
575+
yield from reasoning_parser.finalize()
576+
total_latency = time.time() - start_time
577+
tokens_per_second = tokens_generated / total_latency if total_latency > 0 else 0
578+
ttft_ms = (ttft * 1000) if ttft is not None else 0
579+
yield GenerationMetrics(
580+
ttft_ms=ttft_ms,
581+
total_tokens=tokens_generated,
582+
tokens_per_second=tokens_per_second,
583+
total_latency_s=total_latency
584+
)
570585
return # Stop generation without yielding stop token
571586

572587
# Only check chat stop tokens if no native stop token found (fallback)
@@ -597,9 +612,26 @@ def generate_streaming(
597612
yield formatted_token
598613
else:
599614
yield new_part_before_stop
615+
616+
# Yield metrics before returning
617+
if reasoning_parser:
618+
yield from reasoning_parser.finalize()
619+
total_latency = time.time() - start_time
620+
tokens_per_second = tokens_generated / total_latency if total_latency > 0 else 0
621+
ttft_ms = (ttft * 1000) if ttft is not None else 0
622+
yield GenerationMetrics(
623+
ttft_ms=ttft_ms,
624+
total_tokens=tokens_generated,
625+
tokens_per_second=tokens_per_second,
626+
total_latency_s=total_latency
627+
)
600628
return # Stop generation without yielding stop token
601629

602630
# No stop token found, process the new text
631+
# Capture time to first token
632+
if ttft is None:
633+
ttft = time.time() - start_time
634+
603635
if reasoning_parser:
604636
# Process through reasoning parser for formatting
605637
for formatted_token in reasoning_parser.process_token(new_text):
@@ -617,6 +649,18 @@ def generate_streaming(
617649
if reasoning_parser:
618650
yield from reasoning_parser.finalize()
619651

652+
# Yield metrics at the end
653+
total_latency = time.time() - start_time
654+
tokens_per_second = tokens_generated / total_latency if total_latency > 0 else 0
655+
ttft_ms = (ttft * 1000) if ttft is not None else 0
656+
metrics = GenerationMetrics(
657+
ttft_ms=ttft_ms,
658+
total_tokens=tokens_generated,
659+
tokens_per_second=tokens_per_second,
660+
total_latency_s=total_latency
661+
)
662+
yield metrics
663+
620664
# Print generation statistics if verbose
621665
if self.verbose:
622666
generation_time = time.time() - start_time

server/schemas.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pydantic import BaseModel, Field
22
from typing import Any, Dict, List, Optional, Union
3+
from dataclasses import dataclass
34

45
class CompletionRequest(BaseModel):
56
model: str
@@ -63,3 +64,12 @@ class StartRequest(BaseModel):
6364

6465
class downloadRequest(BaseModel):
6566
model: str
67+
68+
69+
@dataclass
70+
class GenerationMetrics:
71+
"""Benchmarking metrics for token generation."""
72+
ttft_ms: float # Time to first token in milliseconds
73+
total_tokens: int # Total tokens generated
74+
tokens_per_second: float # Throughput
75+
total_latency_s: float # End-to-end latency in seconds

tiles/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ tokio = { version = "1" , features = ["macros", "rt-multi-thread"]}
1414
owo-colors = "4"
1515
futures-util = "0.3"
1616
hf-hub = {version = "0.4", features = ["tokio"]}
17+
chrono = "0.4"

tiles/src/commands/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,7 @@ pub async fn start_server(runtime: &Runtime) {
1818
pub async fn stop_server(runtime: &Runtime) {
1919
let _ = runtime.stop_server_daemon().await;
2020
}
21+
22+
pub async fn bench(runtime: &Runtime, run_args: RunArgs) {
23+
runtime.bench(run_args).await;
24+
}

tiles/src/main.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ enum Commands {
2222
flags: RunFlags,
2323
},
2424

25+
/// Runs a benchmark and saves results to log file
26+
Bench {
27+
/// Path to the Modelfile (uses default model if not provided)
28+
modelfile_path: Option<String>,
29+
},
30+
2531
/// Checks the status of dependencies
2632
Health,
2733

@@ -70,6 +76,13 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
7076
};
7177
commands::run(&runtime, run_args).await;
7278
}
79+
Commands::Bench { modelfile_path } => {
80+
let run_args = RunArgs {
81+
modelfile_path,
82+
relay_count: 0, // unused by bench
83+
};
84+
commands::bench(&runtime, run_args).await;
85+
}
7386
Commands::Health => {
7487
commands::check_health();
7588
}

tiles/src/runtime/cpu.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,8 @@ impl CPURuntime {
2323
pub async fn stop_server_daemon(&self) -> Result<()> {
2424
unimplemented!()
2525
}
26+
27+
pub async fn bench(&self, _run_args: super::RunArgs) {
28+
unimplemented!()
29+
}
2630
}

0 commit comments

Comments
 (0)