Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 50 additions & 57 deletions lmms_eval/tasks/ocrbench_v2/spotting_metric.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import ast
import os
import re
import shutil
import tempfile
import zipfile

import ipdb

import lmms_eval.tasks.ocrbench_v2.spotting_eval.rrc_evaluation_funcs_1_1 as rrc_evaluation_funcs
from lmms_eval.tasks.ocrbench_v2.spotting_eval.script import (
default_evaluation_params,
Expand Down Expand Up @@ -127,58 +125,53 @@ def zip_folder(source_folder, destination_zip):
def spotting_evaluation(prediction_list, img_metas):
score = 0

submit_path = "./lmms_eval/tasks/ocrbench_v2/spotting_eval/submit"
gt_path = "./lmms_eval/tasks/ocrbench_v2/spotting_eval/gt"
submit_zip_path = "./lmms_eval/tasks/ocrbench_v2/spotting_eval/submit.zip"
gt_zip_path = "./lmms_eval/tasks/ocrbench_v2/spotting_eval/gt.zip"
for file_path in [submit_path, gt_path, submit_zip_path, gt_zip_path]:
if "zip" in file_path:
if os.path.exists(file_path):
os.remove(file_path)
else:
if os.path.exists(file_path):
shutil.rmtree(file_path)
os.makedirs(file_path)

res_submit_list = []
for item in prediction_list:
if len(item) != 5:
ipdb.set_trace()
x1, y1, x2, y2, rec = item
if x1 >= x2 or y1 >= y2:
continue

res_submit_list.append(",".join([str(x1), str(y1), str(x2), str(y1), str(x2), str(y2), str(x1), str(y2), rec]))

res_gt_list = []
for bbox, rec in zip(img_metas["bbox_list"], img_metas["content"]):
x_coords = bbox[0::2]
y_coords = bbox[1::2]

x1, y1 = min(x_coords), min(y_coords)
x2, y2 = max(x_coords), max(y_coords)

res_gt_list.append(",".join([str(x1), str(y1), str(x2), str(y1), str(x2), str(y2), str(x1), str(y2), rec]))

if len(res_submit_list) == 0 or len(res_gt_list) == 0:
return 0

with open(os.path.join(submit_path, "res_img_0.txt"), "w") as f:
for item in res_submit_list[:-1]:
f.write(item + "\n")
f.write(res_submit_list[-1])

with open(os.path.join(gt_path, "gt_img_0.txt"), "w") as f:
for item in res_gt_list[:-1]:
f.write(item + "\n")
f.write(res_gt_list[-1])

zip_folder(submit_path, submit_zip_path)
zip_folder(gt_path, gt_zip_path)

command = {"g": gt_zip_path, "s": submit_zip_path, "o": "./", "p": '{"IOU_CONSTRAINT":0.5}'}

# run rrc_evaluation_funcs
result = rrc_evaluation_funcs.main_evaluation(command, default_evaluation_params, validate_data, evaluate_method)
score = result["method"]["hmean"]
with tempfile.TemporaryDirectory(prefix="ocrbench-v2-spotting-") as temp_dir:
submit_path = os.path.join(temp_dir, "submit")
gt_path = os.path.join(temp_dir, "gt")
submit_zip_path = os.path.join(temp_dir, "submit.zip")
gt_zip_path = os.path.join(temp_dir, "gt.zip")
os.makedirs(submit_path, exist_ok=True)
os.makedirs(gt_path, exist_ok=True)
os.makedirs(os.path.join(temp_dir, "lmms_eval", "tasks", "ocrbench_v2", "spotting_eval"), exist_ok=True)

res_submit_list = []
for item in prediction_list:
if len(item) != 5:
continue
x1, y1, x2, y2, rec = item
if x1 >= x2 or y1 >= y2:
continue

res_submit_list.append(",".join([str(x1), str(y1), str(x2), str(y1), str(x2), str(y2), str(x1), str(y2), rec]))

res_gt_list = []
for bbox, rec in zip(img_metas["bbox_list"], img_metas["content"]):
x_coords = bbox[0::2]
y_coords = bbox[1::2]

x1, y1 = min(x_coords), min(y_coords)
x2, y2 = max(x_coords), max(y_coords)

res_gt_list.append(",".join([str(x1), str(y1), str(x2), str(y1), str(x2), str(y2), str(x1), str(y2), rec]))

if len(res_submit_list) == 0 or len(res_gt_list) == 0:
return 0

with open(os.path.join(submit_path, "res_img_0.txt"), "w") as f:
for item in res_submit_list[:-1]:
f.write(item + "\n")
f.write(res_submit_list[-1])

with open(os.path.join(gt_path, "gt_img_0.txt"), "w") as f:
for item in res_gt_list[:-1]:
f.write(item + "\n")
f.write(res_gt_list[-1])

zip_folder(submit_path, submit_zip_path)
zip_folder(gt_path, gt_zip_path)

command = {"g": gt_zip_path, "s": submit_zip_path, "o": temp_dir, "p": '{"IOU_CONSTRAINT":0.5}'}

result = rrc_evaluation_funcs.main_evaluation(command, default_evaluation_params, validate_data, evaluate_method)
score = result["method"]["hmean"]
return score
80 changes: 41 additions & 39 deletions lmms_eval/tasks/ocrbench_v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,23 @@
vqa_evaluation_case_sensitive,
)

