11import json
2+ import time
23import warnings
34from typing import List , Optional , Tuple , Union
45
@@ -26,7 +27,7 @@ class Sglang(lmms):
2627
2728 def __init__ (
2829 self ,
29- model_version : str = "Qwen/Qwen2.5-VL-3B-Instruct" ,
30+ model : str = "Qwen/Qwen2.5-VL-3B-Instruct" ,
3031 tensor_parallel_size : int = 1 ,
3132 gpu_memory_utilization : float = 0.8 ,
3233 batch_size : int = 1 ,
@@ -40,7 +41,7 @@ def __init__(
4041 # Manually set a image token for GPT4V so that we can search for it
4142 # and split the text and image
4243 # Here we just use the same token as llava for convenient
43- self .model_version = model_version
44+ self .model = model
4445 self .max_frame_num = max_frame_num
4546 self .threads = threads
4647 self .chat_template = chat_template
@@ -53,9 +54,9 @@ def __init__(
5354 except json .JSONDecodeError :
5455 eval_logger .warning (f"Failed to parse JSON-like string for argument '{ key } ': { value } " )
5556
56- # Set up vllm client
57- self .client = Engine (model_path = model_version , tp_size = tensor_parallel_size , mem_fraction_static = gpu_memory_utilization , ** kwargs )
58- self .processor = AutoProcessor .from_pretrained (model_version )
57+ # Set up sglang client
58+ self .client = Engine (model_path = model , tp_size = tensor_parallel_size , mem_fraction_static = gpu_memory_utilization , ** kwargs )
59+ self .processor = AutoProcessor .from_pretrained (model )
5960
6061 accelerator = Accelerator ()
6162 if accelerator .num_processes > 1 :
@@ -160,10 +161,46 @@ def generate_until(self, requests) -> List[str]:
160161 tokenize = False ,
161162 add_generation_prompt = True ,
162163 )
164+
165+ start_time = time .time ()
163166 outputs = self .client .generate (texts , params , image_data = image_data )
167+ end_time = time .time ()
164168
165169 response_text = [o ["text" ] for o in outputs ]
166170
171+ # Calculate timing metrics for batch
172+ e2e_latency = end_time - start_time
173+ total_tokens = 0
174+
175+ for idx , output in enumerate (outputs ):
176+ # Get token count from output
177+ if "meta_info" in output and "completion_tokens" in output ["meta_info" ]:
178+ output_tokens = output ["meta_info" ]["completion_tokens" ]
179+ else :
180+ output_tokens = len (output ["text" ].split ())
181+
182+ total_tokens += output_tokens
183+
184+ # Get TTFT if available
185+ if "meta_info" in output and "ttft" in output ["meta_info" ]:
186+ ttft = output ["meta_info" ]["ttft" ]
187+ else :
188+ # Estimate TTFT as a fraction of total time
189+ ttft = e2e_latency * 0.1 / len (outputs )
190+
191+ if output_tokens > 1 :
192+ tpot = (e2e_latency / len (outputs ) - ttft ) / (output_tokens - 1 )
193+ inference_speed = 1 / tpot if tpot > 0 else 0
194+ else :
195+ tpot = e2e_latency / len (outputs )
196+ inference_speed = 0
197+
198+ eval_logger .info (f"Batch { idx } - E2E: { e2e_latency / len (outputs ):.3f} s, TTFT: { ttft :.3f} s, TPOT: { tpot :.3f} s, Speed: { inference_speed :.1f} tokens/s, Output tokens: { output_tokens } " )
199+
200+ if len (outputs ) > 1 :
201+ avg_speed = total_tokens / e2e_latency if e2e_latency > 0 else 0
202+ eval_logger .info (f"Batch summary - Total time: { e2e_latency :.3f} s, Total tokens: { total_tokens } , Avg speed: { avg_speed :.1f} tokens/s" )
203+
167204 assert len (response_text ) == len (batch_requests )
168205 res .extend (response_text )
169206 pbar .update (len (batch_requests ))
0 commit comments