Skip to content

update xycut_enhanced and add region detection #3930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,7 @@ def __call__(self, pred, img_size, ori_img_size):
structure_probs, bbox_preds, img_size, ori_img_size
)
structure_str_list = [
(
["<html>", "<body>", "<table>"]
+ structure
+ ["</table>", "</body>", "</html>"]
)
for structure in structure_str_list
(["<table>"] + structure + ["</table>"]) for structure in structure_str_list
]
return [
{"bbox": bbox, "structure": structure, "structure_score": structure_score}
Expand Down
459 changes: 323 additions & 136 deletions paddlex/inference/pipelines/layout_parsing/pipeline_v2.py

Large diffs are not rendered by default.

220 changes: 175 additions & 45 deletions paddlex/inference/pipelines/layout_parsing/result_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,23 @@
from __future__ import annotations

import copy
import math
import re
from pathlib import Path
from typing import List

import numpy as np
from PIL import Image, ImageDraw
from PIL import Image, ImageDraw, ImageFont

from ....utils.fonts import PINGFANG_FONT_FILE_PATH
from ...common.result import (
BaseCVResult,
HtmlMixin,
JsonMixin,
MarkdownMixin,
XlsxMixin,
)
from .setting import BLOCK_LABEL_MAP


class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
Expand Down Expand Up @@ -73,6 +76,9 @@ def _to_img(self) -> dict[str, np.ndarray]:
res_img_dict[key] = value
res_img_dict["layout_det_res"] = self["layout_det_res"].img["res"]

if model_settings["use_region_detection"]:
res_img_dict["region_det_res"] = self["region_det_res"].img["res"]

if model_settings["use_general_ocr"] or model_settings["use_table_recognition"]:
res_img_dict["overall_ocr_res"] = self["overall_ocr_res"].img["ocr_res_img"]

Expand Down Expand Up @@ -103,16 +109,23 @@ def _to_img(self) -> dict[str, np.ndarray]:
# for layout ordering image
image = Image.fromarray(self["doc_preprocessor_res"]["output_img"][:, :, ::-1])
draw = ImageDraw.Draw(image, "RGBA")
font_size = int(0.018 * int(image.width)) + 2
font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
parsing_result: List[LayoutParsingBlock] = self["parsing_res_list"]
for block in parsing_result:
bbox = block.bbox
index = block.index
label = block.label
fill_color = get_show_color(label)
index = block.order_index
label = block.order_label
fill_color = get_show_color(label, True)
draw.rectangle(bbox, fill=fill_color)
if index is not None:
text_position = (bbox[2] + 2, bbox[1] - 10)
draw.text(text_position, str(index), fill="red")
text_position = (bbox[2] + 2, bbox[1] - font_size // 2)
if int(image.width) - bbox[2] < font_size:
text_position = (
int(bbox[2] - font_size * 1.1),
bbox[1] - font_size // 2,
)
draw.text(text_position, str(index), font=font, fill="red")

res_img_dict["layout_order_res"] = image

Expand Down Expand Up @@ -283,22 +296,33 @@ def format_title(title):
" ",
)

# def format_centered_text():
# return (
# f'<div style="text-align: center;">{block.content}</div>'.replace(
# "-\n",
# "",
# ).replace("\n", " ")
# + "\n"
# )

def format_centered_text():
return (
f'<div style="text-align: center;">{block.content}</div>'.replace(
"-\n",
"",
).replace("\n", " ")
+ "\n"
)
return block.content

# def format_image():
# img_tags = []
# image_path = "".join(block.image.keys())
# img_tags.append(
# '<div style="text-align: center;"><img src="{}" alt="Image" /></div>'.format(
# image_path.replace("-\n", "").replace("\n", " "),
# ),
# )
# return "\n".join(img_tags)

def format_image():
img_tags = []
image_path = "".join(block.image.keys())
img_tags.append(
'<div style="text-align: center;"><img src="{}" alt="Image" /></div>'.format(
image_path.replace("-\n", "").replace("\n", " "),
),
"![]({})".format(image_path.replace("-\n", "").replace("\n", " "))
)
return "\n".join(img_tags)

Expand Down Expand Up @@ -332,7 +356,7 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
num_of_prev_lines = prev_block.num_of_lines
pre_block_seg_end_coordinate = prev_block.seg_end_coordinate
prev_end_space_small = (
context_right_coordinate - pre_block_seg_end_coordinate < 10
abs(prev_block_bbox[2] - pre_block_seg_end_coordinate) < 10
)
prev_lines_more_than_one = num_of_prev_lines > 1

Expand All @@ -347,8 +371,12 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
prev_block_bbox[2], context_right_coordinate
)
prev_end_space_small = (
prev_block_bbox[2] - pre_block_seg_end_coordinate < 10
abs(context_right_coordinate - pre_block_seg_end_coordinate)
< 10
)
edge_distance = 0
else:
edge_distance = abs(block_box[0] - prev_block_bbox[2])