# Add the following functions to your existing utils.py file
OCRBench_v2_score = {
"text_recognition_en": [],
"text_detection_en": [],
"text_spotting_en": [],
"relationship_extraction_en": [],
"element_parsing_en": [],
"mathematical_calculation_en": [],
"visual_text_understanding_en": [],
"knowledge_reasoning_en": [],
"text_recognition_cn": [],
"relationship_extraction_cn": [],
"element_parsing_cn": [],
"visual_text_understanding_cn": [],
"knowledge_reasoning_cn": [],
}

def _make_score_buckets():
return {
"text_recognition_en": [],
"text_detection_en": [],
"text_spotting_en": [],
"relationship_extraction_en": [],
"element_parsing_en": [],
"mathematical_calculation_en": [],
"visual_text_understanding_en": [],
"knowledge_reasoning_en": [],
"text_recognition_cn": [],
"relationship_extraction_cn": [],
"element_parsing_cn": [],
"visual_text_understanding_cn": [],
"knowledge_reasoning_cn": [],
}


teds = TEDS(n_jobs=32)
Expand Down Expand Up @@ -253,7 +254,7 @@ def ocrbench_v2_process_results(doc, results):
else:
pred_chart_html = dict_to_html(pred_chart_dict)
if isinstance(answer, str):
answer = convert_str_to_multi_dict(pred)
answer = convert_str_to_multi_dict(answer)
gt_chart_html = dict_to_html(answer)
score = teds.evaluate(pred_chart_html, gt_chart_html)
else:
Expand Down Expand Up @@ -332,7 +333,7 @@ def ocrbench_v2_process_results(doc, results):
score = (get_value_or_zero(ocr_metric["bleu"]) + get_value_or_zero(ocr_metric["meteor"]) + get_value_or_zero(ocr_metric["f_measure"]) + (1 - get_value_or_zero(ocr_metric["edit_dist"]))) / 4
elif data_type == "full-page OCR en":
if not pred:
score == 0
score = 0
else:
ocr_metric = cal_per_metrics(pred, gt_ans[0])
score = (get_value_or_zero(ocr_metric["bleu"]) + get_value_or_zero(ocr_metric["meteor"]) + get_value_or_zero(ocr_metric["f_measure"]) + (1 - get_value_or_zero(ocr_metric["edit_dist"]))) / 4
Expand Down Expand Up @@ -372,12 +373,13 @@ def ocrbench_v2_process_results(doc, results):
}


def calculate_average_score(categories):
return sum(sum(OCRBench_v2_score[cat]) / len(OCRBench_v2_score[cat]) if len(OCRBench_v2_score[cat]) > 0 else 0 for cat in categories) / len(categories)
def calculate_average_score(categories, score_buckets):
return sum(sum(score_buckets[cat]) / len(score_buckets[cat]) if len(score_buckets[cat]) > 0 else 0 for cat in categories) / len(categories)


