Skip to content

Commit 15c32bf

Browse files
authored
fix ocrbench v2 scoring regressions (#1229)
1 parent ac5ad5b commit 15c32bf

4 files changed

Lines changed: 179 additions & 114 deletions

File tree

lmms_eval/tasks/ocrbench_v2/spotting_metric.py

Lines changed: 50 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import ast
22
import os
33
import re
4-
import shutil
4+
import tempfile
55
import zipfile
66

7-
import ipdb
8-
97
import lmms_eval.tasks.ocrbench_v2.spotting_eval.rrc_evaluation_funcs_1_1 as rrc_evaluation_funcs
108
from lmms_eval.tasks.ocrbench_v2.spotting_eval.script import (
119
default_evaluation_params,
@@ -127,58 +125,53 @@ def zip_folder(source_folder, destination_zip):
127125
def spotting_evaluation(prediction_list, img_metas):
128126
score = 0
129127

130-
submit_path = "./lmms_eval/tasks/ocrbench_v2/spotting_eval/submit"
131-
gt_path = "./lmms_eval/tasks/ocrbench_v2/spotting_eval/gt"
132-
submit_zip_path = "./lmms_eval/tasks/ocrbench_v2/spotting_eval/submit.zip"
133-
gt_zip_path = "./lmms_eval/tasks/ocrbench_v2/spotting_eval/gt.zip"
134-
for file_path in [submit_path, gt_path, submit_zip_path, gt_zip_path]:
135-
if "zip" in file_path:
136-
if os.path.exists(file_path):
137-
os.remove(file_path)
138-
else:
139-
if os.path.exists(file_path):
140-
shutil.rmtree(file_path)
141-
os.makedirs(file_path)
142-
143-
res_submit_list = []
144-
for item in prediction_list:
145-
if len(item) != 5:
146-
ipdb.set_trace()
147-
x1, y1, x2, y2, rec = item
148-
if x1 >= x2 or y1 >= y2:
149-
continue
150-
151-
res_submit_list.append(",".join([str(x1), str(y1), str(x2), str(y1), str(x2), str(y2), str(x1), str(y2), rec]))
152-
153-
res_gt_list = []
154-
for bbox, rec in zip(img_metas["bbox_list"], img_metas["content"]):
155-
x_coords = bbox[0::2]
156-
y_coords = bbox[1::2]
157-
158-
x1, y1 = min(x_coords), min(y_coords)
159-
x2, y2 = max(x_coords), max(y_coords)
160-
161-
res_gt_list.append(",".join([str(x1), str(y1), str(x2), str(y1), str(x2), str(y2), str(x1), str(y2), rec]))
162-
163-
if len(res_submit_list) == 0 or len(res_gt_list) == 0:
164-
return 0
165-
166-
with open(os.path.join(submit_path, "res_img_0.txt"), "w") as f:
167-
for item in res_submit_list[:-1]:
168-
f.write(item + "\n")
169-
f.write(res_submit_list[-1])
170-
171-
with open(os.path.join(gt_path, "gt_img_0.txt"), "w") as f:
172-
for item in res_gt_list[:-1]:
173-
f.write(item + "\n")
174-
f.write(res_gt_list[-1])
175-
176-
zip_folder(submit_path, submit_zip_path)
177-
zip_folder(gt_path, gt_zip_path)
178-
179-
command = {"g": gt_zip_path, "s": submit_zip_path, "o": "./", "p": '{"IOU_CONSTRAINT":0.5}'}
180-
181-
# run rrc_evaluation_funcs
182-
result = rrc_evaluation_funcs.main_evaluation(command, default_evaluation_params, validate_data, evaluate_method)
183-
score = result["method"]["hmean"]
128+
with tempfile.TemporaryDirectory(prefix="ocrbench-v2-spotting-") as temp_dir:
129+
submit_path = os.path.join(temp_dir, "submit")
130+
gt_path = os.path.join(temp_dir, "gt")
131+
submit_zip_path = os.path.join(temp_dir, "submit.zip")
132+
gt_zip_path = os.path.join(temp_dir, "gt.zip")
133+
os.makedirs(submit_path, exist_ok=True)
134+
os.makedirs(gt_path, exist_ok=True)
135+
os.makedirs(os.path.join(temp_dir, "lmms_eval", "tasks", "ocrbench_v2", "spotting_eval"), exist_ok=True)
136+
137+
res_submit_list = []
138+
for item in prediction_list:
139+
if len(item) != 5:
140+
continue
141+
x1, y1, x2, y2, rec = item
142+
if x1 >= x2 or y1 >= y2:
143+
continue
144+
145+
res_submit_list.append(",".join([str(x1), str(y1), str(x2), str(y1), str(x2), str(y2), str(x1), str(y2), rec]))
146+
147+
res_gt_list = []
148+
for bbox, rec in zip(img_metas["bbox_list"], img_metas["content"]):
149+
x_coords = bbox[0::2]
150+
y_coords = bbox[1::2]
151+
152+
x1, y1 = min(x_coords), min(y_coords)
153+
x2, y2 = max(x_coords), max(y_coords)
154+
155+
res_gt_list.append(",".join([str(x1), str(y1), str(x2), str(y1), str(x2), str(y2), str(x1), str(y2), rec]))
156+
157+
if len(res_submit_list) == 0 or len(res_gt_list) == 0:
158+
return 0
159+
160+
with open(os.path.join(submit_path, "res_img_0.txt"), "w") as f:
161+
for item in res_submit_list[:-1]:
162+
f.write(item + "\n")
163+
f.write(res_submit_list[-1])
164+
165+
with open(os.path.join(gt_path, "gt_img_0.txt"), "w") as f:
166+
for item in res_gt_list[:-1]:
167+
f.write(item + "\n")
168+
f.write(res_gt_list[-1])
169+
170+
zip_folder(submit_path, submit_zip_path)
171+
zip_folder(gt_path, gt_zip_path)
172+
173+
command = {"g": gt_zip_path, "s": submit_zip_path, "o": temp_dir, "p": '{"IOU_CONSTRAINT":0.5}'}
174+
175+
result = rrc_evaluation_funcs.main_evaluation(command, default_evaluation_params, validate_data, evaluate_method)
176+
score = result["method"]["hmean"]
184177
return score

lmms_eval/tasks/ocrbench_v2/utils.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,23 @@
3434
vqa_evaluation_case_sensitive,
3535
)
3636

37-
# Add the following functions to your existing utils.py file
38-
OCRBench_v2_score = {
39-
"text_recognition_en": [],
40-
"text_detection_en": [],
41-
"text_spotting_en": [],
42-
"relationship_extraction_en": [],
43-
"element_parsing_en": [],
44-
"mathematical_calculation_en": [],
45-
"visual_text_understanding_en": [],
46-
"knowledge_reasoning_en": [],
47-
"text_recognition_cn": [],
48-
"relationship_extraction_cn": [],
49-
"element_parsing_cn": [],
50-
"visual_text_understanding_cn": [],
51-
"knowledge_reasoning_cn": [],
52-
}
37+
38+
def _make_score_buckets():
39+
return {
40+
"text_recognition_en": [],
41+
"text_detection_en": [],
42+
"text_spotting_en": [],
43+
"relationship_extraction_en": [],
44+
"element_parsing_en": [],
45+
"mathematical_calculation_en": [],
46+
"visual_text_understanding_en": [],
47+
"knowledge_reasoning_en": [],
48+
"text_recognition_cn": [],
49+
"relationship_extraction_cn": [],
50+
"element_parsing_cn": [],
51+
"visual_text_understanding_cn": [],
52+
"knowledge_reasoning_cn": [],
53+
}
5354

5455

5556
teds = TEDS(n_jobs=32)
@@ -253,7 +254,7 @@ def ocrbench_v2_process_results(doc, results):
253254
else:
254255
pred_chart_html = dict_to_html(pred_chart_dict)
255256
if isinstance(answer, str):
256-
answer = convert_str_to_multi_dict(pred)
257+
answer = convert_str_to_multi_dict(answer)
257258
gt_chart_html = dict_to_html(answer)
258259
score = teds.evaluate(pred_chart_html, gt_chart_html)
259260
else:
@@ -332,7 +333,7 @@ def ocrbench_v2_process_results(doc, results):
332333
score = (get_value_or_zero(ocr_metric["bleu"]) + get_value_or_zero(ocr_metric["meteor"]) + get_value_or_zero(ocr_metric["f_measure"]) + (1 - get_value_or_zero(ocr_metric["edit_dist"]))) / 4
333334
elif data_type == "full-page OCR en":
334335
if not pred:
335-
score == 0
336+
score = 0
336337
else:
337338
ocr_metric = cal_per_metrics(pred, gt_ans[0])
338339
score = (get_value_or_zero(ocr_metric["bleu"]) + get_value_or_zero(ocr_metric["meteor"]) + get_value_or_zero(ocr_metric["f_measure"]) + (1 - get_value_or_zero(ocr_metric["edit_dist"]))) / 4
@@ -372,12 +373,13 @@ def ocrbench_v2_process_results(doc, results):
372373
}
373374