current_start_space_small = (
seg_start_coordinate - context_left_coordinate < 10
Expand All @@ -358,6 +386,7 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
prev_end_space_small
and current_start_space_small
and prev_lines_more_than_one
and edge_distance < max(prev_block.width, block.width)
):
seg_start_flag = False
else:
Expand All @@ -371,14 +400,19 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):

handlers = {
"paragraph_title": lambda: format_title(block.content),
"abstract_title": lambda: format_title(block.content),
"reference_title": lambda: format_title(block.content),
"content_title": lambda: format_title(block.content),
"doc_title": lambda: f"# {block.content}".replace(
"-\n",
"",
).replace("\n", " "),
"table_title": lambda: format_centered_text(),
"figure_title": lambda: format_centered_text(),
"chart_title": lambda: format_centered_text(),
"text": lambda: block.content.replace("-\n", " ").replace("\n", " "),
"text": lambda: block.content.replace("\n\n", "\n").replace(
"\n", "\n\n"
),
"abstract": lambda: format_first_line(
["摘要", "abstract"], lambda l: f"## {l}\n", " "
),
Expand Down Expand Up @@ -416,24 +450,7 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
if handler:
prev_block = block
if label == last_label == "text" and seg_start_flag == False:
last_char_of_markdown = (
markdown_content[-1] if markdown_content else ""
)
first_char_of_handler = handler()[0] if handler() else ""
last_is_chinese_char = (
re.match(r"[\u4e00-\u9fff]", last_char_of_markdown)
if last_char_of_markdown
else False
)
first_is_chinese_char = (
re.match(r"[\u4e00-\u9fff]", first_char_of_handler)
if first_char_of_handler
else False
)
if not (last_is_chinese_char or first_is_chinese_char):
markdown_content += " " + handler()
else:
markdown_content += handler()
markdown_content += handler()
else:
markdown_content += (
"\n\n" + handler() if markdown_content else handler()
Expand Down Expand Up @@ -467,8 +484,8 @@ class LayoutParsingBlock:

def __init__(self, label, bbox, content="") -> None:
self.label = label
self.region_label = "other"
self.bbox = [int(item) for item in bbox]
self.order_label = None
self.bbox = list(map(int, bbox))
self.content = content
self.seg_start_coordinate = float("inf")
self.seg_end_coordinate = float("-inf")
Expand All @@ -478,7 +495,9 @@ def __init__(self, label, bbox, content="") -> None:
self.num_of_lines = 1
self.image = None
self.index = None
self.visual_index = None
self.order_index = None
self.text_line_width = 1
self.text_line_height = 1
self.direction = self.get_bbox_direction()
self.child_blocks = []
self.update_direction_info()
Expand All @@ -487,14 +506,14 @@ def __str__(self) -> str:
return f"{self.__dict__}"

def __repr__(self) -> str:
_str = f"\n\n#################\nlabel:\t{self.label}\nregion_label:\t{self.region_label}\nbbox:\t{self.bbox}\ncontent:\t{self.content}\n#################"
_str = f"\n\n#################\nindex:\t{self.index}\nlabel:\t{self.label}\nregion_label:\t{self.order_label}\nbbox:\t{self.bbox}\ncontent:\t{self.content}\n#################"
return _str

def to_dict(self) -> dict:
return self.__dict__

def update_direction_info(self) -> None:
if self.region_label == "vision":
if self.order_label == "vision":
self.direction = "horizontal"
if self.direction == "horizontal":
self.secondary_direction = "vertical"
Expand Down Expand Up @@ -542,19 +561,130 @@ def get_centroid(self) -> tuple:
centroid = ((x1 + x2) / 2, (y1 + y2) / 2)
return centroid

def get_bbox_direction(self, orientation_ratio: float = 1.0) -> bool:
def get_bbox_direction(self, direction_ratio: float = 1.0) -> bool:
"""
Determine if a bounding box is horizontal or vertical.

Args:
bbox (List[float]): Bounding box [x_min, y_min, x_max, y_max].
orientation_ratio (float): Ratio for determining orientation. Default is 1.0.
direction_ratio (float): Ratio for determining direction. Default is 1.0.

Returns:
str: "horizontal" or "vertical".
"""
return (
"horizontal" if self.width * direction_ratio >= self.height else "vertical"
)


class LayoutParsingRegion:

def __init__(self, bbox, blocks: List[LayoutParsingBlock] = []) -> None:
self.bbox = bbox
self.block_map = {}
self.direction = "horizontal"
self.calculate_bbox_metrics()
self.doc_title_block_idxes = []
self.paragraph_title_block_idxes = []
self.vision_block_idxes = []
self.unordered_block_idxes = []
self.vision_title_block_idxes = []
self.normal_text_block_idxes = []
self.header_block_idxes = []
self.footer_block_idxes = []
self.text_line_width = 20
self.text_line_height = 10
self.init_region_info_from_layout(blocks)
self.init_direction_info()

def init_region_info_from_layout(self, blocks: List[LayoutParsingBlock]):
horizontal_normal_text_block_num = 0
text_line_height_list = []
text_line_width_list = []
for idx, block in enumerate(blocks):
self.block_map[idx] = block
block.index = idx
if block.label in BLOCK_LABEL_MAP["header_labels"]:
self.header_block_idxes.append(idx)
elif block.label in BLOCK_LABEL_MAP["doc_title_labels"]:
self.doc_title_block_idxes.append(idx)
elif block.label in BLOCK_LABEL_MAP["paragraph_title_labels"]:
self.paragraph_title_block_idxes.append(idx)
elif block.label in BLOCK_LABEL_MAP["vision_labels"]:
self.vision_block_idxes.append(idx)
elif block.label in BLOCK_LABEL_MAP["vision_title_labels"]:
self.vision_title_block_idxes.append(idx)
elif block.label in BLOCK_LABEL_MAP["footer_labels"]:
self.footer_block_idxes.append(idx)
elif block.label in BLOCK_LABEL_MAP["unordered_labels"]:
self.unordered_block_idxes.append(idx)
else:
self.normal_text_block_idxes.append(idx)
text_line_height_list.append(block.text_line_height)
text_line_width_list.append(block.text_line_width)
if block.direction == "horizontal":
horizontal_normal_text_block_num += 1
self.direction = (
"horizontal"
if self.width * orientation_ratio >= self.height
if horizontal_normal_text_block_num
>= len(self.normal_text_block_idxes) * 0.5
else "vertical"
)
self.text_line_width = (
np.mean(text_line_width_list) if text_line_width_list else 20
)
self.text_line_height = (
np.mean(text_line_height_list) if text_line_height_list else 10
)

def init_direction_info(self):
if self.direction == "horizontal":
self.direction_start_index = 0
self.direction_end_index = 2
self.secondary_direction_start_index = 1
self.secondary_direction_end_index = 3
self.secondary_direction = "vertical"
else:
self.direction_start_index = 1
self.direction_end_index = 3
self.secondary_direction_start_index = 0
self.secondary_direction_end_index = 2
self.secondary_direction = "horizontal"

self.direction_center_coordinate = (
self.bbox[self.direction_start_index] + self.bbox[self.direction_end_index]
) / 2
self.secondary_direction_center_coordinate = (
self.bbox[self.secondary_direction_start_index]
+ self.bbox[self.secondary_direction_end_index]
) / 2

def calculate_bbox_metrics(self):
x1, y1, x2, y2 = self.bbox
x_center, y_center = (x1 + x2) / 2, (y1 + y2) / 2
self.euclidean_distance = math.sqrt(((x1) ** 2 + (y1) ** 2))
self.center_euclidean_distance = math.sqrt(((x_center) ** 2 + (y_center) ** 2))
self.angle_rad = math.atan2(y_center, x_center)

def sort_normal_blocks(self, blocks):
if self.direction == "horizontal":
blocks.sort(
key=lambda x: (
x.bbox[1] // self.text_line_height,
x.bbox[0] // self.text_line_width,
x.bbox[1] ** 2 + x.bbox[0] ** 2,
),
)
else:
blocks.sort(
key=lambda x: (
-x.bbox[0] // self.text_line_width,
x.bbox[1] // self.text_line_height,
-(x.bbox[2] ** 2 + x.bbox[1] ** 2),
),
)

def sort(self):
from .xycut_enhanced import xycut_enhanced

return xycut_enhanced(self)
Loading