2424from openai import AzureOpenAI , OpenAI
2525from PIL import Image
2626
27+ from lmms_eval .models .model_utils .gen_metrics import log_metrics
2728from lmms_eval .models .simple .openai_compatible import (
2829 OpenAICompatible as OpenAICompatibleSimple ,
2930)
@@ -40,6 +41,8 @@ def generate_until(self, requests) -> List[str]:
4041 res = []
4142 pbar = tqdm (total = len (requests ), disable = (self .rank != 0 ), desc = "Model Responding" )
4243
44+ e2e_latency = 0
45+ total_tokens = 0
4346 for ctx , doc_to_messages , gen_kwargs , doc_id , task , split in [reg .args for reg in requests ]:
4447 if self .continual_mode is True and self .cache_mode == "resume" :
4548 doc_uuid = f"{ task } ___{ split } ___{ doc_id } "
@@ -88,29 +91,14 @@ def generate_until(self, requests) -> List[str]:
8891 response_text = response .choices [0 ].message .content
8992
9093 # Calculate timing metrics
91- e2e_latency = end_time - start_time
94+ e2e_latency + = end_time - start_time
9295
9396 # Get token counts from response if available
9497 if hasattr (response , "usage" ):
95- completion_tokens = response .usage .completion_tokens
96- prompt_tokens = response .usage .prompt_tokens
98+ total_tokens += response .usage .completion_tokens
9799 else :
98100 # Approximate token count if not provided
99- completion_tokens = len (response_text .split ())
100- prompt_tokens = len (str (payload ["messages" ]).split ())
101-
102- # Calculate TPOT and inference speed
103- if completion_tokens > 1 :
104- # Assuming TTFT is negligible for API calls, estimate it as a small fraction
105- ttft = e2e_latency * 0.1 # Rough estimate
106- tpot = (e2e_latency - ttft ) / (completion_tokens - 1 )
107- inference_speed = 1 / tpot if tpot > 0 else 0
108- else :
109- tpot = e2e_latency
110- inference_speed = 0
111-
112- # Log throughput metrics
113- eval_logger .info (f"Inference metrics - E2E: { e2e_latency :.3f} s, TPOT: { tpot :.3f} s, Speed: { inference_speed :.1f} tokens/s, Output tokens: { completion_tokens } " )
101+ total_tokens += len (response_text .split ())
114102
115103 break # If successful, break out of the loop
116104
@@ -134,5 +122,15 @@ def generate_until(self, requests) -> List[str]:
134122 with open (self .response_persistent_file , "w" ) as f :
135123 json .dump (self .response_cache , f )
136124
125+ # Calculate average speed
126+ avg_speed = total_tokens / e2e_latency if e2e_latency > 0 else 0
127+ # Log metrics
128+ metric_dict = {
129+ "total_tokens" : total_tokens ,
130+ "e2e_latency" : e2e_latency ,
131+ "avg_speed" : avg_speed ,
132+ }
133+ log_metrics (** metric_dict )
134+
137135 pbar .close ()
138136 return res
0 commit comments