@@ -187,6 +187,14 @@ def batch_elements_kwargs(self):
187187
188188 def run_inference (
189189 self , batch : List [Dict [str , Any ]], model , inference_args = None ):
190+
191+ if model is not None :
192+ self ._model = model
193+ self ._model .to (self .device )
194+ self ._model .eval ()
195+ if self ._processor is None :
196+ from transformers import BlipProcessor
197+ self ._processor = BlipProcessor .from_pretrained (self .model_name )
190198 if self ._model is None :
191199 self ._model = self .load_model ()
192200
@@ -275,72 +283,127 @@ def batch_elements_kwargs(self):
275283
276284 def run_inference (
277285 self , batch : List [Dict [str , Any ]], model , inference_args = None ):
286+
287+ if model is not None :
288+ self ._model = model
289+ self ._model .to (self .device )
290+ self ._model .eval ()
291+ if self ._processor is None :
292+ from transformers import CLIPProcessor
293+ self ._processor = CLIPProcessor .from_pretrained (self .model_name )
278294 if self ._model is None :
279295 self ._model = self .load_model ()
280296
281- start = now_millis ()
297+ start_batch = now_millis ()
298+
299+ # Flat lists for a single batched CLIP forward pass
300+ images : List [PILImage .Image ] = []
301+ texts : List [str ] = []
302+ offsets : List [Tuple [int , int ]] = []
303+ # per element -> [start, end) in flat arrays
304+ candidates_list : List [List [str ]] = []
305+ blip_ms_list : List [Optional [int ]] = []
306+
307+ for x in batch :
308+ image_bytes = x ["image_bytes" ]
309+ candidates = [str (c ) for c in (x .get ("candidates" , []) or [])]
310+ candidates_list .append (candidates )
311+ blip_ms_list .append (x .get ("blip_ms" , None ))
312+
313+ try :
314+ img = decode_pil (image_bytes )
315+ except Exception :
316+ img = PILImage .new ("RGB" , (224 , 224 ), color = (0 , 0 , 0 ))
317+
318+ start_i = len (texts )
319+ for c in candidates :
320+ images .append (img )
321+ texts .append (c )
322+ end_i = len (texts )
323+ offsets .append ((start_i , end_i ))
324+
325+ results : List [Dict [str , Any ]] = []
326+
327+ # Fast path: no candidates at all
328+ if not texts :
329+ for blip_ms in blip_ms_list :
330+ total_ms = int (blip_ms ) if blip_ms is not None else None
331+ results .append ({
332+ "best_caption" : "" ,
333+ "best_score" : None ,
334+ "candidates" : [],
335+ "scores" : [],
336+ "blip_ms" : blip_ms ,
337+ "clip_ms" : 0 ,
338+ "total_ms" : total_ms ,
339+ })
340+ return results
282341
283- results = []
284342 with torch .no_grad ():
285- for x in batch :
286- image_bytes = x ["image_bytes" ]
287- candidates = x .get ("candidates" , [])
288- blip_ms = x .get ("blip_ms" , None )
289-
290- # Decode image
291- try :
292- image = decode_pil (image_bytes )
293- except Exception :
294- image = PILImage .new ("RGB" , (224 , 224 ), color = (0 , 0 , 0 ))
295-
296- if not candidates :
297- clip_ms = now_millis () - start
298- results .append ({
299- "best_caption" : "" ,
300- "best_score" : None ,
301- "candidates" : [],
302- "scores" : [],
303- "blip_ms" : blip_ms ,
304- "clip_ms" : clip_ms ,
305- "total_ms" : None ,
306- })
307- continue
308-
309- # CLIPProcessor can accept a single image and list of texts
310- inputs = self ._processor (
311- text = candidates , images = image , return_tensors = "pt" , padding = True )
312- inputs = {k : v .to (self .device ) for k , v in inputs .items ()}
313-
314- outputs = self ._model (** inputs )
315- # logits_per_image shape: [1, num_texts]
316- logits = outputs .logits_per_image [0 ]
317-
318- if self .score_normalize :
319- # optional normalization to [0..1] via softmax
320- probs = torch .softmax (logits , dim = - 1 )
321- scores_t = probs
322- else :
323- scores_t = logits
324-
325- scores = scores_t .detach ().cpu ().tolist ()
326- best_idx = int (torch .argmax (scores_t ).item ())
327- best_caption = candidates [best_idx ]
328- best_score = float (scores [best_idx ])
329-
330- clip_ms = now_millis () - start
331- total_ms = None
332- if blip_ms is not None :
333- total_ms = int (blip_ms ) + int (clip_ms )
343+ inputs = self ._processor (
344+ text = texts ,
345+ images = images ,
346+ return_tensors = "pt" ,
347+ padding = True ,
348+ truncation = True ,
349+ )
350+ inputs = {k : (v .to (self .device ) if torch .is_tensor (v ) else v )
351+ for k , v in inputs .items ()}
352+
353+ # avoid NxN logits inside CLIPModel.forward()
354+ img = self ._model .get_image_features (pixel_values = inputs ["pixel_values" ]) # [N, D]
355+ txt = self ._model .get_text_features (
356+ input_ids = inputs ["input_ids" ],
357+ attention_mask = inputs .get ("attention_mask" ),
358+ ) # [N, D]
334359
360+ img = img / img .norm (dim = - 1 , keepdim = True )
361+ txt = txt / txt .norm (dim = - 1 , keepdim = True )
362+
363+ logit_scale = self ._model .logit_scale .exp () # scalar tensor
364+ pair_scores = (img * txt ).sum (dim = - 1 ) * logit_scale # [N]
365+ pair_scores_cpu = pair_scores .detach ().cpu ().tolist ()
366+
367+ batch_ms = now_millis () - start_batch
368+ total_pairs = len (texts )
369+
370+ for (start_i , end_i ), candidates , blip_ms in zip (offsets , candidates_list , blip_ms_list ):
371+ if start_i == end_i :
372+ total_ms = int (blip_ms ) if blip_ms is not None else None
335373 results .append ({
336- "best_caption" : best_caption ,
337- "best_score" : best_score ,
338- "candidates" : candidates ,
339- "scores" : scores ,
374+ "best_caption" : "" ,
375+ "best_score" : None ,
376+ "candidates" : [] ,
377+ "scores" : [] ,
340378 "blip_ms" : blip_ms ,
341- "clip_ms" : clip_ms ,
379+ "clip_ms" : 0 ,
342380 "total_ms" : total_ms ,
343381 })
382+ continue
383+
384+ scores = [float (pair_scores_cpu [j ]) for j in range (start_i , end_i )]
385+
386+ if self .score_normalize :
387+ scores_t = torch .tensor (scores , dtype = torch .float32 )
388+ scores = torch .softmax (scores_t , dim = 0 ).tolist ()
389+
390+ best_idx = max (range (len (scores )), key = lambda i : scores [i ])
391+
392+ pairs = end_i - start_i
393+ clip_ms_elem = int (batch_ms * (pairs / max (1 , total_pairs )))
394+ if pairs > 0 :
395+ clip_ms_elem = max (1 , clip_ms_elem )
396+
397+ total_ms = int (blip_ms ) + clip_ms_elem if blip_ms is not None else None
398+ results .append ({
399+ "best_caption" : candidates [best_idx ],
400+ "best_score" : float (scores [best_idx ]),
401+ "candidates" : candidates ,
402+ "scores" : scores ,
403+ "blip_ms" : blip_ms ,
404+ "clip_ms" : clip_ms_elem ,
405+ "total_ms" : total_ms ,
406+ })
344407
345408 return results
346409
0 commit comments