def ocrbench_v2_aggregate_accuracy(results, args):
question_type_scores = {}
score_buckets = _make_score_buckets()

for result in results:
if "ignore" in result.keys() and result["ignore"] == "True":
Expand All @@ -391,43 +393,43 @@ def ocrbench_v2_aggregate_accuracy(results, args):
question_type_scores[question_type].append(score)

if question_type in ["text recognition en", "fine-grained text recognition en", "full-page OCR en"]:
OCRBench_v2_score["text_recognition_en"].append(score)
score_buckets["text_recognition_en"].append(score)

elif question_type in ["text grounding en", "VQA with position en"]:
OCRBench_v2_score["text_detection_en"].append(score)
score_buckets["text_detection_en"].append(score)

elif question_type == "text spotting en":
OCRBench_v2_score["text_spotting_en"].append(score)
score_buckets["text_spotting_en"].append(score)

elif question_type in ["key information extraction en", "key information mapping en"]:
OCRBench_v2_score["relationship_extraction_en"].append(score)
score_buckets["relationship_extraction_en"].append(score)

elif question_type in ["document parsing en", "chart parsing en", "table parsing en", "formula recognition en"]:
OCRBench_v2_score["element_parsing_en"].append(score)
score_buckets["element_parsing_en"].append(score)

elif question_type in ["math QA en", "text counting en"]:
OCRBench_v2_score["mathematical_calculation_en"].append(score)
score_buckets["mathematical_calculation_en"].append(score)

elif question_type in ["document classification en", "cognition VQA en", "diagram QA en"]:
OCRBench_v2_score["visual_text_understanding_en"].append(score)
score_buckets["visual_text_understanding_en"].append(score)

elif question_type in ["reasoning VQA en", "science QA en", "APP agent en", "ASCII art classification en"]:
OCRBench_v2_score["knowledge_reasoning_en"].append(score)
score_buckets["knowledge_reasoning_en"].append(score)

elif question_type == "full-page OCR cn":
OCRBench_v2_score["text_recognition_cn"].append(score)
score_buckets["text_recognition_cn"].append(score)

elif question_type in ["key information extraction cn", "handwritten answer extraction cn"]:
OCRBench_v2_score["relationship_extraction_cn"].append(score)
score_buckets["relationship_extraction_cn"].append(score)

elif question_type in ["document parsing cn", "table parsing cn", "formula recognition cn"]:
OCRBench_v2_score["element_parsing_cn"].append(score)
score_buckets["element_parsing_cn"].append(score)

elif question_type == "cognition VQA cn":
OCRBench_v2_score["visual_text_understanding_cn"].append(score)
score_buckets["visual_text_understanding_cn"].append(score)

elif question_type in ["reasoning VQA cn", "text translation cn"]:
OCRBench_v2_score["knowledge_reasoning_cn"].append(score)
score_buckets["knowledge_reasoning_cn"].append(score)

else:
print("No such task!")
Expand All @@ -437,8 +439,8 @@ def ocrbench_v2_aggregate_accuracy(results, args):

chinese_tasks = ["text_recognition_cn", "relationship_extraction_cn", "element_parsing_cn", "visual_text_understanding_cn", "knowledge_reasoning_cn"]

OCRBench_v2_English_subset_score = calculate_average_score(english_tasks)
OCRBench_v2_Chinese_subset_score = calculate_average_score(chinese_tasks)
OCRBench_v2_English_subset_score = calculate_average_score(english_tasks, score_buckets)
OCRBench_v2_Chinese_subset_score = calculate_average_score(chinese_tasks, score_buckets)

