Skip to content

Commit 1bb5c36

Browse files
committed
Fix clip blip
1 parent 6c01f24 commit 1bb5c36

File tree

1 file changed

+119
-56
lines changed

1 file changed

+119
-56
lines changed

sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py

Lines changed: 119 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,14 @@ def batch_elements_kwargs(self):
187187

188188
def run_inference(
189189
self, batch: List[Dict[str, Any]], model, inference_args=None):
190+
191+
if model is not None:
192+
self._model = model
193+
self._model.to(self.device)
194+
self._model.eval()
195+
if self._processor is None:
196+
from transformers import BlipProcessor
197+
self._processor = BlipProcessor.from_pretrained(self.model_name)
190198
if self._model is None:
191199
self._model = self.load_model()
192200

@@ -275,72 +283,127 @@ def batch_elements_kwargs(self):
275283

276284
def run_inference(
277285
self, batch: List[Dict[str, Any]], model, inference_args=None):
286+
287+
if model is not None:
288+
self._model = model
289+
self._model.to(self.device)
290+
self._model.eval()
291+
if self._processor is None:
292+
from transformers import CLIPProcessor
293+
self._processor = CLIPProcessor.from_pretrained(self.model_name)
278294
if self._model is None:
279295
self._model = self.load_model()
280296

281-
start = now_millis()
297+
start_batch = now_millis()
298+
299+
# Flat lists for a single batched CLIP forward pass
300+
images: List[PILImage.Image] = []
301+
texts: List[str] = []
302+
offsets: List[Tuple[int, int]] = []
303+
# per element -> [start, end) in flat arrays
304+
candidates_list: List[List[str]] = []
305+
blip_ms_list: List[Optional[int]] = []
306+
307+
for x in batch:
308+
image_bytes = x["image_bytes"]
309+
candidates = [str(c) for c in (x.get("candidates", []) or [])]
310+
candidates_list.append(candidates)
311+
blip_ms_list.append(x.get("blip_ms", None))
312+
313+
try:
314+
img = decode_pil(image_bytes)
315+
except Exception:
316+
img = PILImage.new("RGB", (224, 224), color=(0, 0, 0))
317+
318+
start_i = len(texts)
319+
for c in candidates:
320+
images.append(img)
321+
texts.append(c)
322+
end_i = len(texts)
323+
offsets.append((start_i, end_i))
324+
325+
results: List[Dict[str, Any]] = []
326+
327+
# Fast path: no candidates at all
328+
if not texts:
329+
for blip_ms in blip_ms_list:
330+
total_ms = int(blip_ms) if blip_ms is not None else None
331+
results.append({
332+
"best_caption": "",
333+
"best_score": None,
334+
"candidates": [],
335+
"scores": [],
336+
"blip_ms": blip_ms,
337+
"clip_ms": 0,
338+
"total_ms": total_ms,
339+
})
340+
return results
282341

283-
results = []
284342
with torch.no_grad():
285-
for x in batch:
286-
image_bytes = x["image_bytes"]
287-
candidates = x.get("candidates", [])
288-
blip_ms = x.get("blip_ms", None)
289-
290-
# Decode image
291-
try:
292-
image = decode_pil(image_bytes)
293-
except Exception:
294-
image = PILImage.new("RGB", (224, 224), color=(0, 0, 0))
295-
296-
if not candidates:
297-
clip_ms = now_millis() - start
298-
results.append({
299-
"best_caption": "",
300-
"best_score": None,
301-
"candidates": [],
302-
"scores": [],
303-
"blip_ms": blip_ms,
304-
"clip_ms": clip_ms,
305-
"total_ms": None,
306-
})
307-
continue
308-
309-
# CLIPProcessor can accept a single image and list of texts
310-
inputs = self._processor(
311-
text=candidates, images=image, return_tensors="pt", padding=True)
312-
inputs = {k: v.to(self.device) for k, v in inputs.items()}
313-
314-
outputs = self._model(**inputs)
315-
# logits_per_image shape: [1, num_texts]
316-
logits = outputs.logits_per_image[0]
317-
318-
if self.score_normalize:
319-
# optional normalization to [0..1] via softmax
320-
probs = torch.softmax(logits, dim=-1)
321-
scores_t = probs
322-
else:
323-
scores_t = logits
324-
325-
scores = scores_t.detach().cpu().tolist()
326-
best_idx = int(torch.argmax(scores_t).item())
327-
best_caption = candidates[best_idx]
328-
best_score = float(scores[best_idx])
329-
330-
clip_ms = now_millis() - start
331-
total_ms = None
332-
if blip_ms is not None:
333-
total_ms = int(blip_ms) + int(clip_ms)
343+
inputs = self._processor(
344+
text=texts,
345+
images=images,
346+
return_tensors="pt",
347+
padding=True,
348+
truncation=True,
349+
)
350+
inputs = {k: (v.to(self.device) if torch.is_tensor(v) else v)
351+
for k, v in inputs.items()}
352+
353+
# avoid NxN logits inside CLIPModel.forward()
354+
img = self._model.get_image_features(pixel_values=inputs["pixel_values"]) # [N, D]
355+
txt = self._model.get_text_features(
356+
input_ids=inputs["input_ids"],
357+
attention_mask=inputs.get("attention_mask"),
358+
) # [N, D]
334359

360+
img = img / img.norm(dim=-1, keepdim=True)
361+
txt = txt / txt.norm(dim=-1, keepdim=True)
362+
363+
logit_scale = self._model.logit_scale.exp() # scalar tensor
364+
pair_scores = (img * txt).sum(dim=-1) * logit_scale # [N]
365+
pair_scores_cpu = pair_scores.detach().cpu().tolist()
366+
367+
batch_ms = now_millis() - start_batch
368+
total_pairs = len(texts)
369+
370+
for (start_i, end_i), candidates, blip_ms in zip(offsets, candidates_list, blip_ms_list):
371+
if start_i == end_i:
372+
total_ms = int(blip_ms) if blip_ms is not None else None
335373
results.append({
336-
"best_caption": best_caption,
337-
"best_score": best_score,
338-
"candidates": candidates,
339-
"scores": scores,
374+
"best_caption": "",
375+
"best_score": None,
376+
"candidates": [],
377+
"scores": [],
340378
"blip_ms": blip_ms,
341-
"clip_ms": clip_ms,
379+
"clip_ms": 0,
342380
"total_ms": total_ms,
343381
})
382+
continue
383+
384+
scores = [float(pair_scores_cpu[j]) for j in range(start_i, end_i)]
385+
386+
if self.score_normalize:
387+
scores_t = torch.tensor(scores, dtype=torch.float32)
388+
scores = torch.softmax(scores_t, dim=0).tolist()
389+
390+
best_idx = max(range(len(scores)), key=lambda i: scores[i])
391+
392+
pairs = end_i - start_i
393+
clip_ms_elem = int(batch_ms * (pairs / max(1, total_pairs)))
394+
if pairs > 0:
395+
clip_ms_elem = max(1, clip_ms_elem)
396+
397+
total_ms = int(blip_ms) + clip_ms_elem if blip_ms is not None else None
398+
results.append({
399+
"best_caption": candidates[best_idx],
400+
"best_score": float(scores[best_idx]),
401+
"candidates": candidates,
402+
"scores": scores,
403+
"blip_ms": blip_ms,
404+
"clip_ms": clip_ms_elem,
405+
"total_ms": total_ms,
406+
})
344407

345408
return results
346409

0 commit comments

Comments
 (0)