Skip to content

Commit 22db32b

Browse files
committed
update xycut_enhanced and add region detection
1 parent 1b904cd commit 22db32b

File tree

7 files changed

+1100
-624
lines changed

7 files changed

+1100
-624
lines changed

paddlex/inference/models/table_structure_recognition/processors.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,7 @@ def __call__(self, pred, img_size, ori_img_size):
130130
structure_probs, bbox_preds, img_size, ori_img_size
131131
)
132132
structure_str_list = [
133-
(
134-
["<html>", "<body>", "<table>"]
135-
+ structure
136-
+ ["</table>", "</body>", "</html>"]
137-
)
138-
for structure in structure_str_list
133+
(["<table>"] + structure + ["</table>"]) for structure in structure_str_list
139134
]
140135
return [
141136
{"bbox": bbox, "structure": structure, "structure_score": structure_score}

paddlex/inference/pipelines/layout_parsing/pipeline_v2.py

Lines changed: 365 additions & 139 deletions
Large diffs are not rendered by default.

paddlex/inference/pipelines/layout_parsing/result_v2.py

Lines changed: 143 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import annotations
1515

1616
import copy
17+
import math
1718
import re
1819
from pathlib import Path
1920
from typing import List
@@ -73,6 +74,9 @@ def _to_img(self) -> dict[str, np.ndarray]:
7374
res_img_dict[key] = value
7475
res_img_dict["layout_det_res"] = self["layout_det_res"].img["res"]
7576

77+
if model_settings["use_region_detection"]:
78+
res_img_dict["region_det_res"] = self["region_det_res"].img["res"]
79+
7680
if model_settings["use_general_ocr"] or model_settings["use_table_recognition"]:
7781
res_img_dict["overall_ocr_res"] = self["overall_ocr_res"].img["ocr_res_img"]
7882

@@ -283,22 +287,33 @@ def format_title(title):
283287
" ",
284288
)
285289

290+
# def format_centered_text():
291+
# return (
292+
# f'<div style="text-align: center;">{block.content}</div>'.replace(
293+
# "-\n",
294+
# "",
295+
# ).replace("\n", " ")
296+
# + "\n"
297+
# )
298+
286299
def format_centered_text():
287-
return (
288-
f'<div style="text-align: center;">{block.content}</div>'.replace(
289-
"-\n",
290-
"",
291-
).replace("\n", " ")
292-
+ "\n"
293-
)
300+
return block.content
301+
302+
# def format_image():
303+
# img_tags = []
304+
# image_path = "".join(block.image.keys())
305+
# img_tags.append(
306+
# '<div style="text-align: center;"><img src="{}" alt="Image" /></div>'.format(
307+
# image_path.replace("-\n", "").replace("\n", " "),
308+
# ),
309+
# )
310+
# return "\n".join(img_tags)
294311

295312
def format_image():
296313
img_tags = []
297314
image_path = "".join(block.image.keys())
298315
img_tags.append(
299-
'<div style="text-align: center;"><img src="{}" alt="Image" /></div>'.format(
300-
image_path.replace("-\n", "").replace("\n", " "),
301-
),
316+
"![]({})".format(image_path.replace("-\n", "").replace("\n", " "))
302317
)
303318
return "\n".join(img_tags)
304319

@@ -332,7 +347,7 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
332347
num_of_prev_lines = prev_block.num_of_lines
333348
pre_block_seg_end_coordinate = prev_block.seg_end_coordinate
334349
prev_end_space_small = (
335-
context_right_coordinate - pre_block_seg_end_coordinate < 10
350+
abs(prev_block_bbox[2] - pre_block_seg_end_coordinate) < 10
336351
)
337352
prev_lines_more_than_one = num_of_prev_lines > 1
338353

@@ -347,8 +362,12 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
347362
prev_block_bbox[2], context_right_coordinate
348363
)
349364
prev_end_space_small = (
350-
prev_block_bbox[2] - pre_block_seg_end_coordinate < 10
365+
abs(context_right_coordinate - pre_block_seg_end_coordinate)
366+
< 10
351367
)
368+
edge_distance = 0
369+
else:
370+
edge_distance = abs(block_box[0] - prev_block_bbox[2])
352371

353372
current_start_space_small = (
354373
seg_start_coordinate - context_left_coordinate < 10
@@ -358,6 +377,7 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
358377
prev_end_space_small
359378
and current_start_space_small
360379
and prev_lines_more_than_one
380+
and edge_distance < max(prev_block.width, block.width)
361381
):
362382
seg_start_flag = False
363383
else:
@@ -371,14 +391,19 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
371391

