@@ -44,6 +44,7 @@ def __init__(
4444 max_frames_num : int = 10 ,
4545 httpx_trust_env : bool = True ,
4646 batch_size : int = 64 ,
47+ max_concurrent_requests : int = 32 ,
4748 ** kwargs ,
4849 ) -> None :
4950 """
@@ -57,16 +58,22 @@ def __init__(
5758 self .model_version = model_version
5859 self .timeout = timeout
5960 self .max_retries = max_retries
60- self .max_size_in_mb = max_size_in_mb # some models have a limit on the size of the image
61+ self .max_size_in_mb = (
62+ max_size_in_mb # some models have a limit on the size of the image
63+ )
6164 self .continual_mode = continual_mode
6265 self .max_frames_num = max_frames_num
6366 if self .continual_mode :
6467 if response_persistent_folder is None :
65- raise ValueError ("Continual mode requires a persistent path for the response. Please provide a valid path." )
68+ raise ValueError (
69+ "Continual mode requires a persistent path for the response. Please provide a valid path."
70+ )
6671
6772 os .makedirs (response_persistent_folder , exist_ok = True )
6873 self .response_persistent_folder = response_persistent_folder
69- self .response_persistent_file = os .path .join (self .response_persistent_folder , f"{ self .model_version } _response.json" )
74+ self .response_persistent_file = os .path .join (
75+ self .response_persistent_folder , f"{ self .model_version } _response.json"
76+ )
7077
7178 if os .path .exists (self .response_persistent_file ):
7279 with open (self .response_persistent_file , "r" ) as f :
@@ -81,7 +88,11 @@ def __init__(
8188 # settings. openai-python uses a httpx.Client with trust_env set to True. Such a
8289 # httpx.Client uses macOS proxy server settings. Adding httpx_trust_env option
8390 # allows httpx to ignore proxy server settings set by VPN clients.
84- http_client = DefaultHttpxClient (trust_env = httpx_trust_env ) if not httpx_trust_env else None
91+ http_client = (
92+ DefaultHttpxClient (trust_env = httpx_trust_env )
93+ if not httpx_trust_env
94+ else None
95+ )
8596
8697 # Use provided parameters or fall back to environment variables
8798 api_key = api_key or os .getenv ("OPENAI_API_KEY" )
@@ -98,16 +109,27 @@ def __init__(
98109 self .client = (
99110 OpenAI (api_key = api_key , base_url = base_url , http_client = http_client )
100111 if not azure_openai
101- else AzureOpenAI (api_key = os .getenv ("AZURE_OPENAI_API_KEY" ), azure_endpoint = os .getenv ("AZURE_OPENAI_API_BASE" ), api_version = os .getenv ("AZURE_OPENAI_API_VERSION" ), http_client = http_client )
112+ else AzureOpenAI (
113+ api_key = os .getenv ("AZURE_OPENAI_API_KEY" ),
114+ azure_endpoint = os .getenv ("AZURE_OPENAI_API_BASE" ),
115+ api_version = os .getenv ("AZURE_OPENAI_API_VERSION" ),
116+ http_client = http_client ,
117+ )
102118 )
103119
104120 accelerator = Accelerator ()
105121 # assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue."
106122 if accelerator .num_processes > 1 :
107- assert accelerator .distributed_type in [DistributedType .FSDP , DistributedType .MULTI_GPU , DistributedType .DEEPSPEED ], "Unsupported distributed type provided. Only DDP and FSDP are supported."
123+ assert accelerator .distributed_type in [
124+ DistributedType .FSDP ,
125+ DistributedType .MULTI_GPU ,
126+ DistributedType .DEEPSPEED ,
127+ ], "Unsupported distributed type provided. Only DDP and FSDP are supported."
108128 self .accelerator = accelerator
109129 if self .accelerator .is_local_main_process :
110- eval_logger .info (f"Using { accelerator .num_processes } devices with data parallelism" )
130+ eval_logger .info (
131+ f"Using { accelerator .num_processes } devices with data parallelism"
132+ )
111133 self ._rank = self .accelerator .local_process_index
112134 self ._world_size = self .accelerator .num_processes
113135 else :
@@ -117,6 +139,7 @@ def __init__(
117139
118140 self .device = self .accelerator .device
119141 self .batch_size_per_gpu = int (batch_size )
142+ self .max_concurrent_requests = max_concurrent_requests
120143
121144 @property
122145 def batch_size (self ):
@@ -164,11 +187,15 @@ def encode_image(self, image: Union[Image.Image, str]):
164187 def encode_video (self , video_path , for_get_frames_num ):
165188 vr = VideoReader (video_path , ctx = cpu (0 ))
166189 total_frame_num = len (vr )
167- uniform_sampled_frames = np .linspace (0 , total_frame_num - 1 , for_get_frames_num , dtype = int )
190+ uniform_sampled_frames = np .linspace (
191+ 0 , total_frame_num - 1 , for_get_frames_num , dtype = int
192+ )
168193
169194 # Ensure the last frame is included
170195 if total_frame_num - 1 not in uniform_sampled_frames :
171- uniform_sampled_frames = np .append (uniform_sampled_frames , total_frame_num - 1 )
196+ uniform_sampled_frames = np .append (
197+ uniform_sampled_frames , total_frame_num - 1
198+ )
172199
173200 frame_idx = uniform_sampled_frames .tolist ()
174201 frames = vr .get_batch (frame_idx ).asnumpy ()
@@ -200,9 +227,15 @@ def _collate(x):
200227
201228 from lmms_eval import utils
202229
203- re_ords = utils .Collator ([reg .args for reg in requests ], _collate , grouping = True )
230+ re_ords = utils .Collator (
231+ [reg .args for reg in requests ], _collate , grouping = True
232+ )
204233 chunks = re_ords .get_batched (n = self .batch_size , batch_fn = None )
205- num_iters = len (requests ) // self .batch_size if len (requests ) % self .batch_size == 0 else len (requests ) // self .batch_size + 1
234+ num_iters = (
235+ len (requests ) // self .batch_size
236+ if len (requests ) % self .batch_size == 0
237+ else len (requests ) // self .batch_size + 1
238+ )
206239 pbar = tqdm (total = num_iters , disable = (self .rank != 0 ), desc = "Model Responding" )
207240
208241 for chunk in chunks :
@@ -234,10 +267,24 @@ def _collate(x):
234267 visuals = self .flatten (visuals )
235268 imgs = []
236269 for visual in visuals :
237- if isinstance (visual , str ) and (".mp4" in visual or ".avi" in visual or ".mov" in visual or ".flv" in visual or ".wmv" in visual ):
270+ if isinstance (visual , str ) and (
271+ ".mp4" in visual
272+ or ".avi" in visual
273+ or ".mov" in visual
274+ or ".flv" in visual
275+ or ".wmv" in visual
276+ ):
238277 frames = self .encode_video (visual , self .max_frames_num )
239278 imgs .extend (frames )
240- elif isinstance (visual , str ) and (".jpg" in visual or ".jpeg" in visual or ".png" in visual or ".gif" in visual or ".bmp" in visual or ".tiff" in visual or ".webp" in visual ):
279+ elif isinstance (visual , str ) and (
280+ ".jpg" in visual
281+ or ".jpeg" in visual
282+ or ".png" in visual
283+ or ".gif" in visual
284+ or ".bmp" in visual
285+ or ".tiff" in visual
286+ or ".webp" in visual
287+ ):
241288 img = self .encode_image (visual )
242289 imgs .append (img )
243290 elif isinstance (visual , Image .Image ):
@@ -248,9 +295,16 @@ def _collate(x):
248295 payload ["model" ] = self .model_version
249296
250297 payload ["messages" ].append ({"role" : "user" , "content" : []})
251- payload ["messages" ][0 ]["content" ].append ({"type" : "text" , "text" : context })
298+ payload ["messages" ][0 ]["content" ].append (
299+ {"type" : "text" , "text" : context }
300+ )
252301 for img in imgs :
253- payload ["messages" ][0 ]["content" ].append ({"type" : "image_url" , "image_url" : {"url" : f"data:image/png;base64,{ img } " }})
302+ payload ["messages" ][0 ]["content" ].append (
303+ {
304+ "type" : "image_url" ,
305+ "image_url" : {"url" : f"data:image/png;base64,{ img } " },
306+ }
307+ )
254308
255309 if "max_new_tokens" not in gen_kwargs :
256310 gen_kwargs ["max_new_tokens" ] = 1024
@@ -288,22 +342,33 @@ def process_single_request(payload, i):
288342
289343 except Exception as e :
290344 error_msg = str (e )
291- eval_logger .info (f"Attempt { attempt + 1 } /{ self .max_retries } failed with error: { error_msg } " )
345+ eval_logger .info (
346+ f"Attempt { attempt + 1 } /{ self .max_retries } failed with error: { error_msg } "
347+ )
292348
293349 if attempt == self .max_retries - 1 :
294- eval_logger .error (f"All { self .max_retries } attempts failed. Last error: { error_msg } " )
350+ eval_logger .error (
351+ f"All { self .max_retries } attempts failed. Last error: { error_msg } "
352+ )
295353 return "" , i
296354 else :
297355 time .sleep (self .timeout )
298356
299357 return "" , i
300358
301- tasks_to_run = [(payload , i ) for i , payload in enumerate (batch_payloads ) if batch_responses [i ] is None ]
359+ tasks_to_run = [
360+ (payload , i )
361+ for i , payload in enumerate (batch_payloads )
362+ if batch_responses [i ] is None
363+ ]
302364
303365 if tasks_to_run :
304- max_workers = min (len (tasks_to_run ), 32 )
366+ max_workers = min (len (tasks_to_run ), self . max_concurrent_requests )
305367 with ThreadPoolExecutor (max_workers = max_workers ) as executor :
306- future_to_index = {executor .submit (process_single_request , payload , i ): i for payload , i in tasks_to_run }
368+ future_to_index = {
369+ executor .submit (process_single_request , payload , i ): i
370+ for payload , i in tasks_to_run
371+ }
307372
308373 for future in as_completed (future_to_index ):
309374 response_text , i = future .result ()
@@ -313,17 +378,25 @@ def process_single_request(payload, i):
313378 for doc_uuid , response_text in zip (batch_doc_uuids , batch_responses ):
314379 if response_text is not None :
315380 self .response_cache [doc_uuid ] = response_text
316- with open (self .response_persistent_file , "w" ) as f :
317- json .dump (self .response_cache , f )
318381
319382 res .extend ([r for r in batch_responses if r is not None ])
320383 pbar .update (1 )
321384
322385 pbar .close ()
386+
387+ # Write cache once at the end if in continual mode
388+ if self .continual_mode is True :
389+ with open (self .response_persistent_file , "w" ) as f :
390+ json .dump (self .response_cache , f )
391+
323392 return res
324393
325394 def generate_until_multi_round (self , requests ) -> List [str ]:
326- raise NotImplementedError ("TODO: Implement multi-round generation for OpenAI compatible models" )
395+ raise NotImplementedError (
396+ "TODO: Implement multi-round generation for OpenAI compatible models"
397+ )
327398
328399 def loglikelihood (self , requests : List [Instance ]) -> List [Tuple [float , bool ]]:
329- raise NotImplementedError ("TODO: Implement loglikelihood for OpenAI compatible models" )
400+ raise NotImplementedError (
401+ "TODO: Implement loglikelihood for OpenAI compatible models"
402+ )
0 commit comments