Final_score = (OCRBench_v2_English_subset_score + OCRBench_v2_Chinese_subset_score) / 2
file_name = generate_submission_file("ocrbench_v2_results.txt", args, subpath="results")
Expand All @@ -450,14 +452,14 @@ def ocrbench_v2_aggregate_accuracy(results, args):
print(f"{q_type} (sample number: {len(scores)}): {avg_score:.2f}", file=f)
print("######################### English Subsets ######################", file=f)
for task in english_tasks:
num_samples = len(OCRBench_v2_score[task])
avg_score = sum(OCRBench_v2_score[task]) / num_samples if num_samples > 0 else 0
num_samples = len(score_buckets[task])
avg_score = sum(score_buckets[task]) / num_samples if num_samples > 0 else 0
print(f"{task.replace('_', ' ').title()} (Total {num_samples}): {avg_score:.2f}", file=f)
print(f"Overall English Score: {OCRBench_v2_English_subset_score:.2f}", file=f)
print("######################### Chinese Subsets ######################", file=f)
for task in chinese_tasks:
num_samples = len(OCRBench_v2_score[task])
avg_score = sum(OCRBench_v2_score[task]) / num_samples if num_samples > 0 else 0
num_samples = len(score_buckets[task])
avg_score = sum(score_buckets[task]) / num_samples if num_samples > 0 else 0
print(f"{task.replace('_', ' ').title()} (Total {num_samples}): {avg_score:.2f}", file=f)
print(f"Overall Chinese Score: {OCRBench_v2_Chinese_subset_score:.2f}", file=f)
print("######################### Final Score ##########################", file=f)
Expand Down
25 changes: 7 additions & 18 deletions lmms_eval/tasks/ocrbench_v2/vqa_metric.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import math
import re

import ipdb


def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
Expand All @@ -26,10 +24,7 @@ def vqa_evaluation(predict, answers):
for j in range(len(answers)):
if isinstance(answers[j], (int, float)):
answers[j] = str(answers[j])
try:
answer = answers[j].lower().strip().replace("\n", " ")
except:
ipdb.set_trace()
answer = str(answers[j]).lower().strip().replace("\n", " ")
if isinstance(predict, (int, float)):
predict = str(predict)
predict = predict.lower().strip().replace("\n", " ")
Expand Down Expand Up @@ -69,10 +64,7 @@ def cn_vqa_evaluation(predict, answers):
for j in range(len(answers)):
if isinstance(answers[j], (int, float)):
answers[j] = str(answers[j])
try:
answer = answers[j].lower().strip().replace("\n", " ").replace(" ", "")
except:
ipdb.set_trace()
answer = str(answers[j]).lower().strip().replace("\n", " ").replace(" ", "")
if isinstance(predict, (int, float)):
predict = str(predict)
predict = predict.lower().strip().replace("\n", " ").replace(" ", "")
Expand All @@ -91,7 +83,7 @@ def cn_vqa_evaluation(predict, answers):
else:
answers = answers.lower().strip().replace("\n", " ").replace(" ", "")
predict = predict.lower().strip().replace("\n", " ").replace(" ", "")
if len(answer.split(",")) < 4:
if len(answers.split(",")) < 4:
if answers in predict:
score = 1
else:
Expand All @@ -112,10 +104,7 @@ def vqa_evaluation_case_sensitive(predict, answers):
for j in range(len(answers)):
if isinstance(answers[j], (int, float)):
answers[j] = str(answers[j])
try:
answer = answers[j].strip().replace("\n", " ")
except:
ipdb.set_trace()
answer = str(answers[j]).strip().replace("\n", " ")
predict = predict.strip().replace("\n", " ")
if len(answer.split()) < 5:
if answer in predict:
Expand Down Expand Up @@ -195,16 +184,16 @@ def counting_evaluation(predict, answers, eval_method):

else:
answers = answers.lower().strip().replace("\n", " ")
predict = predict.lower().strip().replace("\n", " ")
predict = predict_processed
if eval_method == "exact match":
if answer in predict:
if answers in predict:
score = 1
else:
score = 0
elif eval_method == "regression":
predict = extract_first_number(predict)
if predict:
answer = int(answer)
answer = int(answers)
if predict <= 0 or predict >= 2 * answer:
score = 0
else:
Expand Down
Loading