@@ -118,6 +118,7 @@ class PerformanceTracker:
118118 def __init__ (self ):
119119 self .prefill_metrics = StepMetrics ()
120120 self .decode_metrics = StepMetrics ()
121+ self .padded_decode_metrics = StepMetrics ()
121122 self ._registered_cleanup = False
122123
123124 def register_cleanup (self ):
@@ -145,10 +146,13 @@ def record_decode(
145146 host_time : Optional [int ] = None ,
146147 device_time : Optional [int ] = None ,
147148 ccl_time : Optional [int ] = None ,
149+ padded_decode : bool = False ,
148150 ):
149151 """Record decode step metrics."""
150- self .decode_metrics .add_measurement (latency , token_count , host_time ,
151- device_time , ccl_time )
152+ metrics = self .padded_decode_metrics if padded_decode \
153+ else self .decode_metrics
154+ metrics .add_measurement (latency , token_count , host_time , device_time ,
155+ ccl_time )
152156
153157 def print_final_stats (self ):
154158 logger .info ("=" * 80 )
@@ -205,4 +209,30 @@ def print_final_stats(self):
205209 else :
206210 logger .info ("DECODE METRICS: No data recorded" )
207211
208- logger .info ("=" * 80 )
212+ logger .info ("-" * 40 )
213+
214+ # Padded decode stats
215+ if self .padded_decode_metrics .get_call_counts () > 0 :
216+ logger .info ("PADDED DECODE METRICS:" )
217+ logger .info (" Total call counts: %d" ,
218+ self .padded_decode_metrics .get_call_counts ())
219+ logger .info (" Total tokens processed: %d" ,
220+ sum (self .padded_decode_metrics .token_counts ))
221+ logger .info (" Average latency: %.2f ms" ,
222+ self .padded_decode_metrics .get_avg_latency ())
223+ logger .info (" Average throughput: %.2f tokens/sec" ,
224+ self .padded_decode_metrics .get_avg_throughput ())
225+ if self .padded_decode_metrics .host_times :
226+ logger .info (" Average host time: %.2f us" ,
227+ self .padded_decode_metrics .get_avg_host_time ())
228+ if self .padded_decode_metrics .device_times :
229+ logger .info (" Average device time: %.2f us" ,
230+ self .padded_decode_metrics .get_avg_device_time ())
231+ if self .padded_decode_metrics .ccl_times :
232+ logger .info (" Average ccl time: %.2f us" ,
233+ self .padded_decode_metrics .get_avg_ccl_time ())
234+
235+ else :
236+ logger .info ("PADDED DECODE METRICS: No data recorded" )
237+
238+ logger .info ("=" * 80 )
0 commit comments