diff --git a/paddlex/inference/models/table_structure_recognition/processors.py b/paddlex/inference/models/table_structure_recognition/processors.py
index 508222befe..e1048ca860 100644
--- a/paddlex/inference/models/table_structure_recognition/processors.py
+++ b/paddlex/inference/models/table_structure_recognition/processors.py
@@ -130,12 +130,7 @@ def __call__(self, pred, img_size, ori_img_size):
structure_probs, bbox_preds, img_size, ori_img_size
)
structure_str_list = [
- (
- ["", "
", "", "", ""]
- )
- for structure in structure_str_list
+ ([""]) for structure in structure_str_list
]
return [
{"bbox": bbox, "structure": structure, "structure_score": structure_score}
diff --git a/paddlex/inference/pipelines/layout_parsing/pipeline_v2.py b/paddlex/inference/pipelines/layout_parsing/pipeline_v2.py
index b5ca253050..a7afd3b54a 100644
--- a/paddlex/inference/pipelines/layout_parsing/pipeline_v2.py
+++ b/paddlex/inference/pipelines/layout_parsing/pipeline_v2.py
@@ -29,10 +29,12 @@
from ...utils.pp_option import PaddlePredictorOption
from ..base import BasePipeline
from ..ocr.result import OCRResult
-from .result_v2 import LayoutParsingBlock, LayoutParsingResultV2
+from .result_v2 import LayoutParsingBlock, LayoutParsingRegion, LayoutParsingResultV2
+from .setting import BLOCK_LABEL_MAP, BLOCK_SETTINGS, LINE_SETTINGS, REGION_SETTINGS
from .utils import (
caculate_bbox_area,
- calculate_text_orientation,
+ calculate_minimum_enclosing_bbox,
+ calculate_overlap_ratio,
convert_formula_res_to_ocr_format,
format_line,
gather_imgs,
@@ -40,11 +42,10 @@
get_sub_regions_ocr_res,
group_boxes_into_lines,
remove_overlap_blocks,
- split_boxes_if_x_contained,
- update_layout_order_config_block_index,
+ shrink_supplement_region_bbox,
+ split_boxes_by_projection,
update_region_box,
)
-from .xycut_enhanced import xycut_enhanced
@pipeline_requires_extra("ocr")
@@ -100,6 +101,10 @@ def inintial_predictor(self, config: dict) -> None:
self.use_general_ocr = config.get("use_general_ocr", True)
self.use_table_recognition = config.get("use_table_recognition", True)
self.use_seal_recognition = config.get("use_seal_recognition", True)
+ self.use_region_detection = config.get(
+ "use_region_detection",
+ False,
+ )
self.use_formula_recognition = config.get(
"use_formula_recognition",
True,
@@ -115,6 +120,16 @@ def inintial_predictor(self, config: dict) -> None:
self.doc_preprocessor_pipeline = self.create_pipeline(
doc_preprocessor_config,
)
+ if self.use_region_detection:
+ region_detection_config = config.get("SubModules", {}).get(
+ "RegionDetection",
+ {
+ "model_config_error": "config error for block_region_detection_model!"
+ },
+ )
+ self.region_detection_model = self.create_model(
+ region_detection_config,
+ )
layout_det_config = config.get("SubModules", {}).get(
"LayoutDetection",
@@ -246,7 +261,7 @@ def check_model_settings_valid(self, input_params: dict) -> bool:
def standardized_data(
self,
image: list,
- layout_order_config: dict,
+ region_det_res: DetResult,
layout_det_res: DetResult,
overall_ocr_res: OCRResult,
formula_res_list: list,
@@ -277,13 +292,16 @@ def standardized_data(
"""
matched_ocr_dict = {}
- layout_to_ocr_mapping = {}
+ region_to_block_map = {}
+ block_to_ocr_map = {}
object_boxes = []
footnote_list = []
+ paragraph_title_list = []
bottom_text_y_max = 0
max_block_area = 0.0
+ doc_title_num = 0
- region_box = [65535, 65535, 0, 0]
+ base_region_bbox = [65535, 65535, 0, 0]
layout_det_res = remove_overlap_blocks(
layout_det_res,
threshold=0.5,
@@ -301,22 +319,26 @@ def standardized_data(
_, _, _, y2 = box
# update the region box and max_block_area according to the layout boxes
- region_box = update_region_box(box, region_box)
+ base_region_bbox = update_region_box(box, base_region_bbox)
max_block_area = max(max_block_area, caculate_bbox_area(box))
- update_layout_order_config_block_index(layout_order_config, label, box_idx)
+ # update_layout_order_config_block_index(layout_order_config, label, box_idx)
# set the label of footnote to text, when it is above the text boxes
if label == "footnote":
footnote_list.append(box_idx)
+ elif label == "paragraph_title":
+ paragraph_title_list.append(box_idx)
if label == "text":
bottom_text_y_max = max(y2, bottom_text_y_max)
+ if label == "doc_title":
+ doc_title_num += 1
if label not in ["formula", "table", "seal"]:
_, matched_idxes = get_sub_regions_ocr_res(
overall_ocr_res, [box], return_match_idx=True
)
- layout_to_ocr_mapping[box_idx] = matched_idxes
+ block_to_ocr_map[box_idx] = matched_idxes
for matched_idx in matched_idxes:
if matched_ocr_dict.get(matched_idx, None) is None:
matched_ocr_dict[matched_idx] = [box_idx]
@@ -330,36 +352,21 @@ def standardized_data(
< bottom_text_y_max
):
layout_det_res["boxes"][footnote_idx]["label"] = "text"
- layout_order_config["text_block_idxes"].append(footnote_idx)
- layout_order_config["footer_block_idxes"].remove(footnote_idx)
- # fix the doc_title label
- doc_title_idxes = layout_order_config.get("doc_title_block_idxes", [])
- paragraph_title_idxes = layout_order_config.get(
- "paragraph_title_block_idxes", []
- )
# check if there is only one paragraph title and without doc_title
- only_one_paragraph_title = (
- len(paragraph_title_idxes) == 1 and len(doc_title_idxes) == 0
- )
+ only_one_paragraph_title = len(paragraph_title_list) == 1 and doc_title_num == 0
if only_one_paragraph_title:
paragraph_title_block_area = caculate_bbox_area(
- layout_det_res["boxes"][paragraph_title_idxes[0]]["coordinate"]
+ layout_det_res["boxes"][paragraph_title_list[0]]["coordinate"]
)
- title_area_max_block_threshold = layout_order_config.get(
- "title_area_max_block_threshold", 0.3
+ title_area_max_block_threshold = BLOCK_SETTINGS.get(
+ "title_conversion_area_ratio_threshold", 0.3
)
if (
paragraph_title_block_area
> max_block_area * title_area_max_block_threshold
):
- layout_det_res["boxes"][paragraph_title_idxes[0]]["label"] = "doc_title"
- layout_order_config["doc_title_block_idxes"].append(
- paragraph_title_idxes[0]
- )
- layout_order_config["paragraph_title_block_idxes"].remove(
- paragraph_title_idxes[0]
- )
+ layout_det_res["boxes"][paragraph_title_list[0]]["label"] = "doc_title"
# Replace the OCR information of the hurdles.
for overall_ocr_idx, layout_box_ids in matched_ocr_dict.items():
@@ -374,6 +381,11 @@ def standardized_data(
for box_idx in layout_box_ids:
layout_box = layout_det_res["boxes"][box_idx]["coordinate"]
crop_box = get_bbox_intersection(overall_ocr_box, layout_box)
+ for ocr_idx in block_to_ocr_map[box_idx]:
+ ocr_box = overall_ocr_res["rec_boxes"][ocr_idx]
+ iou = calculate_overlap_ratio(ocr_box, crop_box, "small")
+ if iou > 0.8:
+ overall_ocr_res["rec_texts"][ocr_idx] = ""
x1, y1, x2, y2 = [int(i) for i in crop_box]
crop_img = np.array(image)[y1:y2, x1:x2]
crop_img_rec_res = next(text_rec_model([crop_img]))
@@ -414,23 +426,149 @@ def standardized_data(
overall_ocr_res["rec_scores"].append(crop_img_rec_score)
overall_ocr_res["rec_texts"].append(crop_img_rec_text)
overall_ocr_res["rec_labels"].append("text")
- layout_to_ocr_mapping[box_idx].remove(overall_ocr_idx)
- layout_to_ocr_mapping[box_idx].append(
+ block_to_ocr_map[box_idx].remove(overall_ocr_idx)
+ block_to_ocr_map[box_idx].append(
len(overall_ocr_res["rec_texts"]) - 1
)
- layout_order_config["all_layout_region_box"] = region_box
- layout_order_config["layout_to_ocr_mapping"] = layout_to_ocr_mapping
- layout_order_config["matched_ocr_dict"] = matched_ocr_dict
+ # use layout bbox to do ocr recognition when there is no matched ocr
+ for layout_box_idx, overall_ocr_idxes in block_to_ocr_map.items():
+ has_text = False
+ for idx in overall_ocr_idxes:
+ if overall_ocr_res["rec_texts"][idx] != "":
+ has_text = True
+ break
+ if not has_text and layout_det_res["boxes"][layout_box_idx][
+ "label"
+ ] not in BLOCK_LABEL_MAP.get("vision_labels", []):
+ crop_box = layout_det_res["boxes"][layout_box_idx]["coordinate"]
+ x1, y1, x2, y2 = [int(i) for i in crop_box]
+ crop_img = np.array(image)[y1:y2, x1:x2]
+ crop_img_rec_res = next(text_rec_model([crop_img]))
+ crop_img_dt_poly = get_bbox_intersection(
+ crop_box, crop_box, return_format="poly"
+ )
+ crop_img_rec_score = crop_img_rec_res["rec_score"]
+ crop_img_rec_text = crop_img_rec_res["rec_text"]
+ text_rec_score_thresh = (
+ text_rec_score_thresh
+ if text_rec_score_thresh is not None
+ else (self.general_ocr_pipeline.text_rec_score_thresh)
+ )
+ if crop_img_rec_score >= text_rec_score_thresh:
+ overall_ocr_res["rec_boxes"] = np.vstack(
+ (overall_ocr_res["rec_boxes"], crop_box)
+ )
+ overall_ocr_res["rec_polys"].append(crop_img_dt_poly)
+ overall_ocr_res["rec_scores"].append(crop_img_rec_score)
+ overall_ocr_res["rec_texts"].append(crop_img_rec_text)
+ overall_ocr_res["rec_labels"].append("text")
+ block_to_ocr_map[layout_box_idx].append(
+ len(overall_ocr_res["rec_texts"]) - 1
+ )
+
+ # when there is no layout detection result but there is ocr result, convert ocr detection result to layout detection result
+ if len(layout_det_res["boxes"]) == 0 and len(overall_ocr_res["rec_boxes"]) > 0:
+ for idx, ocr_rec_box in enumerate(overall_ocr_res["rec_boxes"]):
+ base_region_bbox = update_region_box(ocr_rec_box, base_region_bbox)
+ layout_det_res["boxes"].append(
+ {
+ "label": "text",
+ "coordinate": ocr_rec_box,
+ "score": overall_ocr_res["rec_scores"][idx],
+ }
+ )
+ block_to_ocr_map[idx] = [idx]
+
+ block_bboxes = [box["coordinate"] for box in layout_det_res["boxes"]]
+ region_det_res["boxes"] = sorted(
+ region_det_res["boxes"],
+ key=lambda item: caculate_bbox_area(item["coordinate"]),
+ )
+ if len(region_det_res["boxes"]) == 0:
+ region_det_res["boxes"] = [
+ {
+ "coordinate": base_region_bbox,
+ "label": "SupplementaryBlock",
+ "score": 1,
+ }
+ ]
+ region_to_block_map[0] = range(len(block_bboxes))
+ else:
+ block_idxes_set = set(range(len(block_bboxes)))
+ # match block to region
+ for region_idx, region_info in enumerate(region_det_res["boxes"]):
+ matched_idxes = []
+ region_to_block_map[region_idx] = []
+ region_bbox = region_info["coordinate"]
+ for block_idx in block_idxes_set:
+ overlap_ratio = calculate_overlap_ratio(
+ region_bbox, block_bboxes[block_idx], mode="small"
+ )
+ if overlap_ratio > REGION_SETTINGS.get(
+ "match_block_overlap_ratio_threshold", 0.8
+ ):
+ region_to_block_map[region_idx].append(block_idx)
+ matched_idxes.append(block_idx)
+ if len(matched_idxes) > 0:
+ for block_idx in matched_idxes:
+ block_idxes_set.remove(block_idx)
+ matched_bboxes = [block_bboxes[idx] for idx in matched_idxes]
+ new_region_bbox = calculate_minimum_enclosing_bbox(matched_bboxes)
+ region_det_res["boxes"][region_idx]["coordinate"] = new_region_bbox
+ # Supplement region block when there is no matched block
+ if len(block_idxes_set) > 0:
+ while len(block_idxes_set) > 0:
+ matched_idxes = []
+ unmatched_bboxes = [block_bboxes[idx] for idx in block_idxes_set]
+ supplement_region_bbox = calculate_minimum_enclosing_bbox(
+ unmatched_bboxes
+ )
+ # check if the new region bbox is overlapped with other region bbox, if have, then shrink the new region bbox
+ for region_info in region_det_res["boxes"]:
+ region_bbox = region_info["coordinate"]
+ overlap_ratio = calculate_overlap_ratio(
+ supplement_region_bbox, region_bbox
+ )
+ if overlap_ratio > 0:
+ supplement_region_bbox, matched_idxes = (
+ shrink_supplement_region_bbox(
+ supplement_region_bbox,
+ region_bbox,
+ image.shape[1],
+ image.shape[0],
+ block_idxes_set,
+ block_bboxes,
+ )
+ )
+ if len(matched_idxes) == 0:
+ matched_idxes = list(block_idxes_set)
+ region_idx = len(region_det_res["boxes"])
+ region_to_block_map[region_idx] = list(matched_idxes)
+ for block_idx in matched_idxes:
+ block_idxes_set.remove(block_idx)
+ region_det_res["boxes"].append(
+ {
+ "coordinate": supplement_region_bbox,
+ "label": "SupplementaryBlock",
+ "score": 1,
+ }
+ )
+
+ region_block_ocr_idx_map = dict(
+ region_to_block_map=region_to_block_map,
+ block_to_ocr_map=block_to_ocr_map,
+ )
- return layout_order_config, layout_det_res
+ return region_block_ocr_idx_map, region_det_res, layout_det_res
- def sort_line_by_x_projection(
+ def sort_line_by_projection(
self,
line: List[List[Union[List[int], str]]],
input_img: np.ndarray,
text_rec_model: Any,
text_rec_score_thresh: Union[float, None] = None,
+ direction: str = "vertical",
) -> None:
"""
Sort a line of text spans based on their vertical position within the layout bounding box.
@@ -443,24 +581,27 @@ def sort_line_by_x_projection(
Returns:
list: The sorted line of text spans.
"""
- splited_boxes = split_boxes_if_x_contained(line)
+ sort_index = 0 if direction == "horizontal" else 1
+ splited_boxes = split_boxes_by_projection(line, direction)
splited_lines = []
if len(line) != len(splited_boxes):
- splited_boxes.sort(key=lambda span: span[0][0])
+ splited_boxes.sort(key=lambda span: span[0][sort_index])
for span in splited_boxes:
- if span[2] == "text":
+ bbox, text, label = span
+ if label == "text":
crop_img = input_img[
- int(span[0][1]) : int(span[0][3]),
- int(span[0][0]) : int(span[0][2]),
+ int(bbox[1]) : int(bbox[3]),
+ int(bbox[0]) : int(bbox[2]),
]
crop_img_rec_res = next(text_rec_model([crop_img]))
crop_img_rec_score = crop_img_rec_res["rec_score"]
crop_img_rec_text = crop_img_rec_res["rec_text"]
- span[1] = (
+ text = (
crop_img_rec_text
if crop_img_rec_score >= text_rec_score_thresh
else ""
)
+ span[1] = text
splited_lines.append(span)
else:
@@ -471,91 +612,88 @@ def sort_line_by_x_projection(
def get_block_rec_content(
self,
image: list,
- layout_order_config: dict,
ocr_rec_res: dict,
block: LayoutParsingBlock,
text_rec_model: Any,
text_rec_score_thresh: Union[float, None] = None,
) -> str:
- text_delimiter_map = {
- "content": "\n",
- }
- line_delimiter_map = {
- "doc_title": " ",
- "content": "\n",
- }
if len(ocr_rec_res["rec_texts"]) == 0:
block.content = ""
return block
- label = block.label
- if label == "reference":
- rec_boxes = ocr_rec_res["boxes"]
- block_left_coordinate = min([box[0] for box in rec_boxes])
- block_right_coordinate = max([box[2] for box in rec_boxes])
- first_line_span_limit = (5,)
- last_line_span_limit = (20,)
- else:
- block_left_coordinate, _, block_right_coordinate, _ = block.bbox
- first_line_span_limit = (10,)
- last_line_span_limit = (10,)
-
- if label == "formula":
- ocr_rec_res["rec_texts"] = [
- rec_res_text.replace("$", "")
- for rec_res_text in ocr_rec_res["rec_texts"]
- ]
- lines = group_boxes_into_lines(
+ lines, text_direction = group_boxes_into_lines(
ocr_rec_res,
- block,
- layout_order_config.get("line_height_iou_threshold", 0.4),
+ LINE_SETTINGS.get("line_height_iou_threshold", 0.8),
)
- block.num_of_lines = len(lines)
+ if block.label == "reference":
+ rec_boxes = ocr_rec_res["boxes"]
+ block_right_coordinate = max([box[2] for box in rec_boxes])
+ else:
+ block_right_coordinate = block.bbox[2]
# format line
- new_lines = []
- horizontal_text_line_num = 0
- for line in lines:
- line.sort(key=lambda span: span[0][0])
+ text_lines = []
+ need_new_line_num = 0
+ start_index = 0 if text_direction == "horizontal" else 1
+ secondary_direction_start_index = 1 if text_direction == "horizontal" else 0
+ line_height_list, line_width_list = [], []
+ for idx, line in enumerate(lines):
+ line.sort(key=lambda span: span[0][start_index])
+
+ text_bboxes_height = [
+ span[0][secondary_direction_start_index + 2]
+ - span[0][secondary_direction_start_index]
+ for span in line
+ ]
+ text_bboxes_width = [
+ span[0][start_index + 2] - span[0][start_index] for span in line
+ ]
+ line_height = np.mean(text_bboxes_height)
+ line_height_list.append(line_height)
+ line_width_list.append(np.mean(text_bboxes_width))
# merge formula and text
ocr_labels = [span[2] for span in line]
if "formula" in ocr_labels:
- line = self.sort_line_by_x_projection(
- line, image, text_rec_model, text_rec_score_thresh
+ line = self.sort_line_by_projection(
+ line, image, text_rec_model, text_rec_score_thresh, text_direction
)
- text_orientation = calculate_text_orientation([span[0] for span in line])
- horizontal_text_line_num += 1 if text_orientation == "horizontal" else 0
-
- line_text = format_line(
+ line_text, need_new_line = format_line(
line,
- block_left_coordinate,
block_right_coordinate,
- first_line_span_limit=first_line_span_limit,
- last_line_span_limit=last_line_span_limit,
+ last_line_span_limit=line_height * 1.5,
block_label=block.label,
- delimiter_map=text_delimiter_map,
)
- new_lines.append(line_text)
-
- delim = line_delimiter_map.get(label, "")
- content = delim.join(new_lines)
+ if need_new_line:
+ need_new_line_num += 1
+ if idx == 0:
+ line_start_coordinate = line[0][0][0]
+ block.seg_start_coordinate = line_start_coordinate
+ elif idx == len(lines) - 1:
+ line_end_coordinate = line[-1][0][2]
+ block.seg_end_coordinate = line_end_coordinate
+ text_lines.append(line_text)
+
+ delim = LINE_SETTINGS["delimiter_map"].get(block.label, "")
+ if need_new_line_num > len(text_lines) * 0.5 and delim == "":
+ delim = "\n"
+ content = delim.join(text_lines)
block.content = content
- block.direction = (
- "horizontal"
- if horizontal_text_line_num > len(new_lines) * 0.5
- else "vertical"
- )
+ block.num_of_lines = len(text_lines)
+ block.direction = text_direction
+ block.text_line_height = np.mean(line_height_list)
+ block.text_line_width = np.mean(line_width_list)
return block
def get_layout_parsing_blocks(
self,
image: list,
- layout_order_config: dict,
+ region_block_ocr_idx_map: dict,
+ region_det_res: DetResult,
overall_ocr_res: OCRResult,
layout_det_res: DetResult,
table_res_list: list,
@@ -614,9 +752,9 @@ def get_layout_parsing_blocks(
_, ocr_idx_list = get_sub_regions_ocr_res(
overall_ocr_res, [block_bbox], return_match_idx=True
)
- layout_order_config["layout_to_ocr_mapping"][box_idx] = ocr_idx_list
+ region_block_ocr_idx_map["block_to_ocr_map"][box_idx] = ocr_idx_list
else:
- ocr_idx_list = layout_order_config["layout_to_ocr_mapping"].get(
+ ocr_idx_list = region_block_ocr_idx_map["block_to_ocr_map"].get(
box_idx, []
)
for box_no in ocr_idx_list:
@@ -630,7 +768,6 @@ def get_layout_parsing_blocks(
block = self.get_block_rec_content(
image=image,
block=block,
- layout_order_config=layout_order_config,
ocr_rec_res=rec_res,
text_rec_model=text_rec_model,
text_rec_score_thresh=text_rec_score_thresh,
@@ -644,28 +781,30 @@ def get_layout_parsing_blocks(
layout_parsing_blocks.append(block)
- # when there is no layout detection result but there is ocr result, use ocr result
- if len(layout_det_res["boxes"]) == 0:
- region_box = [65535, 65535, 0, 0]
- for ocr_idx, (ocr_rec_box, ocr_rec_text) in enumerate(
- zip(overall_ocr_res["rec_boxes"], overall_ocr_res["rec_texts"])
- ):
- update_layout_order_config_block_index(
- layout_order_config, "text", ocr_idx
- )
- region_box = update_region_box(ocr_rec_box, region_box)
- layout_parsing_blocks.append(
- LayoutParsingBlock(
- label="text", bbox=ocr_rec_box, content=ocr_rec_text
- )
- )
- layout_order_config["all_layout_region_box"] = region_box
+ region_list: List[LayoutParsingRegion] = []
+ for region_idx, region_info in enumerate(region_det_res["boxes"]):
+ region_bbox = region_info["coordinate"]
+ region_blocks = [
+ layout_parsing_blocks[idx]
+ for idx in region_block_ocr_idx_map["region_to_block_map"][region_idx]
+ ]
+ region = LayoutParsingRegion(
+ bbox=region_bbox,
+ blocks=region_blocks,
+ )
+ region_list.append(region)
- return layout_parsing_blocks, layout_order_config
+ region_list = sorted(
+ region_list,
+ key=lambda r: (r.euclidean_distance // 50, r.center_euclidean_distance),
+ )
+
+ return region_list
def get_layout_parsing_res(
self,
image: list,
+ region_det_res: DetResult,
layout_det_res: DetResult,
overall_ocr_res: OCRResult,
table_res_list: list,
@@ -686,23 +825,25 @@ def get_layout_parsing_res(
Returns:
list: A list of dictionaries representing the layout parsing result.
"""
- from .setting import layout_order_config
# Standardize data
- layout_order_config, layout_det_res = self.standardized_data(
- image=image,
- layout_order_config=copy.deepcopy(layout_order_config),
- layout_det_res=layout_det_res,
- overall_ocr_res=overall_ocr_res,
- formula_res_list=formula_res_list,
- text_rec_model=self.general_ocr_pipeline.text_rec_model,
- text_rec_score_thresh=text_rec_score_thresh,
+ region_block_ocr_idx_map, region_det_res, layout_det_res = (
+ self.standardized_data(
+ image=image,
+ region_det_res=region_det_res,
+ layout_det_res=layout_det_res,
+ overall_ocr_res=overall_ocr_res,
+ formula_res_list=formula_res_list,
+ text_rec_model=self.general_ocr_pipeline.text_rec_model,
+ text_rec_score_thresh=text_rec_score_thresh,
+ )
)
# Format layout parsing block
- parsing_res_list, layout_order_config = self.get_layout_parsing_blocks(
+ region_list = self.get_layout_parsing_blocks(
image=image,
- layout_order_config=layout_order_config,
+ region_block_ocr_idx_map=region_block_ocr_idx_map,
+ region_det_res=region_det_res,
overall_ocr_res=overall_ocr_res,
layout_det_res=layout_det_res,
table_res_list=table_res_list,
@@ -711,10 +852,15 @@ def get_layout_parsing_res(
text_rec_score_thresh=self.general_ocr_pipeline.text_rec_score_thresh,
)
- parsing_res_list = xycut_enhanced(
- parsing_res_list,
- layout_order_config,
- )
+ parsing_res_list = []
+ for region in region_list:
+ parsing_res_list.extend(region.sort())
+
+ index = 1
+ for block in parsing_res_list:
+ if block.label in BLOCK_LABEL_MAP["visualize_index_labels"]:
+ block.order_index = index
+ index += 1
return parsing_res_list
@@ -726,6 +872,9 @@ def get_model_settings(
use_seal_recognition: Union[bool, None],
use_table_recognition: Union[bool, None],
use_formula_recognition: Union[bool, None],
+ use_chart_recognition: Union[bool, None],
+ use_region_detection: Union[bool, None],
+ is_pretty_markdown: Union[bool, None],
) -> dict:
"""
Get the model settings based on the provided parameters or default values.
@@ -762,12 +911,18 @@ def get_model_settings(
if use_formula_recognition is None:
use_formula_recognition = self.use_formula_recognition
+ if use_region_detection is None:
+ use_region_detection = self.use_region_detection
+
return dict(
use_doc_preprocessor=use_doc_preprocessor,
use_general_ocr=use_general_ocr,
use_seal_recognition=use_seal_recognition,
use_table_recognition=use_table_recognition,
use_formula_recognition=use_formula_recognition,
+ use_chart_recognition=use_chart_recognition,
+ use_region_detection=use_region_detection,
+ is_pretty_markdown=is_pretty_markdown,
)
def predict(
@@ -780,6 +935,8 @@ def predict(
use_seal_recognition: Union[bool, None] = None,
use_table_recognition: Union[bool, None] = None,
use_formula_recognition: Union[bool, None] = None,
+ use_chart_recognition: Union[bool, None] = None,
+ use_region_detection: Union[bool, None] = None,
layout_threshold: Optional[Union[float, dict]] = None,
layout_nms: Optional[bool] = None,
layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None,
@@ -799,6 +956,7 @@ def predict(
use_table_cells_ocr_results: bool = False,
use_e2e_wired_table_rec_model: bool = False,
use_e2e_wireless_table_rec_model: bool = True,
+ is_pretty_markdown: Union[bool, None] = None,
**kwargs,
) -> LayoutParsingResultV2:
"""
@@ -812,6 +970,7 @@ def predict(
use_seal_recognition (Optional[bool]): Whether to use seal recognition.
use_table_recognition (Optional[bool]): Whether to use table recognition.
use_formula_recognition (Optional[bool]): Whether to use formula recognition.
+ use_region_detection (Optional[bool]): Whether to use region detection.
layout_threshold (Optional[float]): The threshold value to filter out low-confidence predictions. Default is None.
layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False.
layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
@@ -848,6 +1007,9 @@ def predict(
use_seal_recognition,
use_table_recognition,
use_formula_recognition,
+ use_chart_recognition,
+ use_region_detection,
+ is_pretty_markdown,
)
if not self.check_model_settings_valid(model_settings):
@@ -878,8 +1040,20 @@ def predict(
layout_merge_bboxes_mode=layout_merge_bboxes_mode,
)
)
+
imgs_in_doc = gather_imgs(doc_preprocessor_image, layout_det_res["boxes"])
+ if model_settings["use_region_detection"]:
+ region_det_res = next(
+ self.region_detection_model(
+ doc_preprocessor_image,
+ layout_nms=True,
+ layout_merge_bboxes_mode="small",
+ ),
+ )
+ else:
+ region_det_res = {"boxes": []}
+
if model_settings["use_formula_recognition"]:
formula_res_all = next(
self.formula_recognition_pipeline(
@@ -915,7 +1089,13 @@ def predict(
),
)
else:
- overall_ocr_res = {}
+ overall_ocr_res = {
+ "dt_polys": [],
+ "rec_texts": [],
+ "rec_scores": [],
+ "rec_polys": [],
+ "rec_boxes": np.array([]),
+ }
overall_ocr_res["rec_labels"] = ["text"] * len(overall_ocr_res["rec_texts"])
@@ -933,9 +1113,14 @@ def predict(
table_contents["rec_texts"].append(
f"${formula_res['rec_formula']}$"
)
- table_contents["rec_boxes"] = np.vstack(
- (table_contents["rec_boxes"], [formula_res["dt_polys"]])
- )
+ if table_contents["rec_boxes"].size == 0:
+ table_contents["rec_boxes"] = np.array(
+ [formula_res["dt_polys"]]
+ )
+ else:
+ table_contents["rec_boxes"] = np.vstack(
+ (table_contents["rec_boxes"], [formula_res["dt_polys"]])
+ )
table_contents["rec_polys"].append(poly_points)
table_contents["rec_scores"].append(1)
@@ -1002,6 +1187,7 @@ def predict(
parsing_res_list = self.get_layout_parsing_res(
doc_preprocessor_image,
+ region_det_res=region_det_res,
layout_det_res=layout_det_res,
overall_ocr_res=overall_ocr_res,
table_res_list=table_res_list,
@@ -1021,6 +1207,7 @@ def predict(
"page_index": batch_data.page_indexes[0],
"doc_preprocessor_res": doc_preprocessor_res,
"layout_det_res": layout_det_res,
+ "region_det_res": region_det_res,
"overall_ocr_res": overall_ocr_res,
"table_res_list": table_res_list,
"seal_res_list": seal_res_list,
diff --git a/paddlex/inference/pipelines/layout_parsing/result_v2.py b/paddlex/inference/pipelines/layout_parsing/result_v2.py
index 95cd2cfa66..6c41434e9b 100644
--- a/paddlex/inference/pipelines/layout_parsing/result_v2.py
+++ b/paddlex/inference/pipelines/layout_parsing/result_v2.py
@@ -14,13 +14,15 @@
from __future__ import annotations
import copy
+import math
import re
from pathlib import Path
from typing import List
import numpy as np
-from PIL import Image, ImageDraw
+from PIL import Image, ImageDraw, ImageFont
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH
from ...common.result import (
BaseCVResult,
HtmlMixin,
@@ -28,6 +30,7 @@
MarkdownMixin,
XlsxMixin,
)
+from .setting import BLOCK_LABEL_MAP
class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
@@ -73,6 +76,9 @@ def _to_img(self) -> dict[str, np.ndarray]:
res_img_dict[key] = value
res_img_dict["layout_det_res"] = self["layout_det_res"].img["res"]
+ if model_settings["use_region_detection"]:
+ res_img_dict["region_det_res"] = self["region_det_res"].img["res"]
+
if model_settings["use_general_ocr"] or model_settings["use_table_recognition"]:
res_img_dict["overall_ocr_res"] = self["overall_ocr_res"].img["ocr_res_img"]
@@ -103,16 +109,23 @@ def _to_img(self) -> dict[str, np.ndarray]:
# for layout ordering image
image = Image.fromarray(self["doc_preprocessor_res"]["output_img"][:, :, ::-1])
draw = ImageDraw.Draw(image, "RGBA")
+ font_size = int(0.018 * int(image.width)) + 2
+ font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
parsing_result: List[LayoutParsingBlock] = self["parsing_res_list"]
for block in parsing_result:
bbox = block.bbox
- index = block.index
- label = block.label
- fill_color = get_show_color(label)
+ index = block.order_index
+ label = block.order_label
+ fill_color = get_show_color(label, True)
draw.rectangle(bbox, fill=fill_color)
if index is not None:
- text_position = (bbox[2] + 2, bbox[1] - 10)
- draw.text(text_position, str(index), fill="red")
+ text_position = (bbox[2] + 2, bbox[1] - font_size // 2)
+ if int(image.width) - bbox[2] < font_size:
+ text_position = (
+ int(bbox[2] - font_size * 1.1),
+ bbox[1] - font_size // 2,
+ )
+ draw.text(text_position, str(index), font=font, fill="red")
res_img_dict["layout_order_res"] = image
@@ -283,22 +296,33 @@ def format_title(title):
" ",
)
+ # def format_centered_text():
+ # return (
+ # f'{block.content}
'.replace(
+ # "-\n",
+ # "",
+ # ).replace("\n", " ")
+ # + "\n"
+ # )
+
def format_centered_text():
- return (
- f'{block.content}
'.replace(
- "-\n",
- "",
- ).replace("\n", " ")
- + "\n"
- )
+ return block.content
+
+ # def format_image():
+ # img_tags = []
+ # image_path = "".join(block.image.keys())
+ # img_tags.append(
+ # ''.format(
+ # image_path.replace("-\n", "").replace("\n", " "),
+ # ),
+ # )
+ # return "\n".join(img_tags)
def format_image():
img_tags = []
image_path = "".join(block.image.keys())
img_tags.append(
- ''.format(
- image_path.replace("-\n", "").replace("\n", " "),
- ),
+ "".format(image_path.replace("-\n", "").replace("\n", " "))
)
return "\n".join(img_tags)
@@ -332,7 +356,7 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
num_of_prev_lines = prev_block.num_of_lines
pre_block_seg_end_coordinate = prev_block.seg_end_coordinate
prev_end_space_small = (
- context_right_coordinate - pre_block_seg_end_coordinate < 10
+ abs(prev_block_bbox[2] - pre_block_seg_end_coordinate) < 10
)
prev_lines_more_than_one = num_of_prev_lines > 1
@@ -347,8 +371,12 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
prev_block_bbox[2], context_right_coordinate
)
prev_end_space_small = (
- prev_block_bbox[2] - pre_block_seg_end_coordinate < 10
+ abs(context_right_coordinate - pre_block_seg_end_coordinate)
+ < 10
)
+ edge_distance = 0
+ else:
+ edge_distance = abs(block_box[0] - prev_block_bbox[2])
current_start_space_small = (
seg_start_coordinate - context_left_coordinate < 10
@@ -358,6 +386,7 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
prev_end_space_small
and current_start_space_small
and prev_lines_more_than_one
+ and edge_distance < max(prev_block.width, block.width)
):
seg_start_flag = False
else:
@@ -371,6 +400,9 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
handlers = {
"paragraph_title": lambda: format_title(block.content),
+ "abstract_title": lambda: format_title(block.content),
+ "reference_title": lambda: format_title(block.content),
+ "content_title": lambda: format_title(block.content),
"doc_title": lambda: f"# {block.content}".replace(
"-\n",
"",
@@ -378,7 +410,9 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
"table_title": lambda: format_centered_text(),
"figure_title": lambda: format_centered_text(),
"chart_title": lambda: format_centered_text(),
- "text": lambda: block.content.replace("-\n", " ").replace("\n", " "),
+ "text": lambda: block.content.replace("\n\n", "\n").replace(
+ "\n", "\n\n"
+ ),
"abstract": lambda: format_first_line(
["摘要", "abstract"], lambda l: f"## {l}\n", " "
),
@@ -416,24 +450,7 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
if handler:
prev_block = block
if label == last_label == "text" and seg_start_flag == False:
- last_char_of_markdown = (
- markdown_content[-1] if markdown_content else ""
- )
- first_char_of_handler = handler()[0] if handler() else ""
- last_is_chinese_char = (
- re.match(r"[\u4e00-\u9fff]", last_char_of_markdown)
- if last_char_of_markdown
- else False
- )
- first_is_chinese_char = (
- re.match(r"[\u4e00-\u9fff]", first_char_of_handler)
- if first_char_of_handler
- else False
- )
- if not (last_is_chinese_char or first_is_chinese_char):
- markdown_content += " " + handler()
- else:
- markdown_content += handler()
+ markdown_content += handler()
else:
markdown_content += (
"\n\n" + handler() if markdown_content else handler()
@@ -467,8 +484,8 @@ class LayoutParsingBlock:
def __init__(self, label, bbox, content="") -> None:
self.label = label
- self.region_label = "other"
- self.bbox = [int(item) for item in bbox]
+ self.order_label = None
+ self.bbox = list(map(int, bbox))
self.content = content
self.seg_start_coordinate = float("inf")
self.seg_end_coordinate = float("-inf")
@@ -478,7 +495,9 @@ def __init__(self, label, bbox, content="") -> None:
self.num_of_lines = 1
self.image = None
self.index = None
- self.visual_index = None
+ self.order_index = None
+ self.text_line_width = 1
+ self.text_line_height = 1
self.direction = self.get_bbox_direction()
self.child_blocks = []
self.update_direction_info()
@@ -487,14 +506,14 @@ def __str__(self) -> str:
return f"{self.__dict__}"
def __repr__(self) -> str:
- _str = f"\n\n#################\nlabel:\t{self.label}\nregion_label:\t{self.region_label}\nbbox:\t{self.bbox}\ncontent:\t{self.content}\n#################"
+ _str = f"\n\n#################\nindex:\t{self.index}\nlabel:\t{self.label}\nregion_label:\t{self.order_label}\nbbox:\t{self.bbox}\ncontent:\t{self.content}\n#################"
return _str
def to_dict(self) -> dict:
return self.__dict__
def update_direction_info(self) -> None:
- if self.region_label == "vision":
+ if self.order_label == "vision":
self.direction = "horizontal"
if self.direction == "horizontal":
self.secondary_direction = "vertical"
@@ -542,19 +561,130 @@ def get_centroid(self) -> tuple:
centroid = ((x1 + x2) / 2, (y1 + y2) / 2)
return centroid
- def get_bbox_direction(self, orientation_ratio: float = 1.0) -> bool:
+ def get_bbox_direction(self, direction_ratio: float = 1.0) -> bool:
"""
Determine if a bounding box is horizontal or vertical.
Args:
bbox (List[float]): Bounding box [x_min, y_min, x_max, y_max].
- orientation_ratio (float): Ratio for determining orientation. Default is 1.0.
+ direction_ratio (float): Ratio for determining direction. Default is 1.0.
Returns:
str: "horizontal" or "vertical".
"""
return (
+ "horizontal" if self.width * direction_ratio >= self.height else "vertical"
+ )
+
+
+class LayoutParsingRegion:
+
+ def __init__(self, bbox, blocks: List[LayoutParsingBlock] = []) -> None:
+ self.bbox = bbox
+ self.block_map = {}
+ self.direction = "horizontal"
+ self.calculate_bbox_metrics()
+ self.doc_title_block_idxes = []
+ self.paragraph_title_block_idxes = []
+ self.vision_block_idxes = []
+ self.unordered_block_idxes = []
+ self.vision_title_block_idxes = []
+ self.normal_text_block_idxes = []
+ self.header_block_idxes = []
+ self.footer_block_idxes = []
+ self.text_line_width = 20
+ self.text_line_height = 10
+ self.init_region_info_from_layout(blocks)
+ self.init_direction_info()
+
+ def init_region_info_from_layout(self, blocks: List[LayoutParsingBlock]):
+ horizontal_normal_text_block_num = 0
+ text_line_height_list = []
+ text_line_width_list = []
+ for idx, block in enumerate(blocks):
+ self.block_map[idx] = block
+ block.index = idx
+ if block.label in BLOCK_LABEL_MAP["header_labels"]:
+ self.header_block_idxes.append(idx)
+ elif block.label in BLOCK_LABEL_MAP["doc_title_labels"]:
+ self.doc_title_block_idxes.append(idx)
+ elif block.label in BLOCK_LABEL_MAP["paragraph_title_labels"]:
+ self.paragraph_title_block_idxes.append(idx)
+ elif block.label in BLOCK_LABEL_MAP["vision_labels"]:
+ self.vision_block_idxes.append(idx)
+ elif block.label in BLOCK_LABEL_MAP["vision_title_labels"]:
+ self.vision_title_block_idxes.append(idx)
+ elif block.label in BLOCK_LABEL_MAP["footer_labels"]:
+ self.footer_block_idxes.append(idx)
+ elif block.label in BLOCK_LABEL_MAP["unordered_labels"]:
+ self.unordered_block_idxes.append(idx)
+ else:
+ self.normal_text_block_idxes.append(idx)
+ text_line_height_list.append(block.text_line_height)
+ text_line_width_list.append(block.text_line_width)
+ if block.direction == "horizontal":
+ horizontal_normal_text_block_num += 1
+ self.direction = (
"horizontal"
- if self.width * orientation_ratio >= self.height
+ if horizontal_normal_text_block_num
+ >= len(self.normal_text_block_idxes) * 0.5
else "vertical"
)
+ self.text_line_width = (
+ np.mean(text_line_width_list) if text_line_width_list else 20
+ )
+ self.text_line_height = (
+ np.mean(text_line_height_list) if text_line_height_list else 10
+ )
+
+ def init_direction_info(self):
+ if self.direction == "horizontal":
+ self.direction_start_index = 0
+ self.direction_end_index = 2
+ self.secondary_direction_start_index = 1
+ self.secondary_direction_end_index = 3
+ self.secondary_direction = "vertical"
+ else:
+ self.direction_start_index = 1
+ self.direction_end_index = 3
+ self.secondary_direction_start_index = 0
+ self.secondary_direction_end_index = 2
+ self.secondary_direction = "horizontal"
+
+ self.direction_center_coordinate = (
+ self.bbox[self.direction_start_index] + self.bbox[self.direction_end_index]
+ ) / 2
+ self.secondary_direction_center_coordinate = (
+ self.bbox[self.secondary_direction_start_index]
+ + self.bbox[self.secondary_direction_end_index]
+ ) / 2
+
+ def calculate_bbox_metrics(self):
+ x1, y1, x2, y2 = self.bbox
+ x_center, y_center = (x1 + x2) / 2, (y1 + y2) / 2
+ self.euclidean_distance = math.sqrt(((x1) ** 2 + (y1) ** 2))
+ self.center_euclidean_distance = math.sqrt(((x_center) ** 2 + (y_center) ** 2))
+ self.angle_rad = math.atan2(y_center, x_center)
+
+ def sort_normal_blocks(self, blocks):
+ if self.direction == "horizontal":
+ blocks.sort(
+ key=lambda x: (
+ x.bbox[1] // self.text_line_height,
+ x.bbox[0] // self.text_line_width,
+ x.bbox[1] ** 2 + x.bbox[0] ** 2,
+ ),
+ )
+ else:
+ blocks.sort(
+ key=lambda x: (
+ -x.bbox[0] // self.text_line_width,
+ x.bbox[1] // self.text_line_height,
+ -(x.bbox[2] ** 2 + x.bbox[1] ** 2),
+ ),
+ )
+
+ def sort(self):
+ from .xycut_enhanced import xycut_enhanced
+
+ return xycut_enhanced(self)
diff --git a/paddlex/inference/pipelines/layout_parsing/setting.py b/paddlex/inference/pipelines/layout_parsing/setting.py
index 97ba6ec1c2..82162a95fa 100644
--- a/paddlex/inference/pipelines/layout_parsing/setting.py
+++ b/paddlex/inference/pipelines/layout_parsing/setting.py
@@ -12,18 +12,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-layout_order_config = {
- # 人工配置项
- "line_height_iou_threshold": 0.4, # For line segmentation of OCR results
- "title_area_max_block_threshold": 0.3, # update paragraph_title -> doc_title
- "block_label_match_iou_threshold": 0.1,
- "block_title_match_iou_threshold": 0.1,
+
+XYCUT_SETTINGS = {
+ "child_block_overlap_ratio_threshold": 0.1,
+ "edge_distance_compare_tolerance_len": 2,
+ "distance_weight_map": {
+ "edge_weight": 10**4,
+ "up_edge_weight": 1,
+ "down_edge_weight": 0.0001,
+ },
+}
+
+REGION_SETTINGS = {
+ "match_block_overlap_ratio_threshold": 0.6,
+ "split_block_overlap_ratio_threshold": 0.4,
+}
+
+BLOCK_SETTINGS = {
+ "title_conversion_area_ratio_threshold": 0.3, # update paragraph_title -> doc_title
+}
+
+LINE_SETTINGS = {
+ "line_height_iou_threshold": 0.6, # For line segmentation of OCR results
+ "delimiter_map": {
+ "doc_title": " ",
+ "content": "\n",
+ },
+}
+
+BLOCK_LABEL_MAP = {
"doc_title_labels": ["doc_title"], # 文档标题
- "paragraph_title_labels": ["paragraph_title"], # 段落标题
+ "paragraph_title_labels": [
+ "paragraph_title",
+ "abstract_title",
+ "reference_title",
+ "content_title",
+ ], # 段落标题
"vision_labels": [
"image",
"table",
"chart",
+ "flowchart",
"figure",
], # 图、表、印章、图表、图
"vision_title_labels": ["table_title", "chart_title", "figure_title"], # 图表标题
@@ -52,19 +81,9 @@
"table",
"chart",
"figure",
+ "abstract_title",
+ "refer_title",
+ "content_title",
+ "flowchart",
],
- # 自动补全配置项
- "layout_to_ocr_mapping": {},
- "all_layout_region_box": [], # 区域box
- "doc_title_block_idxes": [],
- "paragraph_title_block_idxes": [],
- "text_title_labels": [], # doc_title_labels+paragraph_title_labels
- "text_title_block_idxes": [],
- "vision_block_idxes": [],
- "vision_title_block_idxes": [],
- "vision_footnote_block_idxes": [],
- "text_block_idxes": [],
- "header_block_idxes": [],
- "footer_block_idxes": [],
- "unordered_block_idxes": [],
}
diff --git a/paddlex/inference/pipelines/layout_parsing/utils.py b/paddlex/inference/pipelines/layout_parsing/utils.py
index 5f90e9d3aa..904156b932 100644
--- a/paddlex/inference/pipelines/layout_parsing/utils.py
+++ b/paddlex/inference/pipelines/layout_parsing/utils.py
@@ -16,7 +16,6 @@
"get_sub_regions_ocr_res",
"get_show_color",
"sorted_layout_boxes",
- "update_layout_order_config_block_index",
]
import re
@@ -28,7 +27,7 @@
from ..components import convert_points_to_boxes
from ..ocr.result import OCRResult
-from .xycut_enhanced import calculate_projection_iou
+from .setting import REGION_SETTINGS
def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> List:
@@ -172,64 +171,129 @@ def sorted_layout_boxes(res, w):
return new_res
-def _calculate_overlap_area_div_minbox_area_ratio(
- bbox1: Union[list, tuple],
- bbox2: Union[list, tuple],
+def calculate_projection_overlap_ratio(
+ bbox1: List[float],
+ bbox2: List[float],
+ direction: str = "horizontal",
+ mode="union",
) -> float:
"""
- Calculate the ratio of the overlap area between bbox1 and bbox2
- to the area of the smaller bounding box.
+ Calculate the IoU of lines between two bounding boxes.
Args:
- bbox1 (list or tuple): Coordinates of the first bounding box [x_min, y_min, x_max, y_max].
- bbox2 (list or tuple): Coordinates of the second bounding box [x_min, y_min, x_max, y_max].
+ bbox1 (List[float]): First bounding box [x_min, y_min, x_max, y_max].
+ bbox2 (List[float]): Second bounding box [x_min, y_min, x_max, y_max].
+ direction (str): direction of the projection, "horizontal" or "vertical".
Returns:
- float: The ratio of the overlap area to the area of the smaller bounding box.
+ float: Line overlap ratio. Returns 0 if there is no overlap.
"""
- bbox1 = list(map(int, bbox1))
- bbox2 = list(map(int, bbox2))
+ start_index, end_index = 1, 3
+ if direction == "horizontal":
+ start_index, end_index = 0, 2
+
+ intersection_start = max(bbox1[start_index], bbox2[start_index])
+ intersection_end = min(bbox1[end_index], bbox2[end_index])
+ overlap = intersection_end - intersection_start
+ if overlap <= 0:
+ return 0
+
+ if mode == "union":
+ ref_width = max(bbox1[end_index], bbox2[end_index]) - min(
+ bbox1[start_index], bbox2[start_index]
+ )
+ elif mode == "small":
+ ref_width = min(
+ bbox1[end_index] - bbox1[start_index], bbox2[end_index] - bbox2[start_index]
+ )
+ elif mode == "large":
+ ref_width = max(
+ bbox1[end_index] - bbox1[start_index], bbox2[end_index] - bbox2[start_index]
+ )
+ else:
+ raise ValueError(
+ f"Invalid mode {mode}, must be one of ['union', 'small', 'large']."
+ )
- x_left = max(bbox1[0], bbox2[0])
- y_top = max(bbox1[1], bbox2[1])
- x_right = min(bbox1[2], bbox2[2])
- y_bottom = min(bbox1[3], bbox2[3])
+ return overlap / ref_width if ref_width > 0 else 0.0
- if x_right <= x_left or y_bottom <= y_top:
- return 0.0
- intersection_area = (x_right - x_left) * (y_bottom - y_top)
- area_bbox1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
- area_bbox2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
- min_box_area = min(area_bbox1, area_bbox2)
+def calculate_overlap_ratio(
+ bbox1: Union[list, tuple], bbox2: Union[list, tuple], mode="union"
+) -> float:
+ """
+ Calculate the overlap ratio between two bounding boxes.
+
+ Args:
+ bbox1 (list or tuple): The first bounding box, format [x_min, y_min, x_max, y_max]
+ bbox2 (list or tuple): The second bounding box, format [x_min, y_min, x_max, y_max]
+ mode (str): The mode of calculation, either 'union', 'small', or 'large'.
+
+ Returns:
+ float: The overlap ratio value between the two bounding boxes
+ """
+ x_min_inter = max(bbox1[0], bbox2[0])
+ y_min_inter = max(bbox1[1], bbox2[1])
+ x_max_inter = min(bbox1[2], bbox2[2])
+ y_max_inter = min(bbox1[3], bbox2[3])
+
+ inter_width = max(0, x_max_inter - x_min_inter)
+ inter_height = max(0, y_max_inter - y_min_inter)
- if min_box_area <= 0:
+ inter_area = inter_width * inter_height
+
+ bbox1_area = caculate_bbox_area(bbox1)
+ bbox2_area = caculate_bbox_area(bbox2)
+
+ if mode == "union":
+ ref_area = bbox1_area + bbox2_area - inter_area
+ elif mode == "small":
+ ref_area = min(bbox1_area, bbox2_area)
+ elif mode == "large":
+ ref_area = max(bbox1_area, bbox2_area)
+ else:
+ raise ValueError(
+ f"Invalid mode {mode}, must be one of ['union', 'small', 'large']."
+ )
+
+ if ref_area == 0:
return 0.0
- return intersection_area / min_box_area
+ return inter_area / ref_area
-def group_boxes_into_lines(ocr_rec_res, block_info, line_height_iou_threshold):
+def group_boxes_into_lines(ocr_rec_res, line_height_iou_threshold):
rec_boxes = ocr_rec_res["boxes"]
rec_texts = ocr_rec_res["rec_texts"]
rec_labels = ocr_rec_res["rec_labels"]
- spans = list(zip(rec_boxes, rec_texts, rec_labels))
+ text_boxes = [
+ rec_boxes[i] for i in range(len(rec_boxes)) if rec_labels[i] == "text"
+ ]
+ text_orientation = calculate_text_orientation(text_boxes)
+
+ match_direction = "vertical" if text_orientation == "horizontal" else "horizontal"
- spans.sort(key=lambda span: span[0][1])
+ spans = list(zip(rec_boxes, rec_texts, rec_labels))
+ sort_index = 1
+ reverse = False
+ if text_orientation == "vertical":
+ sort_index = 0
+ reverse = True
+ spans.sort(key=lambda span: span[0][sort_index], reverse=reverse)
spans = [list(span) for span in spans]
lines = []
line = [spans[0]]
- line_region_box = spans[0][0][:]
- block_info.seg_start_coordinate = spans[0][0][0]
- block_info.seg_end_coordinate = spans[-1][0][2]
+ line_region_box = spans[0][0].copy()
# merge line
for span in spans[1:]:
rec_bbox = span[0]
if (
- calculate_projection_iou(line_region_box, rec_bbox, "vertical")
+ calculate_projection_overlap_ratio(
+ line_region_box, rec_bbox, match_direction, mode="small"
+ )
>= line_height_iou_threshold
):
line.append(span)
@@ -238,10 +302,36 @@ def group_boxes_into_lines(ocr_rec_res, block_info, line_height_iou_threshold):
else:
lines.append(line)
line = [span]
- line_region_box = rec_bbox[:]
+ line_region_box = rec_bbox.copy()
lines.append(line)
- return lines
+ return lines, text_orientation
+
+
+def calculate_minimum_enclosing_bbox(bboxes):
+ """
+ Calculate the minimum enclosing bounding box for a list of bounding boxes.
+
+ Args:
+ bboxes (list): A list of bounding boxes represented as lists of four integers [x1, y1, x2, y2].
+
+ Returns:
+ list: The minimum enclosing bounding box represented as a list of four integers [x1, y1, x2, y2].
+ """
+ if not bboxes:
+ raise ValueError("The list of bounding boxes is empty.")
+
+ # Convert the list of bounding boxes to a NumPy array
+ bboxes_array = np.array(bboxes)
+
+ # Compute the minimum and maximum values along the respective axes
+ min_x = np.min(bboxes_array[:, 0])
+ min_y = np.min(bboxes_array[:, 1])
+ max_x = np.max(bboxes_array[:, 2])
+ max_y = np.max(bboxes_array[:, 3])
+
+ # Return the minimum enclosing bounding box
+ return [min_x, min_y, max_x, max_y]
def calculate_text_orientation(
@@ -258,24 +348,49 @@ def calculate_text_orientation(
str: "horizontal" or "vertical".
"""
- bboxes = np.array(bboxes)
- x_min = np.min(bboxes[:, 0])
- x_max = np.max(bboxes[:, 2])
- width = x_max - x_min
- y_min = np.min(bboxes[:, 1])
- y_max = np.max(bboxes[:, 3])
- height = y_max - y_min
- return "horizontal" if width * orientation_ratio >= height else "vertical"
+ horizontal_box_num = 0
+ for bbox in bboxes:
+ if len(bbox) != 4:
+ raise ValueError(
+ "Invalid bounding box format. Expected a list of length 4."
+ )
+ x1, y1, x2, y2 = bbox
+ width = x2 - x1
+ height = y2 - y1
+ horizontal_box_num += 1 if width * orientation_ratio >= height else 0
+
+ return "horizontal" if horizontal_box_num >= len(bboxes) * 0.5 else "vertical"
+
+
+def is_english_letter(char):
+ return bool(re.match(r"^[A-Za-z]$", char))
+
+
+def is_non_breaking_punctuation(char):
+ """
+ 判断一个字符是否是不需要换行的标点符号,包括全角和半角的符号。
+
+ :param char: str, 单个字符
+ :return: bool, 如果字符是不需要换行的标点符号,返回True,否则返回False
+ """
+ non_breaking_punctuations = {
+ ",", # 半角逗号
+ ",", # 全角逗号
+ "、", # 顿号
+ ";", # 半角分号
+ ";", # 全角分号
+ ":", # 半角冒号
+ ":", # 全角冒号
+ }
+
+ return char in non_breaking_punctuations
def format_line(
line: List[List[Union[List[int], str]]],
- block_left_coordinate: int,
block_right_coordinate: int,
- first_line_span_limit: int = 10,
last_line_span_limit: int = 10,
block_label: str = "text",
- delimiter_map: Dict = {},
) -> None:
"""
Format a line of text spans based on layout constraints.
@@ -290,92 +405,108 @@ def format_line(
Returns:
None: The function modifies the line in place.
"""
- first_span = line[0]
- last_span = line[-1]
-
- if first_span[0][0] - block_left_coordinate > first_line_span_limit:
- first_span[1] = "\n" + first_span[1]
- if block_right_coordinate - last_span[0][2] > last_line_span_limit:
- last_span[1] = last_span[1] + "\n"
+ last_span_box = line[-1][0]
- line[0] = first_span
- line[-1] = last_span
+ for span in line:
+ if span[2] == "formula" and block_label != "formula":
+ if len(line) > 1:
+ span[1] = f"${span[1]}$"
+ else:
+ span[1] = f"\n${span[1]}$"
- delim = delimiter_map.get(block_label, " ")
- line_text = delim.join([span[1] for span in line])
+ line_text = " ".join([span[1] for span in line])
- if block_label != "reference":
- line_text = remove_extra_space(line_text)
+ need_new_line = False
+ if (
+ block_right_coordinate - last_span_box[2] > last_line_span_limit
+ and not line_text.endswith("-")
+ and len(line_text) > 0
+ and not is_english_letter(line_text[-1])
+ and not is_non_breaking_punctuation(line_text[-1])
+ ):
+ need_new_line = True
if line_text.endswith("-"):
line_text = line_text[:-1]
- return line_text
+ elif (
+ len(line_text) > 0 and is_english_letter(line_text[-1])
+ ) or line_text.endswith("$"):
+ line_text += " "
+
+ return line_text, need_new_line
-def split_boxes_if_x_contained(boxes, offset=1e-5):
+def split_boxes_by_projection(spans: List[List[int]], direction, offset=1e-5):
"""
Check if there is any complete containment in the x-direction
between the bounding boxes and split the containing box accordingly.
Args:
- boxes (list of lists): Each element is a list containing an ndarray of length 4, a description, and a label.
+ spans (list of lists): Each element is a list containing an ndarray of length 4, a text string, and a label.
+ direction: 'horizontal' or 'vertical', indicating whether the spans are arranged horizontally or vertically.
offset (float): A small offset value to ensure that the split boxes are not too close to the original boxes.
Returns:
A new list of boxes, including split boxes, with the same `rec_text` and `label` attributes.
"""
- def is_x_contained(box_a, box_b):
+ def is_projection_contained(box_a, box_b, start_idx, end_idx):
"""Check if box_a completely contains box_b in the x-direction."""
- return box_a[0][0] <= box_b[0][0] and box_a[0][2] >= box_b[0][2]
+ return box_a[start_idx] <= box_b[start_idx] and box_a[end_idx] >= box_b[end_idx]
new_boxes = []
+ if direction == "horizontal":
+ projection_start_index, projection_end_index = 0, 2
+ else:
+ projection_start_index, projection_end_index = 1, 3
- for i in range(len(boxes)):
- box_a = boxes[i]
+ for i in range(len(spans)):
+ span = spans[i]
is_split = False
- for j in range(len(boxes)):
- if i == j:
- continue
- box_b = boxes[j]
- if is_x_contained(box_a, box_b):
+ for j in range(i, len(spans)):
+ box_b = spans[j][0]
+ box_a, text, label = span
+ if is_projection_contained(
+ box_a, box_b, projection_start_index, projection_end_index
+ ):
is_split = True
# Split box_a based on the x-coordinates of box_b
- if box_a[0][0] < box_b[0][0]:
- w = box_b[0][0] - offset - box_a[0][0]
+ if box_a[projection_start_index] < box_b[projection_start_index]:
+ w = (
+ box_b[projection_start_index]
+ - offset
+ - box_a[projection_start_index]
+ )
if w > 1:
+ new_bbox = box_a.copy()
+ new_bbox[projection_end_index] = (
+ box_b[projection_start_index] - offset
+ )
new_boxes.append(
[
- np.array(
- [
- box_a[0][0],
- box_a[0][1],
- box_b[0][0] - offset,
- box_a[0][3],
- ]
- ),
- box_a[1],
- box_a[2],
+ np.array(new_bbox),
+ text,
+ label,
]
)
- if box_a[0][2] > box_b[0][2]:
- w = box_a[0][2] - box_b[0][2] + offset
+ if box_a[projection_end_index] > box_b[projection_end_index]:
+ w = (
+ box_a[projection_end_index]
+ - box_b[projection_end_index]
+ + offset
+ )
if w > 1:
- box_a = [
- np.array(
- [
- box_b[0][2] + offset,
- box_a[0][1],
- box_a[0][2],
- box_a[0][3],
- ]
- ),
- box_a[1],
- box_a[2],
+ box_a[projection_start_index] = (
+ box_b[projection_end_index] + offset
+ )
+ span = [
+ np.array(box_a),
+ text,
+ label,
]
- if j == len(boxes) - 1 and is_split:
- new_boxes.append(box_a)
+ if j == len(spans) - 1 and is_split:
+ new_boxes.append(span)
if not is_split:
- new_boxes.append(box_a)
+ new_boxes.append(span)
return new_boxes
@@ -451,10 +582,10 @@ def _get_minbox_if_overlap_by_ratio(
The selected bounding box or None if the overlap ratio is not exceeded.
"""
# Calculate the areas of both bounding boxes
- area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
- area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
+ area1 = caculate_bbox_area(bbox1)
+ area2 = caculate_bbox_area(bbox2)
# Calculate the overlap ratio using a helper function
- overlap_ratio = _calculate_overlap_area_div_minbox_area_ratio(bbox1, bbox2)
+ overlap_ratio = calculate_overlap_ratio(bbox1, bbox2, mode="small")
# Check if the overlap ratio exceeds the threshold
if overlap_ratio > ratio:
if (area1 <= area2 and smaller) or (area1 >= area2 and not smaller):
@@ -496,8 +627,17 @@ def remove_overlap_blocks(
smaller=smaller,
)
if overlap_box_index is not None:
- # Determine which block to remove based on overlap_box_index
- if overlap_box_index == 1:
+ if block1["label"] == "image" and block2["label"] == "image":
+ # Determine which block to remove based on overlap_box_index
+ if overlap_box_index == 1:
+ drop_index = i
+ else:
+ drop_index = j
+ elif block1["label"] == "image" and block2["label"] != "image":
+ drop_index = i
+ elif block1["label"] != "image" and block2["label"] == "image":
+ drop_index = j
+ elif overlap_box_index == 1:
drop_index = i
else:
drop_index = j
@@ -556,39 +696,96 @@ def get_bbox_intersection(bbox1, bbox2, return_format="bbox"):
raise ValueError("return_format must be either 'bbox' or 'poly'.")
-def update_layout_order_config_block_index(
- config: dict, block_label: str, block_idx: int
-) -> None:
+def shrink_supplement_region_bbox(
+ supplement_region_bbox,
+ ref_region_bbox,
+ image_width,
+ image_height,
+ block_idxes_set,
+ block_bboxes,
+) -> List:
+ """
+ Shrink the supplement region bbox according to the reference region bbox and match the block bboxes.
- doc_title_labels = config["doc_title_labels"]
- paragraph_title_labels = config["paragraph_title_labels"]
- vision_labels = config["vision_labels"]
- vision_title_labels = config["vision_title_labels"]
- header_labels = config["header_labels"]
- unordered_labels = config["unordered_labels"]
- footer_labels = config["footer_labels"]
- text_labels = config["text_labels"]
- text_title_labels = doc_title_labels + paragraph_title_labels
- config["text_title_labels"] = text_title_labels
-
- if block_label in doc_title_labels:
- config["doc_title_block_idxes"].append(block_idx)
- if block_label in paragraph_title_labels:
- config["paragraph_title_block_idxes"].append(block_idx)
- if block_label in vision_labels:
- config["vision_block_idxes"].append(block_idx)
- if block_label in vision_title_labels:
- config["vision_title_block_idxes"].append(block_idx)
- if block_label in unordered_labels:
- config["unordered_block_idxes"].append(block_idx)
- if block_label in text_title_labels:
- config["text_title_block_idxes"].append(block_idx)
- if block_label in text_labels:
- config["text_block_idxes"].append(block_idx)
- if block_label in header_labels:
- config["header_block_idxes"].append(block_idx)
- if block_label in footer_labels:
- config["footer_block_idxes"].append(block_idx)
+ Args:
+ supplement_region_bbox (list): The supplement region bbox.
+ ref_region_bbox (list): The reference region bbox.
+ image_width (int): The width of the image.
+ image_height (int): The height of the image.
+ block_idxes_set (set): The indexes of the blocks that intersect with the region bbox.
+ block_bboxes (dict): The dictionary of block bboxes.
+
+ Returns:
+ list: The new region bbox and the matched block idxes.
+ """
+ x1, y1, x2, y2 = supplement_region_bbox
+ x1_prime, y1_prime, x2_prime, y2_prime = ref_region_bbox
+ index_conversion_map = {0: 2, 1: 3, 2: 0, 3: 1}
+ edge_distance_list = [
+ (x1_prime - x1) / image_width,
+ (y1_prime - y1) / image_height,
+ (x2 - x2_prime) / image_width,
+ (y2 - y2_prime) / image_height,
+ ]
+ edge_distance_list_tmp = edge_distance_list[:]
+ min_distance = min(edge_distance_list)
+ src_index = index_conversion_map[edge_distance_list.index(min_distance)]
+ if len(block_idxes_set) == 0:
+ return supplement_region_bbox, []
+ for _ in range(3):
+ dst_index = index_conversion_map[src_index]
+ tmp_region_bbox = supplement_region_bbox[:]
+ tmp_region_bbox[dst_index] = ref_region_bbox[src_index]
+ iner_block_idxes, split_block_idxes = [], []
+ for block_idx in block_idxes_set:
+ overlap_ratio = calculate_overlap_ratio(
+ tmp_region_bbox, block_bboxes[block_idx], mode="small"
+ )
+ if overlap_ratio > REGION_SETTINGS.get(
+ "match_block_overlap_ratio_threshold", 0.8
+ ):
+ iner_block_idxes.append(block_idx)
+ elif overlap_ratio > REGION_SETTINGS.get(
+ "split_block_overlap_ratio_threshold", 0.4
+ ):
+ split_block_idxes.append(block_idx)
+
+ if len(iner_block_idxes) > 0:
+ if len(split_block_idxes) > 0:
+ for split_block_idx in split_block_idxes:
+ split_block_bbox = block_bboxes[split_block_idx]
+ x1, y1, x2, y2 = tmp_region_bbox
+ x1_prime, y1_prime, x2_prime, y2_prime = split_block_bbox
+ edge_distance_list = [
+ (x1_prime - x1) / image_width,
+ (y1_prime - y1) / image_height,
+ (x2 - x2_prime) / image_width,
+ (y2 - y2_prime) / image_height,
+ ]
+ max_distance = max(edge_distance_list)
+ src_index = edge_distance_list.index(max_distance)
+ dst_index = index_conversion_map[src_index]
+ tmp_region_bbox[dst_index] = split_block_bbox[src_index]
+ tmp_region_bbox, iner_idxes = shrink_supplement_region_bbox(
+ tmp_region_bbox,
+ ref_region_bbox,
+ image_width,
+ image_height,
+ iner_block_idxes,
+ block_bboxes,
+ )
+ if len(iner_idxes) == 0:
+ continue
+ matched_bboxes = [block_bboxes[idx] for idx in iner_block_idxes]
+ supplement_region_bbox = calculate_minimum_enclosing_bbox(matched_bboxes)
+ break
+ else:
+ edge_distance_list_tmp = [
+ x for x in edge_distance_list_tmp if x != min_distance
+ ]
+ min_distance = min(edge_distance_list_tmp)
+ src_index = index_conversion_map[edge_distance_list.index(min_distance)]
+ return supplement_region_bbox, iner_block_idxes
def update_region_box(bbox, region_box):
@@ -618,51 +815,69 @@ def convert_formula_res_to_ocr_format(formula_res_list: List, ocr_res: dict):
(x_min, y_max),
]
ocr_res["dt_polys"].append(poly_points)
- ocr_res["rec_texts"].append(f"${formula_res['rec_formula']}$")
- ocr_res["rec_boxes"] = np.vstack(
- (ocr_res["rec_boxes"], [formula_res["dt_polys"]])
- )
+ ocr_res["rec_texts"].append(f"{formula_res['rec_formula']}")
+ if ocr_res["rec_boxes"].size == 0:
+ ocr_res["rec_boxes"] = np.array(formula_res["dt_polys"])
+ else:
+ ocr_res["rec_boxes"] = np.vstack(
+ (ocr_res["rec_boxes"], [formula_res["dt_polys"]])
+ )
ocr_res["rec_labels"].append("formula")
ocr_res["rec_polys"].append(poly_points)
ocr_res["rec_scores"].append(1)
def caculate_bbox_area(bbox):
- x1, y1, x2, y2 = bbox
+ x1, y1, x2, y2 = map(float, bbox)
area = abs((x2 - x1) * (y2 - y1))
return area
-def get_show_color(label: str) -> Tuple:
- label_colors = {
- # Medium Blue (from 'titles_list')
- "paragraph_title": (102, 102, 255, 100),
- "doc_title": (255, 248, 220, 100), # Cornsilk
- # Light Yellow (from 'tables_caption_list')
- "table_title": (255, 255, 102, 100),
- # Sky Blue (from 'imgs_caption_list')
- "figure_title": (102, 178, 255, 100),
- "chart_title": (221, 160, 221, 100), # Plum
- "vision_footnote": (144, 238, 144, 100), # Light Green
- # Deep Purple (from 'texts_list')
- "text": (153, 0, 76, 100),
- # Bright Green (from 'interequations_list')
- "formula": (0, 255, 0, 100),
- "abstract": (255, 239, 213, 100), # Papaya Whip
- # Medium Green (from 'lists_list' and 'indexs_list')
- "content": (40, 169, 92, 100),
- # Neutral Gray (from 'dropped_bbox_list')
- "seal": (158, 158, 158, 100),
- # Olive Yellow (from 'tables_body_list')
- "table": (204, 204, 0, 100),
- # Bright Green (from 'imgs_body_list')
- "image": (153, 255, 51, 100),
- # Bright Green (from 'imgs_body_list')
- "figure": (153, 255, 51, 100),
- "chart": (216, 191, 216, 100), # Thistle
- # Pale Yellow-Green (from 'tables_footnote_list')
- "reference": (229, 255, 204, 100),
- "algorithm": (255, 250, 240, 100), # Floral White
- }
+def get_show_color(label: str, order_label=False) -> Tuple:
+ if order_label:
+ label_colors = {
+ "doc_title": (255, 248, 220, 100), # Cornsilk
+ "doc_title_text": (255, 239, 213, 100),
+ "paragraph_title": (102, 102, 255, 100),
+ "sub_paragraph_title": (102, 178, 255, 100),
+ "vision": (153, 255, 51, 100),
+ "vision_title": (144, 238, 144, 100), # Light Green
+ "vision_footnote": (144, 238, 144, 100), # Light Green
+ "normal_text": (153, 0, 76, 100),
+ "cross_layout": (53, 218, 207, 100), # Thistle
+ "cross_reference": (221, 160, 221, 100), # Floral White
+ }
+ else:
+ label_colors = {
+ # Medium Blue (from 'titles_list')
+ "paragraph_title": (102, 102, 255, 100),
+ "doc_title": (255, 248, 220, 100), # Cornsilk
+ # Light Yellow (from 'tables_caption_list')
+ "table_title": (255, 255, 102, 100),
+ # Sky Blue (from 'imgs_caption_list')
+ "figure_title": (102, 178, 255, 100),
+ "chart_title": (221, 160, 221, 100), # Plum
+ "vision_footnote": (144, 238, 144, 100), # Light Green
+ # Deep Purple (from 'texts_list')
+ "text": (153, 0, 76, 100),
+ # Bright Green (from 'interequations_list')
+ "formula": (0, 255, 0, 100),
+ "abstract": (255, 239, 213, 100), # Papaya Whip
+ # Medium Green (from 'lists_list' and 'indexs_list')
+ "content": (40, 169, 92, 100),
+ # Neutral Gray (from 'dropped_bbox_list')
+ "seal": (158, 158, 158, 100),
+ # Olive Yellow (from 'tables_body_list')
+ "table": (204, 204, 0, 100),
+ # Bright Green (from 'imgs_body_list')
+ "image": (153, 255, 51, 100),
+ # Bright Green (from 'imgs_body_list')
+ "figure": (153, 255, 51, 100),
+ "chart": (216, 191, 216, 100), # Thistle
+ # Pale Yellow-Green (from 'tables_footnote_list')
+ "reference": (229, 255, 204, 100),
+ # "reference_content": (229, 255, 204, 100),
+ "algorithm": (255, 250, 240, 100), # Floral White
+ }
default_color = (158, 158, 158, 100)
return label_colors.get(label, default_color)
diff --git a/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py b/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py
index 0f10610809..1f333db496 100644
--- a/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py
+++ b/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py
@@ -12,77 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, List, Tuple, Union
+from typing import List, Tuple
import numpy as np
-from ..result_v2 import LayoutParsingBlock
-
-
-def calculate_projection_iou(
- bbox1: List[float], bbox2: List[float], direction: str = "horizontal"
-) -> float:
- """
- Calculate the IoU of lines between two bounding boxes.
-
- Args:
- bbox1 (List[float]): First bounding box [x_min, y_min, x_max, y_max].
- bbox2 (List[float]): Second bounding box [x_min, y_min, x_max, y_max].
- direction (str): direction of the projection, "horizontal" or "vertical".
-
- Returns:
- float: Line IoU. Returns 0 if there is no overlap.
- """
- start_index, end_index = 1, 3
- if direction == "horizontal":
- start_index, end_index = 0, 2
-
- intersection_start = max(bbox1[start_index], bbox2[start_index])
- intersection_end = min(bbox1[end_index], bbox2[end_index])
- overlap = intersection_end - intersection_start
- if overlap <= 0:
- return 0
- union_width = max(bbox1[end_index], bbox2[end_index]) - min(
- bbox1[start_index], bbox2[start_index]
- )
-
- return overlap / union_width if union_width > 0 else 0.0
-
-
-def calculate_iou(
- bbox1: Union[list, tuple],
- bbox2: Union[list, tuple],
-) -> float:
- """
- Calculate the Intersection over Union (IoU) of two bounding boxes.
-
- Parameters:
- bbox1 (list or tuple): The first bounding box, format [x_min, y_min, x_max, y_max]
- bbox2 (list or tuple): The second bounding box, format [x_min, y_min, x_max, y_max]
-
- Returns:
- float: The IoU value between the two bounding boxes
- """
-
- x_min_inter = max(bbox1[0], bbox2[0])
- y_min_inter = max(bbox1[1], bbox2[1])
- x_max_inter = min(bbox1[2], bbox2[2])
- y_max_inter = min(bbox1[3], bbox2[3])
-
- inter_width = max(0, x_max_inter - x_min_inter)
- inter_height = max(0, y_max_inter - y_min_inter)
-
- inter_area = inter_width * inter_height
-
- bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
- bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
-
- union_area = bbox1_area + bbox2_area - inter_area
-
- if union_area == 0:
- return 0.0
-
- return inter_area / union_area
+from ..result_v2 import LayoutParsingBlock, LayoutParsingRegion
+from ..setting import BLOCK_LABEL_MAP, XYCUT_SETTINGS
+from ..utils import calculate_projection_overlap_ratio
def get_nearest_edge_distance(
@@ -96,7 +32,7 @@ def get_nearest_edge_distance(
Args:
bbox1 (list): The bounding box coordinates [x1, y1, x2, y2] of the input object.
bbox2 (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against.
- weight (list, optional): Directional weights for the edge distances [left, right, up, down]. Defaults to [1, 1, 1, 1].
+ weight (list, optional): directional weights for the edge distances [left, right, up, down]. Defaults to [1, 1, 1, 1].
Returns:
float: The calculated minimum edge distance between the bounding boxes.
@@ -104,8 +40,8 @@ def get_nearest_edge_distance(
x1, y1, x2, y2 = bbox1
x1_prime, y1_prime, x2_prime, y2_prime = bbox2
min_x_distance, min_y_distance = 0, 0
- horizontal_iou = calculate_projection_iou(bbox1, bbox2, "horizontal")
- vertical_iou = calculate_projection_iou(bbox1, bbox2, "vertical")
+ horizontal_iou = calculate_projection_overlap_ratio(bbox1, bbox2, "horizontal")
+ vertical_iou = calculate_projection_overlap_ratio(bbox1, bbox2, "vertical")
if horizontal_iou > 0 and vertical_iou > 0:
return 0.0
if horizontal_iou == 0:
@@ -319,8 +255,7 @@ def recursive_xy_cut(
def reference_insert(
block: LayoutParsingBlock,
sorted_blocks: List[LayoutParsingBlock],
- config: Dict,
- median_width: float = 0.0,
+ **kwargs,
):
"""
Insert reference block into sorted blocks based on the distance between the block and the nearest sorted block.
@@ -350,8 +285,7 @@ def reference_insert(
def manhattan_insert(
block: LayoutParsingBlock,
sorted_blocks: List[LayoutParsingBlock],
- config: Dict,
- median_width: float = 0.0,
+ **kwargs,
):
"""
Insert a block into a sorted list of blocks based on the Manhattan distance between the block and the nearest sorted block.
@@ -380,8 +314,7 @@ def manhattan_insert(
def weighted_distance_insert(
block: LayoutParsingBlock,
sorted_blocks: List[LayoutParsingBlock],
- config: Dict,
- median_width: float = 0.0,
+ region: LayoutParsingRegion,
):
"""
Insert a block into a sorted list of blocks based on the weighted distance between the block and the nearest sorted block.
@@ -395,11 +328,8 @@ def weighted_distance_insert(
Returns:
sorted_blocks: The updated sorted blocks after insertion.
"""
- doc_title_labels = config.get("doc_title_labels", [])
- paragraph_title_labels = config.get("paragraph_title_labels", [])
- vision_labels = config.get("vision_labels", [])
- xy_cut_block_labels = config.get("xy_cut_block_labels", [])
- tolerance_len = config.get("tolerance_len", 2)
+
+ tolerance_len = XYCUT_SETTINGS["edge_distance_compare_tolerance_len"]
x1, y1, x2, y2 = block.bbox
min_weighted_distance, min_edge_distance, min_up_edge_distance = (
float("inf"),
@@ -412,36 +342,43 @@ def weighted_distance_insert(
x1_prime, y1_prime, x2_prime, y2_prime = sorted_block.bbox
# Calculate edge distance
- weight = _get_weights(block.region_label, block.direction)
+ weight = _get_weights(block.order_label, block.direction)
edge_distance = get_nearest_edge_distance(block.bbox, sorted_block.bbox, weight)
- if block.label in doc_title_labels:
- disperse = max(1, median_width)
+ if block.label in BLOCK_LABEL_MAP["doc_title_labels"]:
+ disperse = max(1, region.text_line_width)
tolerance_len = max(tolerance_len, disperse)
if block.label == "abstract":
tolerance_len *= 2
edge_distance = max(0.1, edge_distance) * 10
# Calculate up edge distances
- up_edge_distance = y1_prime
- left_edge_distance = x1_prime
+ up_edge_distance = y1_prime if region.direction == "horizontal" else -x2_prime
+ left_edge_distance = x1_prime if region.direction == "horizontal" else y1_prime
+ is_below_sorted_block = (
+ y2_prime < y1 if region.direction == "horizontal" else x1_prime > x2
+ )
+
if (
- block.label in xy_cut_block_labels
- or block.label in doc_title_labels
- or block.label in paragraph_title_labels
- or block.label in vision_labels
- ) and y1 > y2_prime:
- up_edge_distance = -y2_prime
- left_edge_distance = -x2_prime
+ block.label not in BLOCK_LABEL_MAP["unordered_labels"]
+ or block.label in BLOCK_LABEL_MAP["doc_title_labels"]
+ or block.label in BLOCK_LABEL_MAP["paragraph_title_labels"]
+ or block.label in BLOCK_LABEL_MAP["vision_labels"]
+ ) and is_below_sorted_block:
+ up_edge_distance = -up_edge_distance
+ left_edge_distance = -left_edge_distance
if abs(min_up_edge_distance - up_edge_distance) <= tolerance_len:
up_edge_distance = min_up_edge_distance
# Calculate weighted distance
weighted_distance = (
- +edge_distance * config.get("edge_weight", 10**4)
- + up_edge_distance * config.get("up_edge_weight", 1)
- + left_edge_distance * config.get("left_edge_weight", 0.0001)
+ +edge_distance
+ * XYCUT_SETTINGS["distance_weight_map"].get("edge_weight", 10**4)
+ + up_edge_distance
+ * XYCUT_SETTINGS["distance_weight_map"].get("up_edge_weight", 1)
+ + left_edge_distance
+ * XYCUT_SETTINGS["distance_weight_map"].get("left_edge_weight", 0.0001)
)
min_edge_distance = min(edge_distance, min_edge_distance)
@@ -490,7 +427,7 @@ def sort_child_blocks(blocks, direction="horizontal") -> List[LayoutParsingBlock
Args:
blocks: A list of LayoutParsingBlock objects representing the child blocks.
- direction: Orientation of the blocks ('horizontal' or 'vertical'). Default is 'horizontal'.
+ direction: direction of the blocks ('horizontal' or 'vertical'). Default is 'horizontal'.
Returns:
sorted_blocks: A sorted list of LayoutParsingBlock objects.
"""
@@ -518,7 +455,7 @@ def sort_child_blocks(blocks, direction="horizontal") -> List[LayoutParsingBlock
def _get_weights(label, dircetion="horizontal"):
- """Define weights based on the label and orientation."""
+ """Define weights based on the label and direction."""
if label == "doc_title":
return (
[1, 0.1, 0.1, 1] if dircetion == "horizontal" else [0.2, 0.1, 1, 1]
@@ -583,6 +520,26 @@ def sort_blocks(blocks, median_width=None, reverse=False):
return blocks
+def sort_normal_blocks(blocks, text_line_height, text_line_width, region_direction):
+ if region_direction == "horizontal":
+ blocks.sort(
+ key=lambda x: (
+ x.bbox[1] // text_line_height,
+ x.bbox[0] // text_line_width,
+ x.bbox[1] ** 2 + x.bbox[0] ** 2,
+ ),
+ )
+ else:
+ blocks.sort(
+ key=lambda x: (
+ -x.bbox[0] // text_line_width,
+ x.bbox[1] // text_line_height,
+ -(x.bbox[2] ** 2 + x.bbox[1] ** 2),
+ ),
+ )
+ return blocks
+
+
def get_cut_blocks(
blocks, cut_direction, cut_coordinates, overall_region_box, mask_labels=[]
):
@@ -604,8 +561,7 @@ def get_cut_blocks(
# 0: horizontal, 1: vertical
cut_aixis = 0 if cut_direction == "horizontal" else 1
blocks.sort(key=lambda x: x.bbox[cut_aixis + 2])
- overall_max_axis_coordinate = overall_region_box[cut_aixis + 2]
- cut_coordinates.append(overall_max_axis_coordinate)
+ cut_coordinates.append(float("inf"))
cut_coordinates = list(set(cut_coordinates))
cut_coordinates.sort()
@@ -618,7 +574,7 @@ def get_cut_blocks(
block = blocks[block_idx]
if block.bbox[cut_aixis + 2] > cut_coordinate:
break
- elif block.region_label not in mask_labels:
+ elif block.order_label not in mask_labels:
group_blocks.append(block)
block_idx += 1
cut_idx = block_idx
@@ -628,44 +584,64 @@ def get_cut_blocks(
return cuted_list
-def split_sub_region_blocks(
- blocks: List[LayoutParsingBlock],
- config: Dict,
+def add_split_block(
+ blocks: List[LayoutParsingBlock], region_bbox: List[int]
+) -> List[LayoutParsingBlock]:
+ block_bboxes = np.array([block.bbox for block in blocks])
+ discontinuous = calculate_discontinuous_projection(
+ block_bboxes, direction="vertical"
+ )
+ current_interval = discontinuous[0]
+ for interval in discontinuous[1:]:
+ gap_len = interval[0] - current_interval[1]
+ if gap_len > 40:
+ x1, _, x2, __ = region_bbox
+ y1 = current_interval[1] + 5
+ y2 = interval[0] - 5
+ bbox = [x1, y1, x2, y2]
+ split_block = LayoutParsingBlock(label="split", bbox=bbox)
+ blocks.append(split_block)
+ current_interval = interval
+
+
+def get_nearest_blocks(
+ block: LayoutParsingBlock,
+ ref_blocks: List[LayoutParsingBlock],
+ overlap_threshold,
+ direction="horizontal",
) -> List:
"""
- Split blocks into sub regions based on the all layout region bbox.
-
+ Get the adjacent blocks with the same direction as the current block.
Args:
- blocks (List[LayoutParsingBlock]): A list of blocks.
- config (Dict): Configuration dictionary.
+ block (LayoutParsingBlock): The current block.
+ blocks (List[LayoutParsingBlock]): A list of all blocks.
+ ref_block_idxes (List[int]): A list of indices of reference blocks.
+ iou_threshold (float): The IOU threshold to determine if two blocks are considered adjacent.
Returns:
- List: A list of lists of blocks, each representing a sub region.
+ Int: The index of the previous block with same direction.
+ Int: The index of the following block with same direction.
"""
+ prev_blocks: List[LayoutParsingBlock] = []
+ post_blocks: List[LayoutParsingBlock] = []
+ sort_index = 1 if direction == "horizontal" else 0
+ for ref_block in ref_blocks:
+ if ref_block.index == block.index:
+ continue
+ overlap_ratio = calculate_projection_overlap_ratio(
+ block.bbox, ref_block.bbox, direction, mode="small"
+ )
+ if overlap_ratio > overlap_threshold:
+ if ref_block.bbox[sort_index] <= block.bbox[sort_index]:
+ prev_blocks.append(ref_block)
+ else:
+ post_blocks.append(ref_block)
- region_bbox = config.get("all_layout_region_box", None)
- x1, y1, x2, y2 = region_bbox
- region_width = x2 - x1
- region_height = y2 - y1
-
- if region_width < region_height:
- return [(blocks, region_bbox)]
-
- all_boxes = np.array([block.bbox for block in blocks])
- discontinuous = calculate_discontinuous_projection(all_boxes, direction="vertical")
- if len(discontinuous) > 1:
- cut_coordinates = []
- region_boxes = []
- current_interval = discontinuous[0]
- for x1, x2 in discontinuous[1:]:
- if x1 - current_interval[1] > 100:
- cut_coordinates.extend([x1, x2])
- region_boxes.append([x1, y1, x2, y2])
- current_interval = [x1, x2]
- region_blocks = get_cut_blocks(blocks, "vertical", cut_coordinates, region_bbox)
-
- return [region_info for region_info in zip(region_blocks, region_boxes)]
- else:
- return [(blocks, region_bbox)]
+ if prev_blocks:
+ prev_blocks.sort(key=lambda x: x.bbox[sort_index], reverse=True)
+ if post_blocks:
+ post_blocks.sort(key=lambda x: x.bbox[sort_index])
+
+ return prev_blocks, post_blocks
def get_adjacent_blocks_by_direction(
@@ -701,9 +677,9 @@ def get_adjacent_blocks_by_direction(
for ref_block_idx in ref_block_idxes:
ref_block = blocks[ref_block_idx]
ref_block_direction = ref_block.direction
- if ref_block.region_label in child_labels:
+ if ref_block.order_label in child_labels:
continue
- match_block_iou = calculate_projection_iou(
+ match_block_iou = calculate_projection_overlap_ratio(
block.bbox,
ref_block.bbox,
ref_block_direction,
@@ -711,7 +687,7 @@ def get_adjacent_blocks_by_direction(
child_match_distance_tolerance_len = block.short_side_length / 10
- if block.region_label == "vision":
+ if block.order_label == "vision":
if ref_block.num_of_lines == 1:
gap_tolerance_len = ref_block.short_side_length * 2
else:
@@ -770,11 +746,8 @@ def get_adjacent_blocks_by_direction(
def update_doc_title_child_blocks(
- blocks: List[LayoutParsingBlock],
block: LayoutParsingBlock,
- prev_idx: int,
- post_idx: int,
- config: dict,
+ region: LayoutParsingRegion,
) -> None:
"""
Update the child blocks of a document title block.
@@ -785,6 +758,7 @@ def update_doc_title_child_blocks(
3. Their short side length should be less than 80% of the parent's short side length.
4. Their long side length should be less than 150% of the parent's long side length.
5. The child block must be text block.
+ 6. The nearest edge distance should be less than 2 times of the text line height.
Args:
blocks (List[LayoutParsingBlock]): overall blocks.
@@ -797,10 +771,22 @@ def update_doc_title_child_blocks(
None
"""
- for idx in [prev_idx, post_idx]:
- if idx is None:
+ ref_blocks = [region.block_map[idx] for idx in region.normal_text_block_idxes]
+ overlap_threshold = XYCUT_SETTINGS["child_block_overlap_ratio_threshold"]
+ prev_blocks, post_blocks = get_nearest_blocks(
+ block, ref_blocks, overlap_threshold, block.direction
+ )
+ prev_block = None
+ post_block = None
+
+ if prev_blocks:
+ prev_block = prev_blocks[0]
+ if post_blocks:
+ post_block = post_blocks[0]
+
+ for ref_block in [prev_block, post_block]:
+ if ref_block is None:
continue
- ref_block = blocks[idx]
with_seem_direction = ref_block.direction == block.direction
short_side_length_condition = (
@@ -812,23 +798,24 @@ def update_doc_title_child_blocks(
or ref_block.long_side_length > 1.5 * block.long_side_length
)
+ nearest_edge_distance = get_nearest_edge_distance(block.bbox, ref_block.bbox)
+
if (
with_seem_direction
+ and ref_block.label in BLOCK_LABEL_MAP["text_labels"]
and short_side_length_condition
and long_side_length_condition
and ref_block.num_of_lines < 3
+ and nearest_edge_distance < ref_block.text_line_height * 2
):
- ref_block.region_label = "doc_title_text"
+ ref_block.order_label = "doc_title_text"
block.append_child_block(ref_block)
- config["text_block_idxes"].remove(idx)
+ region.normal_text_block_idxes.remove(ref_block.index)
def update_paragraph_title_child_blocks(
- blocks: List[LayoutParsingBlock],
block: LayoutParsingBlock,
- prev_idx: int,
- post_idx: int,
- config: dict,
+ region: LayoutParsingRegion,
) -> None:
"""
Update the child blocks of a paragraph title block.
@@ -849,25 +836,39 @@ def update_paragraph_title_child_blocks(
None
"""
- paragraph_title_labels = config.get("paragraph_title_labels", [])
- for idx in [prev_idx, post_idx]:
- if idx is None:
- continue
- ref_block = blocks[idx]
- with_seem_direction = ref_block.direction == block.direction
- if with_seem_direction and ref_block.label in paragraph_title_labels:
- ref_block.region_label = "sub_paragraph_title"
- block.append_child_block(ref_block)
- config["paragraph_title_block_idxes"].remove(idx)
+ if block.order_label == "sub_paragraph_title":
+ return
+ ref_blocks = [
+ region.block_map[idx]
+ for idx in region.paragraph_title_block_idxes + region.normal_text_block_idxes
+ ]
+ overlap_threshold = XYCUT_SETTINGS["child_block_overlap_ratio_threshold"]
+ prev_blocks, post_blocks = get_nearest_blocks(
+ block, ref_blocks, overlap_threshold, block.direction
+ )
+ for ref_blocks in [prev_blocks, post_blocks]:
+ for ref_block in ref_blocks:
+ if ref_block.label not in BLOCK_LABEL_MAP["paragraph_title_labels"]:
+ break
+ min_text_line_height = min(
+ block.text_line_height, ref_block.text_line_height
+ )
+ nearest_edge_distance = get_nearest_edge_distance(
+ block.bbox, ref_block.bbox
+ )
+ with_seem_direction = ref_block.direction == block.direction
+ if (
+ with_seem_direction
+ and nearest_edge_distance <= min_text_line_height * 1.5
+ ):
+ ref_block.order_label = "sub_paragraph_title"
+ block.append_child_block(ref_block)
+ region.paragraph_title_block_idxes.remove(ref_block.index)
def update_vision_child_blocks(
- blocks: List[LayoutParsingBlock],
block: LayoutParsingBlock,
- ref_block_idxes: List[int],
- prev_idx: int,
- post_idx: int,
- config: dict,
+ region: LayoutParsingRegion,
) -> None:
"""
Update the child blocks of a paragraph title block.
@@ -896,83 +897,185 @@ def update_vision_child_blocks(
None
"""
- vision_title_labels = config.get("vision_title_labels", [])
- text_labels = config.get("text_labels", [])
- for idx in [prev_idx, post_idx]:
- if idx is None:
- continue
- ref_block = blocks[idx]
- nearest_edge_distance = get_nearest_edge_distance(block.bbox, ref_block.bbox)
- block_center = block.get_centroid()
- ref_block_center = ref_block.get_centroid()
- if ref_block.label in vision_title_labels and nearest_edge_distance <= min(
- block.height * 0.5, ref_block.height * 2
- ):
- ref_block.region_label = "vision_title"
- block.append_child_block(ref_block)
- config["vision_title_block_idxes"].remove(idx)
- elif (
- nearest_edge_distance <= 15
- and ref_block.short_side_length < block.short_side_length
- and ref_block.long_side_length < 0.5 * block.long_side_length
- and ref_block.direction == block.direction
- and (
- abs(block_center[0] - ref_block_center[0]) < 10
- or (
- block.bbox[0] - ref_block.bbox[0] < 10
- and ref_block.num_of_lines == 1
- )
- or (
- block.bbox[2] - ref_block.bbox[2] < 10
- and ref_block.num_of_lines == 1
- )
+ ref_blocks = [
+ region.block_map[idx]
+ for idx in region.normal_text_block_idxes + region.vision_title_block_idxes
+ ]
+ overlap_threshold = XYCUT_SETTINGS["child_block_overlap_ratio_threshold"]
+ has_vision_footnote = False
+ has_vision_title = False
+ for direction in [block.direction, block.secondary_direction]:
+ prev_blocks, post_blocks = get_nearest_blocks(
+ block, ref_blocks, overlap_threshold, direction
+ )
+ for ref_block in prev_blocks:
+ if (
+ ref_block.label
+ not in BLOCK_LABEL_MAP["text_labels"]
+ + BLOCK_LABEL_MAP["vision_title_labels"]
+ ):
+ break
+ nearest_edge_distance = get_nearest_edge_distance(
+ block.bbox, ref_block.bbox
)
- ):
- has_vision_footnote = False
- if len(block.child_blocks) > 0:
- for child_block in block.child_blocks:
- if child_block.label in text_labels:
- has_vision_footnote = True
- if not has_vision_footnote:
- ref_block.region_label = "vision_footnote"
+ block_center = block.get_centroid()
+ ref_block_center = ref_block.get_centroid()
+ if ref_block.label in BLOCK_LABEL_MAP["vision_title_labels"]:
+ has_vision_title = True
+ ref_block.order_label = "vision_title"
block.append_child_block(ref_block)
- config["text_block_idxes"].remove(idx)
+ region.vision_title_block_idxes.remove(ref_block.index)
+ if ref_block.label in BLOCK_LABEL_MAP["text_labels"]:
+ if (
+ not has_vision_footnote
+ and nearest_edge_distance <= block.text_line_height * 2
+ and ref_block.short_side_length < block.short_side_length
+ and ref_block.long_side_length < 0.5 * block.long_side_length
+ and ref_block.direction == block.direction
+ and (
+ abs(block_center[0] - ref_block_center[0]) < 10
+ or (
+ block.bbox[0] - ref_block.bbox[0] < 10
+ and ref_block.num_of_lines == 1
+ )
+ or (
+ block.bbox[2] - ref_block.bbox[2] < 10
+ and ref_block.num_of_lines == 1
+ )
+ )
+ ):
+ has_vision_footnote = True
+ ref_block.order_label = "vision_footnote"
+ block.append_child_block(ref_block)
+ region.normal_text_block_idxes.remove(ref_block.index)
+ break
+ for ref_block in post_blocks:
+ if (
+ has_vision_footnote
+ and ref_block.label in BLOCK_LABEL_MAP["text_labels"]
+ ):
+ break
+ nearest_edge_distance = get_nearest_edge_distance(
+ block.bbox, ref_block.bbox
+ )
+ block_center = block.get_centroid()
+ ref_block_center = ref_block.get_centroid()
+ if ref_block.label in BLOCK_LABEL_MAP["vision_title_labels"]:
+ has_vision_title = True
+ ref_block.order_label = "vision_title"
+ block.append_child_block(ref_block)
+ region.vision_title_block_idxes.remove(ref_block.index)
+ if ref_block.label in BLOCK_LABEL_MAP["text_labels"]:
+ if (
+ not has_vision_footnote
+ and nearest_edge_distance <= block.text_line_height * 2
+ and ref_block.short_side_length < block.short_side_length
+ and ref_block.long_side_length < 0.5 * block.long_side_length
+ and ref_block.direction == block.direction
+ and (
+ abs(block_center[0] - ref_block_center[0]) < 10
+ or (
+ block.bbox[0] - ref_block.bbox[0] < 10
+ and ref_block.num_of_lines == 1
+ )
+ or (
+ block.bbox[2] - ref_block.bbox[2] < 10
+ and ref_block.num_of_lines == 1
+ )
+ )
+ ):
+ has_vision_footnote = True
+ ref_block.order_label = "vision_footnote"
+ block.append_child_block(ref_block)
+ region.normal_text_block_idxes.remove(ref_block.index)
+ break
+ if has_vision_title:
+ break
-def calculate_discontinuous_projection(boxes, direction="horizontal") -> List:
+def calculate_discontinuous_projection(
+ boxes, direction="horizontal", return_num=False
+) -> List:
"""
Calculate the discontinuous projection of boxes along the specified direction.
Args:
boxes (ndarray): Array of bounding boxes represented by [[x_min, y_min, x_max, y_max]].
- direction (str): Direction along which to perform the projection ('horizontal' or 'vertical').
+ direction (str): direction along which to perform the projection ('horizontal' or 'vertical').
Returns:
list: List of tuples representing the merged intervals.
"""
+ boxes = np.array(boxes)
if direction == "horizontal":
intervals = boxes[:, [0, 2]]
elif direction == "vertical":
intervals = boxes[:, [1, 3]]
else:
- raise ValueError("Direction must be 'horizontal' or 'vertical'")
+ raise ValueError("direction must be 'horizontal' or 'vertical'")
intervals = intervals[np.argsort(intervals[:, 0])]
merged_intervals = []
+ num = 1
current_start, current_end = intervals[0]
+ num_list = []
for start, end in intervals[1:]:
if start <= current_end:
+ num += 1
current_end = max(current_end, end)
else:
+ num_list.append(num)
merged_intervals.append((current_start, current_end))
+ num = 1
current_start, current_end = start, end
+ num_list.append(num)
merged_intervals.append((current_start, current_end))
+ if return_num:
+ return merged_intervals, num_list
return merged_intervals
+def is_projection_consistent(blocks, intervals, direction="horizontal"):
+
+ for interval in intervals:
+ if direction == "horizontal":
+ start_index, stop_index = 0, 2
+ interval_box = [interval[0], 0, interval[1], 1]
+ else:
+ start_index, stop_index = 1, 3
+ interval_box = [0, interval[0], 1, interval[1]]
+ same_interval_bboxes = []
+ for block in blocks:
+ overlap_ratio = calculate_projection_overlap_ratio(
+ interval_box, block.bbox, direction=direction
+ )
+ if overlap_ratio > 0 and block.label in BLOCK_LABEL_MAP["text_labels"]:
+ same_interval_bboxes.append(block.bbox)
+ start_coordinates = [bbox[start_index] for bbox in same_interval_bboxes]
+ if start_coordinates:
+ min_start_coordinate = min(start_coordinates)
+ max_start_coordinate = max(start_coordinates)
+ is_start_consistent = (
+ False
+ if max_start_coordinate - min_start_coordinate
+ >= abs(interval[0] - interval[1]) * 0.05
+ else True
+ )
+ stop_coordinates = [bbox[stop_index] for bbox in same_interval_bboxes]
+ min_stop_coordinate = min(stop_coordinates)
+ max_stop_coordinate = max(stop_coordinates)
+ if (
+ max_stop_coordinate - min_stop_coordinate
+ >= abs(interval[0] - interval[1]) * 0.05
+ and is_start_consistent
+ ):
+ return False
+ return True
+
+
def shrink_overlapping_boxes(
boxes, direction="horizontal", min_threshold=0, max_threshold=0.1
) -> List:
@@ -981,7 +1084,7 @@ def shrink_overlapping_boxes(
Args:
boxes (ndarray): Array of bounding boxes represented by [[x_min, y_min, x_max, y_max]].
- direction (str): Direction along which to perform the shrinking ('horizontal' or 'vertical').
+ direction (str): direction along which to perform the shrinking ('horizontal' or 'vertical').
min_threshold (float): Minimum threshold for shrinking. Default is 0.
max_threshold (float): Maximum threshold for shrinking. Default is 0.2.
@@ -992,10 +1095,10 @@ def shrink_overlapping_boxes(
for block in boxes[1:]:
x1, y1, x2, y2 = current_block.bbox
x1_prime, y1_prime, x2_prime, y2_prime = block.bbox
- cut_iou = calculate_projection_iou(
+ cut_iou = calculate_projection_overlap_ratio(
current_block.bbox, block.bbox, direction=direction
)
- match_iou = calculate_projection_iou(
+ match_iou = calculate_projection_overlap_ratio(
current_block.bbox,
block.bbox,
direction="horizontal" if direction == "vertical" else "vertical",
diff --git a/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py b/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py
index 1a5ffdd75e..1dec70a522 100644
--- a/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py
+++ b/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py
@@ -12,24 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Tuple
+from typing import Dict, List, Tuple
import numpy as np
-from ..result_v2 import LayoutParsingBlock
+from ..result_v2 import LayoutParsingBlock, LayoutParsingRegion
+from ..setting import BLOCK_LABEL_MAP
+from ..utils import calculate_overlap_ratio, calculate_projection_overlap_ratio
from .utils import (
calculate_discontinuous_projection,
- calculate_iou,
- calculate_projection_iou,
- get_adjacent_blocks_by_direction,
get_cut_blocks,
+ get_nearest_edge_distance,
insert_child_blocks,
+ is_projection_consistent,
manhattan_insert,
recursive_xy_cut,
recursive_yx_cut,
reference_insert,
shrink_overlapping_boxes,
- sort_blocks,
+ sort_normal_blocks,
update_doc_title_child_blocks,
update_paragraph_title_child_blocks,
update_vision_child_blocks,
@@ -38,8 +39,7 @@
def pre_process(
- blocks: List[LayoutParsingBlock],
- config: Dict,
+ region: LayoutParsingRegion,
) -> List:
"""
Preprocess the layout for sorting purposes.
@@ -49,118 +49,116 @@ def pre_process(
2. Match the blocks with their children.
Args:
- blocks (List[LayoutParsingBlock]): A list of LayoutParsingBlock objects representing the layout.
- config (Dict): Configuration parameters that include settings for pre-cutting and sorting.
+ region: LayoutParsingRegion, the layout region to be pre-processed.
Returns:
List: A list of pre-cutted layout blocks list.
"""
- region_bbox = config.get("all_layout_region_box", None)
- region_x_center = (region_bbox[0] + region_bbox[2]) / 2
- region_y_center = (region_bbox[1] + region_bbox[3]) / 2
-
- header_block_idxes = config.get("header_block_idxes", [])
- header_blocks = []
- for idx in header_block_idxes:
- blocks[idx].region_label = "header"
- header_blocks.append(blocks[idx])
-
- unordered_block_idxes = config.get("unordered_block_idxes", [])
- unordered_blocks = []
- for idx in unordered_block_idxes:
- blocks[idx].region_label = "unordered"
- unordered_blocks.append(blocks[idx])
-
- footer_block_idxes = config.get("footer_block_idxes", [])
- footer_blocks = []
- for idx in footer_block_idxes:
- blocks[idx].region_label = "footer"
- footer_blocks.append(blocks[idx])
-
- mask_labels = ["header", "unordered", "footer"]
- child_labels = [
+ mask_labels = [
+ "header",
+ "unordered",
+ "footer",
"vision_footnote",
"sub_paragraph_title",
"doc_title_text",
"vision_title",
]
pre_cut_block_idxes = []
- for block_idx, block in enumerate(blocks):
- if block.label in mask_labels:
- continue
-
- if block.region_label not in child_labels:
- update_region_label(blocks, config, block_idx)
+ block_map = region.block_map
+ blocks: List[LayoutParsingBlock] = list(block_map.values())
+ for block in blocks:
+ if block.order_label not in mask_labels:
+ update_region_label(block, region)
block_direction = block.direction
if block_direction == "horizontal":
- region_bbox_center = region_x_center
tolerance_len = block.long_side_length // 5
else:
- region_bbox_center = region_y_center
tolerance_len = block.short_side_length // 10
block_center = (block.start_coordinate + block.end_coordinate) / 2
- center_offset = abs(block_center - region_bbox_center)
+ center_offset = abs(block_center - region.direction_center_coordinate)
is_centered = center_offset <= tolerance_len
if is_centered:
- pre_cut_block_idxes.append(block_idx)
+ pre_cut_block_idxes.append(block.index)
pre_cut_list = []
- cut_direction = "vertical"
+ cut_direction = region.secondary_direction
cut_coordinates = []
discontinuous = []
- mask_labels = child_labels + mask_labels
all_boxes = np.array(
- [block.bbox for block in blocks if block.region_label not in mask_labels]
+ [block.bbox for block in blocks if block.order_label not in mask_labels]
)
+ if len(all_boxes) == 0:
+ return pre_cut_list
if pre_cut_block_idxes:
- horizontal_cut_num = 0
- for block_idx in pre_cut_block_idxes:
- block = blocks[block_idx]
- horizontal_cut_num += 1 if block.secondary_direction == "horizontal" else 0
- cut_direction = (
- "horizontal"
- if horizontal_cut_num > len(pre_cut_block_idxes) * 0.5
- else "vertical"
- )
- discontinuous = calculate_discontinuous_projection(
- all_boxes, direction=cut_direction
+ discontinuous, num_list = calculate_discontinuous_projection(
+ all_boxes, direction=cut_direction, return_num=True
)
for idx in pre_cut_block_idxes:
- block = blocks[idx]
+ block = block_map[idx]
if (
- block.region_label not in mask_labels
+ block.order_label not in mask_labels
and block.secondary_direction == cut_direction
):
if (
block.secondary_direction_start_coordinate,
block.secondary_direction_end_coordinate,
) in discontinuous:
- cut_coordinates.append(block.secondary_direction_start_coordinate)
- cut_coordinates.append(block.secondary_direction_end_coordinate)
- if not discontinuous:
- discontinuous = calculate_discontinuous_projection(
- all_boxes, direction=cut_direction
- )
- current_interval = discontinuous[0]
- for interval in discontinuous[1:]:
- gap_len = interval[0] - current_interval[1]
- if gap_len > 40:
- cut_coordinates.append(current_interval[1])
- current_interval = interval
- overall_region_box = config.get("all_layout_region_box")
+ idx = discontinuous.index(
+ (
+ block.secondary_direction_start_coordinate,
+ block.secondary_direction_end_coordinate,
+ )
+ )
+ if num_list[idx] == 1:
+ cut_coordinates.append(
+ block.secondary_direction_start_coordinate
+ )
+ cut_coordinates.append(block.secondary_direction_end_coordinate)
+ secondary_discontinuous = calculate_discontinuous_projection(
+ all_boxes, direction=region.direction
+ )
+ if len(secondary_discontinuous) == 1:
+ if not discontinuous:
+ discontinuous = calculate_discontinuous_projection(
+ all_boxes, direction=cut_direction
+ )
+ current_interval = discontinuous[0]
+ for interval in discontinuous[1:]:
+ gap_len = interval[0] - current_interval[1]
+ if gap_len >= region.text_line_height * 5:
+ cut_coordinates.append(current_interval[1])
+ elif gap_len > region.text_line_height * 2:
+ x1, _, x2, __ = region.bbox
+ y1 = current_interval[1]
+ y2 = interval[0]
+ bbox = [x1, y1, x2, y2]
+ ref_interval = interval[0] - current_interval[1]
+ ref_bboxes = []
+ for block in blocks:
+ if get_nearest_edge_distance(bbox, block.bbox) < ref_interval * 2:
+ ref_bboxes.append(block.bbox)
+ discontinuous = calculate_discontinuous_projection(
+ ref_bboxes, direction=region.direction
+ )
+ if len(discontinuous) != 2:
+ cut_coordinates.append(current_interval[1])
+ current_interval = interval
cut_list = get_cut_blocks(
- blocks, cut_direction, cut_coordinates, overall_region_box, mask_labels
+ blocks, cut_direction, cut_coordinates, region.bbox, mask_labels
)
pre_cut_list.extend(cut_list)
+ if region.direction == "vertical":
+ pre_cut_list = pre_cut_list[::-1]
- return header_blocks, pre_cut_list, footer_blocks, unordered_blocks
+ return pre_cut_list
def update_region_label(
- blocks: List[LayoutParsingBlock], config: Dict[str, Any], block_idx: int
+ block: LayoutParsingBlock,
+ region: LayoutParsingRegion,
) -> None:
"""
Update the region label of a block based on its label and match the block with its children.
@@ -173,76 +171,51 @@ def update_region_label(
Returns:
None
"""
-
- # special title block labels
- doc_title_labels = config.get("doc_title_labels", [])
- paragraph_title_labels = config.get("paragraph_title_labels", [])
- vision_labels = config.get("vision_labels", [])
-
- block = blocks[block_idx]
- if block.label in doc_title_labels:
- block.region_label = "doc_title"
- # Force the direction of vision type to be horizontal
- if block.label in vision_labels:
- block.region_label = "vision"
+ if block.label in BLOCK_LABEL_MAP["header_labels"]:
+ block.order_label = "header"
+ elif block.label in BLOCK_LABEL_MAP["doc_title_labels"]:
+ block.order_label = "doc_title"
+ elif (
+ block.label in BLOCK_LABEL_MAP["paragraph_title_labels"]
+ and block.order_label is None
+ ):
+ block.order_label = "paragraph_title"
+ elif block.label in BLOCK_LABEL_MAP["vision_labels"]:
+ block.order_label = "vision"
+ block.num_of_lines = 1
block.update_direction_info()
- # some paragraph title block may be labeled as sub_title, so we need to check if block.region_label is "other"(default).
- if block.label in paragraph_title_labels and block.region_label == "other":
- block.region_label = "paragraph_title"
+ elif block.label in BLOCK_LABEL_MAP["footer_labels"]:
+ block.order_label = "footer"
+ elif block.label in BLOCK_LABEL_MAP["unordered_labels"]:
+ block.order_label = "unordered"
+ else:
+ block.order_label = "normal_text"
# only vision and doc title block can have child block
- if block.region_label not in ["vision", "doc_title", "paragraph_title"]:
+ if block.order_label not in ["vision", "doc_title", "paragraph_title"]:
return
- iou_threshold = config.get("child_block_match_iou_threshold", 0.1)
# match doc title text block
- if block.region_label == "doc_title":
- text_block_idxes = config.get("text_block_idxes", [])
- prev_idx, post_idx = get_adjacent_blocks_by_direction(
- blocks, block_idx, text_block_idxes, iou_threshold
- )
- update_doc_title_child_blocks(blocks, block, prev_idx, post_idx, config)
+ if block.order_label == "doc_title":
+ update_doc_title_child_blocks(block, region)
# match sub title block
- elif block.region_label == "paragraph_title":
- iou_threshold = config.get("sub_title_match_iou_threshold", 0.1)
- paragraph_title_block_idxes = config.get("paragraph_title_block_idxes", [])
- text_block_idxes = config.get("text_block_idxes", [])
- megred_block_idxes = text_block_idxes + paragraph_title_block_idxes
- prev_idx, post_idx = get_adjacent_blocks_by_direction(
- blocks, block_idx, megred_block_idxes, iou_threshold
- )
- update_paragraph_title_child_blocks(blocks, block, prev_idx, post_idx, config)
- # match vision title block
- elif block.region_label == "vision":
- # for matching vision title block
- vision_title_block_idxes = config.get("vision_title_block_idxes", [])
- # for matching vision footnote block
- text_block_idxes = config.get("text_block_idxes", [])
- megred_block_idxes = text_block_idxes + vision_title_block_idxes
- # Some vision title block may be matched with multiple vision title block, so we need to try multiple times
- for i in range(3):
- prev_idx, post_idx = get_adjacent_blocks_by_direction(
- blocks, block_idx, megred_block_idxes, iou_threshold
- )
- update_vision_child_blocks(
- blocks, block, megred_block_idxes, prev_idx, post_idx, config
- )
+ elif block.order_label == "paragraph_title":
+ update_paragraph_title_child_blocks(block, region)
+ # match vision title block and vision footnote block
+ elif block.order_label == "vision":
+ update_vision_child_blocks(block, region)
def get_layout_structure(
blocks: List[LayoutParsingBlock],
- median_width: float,
- config: dict,
- threshold: float = 0.8,
+ region_direction: str,
+ region_secondary_direction: str,
) -> Tuple[List[Dict[str, any]], bool]:
"""
Determine the layout cross column of blocks.
Args:
blocks (List[Dict[str, any]]): List of block dictionaries containing 'label' and 'block_bbox'.
- median_width (float): Median width of text blocks.
- no_mask_labels (List[str]): Labels of blocks to be considered for layout analysis.
- threshold (float): Threshold for determining layout overlap.
Returns:
Tuple[List[Dict[str, any]], bool]: Updated list of blocks with layout information and a boolean
@@ -251,69 +224,86 @@ def get_layout_structure(
blocks.sort(
key=lambda x: (x.bbox[0], x.width),
)
- check_single_layout = {}
- doc_title_labels = config.get("doc_title_labels", [])
- region_box = config.get("all_layout_region_box", [0, 0, 0, 0])
+ mask_labels = ["doc_title", "cross_layout", "cross_reference"]
for block_idx, block in enumerate(blocks):
- cover_count = 0
- match_block_with_threshold_indexes = []
+ if block.order_label in mask_labels:
+ continue
for ref_idx, ref_block in enumerate(blocks):
- if block_idx == ref_idx:
+ if block_idx == ref_idx or ref_block.order_label in mask_labels:
continue
- bbox_iou = calculate_iou(block.bbox, ref_block.bbox)
+ bbox_iou = calculate_overlap_ratio(block.bbox, ref_block.bbox)
if bbox_iou > 0:
- if block.region_label == "vision" or block.area < ref_block.area:
- block.region_label = "cross_text"
+ if ref_block.order_label == "vision":
+ ref_block.order_label = "cross_layout"
+ break
+ if block.order_label == "vision" or block.area < ref_block.area:
+ block.order_label = "cross_layout"
break
- match_projection_iou = calculate_projection_iou(
+ match_projection_iou = calculate_projection_overlap_ratio(
block.bbox,
ref_block.bbox,
- "horizontal",
+ region_direction,
)
-
if match_projection_iou > 0:
- cover_count += 1
- if match_projection_iou > threshold:
- match_block_with_threshold_indexes.append(
- (ref_idx, match_projection_iou),
+ for second_ref_idx, second_ref_block in enumerate(blocks):
+ if (
+ second_ref_idx in [block_idx, ref_idx]
+ or second_ref_block.order_label in mask_labels
+ ):
+ continue
+
+ bbox_iou = calculate_overlap_ratio(
+ block.bbox, second_ref_block.bbox
)
- if ref_block.bbox[2] >= block.bbox[2]:
- break
-
- block_center = (block.bbox[0] + block.bbox[2]) / 2
- region_bbox_center = (region_box[0] + region_box[2]) / 2
- center_offset = abs(block_center - region_bbox_center)
- is_centered = center_offset <= median_width * 0.05
- width_gather_than_median = block.width > median_width * 1.3
-
- if (
- cover_count >= 2
- and block.label not in doc_title_labels
- and (width_gather_than_median != is_centered)
- ):
- block.region_label = (
- "cross_reference" if block.label == "reference" else "cross_text"
- )
- else:
- check_single_layout[block_idx] = match_block_with_threshold_indexes
-
- # Check single-layout block
- for idx, single_layout in check_single_layout.items():
- if single_layout:
- index, match_iou = single_layout[-1]
- if match_iou > 0.9 and blocks[index].region_label == "cross_text":
- blocks[idx].region_label = (
- "cross_reference" if block.label == "reference" else "cross_text"
- )
+ if bbox_iou > 0.1:
+ if second_ref_block.order_label == "vision":
+ second_ref_block.order_label = "cross_layout"
+ break
+ if (
+ block.order_label == "vision"
+ or block.area < second_ref_block.area
+ ):
+ block.order_label = "cross_layout"
+ break
+
+ second_match_projection_iou = calculate_projection_overlap_ratio(
+ block.bbox,
+ second_ref_block.bbox,
+ region_direction,
+ )
+ ref_match_projection_iou = calculate_projection_overlap_ratio(
+ ref_block.bbox,
+ second_ref_block.bbox,
+ region_direction,
+ )
+ ref_match_projection_iou_ = calculate_projection_overlap_ratio(
+ ref_block.bbox,
+ second_ref_block.bbox,
+ region_secondary_direction,
+ )
+ if (
+ second_match_projection_iou > 0
+ and ref_match_projection_iou == 0
+ and ref_match_projection_iou_ > 0
+ ):
+ if block.order_label == "vision" or (
+ ref_block.order_label == "normal_text"
+ and second_ref_block.order_label == "normal_text"
+ ):
+ block.order_label = (
+ "cross_reference"
+ if block.label == "reference"
+ else "cross_layout"
+ )
def sort_by_xycut(
block_bboxes: List,
- direction: int = 0,
+ direction: str = "vertical",
min_gap: int = 1,
) -> List[int]:
"""
@@ -323,7 +313,7 @@ def sort_by_xycut(
block_bboxes (Union[np.ndarray, List[List[int]]]): An array or list of bounding boxes,
where each box is represented as
[x_min, y_min, x_max, y_max].
- direction (int): Direction for the initial cut. Use 1 for Y-axis first and 0 for X-axis first.
+ direction (int): direction for the initial cut. Use 1 for Y-axis first and 0 for X-axis first.
Defaults to 0.
min_gap (int): Minimum gap width to consider a separation between segments. Defaults to 1.
@@ -332,7 +322,7 @@ def sort_by_xycut(
"""
block_bboxes = np.asarray(block_bboxes).astype(int)
res = []
- if direction == 1:
+ if direction == "vertical":
recursive_yx_cut(
block_bboxes,
np.arange(len(block_bboxes)).tolist(),
@@ -352,8 +342,7 @@ def sort_by_xycut(
def match_unsorted_blocks(
sorted_blocks: List[LayoutParsingBlock],
unsorted_blocks: List[LayoutParsingBlock],
- config: Dict,
- median_width: int,
+ region: LayoutParsingRegion,
) -> List[LayoutParsingBlock]:
"""
Match special blocks with the sorted blocks based on their region labels.
@@ -367,7 +356,7 @@ def match_unsorted_blocks(
List[LayoutParsingBlock]: The updated sorted blocks after matching special blocks.
"""
distance_type_map = {
- "cross_text": weighted_distance_insert,
+ "cross_layout": weighted_distance_insert,
"paragraph_title": weighted_distance_insert,
"doc_title": weighted_distance_insert,
"vision_title": weighted_distance_insert,
@@ -377,27 +366,30 @@ def match_unsorted_blocks(
"other": manhattan_insert,
}
- unsorted_blocks = sort_blocks(unsorted_blocks, median_width, reverse=False)
+ unsorted_blocks = sort_normal_blocks(
+ unsorted_blocks,
+ region.text_line_height,
+ region.text_line_width,
+ region.direction,
+ )
for idx, block in enumerate(unsorted_blocks):
- region_label = block.region_label
- if idx == 0 and region_label == "doc_title":
+ order_label = block.order_label
+ if idx == 0 and order_label == "doc_title":
sorted_blocks.insert(0, block)
continue
- sorted_blocks = distance_type_map[region_label](
- block, sorted_blocks, config, median_width
- )
+ sorted_blocks = distance_type_map[order_label](block, sorted_blocks, region)
return sorted_blocks
def xycut_enhanced(
- blocks: List[LayoutParsingBlock], config: Dict
-) -> List[LayoutParsingBlock]:
+ region: LayoutParsingRegion,
+) -> LayoutParsingRegion:
"""
xycut_enhance function performs the following steps:
1. Preprocess the input blocks by extracting headers, footers, and pre-cut blocks.
2. Mask blocks that are crossing different blocks.
3. Perform xycut_enhanced algorithm on the remaining blocks.
- 4. Match special blocks with the sorted blocks based on their region labels.
+ 4. Match unsorted blocks with the sorted blocks based on their order labels.
5. Update child blocks of the sorted blocks based on their parent blocks.
6. Return the ordered result list.
@@ -407,45 +399,51 @@ def xycut_enhanced(
Returns:
List[LayoutParsingBlock]: Ordered result list after processing.
"""
- if len(blocks) == 0:
- return blocks
+ if len(region.block_map) == 0:
+ return []
- text_labels = config.get("text_labels", [])
- header_blocks, pre_cut_list, footer_blocks, unordered_blocks = pre_process(
- blocks, config
- )
+ pre_cut_list: List[List[LayoutParsingBlock]] = pre_process(region)
final_order_res_list: List[LayoutParsingBlock] = []
- header_blocks = sort_blocks(header_blocks)
- footer_blocks = sort_blocks(footer_blocks)
- unordered_blocks = sort_blocks(unordered_blocks)
+ header_blocks: List[LayoutParsingBlock] = [
+ region.block_map[idx] for idx in region.header_block_idxes
+ ]
+ unordered_blocks: List[LayoutParsingBlock] = [
+ region.block_map[idx] for idx in region.unordered_block_idxes
+ ]
+ footer_blocks: List[LayoutParsingBlock] = [
+ region.block_map[idx] for idx in region.footer_block_idxes
+ ]
+ header_blocks: List[LayoutParsingBlock] = sort_normal_blocks(
+ header_blocks, region.text_line_height, region.text_line_width, region.direction
+ )
+ footer_blocks: List[LayoutParsingBlock] = sort_normal_blocks(
+ footer_blocks, region.text_line_height, region.text_line_width, region.direction
+ )
+ unordered_blocks: List[LayoutParsingBlock] = sort_normal_blocks(
+ unordered_blocks,
+ region.text_line_height,
+ region.text_line_width,
+ region.direction,
+ )
final_order_res_list.extend(header_blocks)
unsorted_blocks: List[LayoutParsingBlock] = []
- sorted_blocks_by_pre_cuts = []
+ sorted_blocks_by_pre_cuts: List[LayoutParsingBlock] = []
for pre_cut_blocks in pre_cut_list:
sorted_blocks: List[LayoutParsingBlock] = []
doc_title_blocks: List[LayoutParsingBlock] = []
xy_cut_blocks: List[LayoutParsingBlock] = []
- pre_cut_blocks: List[LayoutParsingBlock]
- median_width = 1
- text_block_width = [
- block.width for block in pre_cut_blocks if block.label in text_labels
- ]
- if len(text_block_width) > 0:
- median_width = int(np.median(text_block_width))
get_layout_structure(
- pre_cut_blocks,
- median_width,
- config,
+ pre_cut_blocks, region.direction, region.secondary_direction
)
# Get xy cut blocks and add other blocks in special_block_map
for block in pre_cut_blocks:
- if block.region_label not in [
- "cross_text",
+ if block.order_label not in [
+ "cross_layout",
"cross_reference",
"doc_title",
"unordered",
@@ -460,53 +458,85 @@ def xycut_enhanced(
block_bboxes = np.array([block.bbox for block in xy_cut_blocks])
block_text_lines = [block.num_of_lines for block in xy_cut_blocks]
discontinuous = calculate_discontinuous_projection(
- block_bboxes, direction="horizontal"
+ block_bboxes, direction=region.direction
)
+ if len(discontinuous) > 1:
+ xy_cut_blocks = [block for block in xy_cut_blocks]
+ # if len(discontinuous) == 1 or max(block_text_lines) == 1 or (not is_projection_consistent(xy_cut_blocks, discontinuous, direction=region.direction) and len(discontinuous) > 2 and max(block_text_lines) - min(block_text_lines) < 3):
if len(discontinuous) == 1 or max(block_text_lines) == 1:
- xy_cut_blocks.sort(key=lambda x: (x.bbox[1] // 5, x.bbox[0]))
- xy_cut_blocks = shrink_overlapping_boxes(xy_cut_blocks, "vertical")
+ xy_cut_blocks.sort(
+ key=lambda x: (
+ x.bbox[region.secondary_direction_start_index]
+ // (region.text_line_height // 2),
+ x.bbox[region.direction_start_index],
+ )
+ )
+ xy_cut_blocks = shrink_overlapping_boxes(
+ xy_cut_blocks, region.secondary_direction
+ )
+ if (
+ len(discontinuous) == 1
+ or max(block_text_lines) == 1
+ or (
+ not is_projection_consistent(
+ xy_cut_blocks, discontinuous, direction=region.direction
+ )
+ and len(discontinuous) > 2
+ and max(block_text_lines) - min(block_text_lines) < 3
+ )
+ ):
+ xy_cut_blocks.sort(
+ key=lambda x: (
+ x.bbox[region.secondary_direction_start_index]
+ // (region.text_line_height // 2),
+ x.bbox[region.direction_start_index],
+ )
+ )
+ xy_cut_blocks = shrink_overlapping_boxes(
+ xy_cut_blocks, region.secondary_direction
+ )
block_bboxes = np.array([block.bbox for block in xy_cut_blocks])
- sorted_indexes = sort_by_xycut(block_bboxes, direction=1, min_gap=1)
+ sorted_indexes = sort_by_xycut(
+ block_bboxes, direction=region.secondary_direction, min_gap=1
+ )
else:
- xy_cut_blocks.sort(key=lambda x: (x.bbox[0] // 20, x.bbox[1]))
- xy_cut_blocks = shrink_overlapping_boxes(xy_cut_blocks, "horizontal")
+ xy_cut_blocks.sort(
+ key=lambda x: (
+ x.bbox[region.direction_start_index]
+ // (region.text_line_width // 2),
+ x.bbox[region.secondary_direction_start_index],
+ )
+ )
+ xy_cut_blocks = shrink_overlapping_boxes(
+ xy_cut_blocks, region.direction
+ )
block_bboxes = np.array([block.bbox for block in xy_cut_blocks])
- sorted_indexes = sort_by_xycut(block_bboxes, direction=0, min_gap=20)
+ sorted_indexes = sort_by_xycut(
+ block_bboxes, direction=region.direction, min_gap=1
+ )
sorted_blocks = [xy_cut_blocks[i] for i in sorted_indexes]
sorted_blocks = match_unsorted_blocks(
sorted_blocks,
doc_title_blocks,
- config,
- median_width,
+ region=region,
)
sorted_blocks_by_pre_cuts.extend(sorted_blocks)
- median_width = 1
- text_block_width = [block.width for block in blocks if block.label in text_labels]
- if len(text_block_width) > 0:
- median_width = int(np.median(text_block_width))
final_order_res_list = match_unsorted_blocks(
sorted_blocks_by_pre_cuts,
unsorted_blocks,
- config,
- median_width,
+ region=region,
)
final_order_res_list.extend(footer_blocks)
final_order_res_list.extend(unordered_blocks)
- index = 0
- visualize_index_labels = config.get("visualize_index_labels", [])
for block_idx, block in enumerate(final_order_res_list):
- if block.label not in visualize_index_labels:
- continue
final_order_res_list = insert_child_blocks(
block, block_idx, final_order_res_list
)
block = final_order_res_list[block_idx]
- index += 1
- block.index = index
return final_order_res_list