Skip to content

Commit a00a2a0

Browse files
committed
feat: Added aggregation to benchmark, needed for memory models
1 parent 48d1ea6 commit a00a2a0

1 file changed

Lines changed: 25 additions & 3 deletions

File tree

tiles/src/runtime/mlx.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ pub struct BenchmarkMetrics {
3333
total_latency_s: f64,
3434
}
3535

36+
impl BenchmarkMetrics {
37+
fn update(&mut self, metrics: BenchmarkMetrics) -> &Self {
38+
if self.ttft_ms == 0.0 {
39+
self.ttft_ms += metrics.ttft_ms;
40+
}
41+
self.total_tokens += metrics.total_tokens;
42+
self.tokens_per_second += metrics.tokens_per_second;
43+
self.total_latency_s += metrics.total_latency_s;
44+
self
45+
}
46+
}
3647
pub struct MLXRuntime {}
3748

3849
impl MLXRuntime {}
@@ -399,6 +410,12 @@ async fn start_repl(mlx_runtime: &MLXRuntime, modelname: &str, run_args: &RunArg
399410
}
400411
let mut remaining_count = run_args.relay_count;
401412
let mut python_code: String = "".to_owned();
413+
let mut bench_metrics: BenchmarkMetrics = BenchmarkMetrics {
414+
ttft_ms: 0.0,
415+
total_tokens: 0,
416+
tokens_per_second: 0.0,
417+
total_latency_s: 0.0,
418+
};
402419
loop {
403420
if remaining_count > 0 {
404421
let chat_start = remaining_count == run_args.relay_count;
@@ -416,6 +433,9 @@ async fn start_repl(mlx_runtime: &MLXRuntime, modelname: &str, run_args: &RunArg
416433
if !response.code.is_empty() {
417434
python_code = response.code;
418435
}
436+
if let Some(metrics) = response.metrics {
437+
bench_metrics.update(metrics);
438+
}
419439
remaining_count -= 1;
420440
} else {
421441
g_reply = response.reply.clone();
@@ -426,14 +446,16 @@ async fn start_repl(mlx_runtime: &MLXRuntime, modelname: &str, run_args: &RunArg
426446
}
427447
// Display benchmark metrics if available
428448
if let Some(metrics) = response.metrics {
449+
bench_metrics.update(metrics);
429450
println!(
430451
"{}",
431452
format!(
432453
"\n{} {:.1} tok/s | {} tokens | {:.0}ms TTFT",
433454
"💡".yellow(),
434-
metrics.tokens_per_second,
435-
metrics.total_tokens,
436-
metrics.ttft_ms
455+
bench_metrics.total_tokens
456+
/ bench_metrics.total_latency_s as i32,
457+
bench_metrics.total_tokens,
458+
bench_metrics.ttft_ms
437459
)
438460
.dimmed()
439461
);

0 commit comments

Comments
 (0)