|
| 1 | +import base64 |
| 2 | +import io |
| 3 | +import os |
| 4 | + |
| 5 | +from PIL import Image |
| 6 | + |
| 7 | +from lmms_eval.tasks.worldqa.utils import ( |
| 8 | + MultiChoiceRegexFilter, |
| 9 | + worldq_gen_gpt_eval, |
| 10 | + worldqa_aggregate_gen, |
| 11 | + worldqa_aggregate_mc, |
| 12 | + worldqa_aggregate_mc_eval, |
| 13 | + worldqa_aggregate_mc_ppl, |
| 14 | + worldqa_doc_to_answer, |
| 15 | + worldqa_doc_to_answer_mc, |
| 16 | + worldqa_doc_to_answer_mc_ppl, |
| 17 | + worldqa_doc_to_choice, |
| 18 | + worldqa_doc_to_text, |
| 19 | + worldqa_doc_to_visual, |
| 20 | + worldqa_process_results, |
| 21 | + worldqa_process_results_mc, |
| 22 | +) |
| 23 | + |
| 24 | + |
| 25 | +def worldvqa_doc_to_visual(doc): |
| 26 | + if "image" in doc and doc["image"] is not None: |
| 27 | + image = doc["image"] |
| 28 | + if isinstance(image, Image.Image): |
| 29 | + return [image.convert("RGB")] |
| 30 | + if isinstance(image, str): |
| 31 | + if os.path.exists(image): |
| 32 | + return [Image.open(image).convert("RGB")] |
| 33 | + decoded = Image.open(io.BytesIO(base64.b64decode(image))).convert("RGB") |
| 34 | + return [decoded] |
| 35 | + if isinstance(image, dict): |
| 36 | + image_path = image.get("path") |
| 37 | + if image_path and os.path.exists(image_path): |
| 38 | + return [Image.open(image_path).convert("RGB")] |
| 39 | + image_bytes = image.get("bytes") |
| 40 | + if image_bytes is not None: |
| 41 | + return [Image.open(io.BytesIO(image_bytes)).convert("RGB")] |
| 42 | + |
| 43 | + video = doc.get("video") |
| 44 | + if isinstance(video, str) and video: |
| 45 | + return [video] |
| 46 | + if isinstance(video, dict): |
| 47 | + video_path = video.get("path") |
| 48 | + if video_path: |
| 49 | + return [video_path] |
| 50 | + |
| 51 | + try: |
| 52 | + return worldqa_doc_to_visual(doc) |
| 53 | + except SystemExit: |
| 54 | + video_idx = doc.get("video_idx") |
| 55 | + if not video_idx: |
| 56 | + return [] |
| 57 | + hf_home = os.path.expanduser(os.getenv("HF_HOME", "~/.cache/huggingface/")) |
| 58 | + return [os.path.join(hf_home, "multi-hop-reasoning", "videos", f"{video_idx}.mp4")] |
| 59 | + |
| 60 | + |
| 61 | +def worldvqa_doc_to_text(doc, lmms_eval_specific_kwargs=None): |
| 62 | + if "option" in doc or "video_idx" in doc: |
| 63 | + return worldqa_doc_to_text(doc, lmms_eval_specific_kwargs=lmms_eval_specific_kwargs) |
| 64 | + |
| 65 | + if lmms_eval_specific_kwargs is None: |
| 66 | + lmms_eval_specific_kwargs = {} |
| 67 | + |
| 68 | + pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "") |
| 69 | + post_prompt = lmms_eval_specific_kwargs.get("post_prompt", "") |
| 70 | + return f"{pre_prompt}{doc['question'].strip()}{post_prompt}" |
| 71 | + |
| 72 | + |
| 73 | +worldvqa_doc_to_answer = worldqa_doc_to_answer |
| 74 | +worldvqa_doc_to_answer_mc = worldqa_doc_to_answer_mc |
| 75 | +worldvqa_doc_to_answer_mc_ppl = worldqa_doc_to_answer_mc_ppl |
| 76 | +worldvqa_doc_to_choice = worldqa_doc_to_choice |
| 77 | +worldvqa_process_results = worldqa_process_results |
| 78 | +worldvqa_process_results_mc = worldqa_process_results_mc |
| 79 | +worldvqa_aggregate_gen = worldqa_aggregate_gen |
| 80 | +worldvqa_aggregate_mc = worldqa_aggregate_mc |
| 81 | +worldvqa_aggregate_mc_eval = worldqa_aggregate_mc_eval |
| 82 | +worldvqa_aggregate_mc_ppl = worldqa_aggregate_mc_ppl |
| 83 | +worldvqa_gen_gpt_eval = worldq_gen_gpt_eval |
| 84 | + |
| 85 | +__all__ = [ |
| 86 | + "MultiChoiceRegexFilter", |
| 87 | + "worldvqa_doc_to_visual", |
| 88 | + "worldvqa_doc_to_text", |
| 89 | + "worldvqa_doc_to_answer", |
| 90 | + "worldvqa_doc_to_answer_mc", |
| 91 | + "worldvqa_doc_to_answer_mc_ppl", |
| 92 | + "worldvqa_doc_to_choice", |
| 93 | + "worldvqa_process_results", |
| 94 | + "worldvqa_process_results_mc", |
| 95 | + "worldvqa_aggregate_gen", |
| 96 | + "worldvqa_aggregate_mc", |
| 97 | + "worldvqa_aggregate_mc_eval", |
| 98 | + "worldvqa_aggregate_mc_ppl", |
| 99 | + "worldvqa_gen_gpt_eval", |
| 100 | +] |
0 commit comments