374375

375-
def calculate_average_score(categories):
376-
return sum(sum(OCRBench_v2_score[cat]) / len(OCRBench_v2_score[cat]) if len(OCRBench_v2_score[cat]) > 0 else 0 for cat in categories) / len(categories)
376+
def calculate_average_score(categories, score_buckets):
377+
return sum(sum(score_buckets[cat]) / len(score_buckets[cat]) if len(score_buckets[cat]) > 0 else 0 for cat in categories) / len(categories)
377378

378379

379380
def ocrbench_v2_aggregate_accuracy(results, args):
380381
question_type_scores = {}
382+
score_buckets = _make_score_buckets()
381383

382384
for result in results:
383385
if "ignore" in result.keys() and result["ignore"] == "True":
@@ -391,43 +393,43 @@ def ocrbench_v2_aggregate_accuracy(results, args):
391393
question_type_scores[question_type].append(score)
392394

393395
if question_type in ["text recognition en", "fine-grained text recognition en", "full-page OCR en"]:
394-
OCRBench_v2_score["text_recognition_en"].append(score)
396+
score_buckets["text_recognition_en"].append(score)
395397

396398
elif question_type in ["text grounding en", "VQA with position en"]:
397-
OCRBench_v2_score["text_detection_en"].append(score)
399+
score_buckets["text_detection_en"].append(score)
398400

