Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/current_tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ python -m lmms_eval --tasks list_with_num
- [MMVet v2](https://github.com/yuweihao/MM-Vet) (mmvetv2)
- [MMVU](https://mmvu-bench.github.io/) (mmvu)
- [MMWorld](https://mmworld-bench.github.io/) (mmworld)
- [MTVQA](https://huggingface.co/datasets/ByteDance/MTVQA) (mtvqa)
- [MMSI-Bench](https://github.com/MMSI-Bench/MMSI-Bench) (mmsi_bench)
- [MMSearch](https://mmsearch.github.io/) (mmsearch)

Expand Down
27 changes: 27 additions & 0 deletions lmms_eval/tasks/mtvqa/mtvqa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
dataset_path: ByteDance/MTVQA
dataset_kwargs:
token: false
task: mtvqa
test_split: test
process_docs: !function utils.mtvqa_process_docs
output_type: generate_until
doc_to_visual: !function utils.mtvqa_doc_to_visual
doc_to_text: !function utils.mtvqa_doc_to_text
doc_to_target: answer
generation_kwargs:
max_new_tokens: 64
temperature: 0
top_p: 1.0
num_beams: 1
do_sample: false
process_results: !function utils.mtvqa_process_results
metric_list:
- metric: mtvqa_score
aggregation: !function utils.mtvqa_aggregate_results
higher_is_better: true
lmms_eval_specific_kwargs:
default:
pre_prompt: ""
post_prompt: "\nAnswer the question using a word or phrase in the language of the question."
metadata:
- version: 0.0
130 changes: 130 additions & 0 deletions lmms_eval/tasks/mtvqa/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import ast
from collections import defaultdict

import datasets
from loguru import logger as eval_logger
from PIL import Image

MTVQA_PROMPT_SUFFIX = "\nAnswer the question using a word or phrase in the language of the question."


def _parse_qa_pairs(value):
if isinstance(value, list):
parsed_pairs = value
elif isinstance(value, str):
raw = value.strip()
if not raw:
return []
try:
parsed_value = ast.literal_eval(raw)
except (SyntaxError, ValueError) as exc:
eval_logger.warning("Failed to parse MTVQA qa_pairs: {}", exc)
return []
parsed_pairs = parsed_value if isinstance(parsed_value, list) else []
else:
return []

normalized_pairs = []
for pair in parsed_pairs:
if not isinstance(pair, dict):
continue
question = str(pair.get("question", "")).strip()
answer = str(pair.get("answer", "")).strip()
if question and answer:
normalized_pairs.append({"question": question, "answer": answer})
return normalized_pairs


def mtvqa_process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
flattened_docs = []
for idx, doc in enumerate(dataset):
doc_dict = dict(doc)
sample_id = str(doc_dict.get("id", "")).strip() or f"mtvqa_{idx}"

category = str(doc_dict.get("lang", "")).strip() or "unknown"
qa_pairs = _parse_qa_pairs(doc_dict.get("qa_pairs"))

for qa_idx, qa in enumerate(qa_pairs):
flattened_docs.append(
{
"question_id": f"{sample_id}_{qa_idx}",
"id": sample_id,
"category": category,
"question": qa["question"],
"answer": qa["answer"],
"image": doc_dict.get("image"),
}
)

if not flattened_docs:
eval_logger.warning("[mtvqa] No samples found after flattening qa_pairs.")
return dataset.select(range(0))

eval_logger.info("[mtvqa] Loaded {} QA pairs from {} images.", len(flattened_docs), len(dataset))
return datasets.Dataset.from_list(flattened_docs)


def _to_rgb_image(image_value):
if isinstance(image_value, Image.Image):
return image_value.convert("RGB")
if isinstance(image_value, dict):
if image_value.get("bytes") is not None:
from io import BytesIO

return Image.open(BytesIO(image_value["bytes"])).convert("RGB")
if image_value.get("path"):
return Image.open(image_value["path"]).convert("RGB")
raise TypeError(f"Unsupported MTVQA image payload type: {type(image_value)}")


def mtvqa_doc_to_visual(doc):
image_value = doc.get("image")
if image_value is None:
sample_id = str(doc.get("id", "")).strip()
raise KeyError(f"Missing MTVQA image payload for sample id: {sample_id}")
return [_to_rgb_image(image_value)]


def mtvqa_doc_to_text(doc, lmms_eval_specific_kwargs=None):
if lmms_eval_specific_kwargs is None:
lmms_eval_specific_kwargs = {}

pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "")
post_prompt = lmms_eval_specific_kwargs.get("post_prompt", MTVQA_PROMPT_SUFFIX)
question = str(doc["question"]).strip()
return f"{pre_prompt}{question}{post_prompt}"


def mtvqa_process_results(doc, results):
prediction = str(results[0]) if results else ""
pred = prediction.strip().lower().replace(".", "")
answer = str(doc["answer"]).strip().lower().replace(".", "")
score = 1.0 if answer in pred else 0.0

return {
"mtvqa_score": {
"category": doc["category"],
"score": score,
}
}


def mtvqa_aggregate_results(results):
category_scores = defaultdict(list)

for result in results:
category = str(result.get("category", "Unknown"))
score = float(result.get("score", 0.0))
category_scores[category].append(score)
category_scores["Average"].append(score)

for category in sorted(category_scores.keys()):
scores = category_scores[category]
if not scores:
continue
eval_logger.info("MTVQA {}: {:.2f}", category, (sum(scores) / len(scores)) * 100)

average_scores = category_scores.get("Average", [])
if not average_scores:
return 0.0
return (sum(average_scores) / len(average_scores)) * 100