372392
handlers = {
373393
"paragraph_title": lambda: format_title(block.content),
394+
"abstract_title": lambda: format_title(block.content),
395+
"reference_title": lambda: format_title(block.content),
396+
"content_title": lambda: format_title(block.content),
374397
"doc_title": lambda: f"# {block.content}".replace(
375398
"-\n",
376399
"",
377400
).replace("\n", " "),
378401
"table_title": lambda: format_centered_text(),
379402
"figure_title": lambda: format_centered_text(),
380403
"chart_title": lambda: format_centered_text(),
381-
"text": lambda: block.content.replace("-\n", " ").replace("\n", " "),
404+
"text": lambda: block.content.replace("\n\n", "\n").replace(
405+
"\n", "\n\n"
406+
),
382407
"abstract": lambda: format_first_line(
383408
["摘要", "abstract"], lambda l: f"## {l}\n", " "
384409
),
@@ -416,24 +441,7 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
416441
if handler:
417442
prev_block = block
418443
if label == last_label == "text" and seg_start_flag == False:
419-
last_char_of_markdown = (
420-
markdown_content[-1] if markdown_content else ""
421-
)
422-
first_char_of_handler = handler()[0] if handler() else ""
423-
last_is_chinese_char = (
424-
re.match(r"[\u4e00-\u9fff]", last_char_of_markdown)
425-
if last_char_of_markdown
426-
else False
427-
)
428-
first_is_chinese_char = (
429-
re.match(r"[\u4e00-\u9fff]", first_char_of_handler)
430-
if first_char_of_handler
431-
else False
432-
)
433-
if not (last_is_chinese_char or first_is_chinese_char):
434-
markdown_content += " " + handler()
435-
else:
436-
markdown_content += handler()
444+
markdown_content += handler()
437445
else:
438446
markdown_content += (
439447
"\n\n" + handler() if markdown_content else handler()
@@ -467,7 +475,7 @@ class LayoutParsingBlock:
467475

468476
def __init__(self, label, bbox, content="") -> None:
469477
self.label = label
470-
self.region_label = "other"
478+
self.order_label = "other"
471479
self.bbox = [int(item) for item in bbox]
472480
self.content = content
473481
self.seg_start_coordinate = float("inf")
@@ -479,39 +487,39 @@ def __init__(self, label, bbox, content="") -> None:
479487
self.image = None
480488
self.index = None
481489
self.visual_index = None
482-
self.direction = self.get_bbox_direction()
490+
self.orientation = self.get_bbox_orientation()
483491
self.child_blocks = []
484-
self.update_direction_info()
492+
self.update_orientation_info()
485493

486494
def __str__(self) -> str:
487495
return f"{self.__dict__}"
488496

489497
def __repr__(self) -> str:
490-
_str = f"\n\n#################\nlabel:\t{self.label}\nregion_label:\t{self.region_label}\nbbox:\t{self.bbox}\ncontent:\t{self.content}\n#################"
498+
_str = f"\n\n#################\nlabel:\t{self.label}\nregion_label:\t{self.order_label}\nbbox:\t{self.bbox}\ncontent:\t{self.content}\n#################"
491499
return _str
492500

493501
def to_dict(self) -> dict:
494502
return self.__dict__
495503

496-
def update_direction_info(self) -> None:
497-
if self.region_label == "vision":
498-
self.direction = "horizontal"
499-
if self.direction == "horizontal":
500-
self.secondary_direction = "vertical"
504+
def update_orientation_info(self) -> None:
505+
if self.order_label == "vision":
506+
self.orientation = "horizontal"
507+
if self.orientation == "horizontal":
508+
self.secondary_orientation = "vertical"
501509
self.short_side_length = self.height
502510
self.long_side_length = self.width
503511
self.start_coordinate = self.bbox[0]
504512
self.end_coordinate = self.bbox[2]
505-
self.secondary_direction_start_coordinate = self.bbox[1]
506-
self.secondary_direction_end_coordinate = self.bbox[3]
513+
self.secondary_orientation_start_coordinate = self.bbox[1]
514+
self.secondary_orientation_end_coordinate = self.bbox[3]
507515
else:
508-
self.secondary_direction = "horizontal"
516+
self.secondary_orientation = "horizontal"
509517
self.short_side_length = self.width
510518
self.long_side_length = self.height
511519
self.start_coordinate = self.bbox[1]
512520
self.end_coordinate = self.bbox[3]
513-
self.secondary_direction_start_coordinate = self.bbox[0]
514-
self.secondary_direction_end_coordinate = self.bbox[2]
521+
self.secondary_orientation_start_coordinate = self.bbox[0]
522+
self.secondary_orientation_end_coordinate = self.bbox[2]
515523

516524
def append_child_block(self, child_block: LayoutParsingBlock) -> None:
517525
if not self.child_blocks:
@@ -525,7 +533,7 @@ def append_child_block(self, child_block: LayoutParsingBlock) -> None:
525533
max(y2, y2_child),
526534
)
527535
self.bbox = union_bbox
528-
self.update_direction_info()
536+
self.update_orientation_info()
529537
child_blocks = [child_block]
530538
if child_block.child_blocks:
531539
child_blocks.extend(child_block.get_child_blocks())
@@ -542,7 +550,7 @@ def get_centroid(self) -> tuple:
542550
centroid = ((x1 + x2) / 2, (y1 + y2) / 2)
543551
return centroid
544552

545-
def get_bbox_direction(self, orientation_ratio: float = 1.0) -> bool:
553+
def get_bbox_orientation(self, orientation_ratio: float = 1.0) -> bool:
546554
"""
547555
Determine if a bounding box is horizontal or vertical.
548556
@@ -558,3 +566,91 @@ def get_bbox_direction(self, orientation_ratio: float = 1.0) -> bool:
558566
if self.width * orientation_ratio >= self.height
559567
else "vertical"
560568
)
569+
570+
571+
class LayoutParsingRegion:
572+
573+
def __init__(
574+
self, region_bbox, blocks: List[LayoutParsingBlock] = [], block_label_mapping={}
575+
) -> None:
576+
self.region_bbox = region_bbox
577+
self.blocks = blocks
578+
self.block_map = {}
579+
self.update_config(block_label_mapping)
580+
self.orientation = None
581+
self.calculate_bbox_metrics()
582+
583+
def update_config(self, block_label_mapping):
584+
self.block_map = {}
585+
self.config = copy.deepcopy(block_label_mapping)
586+
self.config["region_bbox"] = self.region_bbox
587+
horizontal_text_block_num = 0
588+
for idx, block in enumerate(self.blocks):
589+
label = block.label
590+
if (
591+
block.order_label not in ["vision", "vision_title"]
592+
and block.orientation == "horizontal"
593+
):
594+
horizontal_text_block_num += 1
595+
self.block_map[idx] = block
596+
self.update_layout_order_config_block_index(label, idx)
597+
text_block_num = (
598+
len(self.blocks)
599+
- len(self.config["vision_block_idxes"])
600+
- len(self.config["vision_title_block_idxes"])
601+
)
602+
self.orientation = (
603+
"horizontal"
604+
if horizontal_text_block_num >= text_block_num * 0.5
605+
else "vertical"
606+
)
607+
self.config["region_orientation"] = self.orientation
608+
609+
def calculate_bbox_metrics(self):
610+
x1, y1, x2, y2 = self.region_bbox
611+
x_center, y_center = (x1 + x2) / 2, (y1 + y2) / 2
612+
self.euclidean_distance = math.sqrt(((x1) ** 2 + (y1) ** 2))
613+
self.center_euclidean_distance = math.sqrt(((x_center) ** 2 + (y_center) ** 2))
614+
self.angle_rad = math.atan2(y_center, x_center)
615+
616+
def sort(self):
617+
from .xycut_enhanced import xycut_enhanced
618+
619+
return xycut_enhanced(self.blocks, self.config)
620+
621+
def update_layout_order_config_block_index(
622+
self, block_label: str, block_idx: int
623+
) -> None:
624+
doc_title_labels = self.config["doc_title_labels"]
625+
paragraph_title_labels = self.config["paragraph_title_labels"]
626+
vision_labels = self.config["vision_labels"]
627+
vision_title_labels = self.config["vision_title_labels"]
628+
header_labels = self.config["header_labels"]
629+
unordered_labels = self.config["unordered_labels"]
630+
footer_labels = self.config["footer_labels"]
631+
text_labels = self.config["text_labels"]
632+
self.config.setdefault("doc_title_block_idxes", [])
633+
self.config.setdefault("paragraph_title_block_idxes", [])
634+
self.config.setdefault("vision_block_idxes", [])
635+
self.config.setdefault("vision_title_block_idxes", [])
636+
self.config.setdefault("unordered_block_idxes", [])
637+
self.config.setdefault("text_block_idxes", [])
638+
self.config.setdefault("header_block_idxes", [])
639+
self.config.setdefault("footer_block_idxes", [])
640+
641+
if block_label in doc_title_labels:
642+
self.config["doc_title_block_idxes"].append(block_idx)
643+
if block_label in paragraph_title_labels:
644+
self.config["paragraph_title_block_idxes"].append(block_idx)
645+
if block_label in vision_labels:
646+
self.config["vision_block_idxes"].append(block_idx)
647+
if block_label in vision_title_labels:
648+
self.config["vision_title_block_idxes"].append(block_idx)
649+
if block_label in unordered_labels:
650+
self.config["unordered_block_idxes"].append(block_idx)
651+
if block_label in text_labels:
652+
self.config["text_block_idxes"].append(block_idx)
653+
if block_label in header_labels:
654+
self.config["header_block_idxes"].append(block_idx)
655+
if block_label in footer_labels:
656+
self.config["footer_block_idxes"].append(block_idx)

0 commit comments

Comments
 (0)