399401
elif question_type == "text spotting en":
400-
OCRBench_v2_score["text_spotting_en"].append(score)
402+
score_buckets["text_spotting_en"].append(score)
401403

402404
elif question_type in ["key information extraction en", "key information mapping en"]:
403-
OCRBench_v2_score["relationship_extraction_en"].append(score)
405+
score_buckets["relationship_extraction_en"].append(score)
404406

405407
elif question_type in ["document parsing en", "chart parsing en", "table parsing en", "formula recognition en"]:
406-
OCRBench_v2_score["element_parsing_en"].append(score)
408+
score_buckets["element_parsing_en"].append(score)
407409

408410
elif question_type in ["math QA en", "text counting en"]:
409-
OCRBench_v2_score["mathematical_calculation_en"].append(score)
411+
score_buckets["mathematical_calculation_en"].append(score)
410412

411413
elif question_type in ["document classification en", "cognition VQA en", "diagram QA en"]:
412-
OCRBench_v2_score["visual_text_understanding_en"].append(score)
414+
score_buckets["visual_text_understanding_en"].append(score)
413415

414416
elif question_type in ["reasoning VQA en", "science QA en", "APP agent en", "ASCII art classification en"]:
415-
OCRBench_v2_score["knowledge_reasoning_en"].append(score)
417+
score_buckets["knowledge_reasoning_en"].append(score)
416418

417419
elif question_type == "full-page OCR cn":
418-
OCRBench_v2_score["text_recognition_cn"].append(score)
420+
score_buckets["text_recognition_cn"].append(score)
419421

420422
elif question_type in ["key information extraction cn", "handwritten answer extraction cn"]:
421-
OCRBench_v2_score["relationship_extraction_cn"].append(score)
423+
score_buckets["relationship_extraction_cn"].append(score)
422424

423425
elif question_type in ["document parsing cn", "table parsing cn", "formula recognition cn"]:
424-
OCRBench_v2_score["element_parsing_cn"].append(score)
426+
score_buckets["element_parsing_cn"].append(score)
425427

426428
elif question_type == "cognition VQA cn":
427-
OCRBench_v2_score["visual_text_understanding_cn"].append(score)
429+
score_buckets["visual_text_understanding_cn"].append(score)
428430

429431
elif question_type in ["reasoning VQA cn", "text translation cn"]:
430-
OCRBench_v2_score["knowledge_reasoning_cn"].append(score)
432+
score_buckets["knowledge_reasoning_cn"].append(score)
431433

432434
else:
433435
print("No such task!")
@@ -437,8 +439,8 @@ def ocrbench_v2_aggregate_accuracy(results, args):
437439

438440
chinese_tasks = ["text_recognition_cn", "relationship_extraction_cn", "element_parsing_cn", "visual_text_understanding_cn", "knowledge_reasoning_cn"]
439441

440-
OCRBench_v2_English_subset_score = calculate_average_score(english_tasks)
441-
OCRBench_v2_Chinese_subset_score = calculate_average_score(chinese_tasks)
442+
OCRBench_v2_English_subset_score = calculate_average_score(english_tasks, score_buckets)
443+
OCRBench_v2_Chinese_subset_score = calculate_average_score(chinese_tasks, score_buckets)
442444

