Skip to content

Commit e23a69c

Browse files
changdazhouTingquanGao
authored andcommitted
support sort by different text line
1 parent 9f0fa73 commit e23a69c

File tree

6 files changed

+827
-624
lines changed

6 files changed

+827
-624
lines changed

paddlex/inference/pipelines/layout_parsing/pipeline_v2.py

Lines changed: 58 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ..base import BasePipeline
3131
from ..ocr.result import OCRResult
3232
from .result_v2 import LayoutParsingBlock, LayoutParsingRegion, LayoutParsingResultV2
33+
from .setting import BLOCK_LABEL_MAP, BLOCK_SETTINGS, LINE_SETTINGS, REGION_SETTINGS
3334
from .utils import (
3435
caculate_bbox_area,
3536
calculate_minimum_enclosing_bbox,
@@ -260,8 +261,6 @@ def check_model_settings_valid(self, input_params: dict) -> bool:
260261
def standardized_data(
261262
self,
262263
image: list,
263-
parameters_config: dict,
264-
block_label_mapping: dict,
265264
region_det_res: DetResult,
266265
layout_det_res: DetResult,
267266
overall_ocr_res: OCRResult,
@@ -360,7 +359,7 @@ def standardized_data(
360359
paragraph_title_block_area = caculate_bbox_area(
361360
layout_det_res["boxes"][paragraph_title_list[0]]["coordinate"]
362361
)
363-
title_area_max_block_threshold = parameters_config["block"].get(
362+
title_area_max_block_threshold = BLOCK_SETTINGS.get(
364363
"title_conversion_area_ratio_threshold", 0.3
365364
)
366365
if (
@@ -441,7 +440,7 @@ def standardized_data(
441440
break
442441
if not has_text and layout_det_res["boxes"][layout_box_idx][
443442
"label"
444-
] not in block_label_mapping.get("vision_labels", []):
443+
] not in BLOCK_LABEL_MAP.get("vision_labels", []):
445444
crop_box = layout_det_res["boxes"][layout_box_idx]["coordinate"]
446445
x1, y1, x2, y2 = [int(i) for i in crop_box]
447446
crop_img = np.array(image)[y1:y2, x1:x2]
@@ -506,7 +505,7 @@ def standardized_data(
506505
overlap_ratio = calculate_overlap_ratio(
507506
region_bbox, block_bboxes[block_idx], mode="small"
508507
)
509-
if overlap_ratio > parameters_config["region"].get(
508+
if overlap_ratio > REGION_SETTINGS.get(
510509
"match_block_overlap_ratio_threshold", 0.8
511510
):
512511
region_to_block_map[region_idx].append(block_idx)
@@ -540,7 +539,6 @@ def standardized_data(
540539
image.shape[0],
541540
block_idxes_set,
542541
block_bboxes,
543-
parameters_config,
544542
)
545543
)
546544
if len(matched_idxes) == 0:
@@ -570,7 +568,7 @@ def sort_line_by_projection(
570568
input_img: np.ndarray,
571569
text_rec_model: Any,
572570
text_rec_score_thresh: Union[float, None] = None,
573-
orientation: str = "vertical",
571+
direction: str = "vertical",
574572
) -> None:
575573
"""
576574
Sort a line of text spans based on their vertical position within the layout bounding box.
@@ -583,8 +581,8 @@ def sort_line_by_projection(
583581
Returns:
584582
list: The sorted line of text spans.
585583
"""
586-
sort_index = 0 if orientation == "horizontal" else 1
587-
splited_boxes = split_boxes_by_projection(line, orientation)
584+
sort_index = 0 if direction == "horizontal" else 1
585+
splited_boxes = split_boxes_by_projection(line, direction)
588586
splited_lines = []
589587
if len(line) != len(splited_boxes):
590588
splited_boxes.sort(key=lambda span: span[0][sort_index])
@@ -614,7 +612,6 @@ def sort_line_by_projection(
614612
def get_block_rec_content(
615613
self,
616614
image: list,
617-
line_parameters_config: dict,
618615
ocr_rec_res: dict,
619616
block: LayoutParsingBlock,
620617
text_rec_model: Any,
@@ -625,37 +622,49 @@ def get_block_rec_content(
625622
block.content = ""
626623
return block
627624

628-
lines, text_orientation = group_boxes_into_lines(
625+
lines, text_direction = group_boxes_into_lines(
629626
ocr_rec_res,
630-
line_parameters_config.get("line_height_iou_threshold", 0.8),
627+
LINE_SETTINGS.get("line_height_iou_threshold", 0.8),
631628
)
632629

633630
if block.label == "reference":
634631
rec_boxes = ocr_rec_res["boxes"]
635632
block_right_coordinate = max([box[2] for box in rec_boxes])
636-
last_line_span_limit = 20
637633
else:
638634
block_right_coordinate = block.bbox[2]
639-
last_line_span_limit = 10
640635

641636
# format line
642637
text_lines = []
643638
need_new_line_num = 0
644-
sort_index = 0 if text_orientation == "horizontal" else 1
639+
start_index = 0 if text_direction == "horizontal" else 1
640+
secondary_direction_start_index = 1 if text_direction == "horizontal" else 0
641+
line_height_list, line_width_list = [], []
645642
for idx, line in enumerate(lines):
646-
line.sort(key=lambda span: span[0][sort_index])
643+
line.sort(key=lambda span: span[0][start_index])
647644

645+
text_bboxes_height = [
646+
span[0][secondary_direction_start_index + 2]
647+
- span[0][secondary_direction_start_index]
648+
for span in line
649+
]
650+
text_bboxes_width = [
651+
span[0][start_index + 2] - span[0][start_index] for span in line
652+
]
653+
654+
line_height = np.mean(text_bboxes_height)
655+
line_height_list.append(line_height)
656+
line_width_list.append(np.mean(text_bboxes_width))
648657
# merge formula and text
649658
ocr_labels = [span[2] for span in line]
650659
if "formula" in ocr_labels:
651660
line = self.sort_line_by_projection(
652-
line, image, text_rec_model, text_rec_score_thresh, text_orientation
661+
line, image, text_rec_model, text_rec_score_thresh, text_direction
653662
)
654663

655664
line_text, need_new_line = format_line(
656665
line,
657666
block_right_coordinate,
658-
last_line_span_limit=last_line_span_limit,
667+
last_line_span_limit=line_height * 1.5,
659668
block_label=block.label,
660669
)
661670
if need_new_line:
@@ -668,21 +677,21 @@ def get_block_rec_content(
668677
block.seg_end_coordinate = line_end_coordinate
669678
text_lines.append(line_text)
670679

671-
delim = line_parameters_config["delimiter_map"].get(block.label, "")
680+
delim = LINE_SETTINGS["delimiter_map"].get(block.label, "")
672681
if need_new_line_num > len(text_lines) * 0.5 and delim == "":
673682
delim = "\n"
674683
content = delim.join(text_lines)
675684
block.content = content
676685
block.num_of_lines = len(text_lines)
677-
block.orientation = text_orientation
686+
block.direction = text_direction
687+
block.text_line_height = np.mean(line_height_list)
688+
block.text_line_width = np.mean(line_width_list)
678689

679690
return block
680691

681692
def get_layout_parsing_blocks(
682693
self,
683694
image: list,
684-
parameters_config: dict,
685-
block_label_mapping: dict,
686695
region_block_ocr_idx_map: dict,
687696
region_det_res: DetResult,
688697
overall_ocr_res: OCRResult,
@@ -759,7 +768,6 @@ def get_layout_parsing_blocks(
759768
block = self.get_block_rec_content(
760769
image=image,
761770
block=block,
762-
line_parameters_config=parameters_config["line"],
763771
ocr_rec_res=rec_res,
764772
text_rec_model=text_rec_model,
765773
text_rec_score_thresh=text_rec_score_thresh,
@@ -781,9 +789,8 @@ def get_layout_parsing_blocks(
781789
for idx in region_block_ocr_idx_map["region_to_block_map"][region_idx]
782790
]
783791
region = LayoutParsingRegion(
784-
region_bbox=region_bbox,
792+
bbox=region_bbox,
785793
blocks=region_blocks,
786-
block_label_mapping=block_label_mapping,
787794
)
788795
region_list.append(region)
789796

@@ -818,14 +825,11 @@ def get_layout_parsing_res(
818825
Returns:
819826
list: A list of dictionaries representing the layout parsing result.
820827
"""
821-
from .setting import block_label_mapping, parameters_config
822828

823829
# Standardize data
824830
region_block_ocr_idx_map, region_det_res, layout_det_res = (
825831
self.standardized_data(
826832
image=image,
827-
parameters_config=parameters_config,
828-
block_label_mapping=block_label_mapping,
829833
region_det_res=region_det_res,
830834
layout_det_res=layout_det_res,
831835
overall_ocr_res=overall_ocr_res,
@@ -838,8 +842,6 @@ def get_layout_parsing_res(
838842
# Format layout parsing block
839843
region_list = self.get_layout_parsing_blocks(
840844
image=image,
841-
parameters_config=parameters_config,
842-
block_label_mapping=block_label_mapping,
843845
region_block_ocr_idx_map=region_block_ocr_idx_map,
844846
region_det_res=region_det_res,
845847
overall_ocr_res=overall_ocr_res,
@@ -854,11 +856,10 @@ def get_layout_parsing_res(
854856
for region in region_list:
855857
parsing_res_list.extend(region.sort())
856858

857-
visualize_index_labels = block_label_mapping["visualize_index_labels"]
858859
index = 1
859860
for block in parsing_res_list:
860-
if block.label in visualize_index_labels:
861-
block.index = index
861+
if block.label in BLOCK_LABEL_MAP["visualize_index_labels"]:
862+
block.order_index = index
862863
index += 1
863864

864865
return parsing_res_list
@@ -956,8 +957,6 @@ def predict(
956957
use_e2e_wired_table_rec_model: bool = False,
957958
use_e2e_wireless_table_rec_model: bool = True,
958959
is_pretty_markdown: Union[bool, None] = None,
959-
use_layout_gt: bool = False,
960-
layout_gt_dir: Union[str, None] = None,
961960
**kwargs,
962961
) -> LayoutParsingResultV2:
963962
"""
@@ -1032,65 +1031,16 @@ def predict(
10321031

10331032
doc_preprocessor_image = doc_preprocessor_res["output_img"]
10341033

1035-
use_layout_gt = use_layout_gt
1036-
if not use_layout_gt:
1037-
layout_det_res = next(
1038-
self.layout_det_model(
1039-
doc_preprocessor_image,
1040-
threshold=layout_threshold,
1041-
layout_nms=layout_nms,
1042-
layout_unclip_ratio=layout_unclip_ratio,
1043-
layout_merge_bboxes_mode=layout_merge_bboxes_mode,
1044-
)
1034+
layout_det_res = next(
1035+
self.layout_det_model(
1036+
doc_preprocessor_image,
1037+
threshold=layout_threshold,
1038+
layout_nms=layout_nms,
1039+
layout_unclip_ratio=layout_unclip_ratio,
1040+
layout_merge_bboxes_mode=layout_merge_bboxes_mode,
10451041
)
1046-
else:
1047-
import json
1048-
import os
1049-
1050-
from ...models.object_detection.result import DetResult
1051-
1052-
label_dir = layout_gt_dir
1053-
notes_path = f"{label_dir}/notes.json"
1054-
labels = f"{label_dir}/labels"
1055-
gt_file = os.path.basename(input)[:-4] + ".txt"
1056-
gt_path = f"{labels}/{gt_file}"
1057-
with open(notes_path, "r") as f:
1058-
notes = json.load(f)
1059-
categories_map = {}
1060-
for categories in notes["categories"]:
1061-
id = int(categories["id"])
1062-
name = categories["name"]
1063-
categories_map[id] = name
1064-
with open(gt_path, "r") as f:
1065-
lines = f.readlines()
1066-
layout_det_res_dic = {
1067-
"input_img": doc_preprocessor_image,
1068-
"page_index": None,
1069-
"boxes": [],
1070-
}
1071-
for line in lines:
1072-
line = line.strip().split(" ")
1073-
category_id = int(line[0])
1074-
label = categories_map[category_id]
1075-
img_h, img_w = doc_preprocessor_image.shape[:2]
1076-
center_x = float(line[1]) * img_w
1077-
center_y = float(line[2]) * img_h
1078-
w = float(line[3]) * img_w
1079-
h = float(line[4]) * img_h
1080-
x0 = center_x - w / 2
1081-
y0 = center_y - h / 2
1082-
x1 = center_x + w / 2
1083-
y1 = center_y + h / 2
1084-
box = [x0, y0, x1, y1]
1085-
layout_det_res_dic["boxes"].append(
1086-
{
1087-
"cls_id": category_id,
1088-
"label": label,
1089-
"coordinate": box,
1090-
"score": 1.0,
1091-
}
1092-
)
1093-
layout_det_res = DetResult(layout_det_res_dic)
1042+
)
1043+
10941044
imgs_in_doc = gather_imgs(doc_preprocessor_image, layout_det_res["boxes"])
10951045

10961046
if model_settings["use_region_detection"]:
@@ -1139,7 +1089,13 @@ def predict(
11391089
),
11401090
)
11411091
else:
1142-
overall_ocr_res = {}
1092+
overall_ocr_res = {
1093+
"dt_polys": [],
1094+
"rec_texts": [],
1095+
"rec_scores": [],
1096+
"rec_polys": [],
1097+
"rec_boxes": np.array([]),
1098+
}
11431099

11441100
overall_ocr_res["rec_labels"] = ["text"] * len(overall_ocr_res["rec_texts"])
11451101

@@ -1157,9 +1113,14 @@ def predict(
11571113
table_contents["rec_texts"].append(
11581114
f"${formula_res['rec_formula']}$"
11591115
)
1160-
table_contents["rec_boxes"] = np.vstack(
1161-
(table_contents["rec_boxes"], [formula_res["dt_polys"]])
1162-
)
1116+
if table_contents["rec_boxes"].size == 0:
1117+
table_contents["rec_boxes"] = np.array(
1118+
[formula_res["dt_polys"]]
1119+
)
1120+
else:
1121+
table_contents["rec_boxes"] = np.vstack(
1122+
(table_contents["rec_boxes"], [formula_res["dt_polys"]])
1123+
)
11631124
table_contents["rec_polys"].append(poly_points)
11641125
table_contents["rec_scores"].append(1)
11651126

0 commit comments

Comments
 (0)