2020from mlx_lm .sample_utils import make_repetition_penalty , make_sampler
2121
2222from ..reasoning_utils import ReasoningExtractor , StreamingReasoningParser
23+ from ..schemas import GenerationMetrics
2324
2425
2526def get_model_context_length (model_path : str ) -> int :
@@ -475,6 +476,7 @@ def generate_streaming(
475476 # Track generation metrics
476477 start_time = time .time ()
477478 tokens_generated = 0
479+ ttft = None # Time to first token
478480
479481 # Create sampler with our parameters
480482 sampler = make_sampler (temp = temperature , top_p = top_p )
@@ -567,6 +569,19 @@ def generate_streaming(
567569 yield formatted_token
568570 else :
569571 yield new_part_before_stop
572+
573+ # Yield metrics before returning
574+ if reasoning_parser :
575+ yield from reasoning_parser .finalize ()
576+ total_latency = time .time () - start_time
577+ tokens_per_second = tokens_generated / total_latency if total_latency > 0 else 0
578+ ttft_ms = (ttft * 1000 ) if ttft is not None else 0
579+ yield GenerationMetrics (
580+ ttft_ms = ttft_ms ,
581+ total_tokens = tokens_generated ,
582+ tokens_per_second = tokens_per_second ,
583+ total_latency_s = total_latency
584+ )
570585 return # Stop generation without yielding stop token
571586
572587 # Only check chat stop tokens if no native stop token found (fallback)
@@ -597,9 +612,26 @@ def generate_streaming(
597612 yield formatted_token
598613 else :
599614 yield new_part_before_stop
615+
616+ # Yield metrics before returning
617+ if reasoning_parser :
618+ yield from reasoning_parser .finalize ()
619+ total_latency = time .time () - start_time
620+ tokens_per_second = tokens_generated / total_latency if total_latency > 0 else 0
621+ ttft_ms = (ttft * 1000 ) if ttft is not None else 0
622+ yield GenerationMetrics (
623+ ttft_ms = ttft_ms ,
624+ total_tokens = tokens_generated ,
625+ tokens_per_second = tokens_per_second ,
626+ total_latency_s = total_latency
627+ )
600628 return # Stop generation without yielding stop token
601629
602630 # No stop token found, process the new text
631+ # Capture time to first token
632+ if ttft is None :
633+ ttft = time .time () - start_time
634+
603635 if reasoning_parser :
604636 # Process through reasoning parser for formatting
605637 for formatted_token in reasoning_parser .process_token (new_text ):
@@ -617,6 +649,18 @@ def generate_streaming(
617649 if reasoning_parser :
618650 yield from reasoning_parser .finalize ()
619651
652+ # Yield metrics at the end
653+ total_latency = time .time () - start_time
654+ tokens_per_second = tokens_generated / total_latency if total_latency > 0 else 0
655+ ttft_ms = (ttft * 1000 ) if ttft is not None else 0
656+ metrics = GenerationMetrics (
657+ ttft_ms = ttft_ms ,
658+ total_tokens = tokens_generated ,
659+ tokens_per_second = tokens_per_second ,
660+ total_latency_s = total_latency
661+ )
662+ yield metrics
663+
620664 # Print generation statistics if verbose
621665 if self .verbose :
622666 generation_time = time .time () - start_time
0 commit comments