443445
Final_score = (OCRBench_v2_English_subset_score + OCRBench_v2_Chinese_subset_score) / 2
444446
file_name = generate_submission_file("ocrbench_v2_results.txt", args, subpath="results")
@@ -450,14 +452,14 @@ def ocrbench_v2_aggregate_accuracy(results, args):
450452
print(f"{q_type} (sample number: {len(scores)}): {avg_score:.2f}", file=f)
451453
print("######################### English Subsets ######################", file=f)
452454
for task in english_tasks:
453-
num_samples = len(OCRBench_v2_score[task])
454-
avg_score = sum(OCRBench_v2_score[task]) / num_samples if num_samples > 0 else 0
455+
num_samples = len(score_buckets[task])
456+
avg_score = sum(score_buckets[task]) / num_samples if num_samples > 0 else 0
455457
print(f"{task.replace('_', ' ').title()} (Total {num_samples}): {avg_score:.2f}", file=f)
456458
print(f"Overall English Score: {OCRBench_v2_English_subset_score:.2f}", file=f)
457459
print("######################### Chinese Subsets ######################", file=f)
458460
for task in chinese_tasks:
459-
num_samples = len(OCRBench_v2_score[task])
460-
avg_score = sum(OCRBench_v2_score[task]) / num_samples if num_samples > 0 else 0
461+
num_samples = len(score_buckets[task])
462+
avg_score = sum(score_buckets[task]) / num_samples if num_samples > 0 else 0
461463
print(f"{task.replace('_', ' ').title()} (Total {num_samples}): {avg_score:.2f}", file=f)
462464
print(f"Overall Chinese Score: {OCRBench_v2_Chinese_subset_score:.2f}", file=f)
463465
print("######################### Final Score ##########################", file=f)

lmms_eval/tasks/ocrbench_v2/vqa_metric.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import math
22
import re
33

4-
import ipdb
5-
64

75
def levenshtein_distance(s1, s2):
86
if len(s1) > len(s2):
@@ -26,10 +24,7 @@ def vqa_evaluation(predict, answers):
2624
for j in range(len(answers)):
2725
if isinstance(answers[j], (int, float)):
2826
answers[j] = str(answers[j])
29-
try:
30-
answer = answers[j].lower().strip().replace("\n", " ")
31-
except:
32-
ipdb.set_trace()
27+
answer = str(answers[j]).lower().strip().replace("\n", " ")
3328
if isinstance(predict, (int, float)):
3429
predict = str(predict)
3530
predict = predict.lower().strip().replace("\n", " ")
@@ -69,10 +64,7 @@ def cn_vqa_evaluation(predict, answers):
6964
for j in range(len(answers)):
7065
if isinstance(answers[j], (int, float)):
7166
answers[j] = str(answers[j])
72-
try:
73-
answer = answers[j].lower().strip().replace("\n", " ").replace(" ", "")
74-
except:
75-
ipdb.set_trace()
67+
answer = str(answers[j]).lower().strip().replace("\n", " ").replace(" ", "")
7668
if isinstance(predict, (int, float)):
7769
predict = str(predict)
7870
predict = predict.lower().strip().replace("\n", " ").replace(" ", "")
@@ -91,7 +83,7 @@ def cn_vqa_evaluation(predict, answers):
9183
else:
9284
answers = answers.lower().strip().replace("\n", " ").replace(" ", "")
9385
predict = predict.lower().strip().replace("\n", " ").replace(" ", "")
94-
if len(answer.split(",")) < 4:
86+
if len(answers.split(",")) < 4:
9587
if answers in predict:
9688
score = 1
9789
else:
@@ -112,10 +104,7 @@ def vqa_evaluation_case_sensitive(predict, answers):
112104
for j in range(len(answers)):
113105
if isinstance(answers[j], (int, float)):
114106
answers[j] = str(answers[j])
115-
try:
116-
answer = answers[j].strip().replace("\n", " ")
117-
except:
118-
ipdb.set_trace()
107+
answer = str(answers[j]).strip().replace("\n", " ")
119108
predict = predict.strip().replace("\n", " ")
120109
if len(answer.split()) < 5:
121110
if answer in predict:
@@ -195,16 +184,16 @@ def counting_evaluation(predict, answers, eval_method):
195184

196185
else:
197186
answers = answers.lower().strip().replace("\n", " ")
198-
predict = predict.lower().strip().replace("\n", " ")
187+
predict = predict_processed
199188
if eval_method == "exact match":
200-
if answer in predict:
189+
if answers in predict:
201190
score = 1
202191
else:
203192
score = 0
204193
elif eval_method == "regression":
205194
predict = extract_first_number(predict)
206195
if predict:
207-
answer = int(answer)
196+
answer = int(answers)
208197
if predict <= 0 or predict >= 2 * answer:
209198
score = 0
210199
else:

0 commit comments

Comments
 (0)