From ced2ebcf80d856755c2d02ee3a32decd6c0cb3c0 Mon Sep 17 00:00:00 2001 From: zhouchangda Date: Thu, 8 May 2025 13:18:03 +0000 Subject: [PATCH 1/2] update xycut_enhanced and add region detection --- .../table_structure_recognition/processors.py | 7 +- .../pipelines/layout_parsing/pipeline_v2.py | 504 +++++++++++++----- .../pipelines/layout_parsing/result_v2.py | 190 +++++-- .../pipelines/layout_parsing/setting.py | 58 +- .../pipelines/layout_parsing/utils.py | 442 ++++++++++----- .../layout_parsing/xycut_enhanced/utils.py | 271 ++++------ .../layout_parsing/xycut_enhanced/xycuts.py | 252 +++++---- 7 files changed, 1100 insertions(+), 624 deletions(-) diff --git a/paddlex/inference/models/table_structure_recognition/processors.py b/paddlex/inference/models/table_structure_recognition/processors.py index 508222befe..e1048ca860 100644 --- a/paddlex/inference/models/table_structure_recognition/processors.py +++ b/paddlex/inference/models/table_structure_recognition/processors.py @@ -130,12 +130,7 @@ def __call__(self, pred, img_size, ori_img_size): structure_probs, bbox_preds, img_size, ori_img_size ) structure_str_list = [ - ( - ["", "", ""] - + structure - + ["
", "", ""] - ) - for structure in structure_str_list + ([""] + structure + ["
"]) for structure in structure_str_list ] return [ {"bbox": bbox, "structure": structure, "structure_score": structure_score} diff --git a/paddlex/inference/pipelines/layout_parsing/pipeline_v2.py b/paddlex/inference/pipelines/layout_parsing/pipeline_v2.py index b5ca253050..7035dcf509 100644 --- a/paddlex/inference/pipelines/layout_parsing/pipeline_v2.py +++ b/paddlex/inference/pipelines/layout_parsing/pipeline_v2.py @@ -29,10 +29,11 @@ from ...utils.pp_option import PaddlePredictorOption from ..base import BasePipeline from ..ocr.result import OCRResult -from .result_v2 import LayoutParsingBlock, LayoutParsingResultV2 +from .result_v2 import LayoutParsingBlock, LayoutParsingRegion, LayoutParsingResultV2 from .utils import ( caculate_bbox_area, - calculate_text_orientation, + calculate_minimum_enclosing_bbox, + calculate_overlap_ratio, convert_formula_res_to_ocr_format, format_line, gather_imgs, @@ -40,11 +41,10 @@ get_sub_regions_ocr_res, group_boxes_into_lines, remove_overlap_blocks, - split_boxes_if_x_contained, - update_layout_order_config_block_index, + shrink_supplement_region_bbox, + split_boxes_by_projection, update_region_box, ) -from .xycut_enhanced import xycut_enhanced @pipeline_requires_extra("ocr") @@ -100,6 +100,10 @@ def inintial_predictor(self, config: dict) -> None: self.use_general_ocr = config.get("use_general_ocr", True) self.use_table_recognition = config.get("use_table_recognition", True) self.use_seal_recognition = config.get("use_seal_recognition", True) + self.use_region_detection = config.get( + "use_region_detection", + False, + ) self.use_formula_recognition = config.get( "use_formula_recognition", True, @@ -115,6 +119,16 @@ def inintial_predictor(self, config: dict) -> None: self.doc_preprocessor_pipeline = self.create_pipeline( doc_preprocessor_config, ) + if self.use_region_detection: + region_detection_config = config.get("SubModules", {}).get( + "RegionDetection", + { + "model_config_error": "config error for block_region_detection_model!" + }, + ) + self.region_detection_model = self.create_model( + region_detection_config, + ) layout_det_config = config.get("SubModules", {}).get( "LayoutDetection", @@ -246,7 +260,9 @@ def check_model_settings_valid(self, input_params: dict) -> bool: def standardized_data( self, image: list, - layout_order_config: dict, + parameters_config: dict, + block_label_mapping: dict, + region_det_res: DetResult, layout_det_res: DetResult, overall_ocr_res: OCRResult, formula_res_list: list, @@ -277,13 +293,16 @@ def standardized_data( """ matched_ocr_dict = {} - layout_to_ocr_mapping = {} + region_to_block_map = {} + block_to_ocr_map = {} object_boxes = [] footnote_list = [] + paragraph_title_list = [] bottom_text_y_max = 0 max_block_area = 0.0 + doc_title_num = 0 - region_box = [65535, 65535, 0, 0] + base_region_bbox = [65535, 65535, 0, 0] layout_det_res = remove_overlap_blocks( layout_det_res, threshold=0.5, @@ -301,22 +320,26 @@ def standardized_data( _, _, _, y2 = box # update the region box and max_block_area according to the layout boxes - region_box = update_region_box(box, region_box) + base_region_bbox = update_region_box(box, base_region_bbox) max_block_area = max(max_block_area, caculate_bbox_area(box)) - update_layout_order_config_block_index(layout_order_config, label, box_idx) + # update_layout_order_config_block_index(layout_order_config, label, box_idx) # set the label of footnote to text, when it is above the text boxes if label == "footnote": footnote_list.append(box_idx) + elif label == "paragraph_title": + paragraph_title_list.append(box_idx) if label == "text": bottom_text_y_max = max(y2, bottom_text_y_max) + if label == "doc_title": + doc_title_num += 1 if label not in ["formula", "table", "seal"]: _, matched_idxes = get_sub_regions_ocr_res( overall_ocr_res, [box], return_match_idx=True ) - layout_to_ocr_mapping[box_idx] = matched_idxes + block_to_ocr_map[box_idx] = matched_idxes for matched_idx in matched_idxes: if matched_ocr_dict.get(matched_idx, None) is None: matched_ocr_dict[matched_idx] = [box_idx] @@ -330,36 +353,21 @@ def standardized_data( < bottom_text_y_max ): layout_det_res["boxes"][footnote_idx]["label"] = "text" - layout_order_config["text_block_idxes"].append(footnote_idx) - layout_order_config["footer_block_idxes"].remove(footnote_idx) - # fix the doc_title label - doc_title_idxes = layout_order_config.get("doc_title_block_idxes", []) - paragraph_title_idxes = layout_order_config.get( - "paragraph_title_block_idxes", [] - ) # check if there is only one paragraph title and without doc_title - only_one_paragraph_title = ( - len(paragraph_title_idxes) == 1 and len(doc_title_idxes) == 0 - ) + only_one_paragraph_title = len(paragraph_title_list) == 1 and doc_title_num == 0 if only_one_paragraph_title: paragraph_title_block_area = caculate_bbox_area( - layout_det_res["boxes"][paragraph_title_idxes[0]]["coordinate"] + layout_det_res["boxes"][paragraph_title_list[0]]["coordinate"] ) - title_area_max_block_threshold = layout_order_config.get( - "title_area_max_block_threshold", 0.3 + title_area_max_block_threshold = parameters_config["block"].get( + "title_conversion_area_ratio_threshold", 0.3 ) if ( paragraph_title_block_area > max_block_area * title_area_max_block_threshold ): - layout_det_res["boxes"][paragraph_title_idxes[0]]["label"] = "doc_title" - layout_order_config["doc_title_block_idxes"].append( - paragraph_title_idxes[0] - ) - layout_order_config["paragraph_title_block_idxes"].remove( - paragraph_title_idxes[0] - ) + layout_det_res["boxes"][paragraph_title_list[0]]["label"] = "doc_title" # Replace the OCR information of the hurdles. for overall_ocr_idx, layout_box_ids in matched_ocr_dict.items(): @@ -374,6 +382,11 @@ def standardized_data( for box_idx in layout_box_ids: layout_box = layout_det_res["boxes"][box_idx]["coordinate"] crop_box = get_bbox_intersection(overall_ocr_box, layout_box) + for ocr_idx in block_to_ocr_map[box_idx]: + ocr_box = overall_ocr_res["rec_boxes"][ocr_idx] + iou = calculate_overlap_ratio(ocr_box, crop_box, "small") + if iou > 0.8: + overall_ocr_res["rec_texts"][ocr_idx] = "" x1, y1, x2, y2 = [int(i) for i in crop_box] crop_img = np.array(image)[y1:y2, x1:x2] crop_img_rec_res = next(text_rec_model([crop_img])) @@ -414,23 +427,150 @@ def standardized_data( overall_ocr_res["rec_scores"].append(crop_img_rec_score) overall_ocr_res["rec_texts"].append(crop_img_rec_text) overall_ocr_res["rec_labels"].append("text") - layout_to_ocr_mapping[box_idx].remove(overall_ocr_idx) - layout_to_ocr_mapping[box_idx].append( + block_to_ocr_map[box_idx].remove(overall_ocr_idx) + block_to_ocr_map[box_idx].append( len(overall_ocr_res["rec_texts"]) - 1 ) - layout_order_config["all_layout_region_box"] = region_box - layout_order_config["layout_to_ocr_mapping"] = layout_to_ocr_mapping - layout_order_config["matched_ocr_dict"] = matched_ocr_dict + # use layout bbox to do ocr recognition when there is no matched ocr + for layout_box_idx, overall_ocr_idxes in block_to_ocr_map.items(): + has_text = False + for idx in overall_ocr_idxes: + if overall_ocr_res["rec_texts"][idx] != "": + has_text = True + break + if not has_text and layout_det_res["boxes"][layout_box_idx][ + "label" + ] not in block_label_mapping.get("vision_labels", []): + crop_box = layout_det_res["boxes"][layout_box_idx]["coordinate"] + x1, y1, x2, y2 = [int(i) for i in crop_box] + crop_img = np.array(image)[y1:y2, x1:x2] + crop_img_rec_res = next(text_rec_model([crop_img])) + crop_img_dt_poly = get_bbox_intersection( + crop_box, crop_box, return_format="poly" + ) + crop_img_rec_score = crop_img_rec_res["rec_score"] + crop_img_rec_text = crop_img_rec_res["rec_text"] + text_rec_score_thresh = ( + text_rec_score_thresh + if text_rec_score_thresh is not None + else (self.general_ocr_pipeline.text_rec_score_thresh) + ) + if crop_img_rec_score >= text_rec_score_thresh: + overall_ocr_res["rec_boxes"] = np.vstack( + (overall_ocr_res["rec_boxes"], crop_box) + ) + overall_ocr_res["rec_polys"].append(crop_img_dt_poly) + overall_ocr_res["rec_scores"].append(crop_img_rec_score) + overall_ocr_res["rec_texts"].append(crop_img_rec_text) + overall_ocr_res["rec_labels"].append("text") + block_to_ocr_map[layout_box_idx].append( + len(overall_ocr_res["rec_texts"]) - 1 + ) + + # when there is no layout detection result but there is ocr result, convert ocr detection result to layout detection result + if len(layout_det_res["boxes"]) == 0 and len(overall_ocr_res["rec_boxes"]) > 0: + for idx, ocr_rec_box in enumerate(overall_ocr_res["rec_boxes"]): + base_region_bbox = update_region_box(ocr_rec_box, base_region_bbox) + layout_det_res["boxes"].append( + { + "label": "text", + "coordinate": ocr_rec_box, + "score": overall_ocr_res["rec_scores"][idx], + } + ) + block_to_ocr_map[idx] = [idx] + + block_bboxes = [box["coordinate"] for box in layout_det_res["boxes"]] + region_det_res["boxes"] = sorted( + region_det_res["boxes"], + key=lambda item: caculate_bbox_area(item["coordinate"]), + ) + if len(region_det_res["boxes"]) == 0: + region_det_res["boxes"] = [ + { + "coordinate": base_region_bbox, + "label": "SupplementaryBlock", + "score": 1, + } + ] + region_to_block_map[0] = range(len(block_bboxes)) + else: + block_idxes_set = set(range(len(block_bboxes))) + # match block to region + for region_idx, region_info in enumerate(region_det_res["boxes"]): + matched_idxes = [] + region_to_block_map[region_idx] = [] + region_bbox = region_info["coordinate"] + for block_idx in block_idxes_set: + overlap_ratio = calculate_overlap_ratio( + region_bbox, block_bboxes[block_idx], mode="small" + ) + if overlap_ratio > parameters_config["region"].get( + "match_block_overlap_ratio_threshold", 0.8 + ): + region_to_block_map[region_idx].append(block_idx) + matched_idxes.append(block_idx) + if len(matched_idxes) > 0: + for block_idx in matched_idxes: + block_idxes_set.remove(block_idx) + matched_bboxes = [block_bboxes[idx] for idx in matched_idxes] + new_region_bbox = calculate_minimum_enclosing_bbox(matched_bboxes) + region_det_res["boxes"][region_idx]["coordinate"] = new_region_bbox + # Supplement region block when there is no matched block + if len(block_idxes_set) > 0: + while len(block_idxes_set) > 0: + matched_idxes = [] + unmatched_bboxes = [block_bboxes[idx] for idx in block_idxes_set] + supplement_region_bbox = calculate_minimum_enclosing_bbox( + unmatched_bboxes + ) + # check if the new region bbox is overlapped with other region bbox, if have, then shrink the new region bbox + for region_info in region_det_res["boxes"]: + region_bbox = region_info["coordinate"] + overlap_ratio = calculate_overlap_ratio( + supplement_region_bbox, region_bbox + ) + if overlap_ratio > 0: + supplement_region_bbox, matched_idxes = ( + shrink_supplement_region_bbox( + supplement_region_bbox, + region_bbox, + image.shape[1], + image.shape[0], + block_idxes_set, + block_bboxes, + parameters_config, + ) + ) + if len(matched_idxes) == 0: + matched_idxes = list(block_idxes_set) + region_idx = len(region_det_res["boxes"]) + region_to_block_map[region_idx] = list(matched_idxes) + for block_idx in matched_idxes: + block_idxes_set.remove(block_idx) + region_det_res["boxes"].append( + { + "coordinate": supplement_region_bbox, + "label": "SupplementaryBlock", + "score": 1, + } + ) + + region_block_ocr_idx_map = dict( + region_to_block_map=region_to_block_map, + block_to_ocr_map=block_to_ocr_map, + ) - return layout_order_config, layout_det_res + return region_block_ocr_idx_map, region_det_res, layout_det_res - def sort_line_by_x_projection( + def sort_line_by_projection( self, line: List[List[Union[List[int], str]]], input_img: np.ndarray, text_rec_model: Any, text_rec_score_thresh: Union[float, None] = None, + orientation: str = "vertical", ) -> None: """ Sort a line of text spans based on their vertical position within the layout bounding box. @@ -443,24 +583,27 @@ def sort_line_by_x_projection( Returns: list: The sorted line of text spans. """ - splited_boxes = split_boxes_if_x_contained(line) + sort_index = 0 if orientation == "horizontal" else 1 + splited_boxes = split_boxes_by_projection(line, orientation) splited_lines = [] if len(line) != len(splited_boxes): - splited_boxes.sort(key=lambda span: span[0][0]) + splited_boxes.sort(key=lambda span: span[0][sort_index]) for span in splited_boxes: - if span[2] == "text": + bbox, text, label = span + if label == "text": crop_img = input_img[ - int(span[0][1]) : int(span[0][3]), - int(span[0][0]) : int(span[0][2]), + int(bbox[1]) : int(bbox[3]), + int(bbox[0]) : int(bbox[2]), ] crop_img_rec_res = next(text_rec_model([crop_img])) crop_img_rec_score = crop_img_rec_res["rec_score"] crop_img_rec_text = crop_img_rec_res["rec_text"] - span[1] = ( + text = ( crop_img_rec_text if crop_img_rec_score >= text_rec_score_thresh else "" ) + span[1] = text splited_lines.append(span) else: @@ -471,91 +614,77 @@ def sort_line_by_x_projection( def get_block_rec_content( self, image: list, - layout_order_config: dict, + line_parameters_config: dict, ocr_rec_res: dict, block: LayoutParsingBlock, text_rec_model: Any, text_rec_score_thresh: Union[float, None] = None, ) -> str: - text_delimiter_map = { - "content": "\n", - } - line_delimiter_map = { - "doc_title": " ", - "content": "\n", - } if len(ocr_rec_res["rec_texts"]) == 0: block.content = "" return block - label = block.label - if label == "reference": - rec_boxes = ocr_rec_res["boxes"] - block_left_coordinate = min([box[0] for box in rec_boxes]) - block_right_coordinate = max([box[2] for box in rec_boxes]) - first_line_span_limit = (5,) - last_line_span_limit = (20,) - else: - block_left_coordinate, _, block_right_coordinate, _ = block.bbox - first_line_span_limit = (10,) - last_line_span_limit = (10,) - - if label == "formula": - ocr_rec_res["rec_texts"] = [ - rec_res_text.replace("$", "") - for rec_res_text in ocr_rec_res["rec_texts"] - ] - lines = group_boxes_into_lines( + lines, text_orientation = group_boxes_into_lines( ocr_rec_res, - block, - layout_order_config.get("line_height_iou_threshold", 0.4), + line_parameters_config.get("line_height_iou_threshold", 0.8), ) - block.num_of_lines = len(lines) + if block.label == "reference": + rec_boxes = ocr_rec_res["boxes"] + block_right_coordinate = max([box[2] for box in rec_boxes]) + last_line_span_limit = 20 + else: + block_right_coordinate = block.bbox[2] + last_line_span_limit = 10 # format line - new_lines = [] - horizontal_text_line_num = 0 - for line in lines: - line.sort(key=lambda span: span[0][0]) + text_lines = [] + need_new_line_num = 0 + sort_index = 0 if text_orientation == "horizontal" else 1 + for idx, line in enumerate(lines): + line.sort(key=lambda span: span[0][sort_index]) # merge formula and text ocr_labels = [span[2] for span in line] if "formula" in ocr_labels: - line = self.sort_line_by_x_projection( - line, image, text_rec_model, text_rec_score_thresh + line = self.sort_line_by_projection( + line, image, text_rec_model, text_rec_score_thresh, text_orientation ) - text_orientation = calculate_text_orientation([span[0] for span in line]) - horizontal_text_line_num += 1 if text_orientation == "horizontal" else 0 - - line_text = format_line( + line_text, need_new_line = format_line( line, - block_left_coordinate, block_right_coordinate, - first_line_span_limit=first_line_span_limit, last_line_span_limit=last_line_span_limit, block_label=block.label, - delimiter_map=text_delimiter_map, ) - new_lines.append(line_text) - - delim = line_delimiter_map.get(label, "") - content = delim.join(new_lines) + if need_new_line: + need_new_line_num += 1 + if idx == 0: + line_start_coordinate = line[0][0][0] + block.seg_start_coordinate = line_start_coordinate + elif idx == len(lines) - 1: + line_end_coordinate = line[-1][0][2] + block.seg_end_coordinate = line_end_coordinate + text_lines.append(line_text) + + delim = line_parameters_config["delimiter_map"].get(block.label, "") + if need_new_line_num > len(text_lines) * 0.5 and delim == "": + delim = "\n" + content = delim.join(text_lines) block.content = content - block.direction = ( - "horizontal" - if horizontal_text_line_num > len(new_lines) * 0.5 - else "vertical" - ) + block.num_of_lines = len(text_lines) + block.orientation = text_orientation return block def get_layout_parsing_blocks( self, image: list, - layout_order_config: dict, + parameters_config: dict, + block_label_mapping: dict, + region_block_ocr_idx_map: dict, + region_det_res: DetResult, overall_ocr_res: OCRResult, layout_det_res: DetResult, table_res_list: list, @@ -614,9 +743,9 @@ def get_layout_parsing_blocks( _, ocr_idx_list = get_sub_regions_ocr_res( overall_ocr_res, [block_bbox], return_match_idx=True ) - layout_order_config["layout_to_ocr_mapping"][box_idx] = ocr_idx_list + region_block_ocr_idx_map["block_to_ocr_map"][box_idx] = ocr_idx_list else: - ocr_idx_list = layout_order_config["layout_to_ocr_mapping"].get( + ocr_idx_list = region_block_ocr_idx_map["block_to_ocr_map"].get( box_idx, [] ) for box_no in ocr_idx_list: @@ -630,7 +759,7 @@ def get_layout_parsing_blocks( block = self.get_block_rec_content( image=image, block=block, - layout_order_config=layout_order_config, + line_parameters_config=parameters_config["line"], ocr_rec_res=rec_res, text_rec_model=text_rec_model, text_rec_score_thresh=text_rec_score_thresh, @@ -644,28 +773,31 @@ def get_layout_parsing_blocks( layout_parsing_blocks.append(block) - # when there is no layout detection result but there is ocr result, use ocr result - if len(layout_det_res["boxes"]) == 0: - region_box = [65535, 65535, 0, 0] - for ocr_idx, (ocr_rec_box, ocr_rec_text) in enumerate( - zip(overall_ocr_res["rec_boxes"], overall_ocr_res["rec_texts"]) - ): - update_layout_order_config_block_index( - layout_order_config, "text", ocr_idx - ) - region_box = update_region_box(ocr_rec_box, region_box) - layout_parsing_blocks.append( - LayoutParsingBlock( - label="text", bbox=ocr_rec_box, content=ocr_rec_text - ) - ) - layout_order_config["all_layout_region_box"] = region_box + region_list: List[LayoutParsingRegion] = [] + for region_idx, region_info in enumerate(region_det_res["boxes"]): + region_bbox = region_info["coordinate"] + region_blocks = [ + layout_parsing_blocks[idx] + for idx in region_block_ocr_idx_map["region_to_block_map"][region_idx] + ] + region = LayoutParsingRegion( + region_bbox=region_bbox, + blocks=region_blocks, + block_label_mapping=block_label_mapping, + ) + region_list.append(region) - return layout_parsing_blocks, layout_order_config + region_list = sorted( + region_list, + key=lambda r: (r.euclidean_distance // 50, r.center_euclidean_distance), + ) + + return region_list def get_layout_parsing_res( self, image: list, + region_det_res: DetResult, layout_det_res: DetResult, overall_ocr_res: OCRResult, table_res_list: list, @@ -686,23 +818,30 @@ def get_layout_parsing_res( Returns: list: A list of dictionaries representing the layout parsing result. """ - from .setting import layout_order_config + from .setting import block_label_mapping, parameters_config # Standardize data - layout_order_config, layout_det_res = self.standardized_data( - image=image, - layout_order_config=copy.deepcopy(layout_order_config), - layout_det_res=layout_det_res, - overall_ocr_res=overall_ocr_res, - formula_res_list=formula_res_list, - text_rec_model=self.general_ocr_pipeline.text_rec_model, - text_rec_score_thresh=text_rec_score_thresh, + region_block_ocr_idx_map, region_det_res, layout_det_res = ( + self.standardized_data( + image=image, + parameters_config=parameters_config, + block_label_mapping=block_label_mapping, + region_det_res=region_det_res, + layout_det_res=layout_det_res, + overall_ocr_res=overall_ocr_res, + formula_res_list=formula_res_list, + text_rec_model=self.general_ocr_pipeline.text_rec_model, + text_rec_score_thresh=text_rec_score_thresh, + ) ) # Format layout parsing block - parsing_res_list, layout_order_config = self.get_layout_parsing_blocks( + region_list = self.get_layout_parsing_blocks( image=image, - layout_order_config=layout_order_config, + parameters_config=parameters_config, + block_label_mapping=block_label_mapping, + region_block_ocr_idx_map=region_block_ocr_idx_map, + region_det_res=region_det_res, overall_ocr_res=overall_ocr_res, layout_det_res=layout_det_res, table_res_list=table_res_list, @@ -711,10 +850,16 @@ def get_layout_parsing_res( text_rec_score_thresh=self.general_ocr_pipeline.text_rec_score_thresh, ) - parsing_res_list = xycut_enhanced( - parsing_res_list, - layout_order_config, - ) + parsing_res_list = [] + for region in region_list: + parsing_res_list.extend(region.sort()) + + visualize_index_labels = block_label_mapping["visualize_index_labels"] + index = 1 + for block in parsing_res_list: + if block.label in visualize_index_labels: + block.index = index + index += 1 return parsing_res_list @@ -726,6 +871,9 @@ def get_model_settings( use_seal_recognition: Union[bool, None], use_table_recognition: Union[bool, None], use_formula_recognition: Union[bool, None], + use_chart_recognition: Union[bool, None], + use_region_detection: Union[bool, None], + is_pretty_markdown: Union[bool, None], ) -> dict: """ Get the model settings based on the provided parameters or default values. @@ -762,12 +910,18 @@ def get_model_settings( if use_formula_recognition is None: use_formula_recognition = self.use_formula_recognition + if use_region_detection is None: + use_region_detection = self.use_region_detection + return dict( use_doc_preprocessor=use_doc_preprocessor, use_general_ocr=use_general_ocr, use_seal_recognition=use_seal_recognition, use_table_recognition=use_table_recognition, use_formula_recognition=use_formula_recognition, + use_chart_recognition=use_chart_recognition, + use_region_detection=use_region_detection, + is_pretty_markdown=is_pretty_markdown, ) def predict( @@ -780,6 +934,8 @@ def predict( use_seal_recognition: Union[bool, None] = None, use_table_recognition: Union[bool, None] = None, use_formula_recognition: Union[bool, None] = None, + use_chart_recognition: Union[bool, None] = None, + use_region_detection: Union[bool, None] = None, layout_threshold: Optional[Union[float, dict]] = None, layout_nms: Optional[bool] = None, layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None, @@ -799,6 +955,9 @@ def predict( use_table_cells_ocr_results: bool = False, use_e2e_wired_table_rec_model: bool = False, use_e2e_wireless_table_rec_model: bool = True, + is_pretty_markdown: Union[bool, None] = None, + use_layout_gt: bool = False, + layout_gt_dir: Union[str, None] = None, **kwargs, ) -> LayoutParsingResultV2: """ @@ -812,6 +971,7 @@ def predict( use_seal_recognition (Optional[bool]): Whether to use seal recognition. use_table_recognition (Optional[bool]): Whether to use table recognition. use_formula_recognition (Optional[bool]): Whether to use formula recognition. + use_region_detection (Optional[bool]): Whether to use region detection. layout_threshold (Optional[float]): The threshold value to filter out low-confidence predictions. Default is None. layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False. layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box. @@ -848,6 +1008,9 @@ def predict( use_seal_recognition, use_table_recognition, use_formula_recognition, + use_chart_recognition, + use_region_detection, + is_pretty_markdown, ) if not self.check_model_settings_valid(model_settings): @@ -869,17 +1032,78 @@ def predict( doc_preprocessor_image = doc_preprocessor_res["output_img"] - layout_det_res = next( - self.layout_det_model( - doc_preprocessor_image, - threshold=layout_threshold, - layout_nms=layout_nms, - layout_unclip_ratio=layout_unclip_ratio, - layout_merge_bboxes_mode=layout_merge_bboxes_mode, + use_layout_gt = use_layout_gt + if not use_layout_gt: + layout_det_res = next( + self.layout_det_model( + doc_preprocessor_image, + threshold=layout_threshold, + layout_nms=layout_nms, + layout_unclip_ratio=layout_unclip_ratio, + layout_merge_bboxes_mode=layout_merge_bboxes_mode, + ) ) - ) + else: + import json + import os + + from ...models.object_detection.result import DetResult + + label_dir = layout_gt_dir + notes_path = f"{label_dir}/notes.json" + labels = f"{label_dir}/labels" + gt_file = os.path.basename(input)[:-4] + ".txt" + gt_path = f"{labels}/{gt_file}" + with open(notes_path, "r") as f: + notes = json.load(f) + categories_map = {} + for categories in notes["categories"]: + id = int(categories["id"]) + name = categories["name"] + categories_map[id] = name + with open(gt_path, "r") as f: + lines = f.readlines() + layout_det_res_dic = { + "input_img": doc_preprocessor_image, + "page_index": None, + "boxes": [], + } + for line in lines: + line = line.strip().split(" ") + category_id = int(line[0]) + label = categories_map[category_id] + img_h, img_w = doc_preprocessor_image.shape[:2] + center_x = float(line[1]) * img_w + center_y = float(line[2]) * img_h + w = float(line[3]) * img_w + h = float(line[4]) * img_h + x0 = center_x - w / 2 + y0 = center_y - h / 2 + x1 = center_x + w / 2 + y1 = center_y + h / 2 + box = [x0, y0, x1, y1] + layout_det_res_dic["boxes"].append( + { + "cls_id": category_id, + "label": label, + "coordinate": box, + "score": 1.0, + } + ) + layout_det_res = DetResult(layout_det_res_dic) imgs_in_doc = gather_imgs(doc_preprocessor_image, layout_det_res["boxes"]) + if model_settings["use_region_detection"]: + region_det_res = next( + self.region_detection_model( + doc_preprocessor_image, + layout_nms=True, + layout_merge_bboxes_mode="small", + ), + ) + else: + region_det_res = {"boxes": []} + if model_settings["use_formula_recognition"]: formula_res_all = next( self.formula_recognition_pipeline( @@ -1002,6 +1226,7 @@ def predict( parsing_res_list = self.get_layout_parsing_res( doc_preprocessor_image, + region_det_res=region_det_res, layout_det_res=layout_det_res, overall_ocr_res=overall_ocr_res, table_res_list=table_res_list, @@ -1021,6 +1246,7 @@ def predict( "page_index": batch_data.page_indexes[0], "doc_preprocessor_res": doc_preprocessor_res, "layout_det_res": layout_det_res, + "region_det_res": region_det_res, "overall_ocr_res": overall_ocr_res, "table_res_list": table_res_list, "seal_res_list": seal_res_list, diff --git a/paddlex/inference/pipelines/layout_parsing/result_v2.py b/paddlex/inference/pipelines/layout_parsing/result_v2.py index 95cd2cfa66..c27c47bc87 100644 --- a/paddlex/inference/pipelines/layout_parsing/result_v2.py +++ b/paddlex/inference/pipelines/layout_parsing/result_v2.py @@ -14,6 +14,7 @@ from __future__ import annotations import copy +import math import re from pathlib import Path from typing import List @@ -73,6 +74,9 @@ def _to_img(self) -> dict[str, np.ndarray]: res_img_dict[key] = value res_img_dict["layout_det_res"] = self["layout_det_res"].img["res"] + if model_settings["use_region_detection"]: + res_img_dict["region_det_res"] = self["region_det_res"].img["res"] + if model_settings["use_general_ocr"] or model_settings["use_table_recognition"]: res_img_dict["overall_ocr_res"] = self["overall_ocr_res"].img["ocr_res_img"] @@ -283,22 +287,33 @@ def format_title(title): " ", ) + # def format_centered_text(): + # return ( + # f'
{block.content}
'.replace( + # "-\n", + # "", + # ).replace("\n", " ") + # + "\n" + # ) + def format_centered_text(): - return ( - f'
{block.content}
'.replace( - "-\n", - "", - ).replace("\n", " ") - + "\n" - ) + return block.content + + # def format_image(): + # img_tags = [] + # image_path = "".join(block.image.keys()) + # img_tags.append( + # '
Image
'.format( + # image_path.replace("-\n", "").replace("\n", " "), + # ), + # ) + # return "\n".join(img_tags) def format_image(): img_tags = [] image_path = "".join(block.image.keys()) img_tags.append( - '
Image
'.format( - image_path.replace("-\n", "").replace("\n", " "), - ), + "![]({})".format(image_path.replace("-\n", "").replace("\n", " ")) ) return "\n".join(img_tags) @@ -332,7 +347,7 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock): num_of_prev_lines = prev_block.num_of_lines pre_block_seg_end_coordinate = prev_block.seg_end_coordinate prev_end_space_small = ( - context_right_coordinate - pre_block_seg_end_coordinate < 10 + abs(prev_block_bbox[2] - pre_block_seg_end_coordinate) < 10 ) prev_lines_more_than_one = num_of_prev_lines > 1 @@ -347,8 +362,12 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock): prev_block_bbox[2], context_right_coordinate ) prev_end_space_small = ( - prev_block_bbox[2] - pre_block_seg_end_coordinate < 10 + abs(context_right_coordinate - pre_block_seg_end_coordinate) + < 10 ) + edge_distance = 0 + else: + edge_distance = abs(block_box[0] - prev_block_bbox[2]) current_start_space_small = ( seg_start_coordinate - context_left_coordinate < 10 @@ -358,6 +377,7 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock): prev_end_space_small and current_start_space_small and prev_lines_more_than_one + and edge_distance < max(prev_block.width, block.width) ): seg_start_flag = False else: @@ -371,6 +391,9 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock): handlers = { "paragraph_title": lambda: format_title(block.content), + "abstract_title": lambda: format_title(block.content), + "reference_title": lambda: format_title(block.content), + "content_title": lambda: format_title(block.content), "doc_title": lambda: f"# {block.content}".replace( "-\n", "", @@ -378,7 +401,9 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock): "table_title": lambda: format_centered_text(), "figure_title": lambda: format_centered_text(), "chart_title": lambda: format_centered_text(), - "text": lambda: block.content.replace("-\n", " ").replace("\n", " "), + "text": lambda: block.content.replace("\n\n", "\n").replace( + "\n", "\n\n" + ), "abstract": lambda: format_first_line( ["摘要", "abstract"], lambda l: f"## {l}\n", " " ), @@ -416,24 +441,7 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock): if handler: prev_block = block if label == last_label == "text" and seg_start_flag == False: - last_char_of_markdown = ( - markdown_content[-1] if markdown_content else "" - ) - first_char_of_handler = handler()[0] if handler() else "" - last_is_chinese_char = ( - re.match(r"[\u4e00-\u9fff]", last_char_of_markdown) - if last_char_of_markdown - else False - ) - first_is_chinese_char = ( - re.match(r"[\u4e00-\u9fff]", first_char_of_handler) - if first_char_of_handler - else False - ) - if not (last_is_chinese_char or first_is_chinese_char): - markdown_content += " " + handler() - else: - markdown_content += handler() + markdown_content += handler() else: markdown_content += ( "\n\n" + handler() if markdown_content else handler() @@ -467,7 +475,7 @@ class LayoutParsingBlock: def __init__(self, label, bbox, content="") -> None: self.label = label - self.region_label = "other" + self.order_label = "other" self.bbox = [int(item) for item in bbox] self.content = content self.seg_start_coordinate = float("inf") @@ -479,39 +487,39 @@ def __init__(self, label, bbox, content="") -> None: self.image = None self.index = None self.visual_index = None - self.direction = self.get_bbox_direction() + self.orientation = self.get_bbox_orientation() self.child_blocks = [] - self.update_direction_info() + self.update_orientation_info() def __str__(self) -> str: return f"{self.__dict__}" def __repr__(self) -> str: - _str = f"\n\n#################\nlabel:\t{self.label}\nregion_label:\t{self.region_label}\nbbox:\t{self.bbox}\ncontent:\t{self.content}\n#################" + _str = f"\n\n#################\nlabel:\t{self.label}\nregion_label:\t{self.order_label}\nbbox:\t{self.bbox}\ncontent:\t{self.content}\n#################" return _str def to_dict(self) -> dict: return self.__dict__ - def update_direction_info(self) -> None: - if self.region_label == "vision": - self.direction = "horizontal" - if self.direction == "horizontal": - self.secondary_direction = "vertical" + def update_orientation_info(self) -> None: + if self.order_label == "vision": + self.orientation = "horizontal" + if self.orientation == "horizontal": + self.secondary_orientation = "vertical" self.short_side_length = self.height self.long_side_length = self.width self.start_coordinate = self.bbox[0] self.end_coordinate = self.bbox[2] - self.secondary_direction_start_coordinate = self.bbox[1] - self.secondary_direction_end_coordinate = self.bbox[3] + self.secondary_orientation_start_coordinate = self.bbox[1] + self.secondary_orientation_end_coordinate = self.bbox[3] else: - self.secondary_direction = "horizontal" + self.secondary_orientation = "horizontal" self.short_side_length = self.width self.long_side_length = self.height self.start_coordinate = self.bbox[1] self.end_coordinate = self.bbox[3] - self.secondary_direction_start_coordinate = self.bbox[0] - self.secondary_direction_end_coordinate = self.bbox[2] + self.secondary_orientation_start_coordinate = self.bbox[0] + self.secondary_orientation_end_coordinate = self.bbox[2] def append_child_block(self, child_block: LayoutParsingBlock) -> None: if not self.child_blocks: @@ -525,7 +533,7 @@ def append_child_block(self, child_block: LayoutParsingBlock) -> None: max(y2, y2_child), ) self.bbox = union_bbox - self.update_direction_info() + self.update_orientation_info() child_blocks = [child_block] if child_block.child_blocks: child_blocks.extend(child_block.get_child_blocks()) @@ -542,7 +550,7 @@ def get_centroid(self) -> tuple: centroid = ((x1 + x2) / 2, (y1 + y2) / 2) return centroid - def get_bbox_direction(self, orientation_ratio: float = 1.0) -> bool: + def get_bbox_orientation(self, orientation_ratio: float = 1.0) -> bool: """ Determine if a bounding box is horizontal or vertical. @@ -558,3 +566,91 @@ def get_bbox_direction(self, orientation_ratio: float = 1.0) -> bool: if self.width * orientation_ratio >= self.height else "vertical" ) + + +class LayoutParsingRegion: + + def __init__( + self, region_bbox, blocks: List[LayoutParsingBlock] = [], block_label_mapping={} + ) -> None: + self.region_bbox = region_bbox + self.blocks = blocks + self.block_map = {} + self.update_config(block_label_mapping) + self.orientation = None + self.calculate_bbox_metrics() + + def update_config(self, block_label_mapping): + self.block_map = {} + self.config = copy.deepcopy(block_label_mapping) + self.config["region_bbox"] = self.region_bbox + horizontal_text_block_num = 0 + for idx, block in enumerate(self.blocks): + label = block.label + if ( + block.order_label not in ["vision", "vision_title"] + and block.orientation == "horizontal" + ): + horizontal_text_block_num += 1 + self.block_map[idx] = block + self.update_layout_order_config_block_index(label, idx) + text_block_num = ( + len(self.blocks) + - len(self.config.get("vision_block_idxes", [])) + - len(self.config.get("vision_title_block_idxes", [])) + ) + self.orientation = ( + "horizontal" + if horizontal_text_block_num >= text_block_num * 0.5 + else "vertical" + ) + self.config["region_orientation"] = self.orientation + + def calculate_bbox_metrics(self): + x1, y1, x2, y2 = self.region_bbox + x_center, y_center = (x1 + x2) / 2, (y1 + y2) / 2 + self.euclidean_distance = math.sqrt(((x1) ** 2 + (y1) ** 2)) + self.center_euclidean_distance = math.sqrt(((x_center) ** 2 + (y_center) ** 2)) + self.angle_rad = math.atan2(y_center, x_center) + + def sort(self): + from .xycut_enhanced import xycut_enhanced + + return xycut_enhanced(self.blocks, self.config) + + def update_layout_order_config_block_index( + self, block_label: str, block_idx: int + ) -> None: + doc_title_labels = self.config["doc_title_labels"] + paragraph_title_labels = self.config["paragraph_title_labels"] + vision_labels = self.config["vision_labels"] + vision_title_labels = self.config["vision_title_labels"] + header_labels = self.config["header_labels"] + unordered_labels = self.config["unordered_labels"] + footer_labels = self.config["footer_labels"] + text_labels = self.config["text_labels"] + self.config.setdefault("doc_title_block_idxes", []) + self.config.setdefault("paragraph_title_block_idxes", []) + self.config.setdefault("vision_block_idxes", []) + self.config.setdefault("vision_title_block_idxes", []) + self.config.setdefault("unordered_block_idxes", []) + self.config.setdefault("text_block_idxes", []) + self.config.setdefault("header_block_idxes", []) + self.config.setdefault("footer_block_idxes", []) + + if block_label in doc_title_labels: + self.config["doc_title_block_idxes"].append(block_idx) + if block_label in paragraph_title_labels: + self.config["paragraph_title_block_idxes"].append(block_idx) + if block_label in vision_labels: + self.config["vision_block_idxes"].append(block_idx) + if block_label in vision_title_labels: + self.config["vision_title_block_idxes"].append(block_idx) + if block_label in unordered_labels: + self.config["unordered_block_idxes"].append(block_idx) + if block_label in text_labels: + self.config["text_block_idxes"].append(block_idx) + if block_label in header_labels: + self.config["header_block_idxes"].append(block_idx) + if block_label in footer_labels: + self.config["footer_block_idxes"].append(block_idx) diff --git a/paddlex/inference/pipelines/layout_parsing/setting.py b/paddlex/inference/pipelines/layout_parsing/setting.py index 97ba6ec1c2..7affbf3096 100644 --- a/paddlex/inference/pipelines/layout_parsing/setting.py +++ b/paddlex/inference/pipelines/layout_parsing/setting.py @@ -12,18 +12,44 @@ # See the License for the specific language governing permissions and # limitations under the License. -layout_order_config = { - # 人工配置项 - "line_height_iou_threshold": 0.4, # For line segmentation of OCR results - "title_area_max_block_threshold": 0.3, # update paragraph_title -> doc_title - "block_label_match_iou_threshold": 0.1, - "block_title_match_iou_threshold": 0.1, +parameters_config = { + "page": {}, + "region": { + "match_block_overlap_ratio_threshold": 0.8, + "split_block_overlap_ratio_threshold": 0.4, + }, + "block": { + "title_conversion_area_ratio_threshold": 0.3, # update paragraph_title -> doc_title + }, + "line": { + "line_height_iou_threshold": 0.6, # For line segmentation of OCR results + "delimiter_map": { + "doc_title": " ", + "content": "\n", + }, + }, + "word": { + "delimiter": " ", + }, + "order": { + "block_label_match_iou_threshold": 0.1, + "block_title_match_iou_threshold": 0.1, + }, +} + +block_label_mapping = { "doc_title_labels": ["doc_title"], # 文档标题 - "paragraph_title_labels": ["paragraph_title"], # 段落标题 + "paragraph_title_labels": [ + "paragraph_title", + "abstract_title", + "reference_title", + "content_title", + ], # 段落标题 "vision_labels": [ "image", "table", "chart", + "flowchart", "figure", ], # 图、表、印章、图表、图 "vision_title_labels": ["table_title", "chart_title", "figure_title"], # 图表标题 @@ -52,19 +78,9 @@ "table", "chart", "figure", + "abstract_title", + "refer_title", + "content_title", + "flowchart", ], - # 自动补全配置项 - "layout_to_ocr_mapping": {}, - "all_layout_region_box": [], # 区域box - "doc_title_block_idxes": [], - "paragraph_title_block_idxes": [], - "text_title_labels": [], # doc_title_labels+paragraph_title_labels - "text_title_block_idxes": [], - "vision_block_idxes": [], - "vision_title_block_idxes": [], - "vision_footnote_block_idxes": [], - "text_block_idxes": [], - "header_block_idxes": [], - "footer_block_idxes": [], - "unordered_block_idxes": [], } diff --git a/paddlex/inference/pipelines/layout_parsing/utils.py b/paddlex/inference/pipelines/layout_parsing/utils.py index 5f90e9d3aa..961c2402bf 100644 --- a/paddlex/inference/pipelines/layout_parsing/utils.py +++ b/paddlex/inference/pipelines/layout_parsing/utils.py @@ -16,7 +16,6 @@ "get_sub_regions_ocr_res", "get_show_color", "sorted_layout_boxes", - "update_layout_order_config_block_index", ] import re @@ -28,7 +27,6 @@ from ..components import convert_points_to_boxes from ..ocr.result import OCRResult -from .xycut_enhanced import calculate_projection_iou def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> List: @@ -172,64 +170,129 @@ def sorted_layout_boxes(res, w): return new_res -def _calculate_overlap_area_div_minbox_area_ratio( - bbox1: Union[list, tuple], - bbox2: Union[list, tuple], +def calculate_projection_overlap_ratio( + bbox1: List[float], + bbox2: List[float], + orientation: str = "horizontal", + mode="union", ) -> float: """ - Calculate the ratio of the overlap area between bbox1 and bbox2 - to the area of the smaller bounding box. + Calculate the IoU of lines between two bounding boxes. Args: - bbox1 (list or tuple): Coordinates of the first bounding box [x_min, y_min, x_max, y_max]. - bbox2 (list or tuple): Coordinates of the second bounding box [x_min, y_min, x_max, y_max]. + bbox1 (List[float]): First bounding box [x_min, y_min, x_max, y_max]. + bbox2 (List[float]): Second bounding box [x_min, y_min, x_max, y_max]. + orientation (str): orientation of the projection, "horizontal" or "vertical". Returns: - float: The ratio of the overlap area to the area of the smaller bounding box. + float: Line overlap ratio. Returns 0 if there is no overlap. """ - bbox1 = list(map(int, bbox1)) - bbox2 = list(map(int, bbox2)) + start_index, end_index = 1, 3 + if orientation == "horizontal": + start_index, end_index = 0, 2 + + intersection_start = max(bbox1[start_index], bbox2[start_index]) + intersection_end = min(bbox1[end_index], bbox2[end_index]) + overlap = intersection_end - intersection_start + if overlap <= 0: + return 0 + + if mode == "union": + ref_width = max(bbox1[end_index], bbox2[end_index]) - min( + bbox1[start_index], bbox2[start_index] + ) + elif mode == "small": + ref_width = min( + bbox1[end_index] - bbox1[start_index], bbox2[end_index] - bbox2[start_index] + ) + elif mode == "large": + ref_width = max( + bbox1[end_index] - bbox1[start_index], bbox2[end_index] - bbox2[start_index] + ) + else: + raise ValueError( + f"Invalid mode {mode}, must be one of ['union', 'small', 'large']." + ) - x_left = max(bbox1[0], bbox2[0]) - y_top = max(bbox1[1], bbox2[1]) - x_right = min(bbox1[2], bbox2[2]) - y_bottom = min(bbox1[3], bbox2[3]) + return overlap / ref_width if ref_width > 0 else 0.0 - if x_right <= x_left or y_bottom <= y_top: - return 0.0 - intersection_area = (x_right - x_left) * (y_bottom - y_top) - area_bbox1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) - area_bbox2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) - min_box_area = min(area_bbox1, area_bbox2) +def calculate_overlap_ratio( + bbox1: Union[list, tuple], bbox2: Union[list, tuple], mode="union" +) -> float: + """ + Calculate the overlap ratio between two bounding boxes. + + Args: + bbox1 (list or tuple): The first bounding box, format [x_min, y_min, x_max, y_max] + bbox2 (list or tuple): The second bounding box, format [x_min, y_min, x_max, y_max] + mode (str): The mode of calculation, either 'union', 'small', or 'large'. + + Returns: + float: The overlap ratio value between the two bounding boxes + """ + x_min_inter = max(bbox1[0], bbox2[0]) + y_min_inter = max(bbox1[1], bbox2[1]) + x_max_inter = min(bbox1[2], bbox2[2]) + y_max_inter = min(bbox1[3], bbox2[3]) + + inter_width = max(0, x_max_inter - x_min_inter) + inter_height = max(0, y_max_inter - y_min_inter) - if min_box_area <= 0: + inter_area = inter_width * inter_height + + bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) + bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) + + if mode == "union": + ref_area = bbox1_area + bbox2_area - inter_area + elif mode == "small": + ref_area = min(bbox1_area, bbox2_area) + elif mode == "large": + ref_area = max(bbox1_area, bbox2_area) + else: + raise ValueError( + f"Invalid mode {mode}, must be one of ['union', 'small', 'large']." + ) + + if ref_area == 0: return 0.0 - return intersection_area / min_box_area + return inter_area / ref_area -def group_boxes_into_lines(ocr_rec_res, block_info, line_height_iou_threshold): +def group_boxes_into_lines(ocr_rec_res, line_height_iou_threshold): rec_boxes = ocr_rec_res["boxes"] rec_texts = ocr_rec_res["rec_texts"] rec_labels = ocr_rec_res["rec_labels"] - spans = list(zip(rec_boxes, rec_texts, rec_labels)) + text_boxes = [ + rec_boxes[i] for i in range(len(rec_boxes)) if rec_labels[i] == "text" + ] + text_orientation = calculate_text_orientation(text_boxes) + + match_orientation = "vertical" if text_orientation == "horizontal" else "horizontal" - spans.sort(key=lambda span: span[0][1]) + spans = list(zip(rec_boxes, rec_texts, rec_labels)) + sort_index = 1 + reverse = False + if text_orientation == "vertical": + sort_index = 0 + reverse = True + spans.sort(key=lambda span: span[0][sort_index], reverse=reverse) spans = [list(span) for span in spans] lines = [] line = [spans[0]] line_region_box = spans[0][0][:] - block_info.seg_start_coordinate = spans[0][0][0] - block_info.seg_end_coordinate = spans[-1][0][2] # merge line for span in spans[1:]: rec_bbox = span[0] if ( - calculate_projection_iou(line_region_box, rec_bbox, "vertical") + calculate_projection_overlap_ratio( + line_region_box, rec_bbox, match_orientation, mode="small" + ) >= line_height_iou_threshold ): line.append(span) @@ -241,7 +304,33 @@ def group_boxes_into_lines(ocr_rec_res, block_info, line_height_iou_threshold): line_region_box = rec_bbox[:] lines.append(line) - return lines + return lines, text_orientation + + +def calculate_minimum_enclosing_bbox(bboxes): + """ + Calculate the minimum enclosing bounding box for a list of bounding boxes. + + Args: + bboxes (list): A list of bounding boxes represented as lists of four integers [x1, y1, x2, y2]. + + Returns: + list: The minimum enclosing bounding box represented as a list of four integers [x1, y1, x2, y2]. + """ + if not bboxes: + raise ValueError("The list of bounding boxes is empty.") + + # Convert the list of bounding boxes to a NumPy array + bboxes_array = np.array(bboxes) + + # Compute the minimum and maximum values along the respective axes + min_x = np.min(bboxes_array[:, 0]) + min_y = np.min(bboxes_array[:, 1]) + max_x = np.max(bboxes_array[:, 2]) + max_y = np.max(bboxes_array[:, 3]) + + # Return the minimum enclosing bounding box + return [min_x, min_y, max_x, max_y] def calculate_text_orientation( @@ -258,24 +347,30 @@ def calculate_text_orientation( str: "horizontal" or "vertical". """ - bboxes = np.array(bboxes) - x_min = np.min(bboxes[:, 0]) - x_max = np.max(bboxes[:, 2]) - width = x_max - x_min - y_min = np.min(bboxes[:, 1]) - y_max = np.max(bboxes[:, 3]) - height = y_max - y_min - return "horizontal" if width * orientation_ratio >= height else "vertical" + horizontal_box_num = 0 + for bbox in bboxes: + if len(bbox) != 4: + raise ValueError( + "Invalid bounding box format. Expected a list of length 4." + ) + x1, y1, x2, y2 = bbox + width = x2 - x1 + height = y2 - y1 + horizontal_box_num += 1 if width * orientation_ratio >= height else 0 + + return "horizontal" if horizontal_box_num >= len(bboxes) * 0.5 else "vertical" + + +def is_english_letter(char): + return bool(re.match(r"^[A-Za-z]$", char)) def format_line( line: List[List[Union[List[int], str]]], - block_left_coordinate: int, block_right_coordinate: int, - first_line_span_limit: int = 10, last_line_span_limit: int = 10, block_label: str = "text", - delimiter_map: Dict = {}, + # delimiter_map: Dict = {}, ) -> None: """ Format a line of text spans based on layout constraints. @@ -290,92 +385,108 @@ def format_line( Returns: None: The function modifies the line in place. """ - first_span = line[0] - last_span = line[-1] - - if first_span[0][0] - block_left_coordinate > first_line_span_limit: - first_span[1] = "\n" + first_span[1] - if block_right_coordinate - last_span[0][2] > last_line_span_limit: - last_span[1] = last_span[1] + "\n" + last_span_box = line[-1][0] - line[0] = first_span - line[-1] = last_span + for span in line: + if span[2] == "formula" and block_label != "formula": + if len(line) > 1: + span[1] = f"${span[1]}$" + else: + span[1] = f"\n${span[1]}$" - delim = delimiter_map.get(block_label, " ") - line_text = delim.join([span[1] for span in line]) + line_text = " ".join([span[1] for span in line]) - if block_label != "reference": - line_text = remove_extra_space(line_text) + need_new_line = False + if ( + block_right_coordinate - last_span_box[2] > last_line_span_limit + and not line_text.endswith("-") + and len(line_text) > 0 + and not is_english_letter(line_text[-1]) + ): + need_new_line = True if line_text.endswith("-"): line_text = line_text[:-1] - return line_text + elif ( + len(line_text) > 0 and is_english_letter(line_text[-1]) + ) or line_text.endswith("$"): + line_text += " " + + return line_text, need_new_line -def split_boxes_if_x_contained(boxes, offset=1e-5): +def split_boxes_by_projection(spans: List[List[int]], orientation, offset=1e-5): """ - Check if there is any complete containment in the x-direction + Check if there is any complete containment in the x-orientation between the bounding boxes and split the containing box accordingly. Args: - boxes (list of lists): Each element is a list containing an ndarray of length 4, a description, and a label. + spans (list of lists): Each element is a list containing an ndarray of length 4, a text string, and a label. + orientation: 'horizontal' or 'vertical', indicating whether the spans are arranged horizontally or vertically. offset (float): A small offset value to ensure that the split boxes are not too close to the original boxes. Returns: A new list of boxes, including split boxes, with the same `rec_text` and `label` attributes. """ - def is_x_contained(box_a, box_b): - """Check if box_a completely contains box_b in the x-direction.""" - return box_a[0][0] <= box_b[0][0] and box_a[0][2] >= box_b[0][2] + def is_projection_contained(box_a, box_b, start_idx, end_idx): + """Check if box_a completely contains box_b in the x-orientation.""" + return box_a[start_idx] <= box_b[start_idx] and box_a[end_idx] >= box_b[end_idx] new_boxes = [] + if orientation == "horizontal": + projection_start_index, projection_end_index = 0, 2 + else: + projection_start_index, projection_end_index = 1, 3 - for i in range(len(boxes)): - box_a = boxes[i] + for i in range(len(spans)): + span = spans[i] + box_a, text, label = span is_split = False - for j in range(len(boxes)): + for j in range(len(spans)): if i == j: continue - box_b = boxes[j] - if is_x_contained(box_a, box_b): + box_b = spans[j][0] + if is_projection_contained( + box_a, box_b, projection_start_index, projection_end_index + ): is_split = True # Split box_a based on the x-coordinates of box_b - if box_a[0][0] < box_b[0][0]: - w = box_b[0][0] - offset - box_a[0][0] + if box_a[projection_start_index] < box_b[projection_start_index]: + w = ( + box_b[projection_start_index] + - offset + - box_a[projection_start_index] + ) if w > 1: + box_a[projection_end_index] = ( + box_b[projection_start_index] - offset + ) new_boxes.append( [ - np.array( - [ - box_a[0][0], - box_a[0][1], - box_b[0][0] - offset, - box_a[0][3], - ] - ), - box_a[1], - box_a[2], + np.array(box_a), + text, + label, ] ) - if box_a[0][2] > box_b[0][2]: - w = box_a[0][2] - box_b[0][2] + offset + if box_a[projection_end_index] > box_b[projection_end_index]: + w = ( + box_a[projection_end_index] + - box_b[projection_end_index] + + offset + ) if w > 1: - box_a = [ - np.array( - [ - box_b[0][2] + offset, - box_a[0][1], - box_a[0][2], - box_a[0][3], - ] - ), - box_a[1], - box_a[2], + box_a[projection_start_index] = ( + box_b[projection_end_index] + offset + ) + span = [ + np.array(box_a), + text, + label, ] - if j == len(boxes) - 1 and is_split: - new_boxes.append(box_a) + if j == len(spans) - 1 and is_split: + new_boxes.append(span) if not is_split: - new_boxes.append(box_a) + new_boxes.append(span) return new_boxes @@ -454,7 +565,7 @@ def _get_minbox_if_overlap_by_ratio( area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) # Calculate the overlap ratio using a helper function - overlap_ratio = _calculate_overlap_area_div_minbox_area_ratio(bbox1, bbox2) + overlap_ratio = calculate_overlap_ratio(bbox1, bbox2, mode="small") # Check if the overlap ratio exceeds the threshold if overlap_ratio > ratio: if (area1 <= area2 and smaller) or (area1 >= area2 and not smaller): @@ -496,8 +607,17 @@ def remove_overlap_blocks( smaller=smaller, ) if overlap_box_index is not None: - # Determine which block to remove based on overlap_box_index - if overlap_box_index == 1: + if block1["label"] == "image" and block2["label"] == "image": + # Determine which block to remove based on overlap_box_index + if overlap_box_index == 1: + drop_index = i + else: + drop_index = j + elif block1["label"] == "image" and block2["label"] != "image": + drop_index = i + elif block1["label"] != "image" and block2["label"] == "image": + drop_index = j + elif overlap_box_index == 1: drop_index = i else: drop_index = j @@ -556,39 +676,99 @@ def get_bbox_intersection(bbox1, bbox2, return_format="bbox"): raise ValueError("return_format must be either 'bbox' or 'poly'.") -def update_layout_order_config_block_index( - config: dict, block_label: str, block_idx: int -) -> None: +def shrink_supplement_region_bbox( + supplement_region_bbox, + ref_region_bbox, + image_width, + image_height, + block_idxes_set, + block_bboxes, + parameters_config, +) -> List: + """ + Shrink the supplement region bbox according to the reference region bbox and match the block bboxes. - doc_title_labels = config["doc_title_labels"] - paragraph_title_labels = config["paragraph_title_labels"] - vision_labels = config["vision_labels"] - vision_title_labels = config["vision_title_labels"] - header_labels = config["header_labels"] - unordered_labels = config["unordered_labels"] - footer_labels = config["footer_labels"] - text_labels = config["text_labels"] - text_title_labels = doc_title_labels + paragraph_title_labels - config["text_title_labels"] = text_title_labels - - if block_label in doc_title_labels: - config["doc_title_block_idxes"].append(block_idx) - if block_label in paragraph_title_labels: - config["paragraph_title_block_idxes"].append(block_idx) - if block_label in vision_labels: - config["vision_block_idxes"].append(block_idx) - if block_label in vision_title_labels: - config["vision_title_block_idxes"].append(block_idx) - if block_label in unordered_labels: - config["unordered_block_idxes"].append(block_idx) - if block_label in text_title_labels: - config["text_title_block_idxes"].append(block_idx) - if block_label in text_labels: - config["text_block_idxes"].append(block_idx) - if block_label in header_labels: - config["header_block_idxes"].append(block_idx) - if block_label in footer_labels: - config["footer_block_idxes"].append(block_idx) + Args: + supplement_region_bbox (list): The supplement region bbox. + ref_region_bbox (list): The reference region bbox. + image_width (int): The width of the image. + image_height (int): The height of the image. + block_idxes_set (set): The indexes of the blocks that intersect with the region bbox. + block_bboxes (dict): The dictionary of block bboxes. + parameters_config (dict): The configuration parameters. + + Returns: + list: The new region bbox and the matched block idxes. + """ + x1, y1, x2, y2 = supplement_region_bbox + x1_prime, y1_prime, x2_prime, y2_prime = ref_region_bbox + index_conversion_map = {0: 2, 1: 3, 2: 0, 3: 1} + edge_distance_list = [ + (x1_prime - x1) / image_width, + (y1_prime - y1) / image_height, + (x2 - x2_prime) / image_width, + (y2 - y2_prime) / image_height, + ] + edge_distance_list_tmp = edge_distance_list[:] + min_distance = min(edge_distance_list) + src_index = index_conversion_map[edge_distance_list.index(min_distance)] + if len(block_idxes_set) == 0: + return supplement_region_bbox, [] + for _ in range(3): + dst_index = index_conversion_map[src_index] + tmp_region_bbox = supplement_region_bbox[:] + tmp_region_bbox[dst_index] = ref_region_bbox[src_index] + iner_block_idxes, split_block_idxes = [], [] + for block_idx in block_idxes_set: + overlap_ratio = calculate_overlap_ratio( + tmp_region_bbox, block_bboxes[block_idx], mode="small" + ) + if overlap_ratio > parameters_config["region"].get( + "match_block_overlap_ratio_threshold", 0.8 + ): + iner_block_idxes.append(block_idx) + elif overlap_ratio > parameters_config["region"].get( + "split_block_overlap_ratio_threshold", 0.4 + ): + split_block_idxes.append(block_idx) + + if len(iner_block_idxes) > 0: + if len(split_block_idxes) > 0: + for split_block_idx in split_block_idxes: + split_block_bbox = block_bboxes[split_block_idx] + x1, y1, x2, y2 = tmp_region_bbox + x1_prime, y1_prime, x2_prime, y2_prime = split_block_bbox + edge_distance_list = [ + (x1_prime - x1) / image_width, + (y1_prime - y1) / image_height, + (x2 - x2_prime) / image_width, + (y2 - y2_prime) / image_height, + ] + max_distance = max(edge_distance_list) + src_index = edge_distance_list.index(max_distance) + dst_index = index_conversion_map[src_index] + tmp_region_bbox[dst_index] = split_block_bbox[src_index] + tmp_region_bbox, iner_idxes = shrink_supplement_region_bbox( + tmp_region_bbox, + ref_region_bbox, + image_width, + image_height, + iner_block_idxes, + block_bboxes, + parameters_config, + ) + if len(iner_idxes) == 0: + continue + matched_bboxes = [block_bboxes[idx] for idx in iner_block_idxes] + supplement_region_bbox = calculate_minimum_enclosing_bbox(matched_bboxes) + break + else: + edge_distance_list_tmp = [ + x for x in edge_distance_list_tmp if x != min_distance + ] + min_distance = min(edge_distance_list_tmp) + src_index = index_conversion_map[edge_distance_list.index(min_distance)] + return supplement_region_bbox, iner_block_idxes def update_region_box(bbox, region_box): @@ -618,7 +798,7 @@ def convert_formula_res_to_ocr_format(formula_res_list: List, ocr_res: dict): (x_min, y_max), ] ocr_res["dt_polys"].append(poly_points) - ocr_res["rec_texts"].append(f"${formula_res['rec_formula']}$") + ocr_res["rec_texts"].append(f"{formula_res['rec_formula']}") ocr_res["rec_boxes"] = np.vstack( (ocr_res["rec_boxes"], [formula_res["dt_polys"]]) ) diff --git a/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py b/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py index 0f10610809..efec88b062 100644 --- a/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py +++ b/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py @@ -12,77 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple import numpy as np from ..result_v2 import LayoutParsingBlock - - -def calculate_projection_iou( - bbox1: List[float], bbox2: List[float], direction: str = "horizontal" -) -> float: - """ - Calculate the IoU of lines between two bounding boxes. - - Args: - bbox1 (List[float]): First bounding box [x_min, y_min, x_max, y_max]. - bbox2 (List[float]): Second bounding box [x_min, y_min, x_max, y_max]. - direction (str): direction of the projection, "horizontal" or "vertical". - - Returns: - float: Line IoU. Returns 0 if there is no overlap. - """ - start_index, end_index = 1, 3 - if direction == "horizontal": - start_index, end_index = 0, 2 - - intersection_start = max(bbox1[start_index], bbox2[start_index]) - intersection_end = min(bbox1[end_index], bbox2[end_index]) - overlap = intersection_end - intersection_start - if overlap <= 0: - return 0 - union_width = max(bbox1[end_index], bbox2[end_index]) - min( - bbox1[start_index], bbox2[start_index] - ) - - return overlap / union_width if union_width > 0 else 0.0 - - -def calculate_iou( - bbox1: Union[list, tuple], - bbox2: Union[list, tuple], -) -> float: - """ - Calculate the Intersection over Union (IoU) of two bounding boxes. - - Parameters: - bbox1 (list or tuple): The first bounding box, format [x_min, y_min, x_max, y_max] - bbox2 (list or tuple): The second bounding box, format [x_min, y_min, x_max, y_max] - - Returns: - float: The IoU value between the two bounding boxes - """ - - x_min_inter = max(bbox1[0], bbox2[0]) - y_min_inter = max(bbox1[1], bbox2[1]) - x_max_inter = min(bbox1[2], bbox2[2]) - y_max_inter = min(bbox1[3], bbox2[3]) - - inter_width = max(0, x_max_inter - x_min_inter) - inter_height = max(0, y_max_inter - y_min_inter) - - inter_area = inter_width * inter_height - - bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) - bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) - - union_area = bbox1_area + bbox2_area - inter_area - - if union_area == 0: - return 0.0 - - return inter_area / union_area +from ..utils import calculate_projection_overlap_ratio def get_nearest_edge_distance( @@ -91,12 +26,12 @@ def get_nearest_edge_distance( weight: List[float] = [1.0, 1.0, 1.0, 1.0], ) -> Tuple[float]: """ - Calculate the nearest edge distance between two bounding boxes, considering directional weights. + Calculate the nearest edge distance between two bounding boxes, considering orientational weights. Args: bbox1 (list): The bounding box coordinates [x1, y1, x2, y2] of the input object. bbox2 (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against. - weight (list, optional): Directional weights for the edge distances [left, right, up, down]. Defaults to [1, 1, 1, 1]. + weight (list, optional): orientational weights for the edge distances [left, right, up, down]. Defaults to [1, 1, 1, 1]. Returns: float: The calculated minimum edge distance between the bounding boxes. @@ -104,8 +39,8 @@ def get_nearest_edge_distance( x1, y1, x2, y2 = bbox1 x1_prime, y1_prime, x2_prime, y2_prime = bbox2 min_x_distance, min_y_distance = 0, 0 - horizontal_iou = calculate_projection_iou(bbox1, bbox2, "horizontal") - vertical_iou = calculate_projection_iou(bbox1, bbox2, "vertical") + horizontal_iou = calculate_projection_overlap_ratio(bbox1, bbox2, "horizontal") + vertical_iou = calculate_projection_overlap_ratio(bbox1, bbox2, "vertical") if horizontal_iou > 0 and vertical_iou > 0: return 0.0 if horizontal_iou == 0: @@ -412,7 +347,7 @@ def weighted_distance_insert( x1_prime, y1_prime, x2_prime, y2_prime = sorted_block.bbox # Calculate edge distance - weight = _get_weights(block.region_label, block.direction) + weight = _get_weights(block.order_label, block.orientation) edge_distance = get_nearest_edge_distance(block.bbox, sorted_block.bbox, weight) if block.label in doc_title_labels: @@ -476,7 +411,7 @@ def insert_child_blocks( if block.child_blocks: sub_blocks = block.get_child_blocks() sub_blocks.append(block) - sub_blocks = sort_child_blocks(sub_blocks, block.direction) + sub_blocks = sort_child_blocks(sub_blocks, block.orientation) sorted_blocks[block_idx] = sub_blocks[0] for block in sub_blocks[1:]: block_idx += 1 @@ -484,17 +419,17 @@ def insert_child_blocks( return sorted_blocks -def sort_child_blocks(blocks, direction="horizontal") -> List[LayoutParsingBlock]: +def sort_child_blocks(blocks, orientation="horizontal") -> List[LayoutParsingBlock]: """ Sort child blocks based on their bounding box coordinates. Args: blocks: A list of LayoutParsingBlock objects representing the child blocks. - direction: Orientation of the blocks ('horizontal' or 'vertical'). Default is 'horizontal'. + orientation: Orientation of the blocks ('horizontal' or 'vertical'). Default is 'horizontal'. Returns: sorted_blocks: A sorted list of LayoutParsingBlock objects. """ - if direction == "horizontal": + if orientation == "horizontal": # from top to bottom blocks.sort( key=lambda x: ( @@ -584,14 +519,14 @@ def sort_blocks(blocks, median_width=None, reverse=False): def get_cut_blocks( - blocks, cut_direction, cut_coordinates, overall_region_box, mask_labels=[] + blocks, cut_orientation, cut_coordinates, overall_region_box, mask_labels=[] ): """ - Cut blocks based on the given cut direction and coordinates. + Cut blocks based on the given cut orientation and coordinates. Args: blocks (list): list of blocks to be cut. - cut_direction (str): cut direction, either "horizontal" or "vertical". + cut_orientation (str): cut orientation, either "horizontal" or "vertical". cut_coordinates (list): list of cut coordinates. overall_region_box (list): the overall region box that contains all blocks. @@ -602,10 +537,9 @@ def get_cut_blocks( # filter out mask blocks,including header, footer, unordered and child_blocks # 0: horizontal, 1: vertical - cut_aixis = 0 if cut_direction == "horizontal" else 1 + cut_aixis = 0 if cut_orientation == "horizontal" else 1 blocks.sort(key=lambda x: x.bbox[cut_aixis + 2]) - overall_max_axis_coordinate = overall_region_box[cut_aixis + 2] - cut_coordinates.append(overall_max_axis_coordinate) + cut_coordinates.append(float("inf")) cut_coordinates = list(set(cut_coordinates)) cut_coordinates.sort() @@ -618,7 +552,7 @@ def get_cut_blocks( block = blocks[block_idx] if block.bbox[cut_aixis + 2] > cut_coordinate: break - elif block.region_label not in mask_labels: + elif block.order_label not in mask_labels: group_blocks.append(block) block_idx += 1 cut_idx = block_idx @@ -628,62 +562,42 @@ def get_cut_blocks( return cuted_list -def split_sub_region_blocks( - blocks: List[LayoutParsingBlock], - config: Dict, -) -> List: - """ - Split blocks into sub regions based on the all layout region bbox. - - Args: - blocks (List[LayoutParsingBlock]): A list of blocks. - config (Dict): Configuration dictionary. - Returns: - List: A list of lists of blocks, each representing a sub region. - """ - - region_bbox = config.get("all_layout_region_box", None) - x1, y1, x2, y2 = region_bbox - region_width = x2 - x1 - region_height = y2 - y1 - - if region_width < region_height: - return [(blocks, region_bbox)] - - all_boxes = np.array([block.bbox for block in blocks]) - discontinuous = calculate_discontinuous_projection(all_boxes, direction="vertical") - if len(discontinuous) > 1: - cut_coordinates = [] - region_boxes = [] - current_interval = discontinuous[0] - for x1, x2 in discontinuous[1:]: - if x1 - current_interval[1] > 100: - cut_coordinates.extend([x1, x2]) - region_boxes.append([x1, y1, x2, y2]) - current_interval = [x1, x2] - region_blocks = get_cut_blocks(blocks, "vertical", cut_coordinates, region_bbox) - - return [region_info for region_info in zip(region_blocks, region_boxes)] - else: - return [(blocks, region_bbox)] - - -def get_adjacent_blocks_by_direction( +def add_split_block( + blocks: List[LayoutParsingBlock], region_bbox: List[int] +) -> List[LayoutParsingBlock]: + block_bboxes = np.array([block.bbox for block in blocks]) + discontinuous = calculate_discontinuous_projection( + block_bboxes, orientation="vertical" + ) + current_interval = discontinuous[0] + for interval in discontinuous[1:]: + gap_len = interval[0] - current_interval[1] + if gap_len > 40: + x1, _, x2, __ = region_bbox + y1 = current_interval[1] + 5 + y2 = interval[0] - 5 + bbox = [x1, y1, x2, y2] + split_block = LayoutParsingBlock(label="split", bbox=bbox) + blocks.append(split_block) + current_interval = interval + + +def get_adjacent_blocks_by_orientation( blocks: List[LayoutParsingBlock], block_idx: int, ref_block_idxes: List[int], iou_threshold, ) -> List: """ - Get the adjacent blocks with the same direction as the current block. + Get the adjacent blocks with the same orientation as the current block. Args: block (LayoutParsingBlock): The current block. blocks (List[LayoutParsingBlock]): A list of all blocks. ref_block_idxes (List[int]): A list of indices of reference blocks. iou_threshold (float): The IOU threshold to determine if two blocks are considered adjacent. Returns: - Int: The index of the previous block with same direction. - Int: The index of the following block with same direction. + Int: The index of the previous block with same orientation. + Int: The index of the following block with same orientation. """ min_prev_block_distance = float("inf") prev_block_index = None @@ -697,21 +611,21 @@ def get_adjacent_blocks_by_direction( "vision_title", ] - # find the nearest text block with same direction to the current block + # find the nearest text block with same orientation to the current block for ref_block_idx in ref_block_idxes: ref_block = blocks[ref_block_idx] - ref_block_direction = ref_block.direction - if ref_block.region_label in child_labels: + ref_block_orientation = ref_block.orientation + if ref_block.order_label in child_labels: continue - match_block_iou = calculate_projection_iou( + match_block_iou = calculate_projection_overlap_ratio( block.bbox, ref_block.bbox, - ref_block_direction, + ref_block_orientation, ) child_match_distance_tolerance_len = block.short_side_length / 10 - if block.region_label == "vision": + if block.order_label == "vision": if ref_block.num_of_lines == 1: gap_tolerance_len = ref_block.short_side_length * 2 else: @@ -721,38 +635,38 @@ def get_adjacent_blocks_by_direction( if match_block_iou >= iou_threshold: prev_distance = ( - block.secondary_direction_start_coordinate - - ref_block.secondary_direction_end_coordinate + block.secondary_orientation_start_coordinate + - ref_block.secondary_orientation_end_coordinate + child_match_distance_tolerance_len ) // 5 + ref_block.start_coordinate / 5000 next_distance = ( - ref_block.secondary_direction_start_coordinate - - block.secondary_direction_end_coordinate + ref_block.secondary_orientation_start_coordinate + - block.secondary_orientation_end_coordinate + child_match_distance_tolerance_len ) // 5 + ref_block.start_coordinate / 5000 if ( - ref_block.secondary_direction_end_coordinate - <= block.secondary_direction_start_coordinate + ref_block.secondary_orientation_end_coordinate + <= block.secondary_orientation_start_coordinate + child_match_distance_tolerance_len and prev_distance < min_prev_block_distance ): min_prev_block_distance = prev_distance if ( - block.secondary_direction_start_coordinate - - ref_block.secondary_direction_end_coordinate + block.secondary_orientation_start_coordinate + - ref_block.secondary_orientation_end_coordinate < gap_tolerance_len ): prev_block_index = ref_block_idx elif ( - ref_block.secondary_direction_start_coordinate - > block.secondary_direction_end_coordinate + ref_block.secondary_orientation_start_coordinate + > block.secondary_orientation_end_coordinate - child_match_distance_tolerance_len and next_distance < min_post_block_distance ): min_post_block_distance = next_distance if ( - ref_block.secondary_direction_start_coordinate - - block.secondary_direction_end_coordinate + ref_block.secondary_orientation_start_coordinate + - block.secondary_orientation_end_coordinate < gap_tolerance_len ): post_block_index = ref_block_idx @@ -781,7 +695,7 @@ def update_doc_title_child_blocks( The child blocks need to meet the following conditions: 1. They must be adjacent - 2. They must have the same direction as the parent block. + 2. They must have the same orientation as the parent block. 3. Their short side length should be less than 80% of the parent's short side length. 4. Their long side length should be less than 150% of the parent's long side length. 5. The child block must be text block. @@ -801,7 +715,7 @@ def update_doc_title_child_blocks( if idx is None: continue ref_block = blocks[idx] - with_seem_direction = ref_block.direction == block.direction + with_seem_orientation = ref_block.orientation == block.orientation short_side_length_condition = ( ref_block.short_side_length < block.short_side_length * 0.8 @@ -813,12 +727,12 @@ def update_doc_title_child_blocks( ) if ( - with_seem_direction + with_seem_orientation and short_side_length_condition and long_side_length_condition and ref_block.num_of_lines < 3 ): - ref_block.region_label = "doc_title_text" + ref_block.order_label = "doc_title_text" block.append_child_block(ref_block) config["text_block_idxes"].remove(idx) @@ -835,7 +749,7 @@ def update_paragraph_title_child_blocks( The child blocks need to meet the following conditions: 1. They must be adjacent - 2. They must have the same direction as the parent block. + 2. They must have the same orientation as the parent block. 3. The child block must be paragraph title block. Args: @@ -854,9 +768,15 @@ def update_paragraph_title_child_blocks( if idx is None: continue ref_block = blocks[idx] - with_seem_direction = ref_block.direction == block.direction - if with_seem_direction and ref_block.label in paragraph_title_labels: - ref_block.region_label = "sub_paragraph_title" + min_height = min(block.height, ref_block.height) + nearest_edge_distance = get_nearest_edge_distance(block.bbox, ref_block.bbox) + with_seem_orientation = ref_block.orientation == block.orientation + if ( + with_seem_orientation + and ref_block.label in paragraph_title_labels + and nearest_edge_distance <= min_height * 2 + ): + ref_block.order_label = "sub_paragraph_title" block.append_child_block(ref_block) config["paragraph_title_block_idxes"].remove(idx) @@ -908,14 +828,14 @@ def update_vision_child_blocks( if ref_block.label in vision_title_labels and nearest_edge_distance <= min( block.height * 0.5, ref_block.height * 2 ): - ref_block.region_label = "vision_title" + ref_block.order_label = "vision_title" block.append_child_block(ref_block) config["vision_title_block_idxes"].remove(idx) elif ( nearest_edge_distance <= 15 and ref_block.short_side_length < block.short_side_length and ref_block.long_side_length < 0.5 * block.long_side_length - and ref_block.direction == block.direction + and ref_block.orientation == block.orientation and ( abs(block_center[0] - ref_block_center[0]) < 10 or ( @@ -934,54 +854,65 @@ def update_vision_child_blocks( if child_block.label in text_labels: has_vision_footnote = True if not has_vision_footnote: - ref_block.region_label = "vision_footnote" + ref_block.order_label = "vision_footnote" block.append_child_block(ref_block) config["text_block_idxes"].remove(idx) -def calculate_discontinuous_projection(boxes, direction="horizontal") -> List: +def calculate_discontinuous_projection( + boxes, orientation="horizontal", return_num=False +) -> List: """ - Calculate the discontinuous projection of boxes along the specified direction. + Calculate the discontinuous projection of boxes along the specified orientation. Args: boxes (ndarray): Array of bounding boxes represented by [[x_min, y_min, x_max, y_max]]. - direction (str): Direction along which to perform the projection ('horizontal' or 'vertical'). + orientation (str): orientation along which to perform the projection ('horizontal' or 'vertical'). Returns: list: List of tuples representing the merged intervals. """ - if direction == "horizontal": + boxes = np.array(boxes) + if orientation == "horizontal": intervals = boxes[:, [0, 2]] - elif direction == "vertical": + elif orientation == "vertical": intervals = boxes[:, [1, 3]] else: - raise ValueError("Direction must be 'horizontal' or 'vertical'") + raise ValueError("orientation must be 'horizontal' or 'vertical'") intervals = intervals[np.argsort(intervals[:, 0])] merged_intervals = [] + num = 1 current_start, current_end = intervals[0] + num_list = [] for start, end in intervals[1:]: if start <= current_end: + num += 1 current_end = max(current_end, end) else: + num_list.append(num) merged_intervals.append((current_start, current_end)) + num = 1 current_start, current_end = start, end + num_list.append(num) merged_intervals.append((current_start, current_end)) + if return_num: + return merged_intervals, num_list return merged_intervals def shrink_overlapping_boxes( - boxes, direction="horizontal", min_threshold=0, max_threshold=0.1 + boxes, orientation="horizontal", min_threshold=0, max_threshold=0.1 ) -> List: """ - Shrink overlapping boxes along the specified direction. + Shrink overlapping boxes along the specified orientation. Args: boxes (ndarray): Array of bounding boxes represented by [[x_min, y_min, x_max, y_max]]. - direction (str): Direction along which to perform the shrinking ('horizontal' or 'vertical'). + orientation (str): orientation along which to perform the shrinking ('horizontal' or 'vertical'). min_threshold (float): Minimum threshold for shrinking. Default is 0. max_threshold (float): Maximum threshold for shrinking. Default is 0.2. @@ -992,15 +923,15 @@ def shrink_overlapping_boxes( for block in boxes[1:]: x1, y1, x2, y2 = current_block.bbox x1_prime, y1_prime, x2_prime, y2_prime = block.bbox - cut_iou = calculate_projection_iou( - current_block.bbox, block.bbox, direction=direction + cut_iou = calculate_projection_overlap_ratio( + current_block.bbox, block.bbox, orientation=orientation ) - match_iou = calculate_projection_iou( + match_iou = calculate_projection_overlap_ratio( current_block.bbox, block.bbox, - direction="horizontal" if direction == "vertical" else "vertical", + orientation="horizontal" if orientation == "vertical" else "vertical", ) - if direction == "vertical": + if orientation == "vertical": if ( (match_iou > 0 and cut_iou > min_threshold and cut_iou < max_threshold) or y2 == y1_prime diff --git a/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py b/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py index 1a5ffdd75e..087c733b8b 100644 --- a/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py +++ b/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py @@ -17,12 +17,12 @@ import numpy as np from ..result_v2 import LayoutParsingBlock +from ..utils import calculate_overlap_ratio, calculate_projection_overlap_ratio from .utils import ( calculate_discontinuous_projection, - calculate_iou, - calculate_projection_iou, - get_adjacent_blocks_by_direction, + get_adjacent_blocks_by_orientation, get_cut_blocks, + get_nearest_edge_distance, insert_child_blocks, manhattan_insert, recursive_xy_cut, @@ -55,26 +55,26 @@ def pre_process( Returns: List: A list of pre-cutted layout blocks list. """ - region_bbox = config.get("all_layout_region_box", None) + region_bbox = config.get("region_bbox", None) region_x_center = (region_bbox[0] + region_bbox[2]) / 2 region_y_center = (region_bbox[1] + region_bbox[3]) / 2 header_block_idxes = config.get("header_block_idxes", []) header_blocks = [] for idx in header_block_idxes: - blocks[idx].region_label = "header" + blocks[idx].order_label = "header" header_blocks.append(blocks[idx]) unordered_block_idxes = config.get("unordered_block_idxes", []) unordered_blocks = [] for idx in unordered_block_idxes: - blocks[idx].region_label = "unordered" + blocks[idx].order_label = "unordered" unordered_blocks.append(blocks[idx]) footer_block_idxes = config.get("footer_block_idxes", []) footer_blocks = [] for idx in footer_block_idxes: - blocks[idx].region_label = "footer" + blocks[idx].order_label = "footer" footer_blocks.append(blocks[idx]) mask_labels = ["header", "unordered", "footer"] @@ -89,11 +89,11 @@ def pre_process( if block.label in mask_labels: continue - if block.region_label not in child_labels: + if block.order_label not in child_labels: update_region_label(blocks, config, block_idx) - block_direction = block.direction - if block_direction == "horizontal": + block_orientation = block.orientation + if block_orientation == "horizontal": region_bbox_center = region_x_center tolerance_len = block.long_side_length // 5 else: @@ -108,51 +108,80 @@ def pre_process( pre_cut_block_idxes.append(block_idx) pre_cut_list = [] - cut_direction = "vertical" + cut_orientation = "vertical" cut_coordinates = [] discontinuous = [] mask_labels = child_labels + mask_labels all_boxes = np.array( - [block.bbox for block in blocks if block.region_label not in mask_labels] + [block.bbox for block in blocks if block.order_label not in mask_labels] ) + if len(all_boxes) == 0: + return header_blocks, pre_cut_list, footer_blocks, unordered_blocks if pre_cut_block_idxes: horizontal_cut_num = 0 for block_idx in pre_cut_block_idxes: block = blocks[block_idx] - horizontal_cut_num += 1 if block.secondary_direction == "horizontal" else 0 - cut_direction = ( + horizontal_cut_num += ( + 1 if block.secondary_orientation == "horizontal" else 0 + ) + cut_orientation = ( "horizontal" if horizontal_cut_num > len(pre_cut_block_idxes) * 0.5 else "vertical" ) - discontinuous = calculate_discontinuous_projection( - all_boxes, direction=cut_direction + discontinuous, num_list = calculate_discontinuous_projection( + all_boxes, orientation=cut_orientation, return_num=True ) for idx in pre_cut_block_idxes: block = blocks[idx] if ( - block.region_label not in mask_labels - and block.secondary_direction == cut_direction + block.order_label not in mask_labels + and block.secondary_orientation == cut_orientation ): if ( - block.secondary_direction_start_coordinate, - block.secondary_direction_end_coordinate, + block.secondary_orientation_start_coordinate, + block.secondary_orientation_end_coordinate, ) in discontinuous: - cut_coordinates.append(block.secondary_direction_start_coordinate) - cut_coordinates.append(block.secondary_direction_end_coordinate) + idx = discontinuous.index( + ( + block.secondary_orientation_start_coordinate, + block.secondary_orientation_end_coordinate, + ) + ) + if num_list[idx] == 1: + cut_coordinates.append( + block.secondary_orientation_start_coordinate + ) + cut_coordinates.append( + block.secondary_orientation_end_coordinate + ) if not discontinuous: discontinuous = calculate_discontinuous_projection( - all_boxes, direction=cut_direction + all_boxes, orientation=cut_orientation ) current_interval = discontinuous[0] for interval in discontinuous[1:]: gap_len = interval[0] - current_interval[1] - if gap_len > 40: + if gap_len >= 60: cut_coordinates.append(current_interval[1]) + elif gap_len > 40: + x1, _, x2, __ = region_bbox + y1 = current_interval[1] + y2 = interval[0] + bbox = [x1, y1, x2, y2] + ref_interval = interval[0] - current_interval[1] + ref_bboxes = [] + for block in blocks: + if get_nearest_edge_distance(bbox, block.bbox) < ref_interval * 2: + ref_bboxes.append(block.bbox) + discontinuous = calculate_discontinuous_projection( + ref_bboxes, orientation="horizontal" + ) + if len(discontinuous) != 2: + cut_coordinates.append(current_interval[1]) current_interval = interval - overall_region_box = config.get("all_layout_region_box") cut_list = get_cut_blocks( - blocks, cut_direction, cut_coordinates, overall_region_box, mask_labels + blocks, cut_orientation, cut_coordinates, region_bbox, mask_labels ) pre_cut_list.extend(cut_list) @@ -181,39 +210,40 @@ def update_region_label( block = blocks[block_idx] if block.label in doc_title_labels: - block.region_label = "doc_title" - # Force the direction of vision type to be horizontal + block.order_label = "doc_title" + # Force the orientation of vision type to be horizontal if block.label in vision_labels: - block.region_label = "vision" - block.update_direction_info() - # some paragraph title block may be labeled as sub_title, so we need to check if block.region_label is "other"(default). - if block.label in paragraph_title_labels and block.region_label == "other": - block.region_label = "paragraph_title" + block.order_label = "vision" + block.num_of_lines = 1 + block.update_orientation_info() + # some paragraph title block may be labeled as sub_title, so we need to check if block.order_label is "other"(default). + if block.label in paragraph_title_labels and block.order_label == "other": + block.order_label = "paragraph_title" # only vision and doc title block can have child block - if block.region_label not in ["vision", "doc_title", "paragraph_title"]: + if block.order_label not in ["vision", "doc_title", "paragraph_title"]: return iou_threshold = config.get("child_block_match_iou_threshold", 0.1) # match doc title text block - if block.region_label == "doc_title": + if block.order_label == "doc_title": text_block_idxes = config.get("text_block_idxes", []) - prev_idx, post_idx = get_adjacent_blocks_by_direction( + prev_idx, post_idx = get_adjacent_blocks_by_orientation( blocks, block_idx, text_block_idxes, iou_threshold ) update_doc_title_child_blocks(blocks, block, prev_idx, post_idx, config) # match sub title block - elif block.region_label == "paragraph_title": + elif block.order_label == "paragraph_title": iou_threshold = config.get("sub_title_match_iou_threshold", 0.1) paragraph_title_block_idxes = config.get("paragraph_title_block_idxes", []) text_block_idxes = config.get("text_block_idxes", []) megred_block_idxes = text_block_idxes + paragraph_title_block_idxes - prev_idx, post_idx = get_adjacent_blocks_by_direction( + prev_idx, post_idx = get_adjacent_blocks_by_orientation( blocks, block_idx, megred_block_idxes, iou_threshold ) update_paragraph_title_child_blocks(blocks, block, prev_idx, post_idx, config) # match vision title block - elif block.region_label == "vision": + elif block.order_label == "vision": # for matching vision title block vision_title_block_idxes = config.get("vision_title_block_idxes", []) # for matching vision footnote block @@ -221,7 +251,7 @@ def update_region_label( megred_block_idxes = text_block_idxes + vision_title_block_idxes # Some vision title block may be matched with multiple vision title block, so we need to try multiple times for i in range(3): - prev_idx, post_idx = get_adjacent_blocks_by_direction( + prev_idx, post_idx = get_adjacent_blocks_by_orientation( blocks, block_idx, megred_block_idxes, iou_threshold ) update_vision_child_blocks( @@ -231,18 +261,12 @@ def update_region_label( def get_layout_structure( blocks: List[LayoutParsingBlock], - median_width: float, - config: dict, - threshold: float = 0.8, ) -> Tuple[List[Dict[str, any]], bool]: """ Determine the layout cross column of blocks. Args: blocks (List[Dict[str, any]]): List of block dictionaries containing 'label' and 'block_bbox'. - median_width (float): Median width of text blocks. - no_mask_labels (List[str]): Labels of blocks to be considered for layout analysis. - threshold (float): Threshold for determining layout overlap. Returns: Tuple[List[Dict[str, any]], bool]: Updated list of blocks with layout information and a boolean @@ -251,79 +275,94 @@ def get_layout_structure( blocks.sort( key=lambda x: (x.bbox[0], x.width), ) - check_single_layout = {} - doc_title_labels = config.get("doc_title_labels", []) - region_box = config.get("all_layout_region_box", [0, 0, 0, 0]) + mask_labels = ["doc_title", "cross_text", "cross_reference"] for block_idx, block in enumerate(blocks): - cover_count = 0 - match_block_with_threshold_indexes = [] + if block.order_label in mask_labels: + continue for ref_idx, ref_block in enumerate(blocks): - if block_idx == ref_idx: + if block_idx == ref_idx or ref_block.order_label in mask_labels: continue - bbox_iou = calculate_iou(block.bbox, ref_block.bbox) + bbox_iou = calculate_overlap_ratio(block.bbox, ref_block.bbox) if bbox_iou > 0: - if block.region_label == "vision" or block.area < ref_block.area: - block.region_label = "cross_text" + if ref_block.order_label == "vision": + ref_block.order_label = "cross_text" + break + if block.order_label == "vision" or block.area < ref_block.area: + block.order_label = "cross_text" break - match_projection_iou = calculate_projection_iou( + match_projection_iou = calculate_projection_overlap_ratio( block.bbox, ref_block.bbox, "horizontal", ) - if match_projection_iou > 0: - cover_count += 1 - if match_projection_iou > threshold: - match_block_with_threshold_indexes.append( - (ref_idx, match_projection_iou), + for second_ref_idx, second_ref_block in enumerate(blocks): + if ( + second_ref_idx in [block_idx, ref_idx] + or second_ref_block.order_label in mask_labels + ): + continue + + bbox_iou = calculate_overlap_ratio( + block.bbox, second_ref_block.bbox ) - if ref_block.bbox[2] >= block.bbox[2]: - break - - block_center = (block.bbox[0] + block.bbox[2]) / 2 - region_bbox_center = (region_box[0] + region_box[2]) / 2 - center_offset = abs(block_center - region_bbox_center) - is_centered = center_offset <= median_width * 0.05 - width_gather_than_median = block.width > median_width * 1.3 - - if ( - cover_count >= 2 - and block.label not in doc_title_labels - and (width_gather_than_median != is_centered) - ): - block.region_label = ( - "cross_reference" if block.label == "reference" else "cross_text" - ) - else: - check_single_layout[block_idx] = match_block_with_threshold_indexes - - # Check single-layout block - for idx, single_layout in check_single_layout.items(): - if single_layout: - index, match_iou = single_layout[-1] - if match_iou > 0.9 and blocks[index].region_label == "cross_text": - blocks[idx].region_label = ( - "cross_reference" if block.label == "reference" else "cross_text" - ) + if bbox_iou > 0.1: + if second_ref_block.order_label == "vision": + second_ref_block.order_label = "cross_text" + break + if ( + block.order_label == "vision" + or block.area < second_ref_block.area + ): + block.order_label = "cross_text" + break + + second_match_projection_iou = calculate_projection_overlap_ratio( + block.bbox, + second_ref_block.bbox, + "horizontal", + ) + ref_match_projection_iou = calculate_projection_overlap_ratio( + ref_block.bbox, + second_ref_block.bbox, + "horizontal", + ) + ref_match_projection_iou_ = calculate_projection_overlap_ratio( + ref_block.bbox, + second_ref_block.bbox, + "vertical", + ) + if ( + second_match_projection_iou > 0 + and ref_match_projection_iou == 0 + and ref_match_projection_iou_ > 0 + and "vision" + not in [ref_block.order_label, second_ref_block.order_label] + ): + block.order_label = ( + "cross_reference" + if block.label == "reference" + else "cross_text" + ) def sort_by_xycut( block_bboxes: List, - direction: int = 0, + orientation: int = 0, min_gap: int = 1, ) -> List[int]: """ - Sort bounding boxes using recursive XY cut method based on the specified direction. + Sort bounding boxes using recursive XY cut method based on the specified orientation. Args: block_bboxes (Union[np.ndarray, List[List[int]]]): An array or list of bounding boxes, where each box is represented as [x_min, y_min, x_max, y_max]. - direction (int): Direction for the initial cut. Use 1 for Y-axis first and 0 for X-axis first. + orientation (int): orientation for the initial cut. Use 1 for Y-axis first and 0 for X-axis first. Defaults to 0. min_gap (int): Minimum gap width to consider a separation between segments. Defaults to 1. @@ -332,7 +371,7 @@ def sort_by_xycut( """ block_bboxes = np.asarray(block_bboxes).astype(int) res = [] - if direction == 1: + if orientation == 1: recursive_yx_cut( block_bboxes, np.arange(len(block_bboxes)).tolist(), @@ -379,11 +418,11 @@ def match_unsorted_blocks( unsorted_blocks = sort_blocks(unsorted_blocks, median_width, reverse=False) for idx, block in enumerate(unsorted_blocks): - region_label = block.region_label - if idx == 0 and region_label == "doc_title": + order_label = block.order_label + if idx == 0 and order_label == "doc_title": sorted_blocks.insert(0, block) continue - sorted_blocks = distance_type_map[region_label]( + sorted_blocks = distance_type_map[order_label]( block, sorted_blocks, config, median_width ) return sorted_blocks @@ -397,7 +436,7 @@ def xycut_enhanced( 1. Preprocess the input blocks by extracting headers, footers, and pre-cut blocks. 2. Mask blocks that are crossing different blocks. 3. Perform xycut_enhanced algorithm on the remaining blocks. - 4. Match special blocks with the sorted blocks based on their region labels. + 4. Match unsorted blocks with the sorted blocks based on their order labels. 5. Update child blocks of the sorted blocks based on their parent blocks. 6. Return the ordered result list. @@ -419,7 +458,6 @@ def xycut_enhanced( header_blocks = sort_blocks(header_blocks) footer_blocks = sort_blocks(footer_blocks) unordered_blocks = sort_blocks(unordered_blocks) - final_order_res_list.extend(header_blocks) unsorted_blocks: List[LayoutParsingBlock] = [] @@ -438,13 +476,11 @@ def xycut_enhanced( get_layout_structure( pre_cut_blocks, - median_width, - config, ) # Get xy cut blocks and add other blocks in special_block_map for block in pre_cut_blocks: - if block.region_label not in [ + if block.order_label not in [ "cross_text", "cross_reference", "doc_title", @@ -460,18 +496,20 @@ def xycut_enhanced( block_bboxes = np.array([block.bbox for block in xy_cut_blocks]) block_text_lines = [block.num_of_lines for block in xy_cut_blocks] discontinuous = calculate_discontinuous_projection( - block_bboxes, direction="horizontal" + block_bboxes, orientation="horizontal" ) + if len(discontinuous) > 1: + xy_cut_blocks = [block for block in xy_cut_blocks] if len(discontinuous) == 1 or max(block_text_lines) == 1: xy_cut_blocks.sort(key=lambda x: (x.bbox[1] // 5, x.bbox[0])) xy_cut_blocks = shrink_overlapping_boxes(xy_cut_blocks, "vertical") block_bboxes = np.array([block.bbox for block in xy_cut_blocks]) - sorted_indexes = sort_by_xycut(block_bboxes, direction=1, min_gap=1) + sorted_indexes = sort_by_xycut(block_bboxes, orientation=1, min_gap=1) else: xy_cut_blocks.sort(key=lambda x: (x.bbox[0] // 20, x.bbox[1])) xy_cut_blocks = shrink_overlapping_boxes(xy_cut_blocks, "horizontal") block_bboxes = np.array([block.bbox for block in xy_cut_blocks]) - sorted_indexes = sort_by_xycut(block_bboxes, direction=0, min_gap=20) + sorted_indexes = sort_by_xycut(block_bboxes, orientation=0, min_gap=20) sorted_blocks = [xy_cut_blocks[i] for i in sorted_indexes] @@ -498,15 +536,9 @@ def xycut_enhanced( final_order_res_list.extend(footer_blocks) final_order_res_list.extend(unordered_blocks) - index = 0 - visualize_index_labels = config.get("visualize_index_labels", []) for block_idx, block in enumerate(final_order_res_list): - if block.label not in visualize_index_labels: - continue final_order_res_list = insert_child_blocks( block, block_idx, final_order_res_list ) block = final_order_res_list[block_idx] - index += 1 - block.index = index return final_order_res_list From a095e76fd883fef5449c2945432f5d28d12bd18c Mon Sep 17 00:00:00 2001 From: zhouchangda Date: Sun, 11 May 2025 10:15:59 +0000 Subject: [PATCH 2/2] support sort by different text line --- .../pipelines/layout_parsing/pipeline_v2.py | 155 +++--- .../pipelines/layout_parsing/result_v2.py | 224 +++++---- .../pipelines/layout_parsing/setting.py | 47 +- .../pipelines/layout_parsing/utils.py | 161 +++--- .../layout_parsing/xycut_enhanced/utils.py | 466 ++++++++++++------ .../layout_parsing/xycut_enhanced/xycuts.py | 398 ++++++++------- 6 files changed, 827 insertions(+), 624 deletions(-) diff --git a/paddlex/inference/pipelines/layout_parsing/pipeline_v2.py b/paddlex/inference/pipelines/layout_parsing/pipeline_v2.py index 7035dcf509..a7afd3b54a 100644 --- a/paddlex/inference/pipelines/layout_parsing/pipeline_v2.py +++ b/paddlex/inference/pipelines/layout_parsing/pipeline_v2.py @@ -30,6 +30,7 @@ from ..base import BasePipeline from ..ocr.result import OCRResult from .result_v2 import LayoutParsingBlock, LayoutParsingRegion, LayoutParsingResultV2 +from .setting import BLOCK_LABEL_MAP, BLOCK_SETTINGS, LINE_SETTINGS, REGION_SETTINGS from .utils import ( caculate_bbox_area, calculate_minimum_enclosing_bbox, @@ -260,8 +261,6 @@ def check_model_settings_valid(self, input_params: dict) -> bool: def standardized_data( self, image: list, - parameters_config: dict, - block_label_mapping: dict, region_det_res: DetResult, layout_det_res: DetResult, overall_ocr_res: OCRResult, @@ -360,7 +359,7 @@ def standardized_data( paragraph_title_block_area = caculate_bbox_area( layout_det_res["boxes"][paragraph_title_list[0]]["coordinate"] ) - title_area_max_block_threshold = parameters_config["block"].get( + title_area_max_block_threshold = BLOCK_SETTINGS.get( "title_conversion_area_ratio_threshold", 0.3 ) if ( @@ -441,7 +440,7 @@ def standardized_data( break if not has_text and layout_det_res["boxes"][layout_box_idx][ "label" - ] not in block_label_mapping.get("vision_labels", []): + ] not in BLOCK_LABEL_MAP.get("vision_labels", []): crop_box = layout_det_res["boxes"][layout_box_idx]["coordinate"] x1, y1, x2, y2 = [int(i) for i in crop_box] crop_img = np.array(image)[y1:y2, x1:x2] @@ -506,7 +505,7 @@ def standardized_data( overlap_ratio = calculate_overlap_ratio( region_bbox, block_bboxes[block_idx], mode="small" ) - if overlap_ratio > parameters_config["region"].get( + if overlap_ratio > REGION_SETTINGS.get( "match_block_overlap_ratio_threshold", 0.8 ): region_to_block_map[region_idx].append(block_idx) @@ -540,7 +539,6 @@ def standardized_data( image.shape[0], block_idxes_set, block_bboxes, - parameters_config, ) ) if len(matched_idxes) == 0: @@ -570,7 +568,7 @@ def sort_line_by_projection( input_img: np.ndarray, text_rec_model: Any, text_rec_score_thresh: Union[float, None] = None, - orientation: str = "vertical", + direction: str = "vertical", ) -> None: """ Sort a line of text spans based on their vertical position within the layout bounding box. @@ -583,8 +581,8 @@ def sort_line_by_projection( Returns: list: The sorted line of text spans. """ - sort_index = 0 if orientation == "horizontal" else 1 - splited_boxes = split_boxes_by_projection(line, orientation) + sort_index = 0 if direction == "horizontal" else 1 + splited_boxes = split_boxes_by_projection(line, direction) splited_lines = [] if len(line) != len(splited_boxes): splited_boxes.sort(key=lambda span: span[0][sort_index]) @@ -614,7 +612,6 @@ def sort_line_by_projection( def get_block_rec_content( self, image: list, - line_parameters_config: dict, ocr_rec_res: dict, block: LayoutParsingBlock, text_rec_model: Any, @@ -625,37 +622,49 @@ def get_block_rec_content( block.content = "" return block - lines, text_orientation = group_boxes_into_lines( + lines, text_direction = group_boxes_into_lines( ocr_rec_res, - line_parameters_config.get("line_height_iou_threshold", 0.8), + LINE_SETTINGS.get("line_height_iou_threshold", 0.8), ) if block.label == "reference": rec_boxes = ocr_rec_res["boxes"] block_right_coordinate = max([box[2] for box in rec_boxes]) - last_line_span_limit = 20 else: block_right_coordinate = block.bbox[2] - last_line_span_limit = 10 # format line text_lines = [] need_new_line_num = 0 - sort_index = 0 if text_orientation == "horizontal" else 1 + start_index = 0 if text_direction == "horizontal" else 1 + secondary_direction_start_index = 1 if text_direction == "horizontal" else 0 + line_height_list, line_width_list = [], [] for idx, line in enumerate(lines): - line.sort(key=lambda span: span[0][sort_index]) + line.sort(key=lambda span: span[0][start_index]) + text_bboxes_height = [ + span[0][secondary_direction_start_index + 2] + - span[0][secondary_direction_start_index] + for span in line + ] + text_bboxes_width = [ + span[0][start_index + 2] - span[0][start_index] for span in line + ] + + line_height = np.mean(text_bboxes_height) + line_height_list.append(line_height) + line_width_list.append(np.mean(text_bboxes_width)) # merge formula and text ocr_labels = [span[2] for span in line] if "formula" in ocr_labels: line = self.sort_line_by_projection( - line, image, text_rec_model, text_rec_score_thresh, text_orientation + line, image, text_rec_model, text_rec_score_thresh, text_direction ) line_text, need_new_line = format_line( line, block_right_coordinate, - last_line_span_limit=last_line_span_limit, + last_line_span_limit=line_height * 1.5, block_label=block.label, ) if need_new_line: @@ -668,21 +677,21 @@ def get_block_rec_content( block.seg_end_coordinate = line_end_coordinate text_lines.append(line_text) - delim = line_parameters_config["delimiter_map"].get(block.label, "") + delim = LINE_SETTINGS["delimiter_map"].get(block.label, "") if need_new_line_num > len(text_lines) * 0.5 and delim == "": delim = "\n" content = delim.join(text_lines) block.content = content block.num_of_lines = len(text_lines) - block.orientation = text_orientation + block.direction = text_direction + block.text_line_height = np.mean(line_height_list) + block.text_line_width = np.mean(line_width_list) return block def get_layout_parsing_blocks( self, image: list, - parameters_config: dict, - block_label_mapping: dict, region_block_ocr_idx_map: dict, region_det_res: DetResult, overall_ocr_res: OCRResult, @@ -759,7 +768,6 @@ def get_layout_parsing_blocks( block = self.get_block_rec_content( image=image, block=block, - line_parameters_config=parameters_config["line"], ocr_rec_res=rec_res, text_rec_model=text_rec_model, text_rec_score_thresh=text_rec_score_thresh, @@ -781,9 +789,8 @@ def get_layout_parsing_blocks( for idx in region_block_ocr_idx_map["region_to_block_map"][region_idx] ] region = LayoutParsingRegion( - region_bbox=region_bbox, + bbox=region_bbox, blocks=region_blocks, - block_label_mapping=block_label_mapping, ) region_list.append(region) @@ -818,14 +825,11 @@ def get_layout_parsing_res( Returns: list: A list of dictionaries representing the layout parsing result. """ - from .setting import block_label_mapping, parameters_config # Standardize data region_block_ocr_idx_map, region_det_res, layout_det_res = ( self.standardized_data( image=image, - parameters_config=parameters_config, - block_label_mapping=block_label_mapping, region_det_res=region_det_res, layout_det_res=layout_det_res, overall_ocr_res=overall_ocr_res, @@ -838,8 +842,6 @@ def get_layout_parsing_res( # Format layout parsing block region_list = self.get_layout_parsing_blocks( image=image, - parameters_config=parameters_config, - block_label_mapping=block_label_mapping, region_block_ocr_idx_map=region_block_ocr_idx_map, region_det_res=region_det_res, overall_ocr_res=overall_ocr_res, @@ -854,11 +856,10 @@ def get_layout_parsing_res( for region in region_list: parsing_res_list.extend(region.sort()) - visualize_index_labels = block_label_mapping["visualize_index_labels"] index = 1 for block in parsing_res_list: - if block.label in visualize_index_labels: - block.index = index + if block.label in BLOCK_LABEL_MAP["visualize_index_labels"]: + block.order_index = index index += 1 return parsing_res_list @@ -956,8 +957,6 @@ def predict( use_e2e_wired_table_rec_model: bool = False, use_e2e_wireless_table_rec_model: bool = True, is_pretty_markdown: Union[bool, None] = None, - use_layout_gt: bool = False, - layout_gt_dir: Union[str, None] = None, **kwargs, ) -> LayoutParsingResultV2: """ @@ -1032,65 +1031,16 @@ def predict( doc_preprocessor_image = doc_preprocessor_res["output_img"] - use_layout_gt = use_layout_gt - if not use_layout_gt: - layout_det_res = next( - self.layout_det_model( - doc_preprocessor_image, - threshold=layout_threshold, - layout_nms=layout_nms, - layout_unclip_ratio=layout_unclip_ratio, - layout_merge_bboxes_mode=layout_merge_bboxes_mode, - ) + layout_det_res = next( + self.layout_det_model( + doc_preprocessor_image, + threshold=layout_threshold, + layout_nms=layout_nms, + layout_unclip_ratio=layout_unclip_ratio, + layout_merge_bboxes_mode=layout_merge_bboxes_mode, ) - else: - import json - import os - - from ...models.object_detection.result import DetResult - - label_dir = layout_gt_dir - notes_path = f"{label_dir}/notes.json" - labels = f"{label_dir}/labels" - gt_file = os.path.basename(input)[:-4] + ".txt" - gt_path = f"{labels}/{gt_file}" - with open(notes_path, "r") as f: - notes = json.load(f) - categories_map = {} - for categories in notes["categories"]: - id = int(categories["id"]) - name = categories["name"] - categories_map[id] = name - with open(gt_path, "r") as f: - lines = f.readlines() - layout_det_res_dic = { - "input_img": doc_preprocessor_image, - "page_index": None, - "boxes": [], - } - for line in lines: - line = line.strip().split(" ") - category_id = int(line[0]) - label = categories_map[category_id] - img_h, img_w = doc_preprocessor_image.shape[:2] - center_x = float(line[1]) * img_w - center_y = float(line[2]) * img_h - w = float(line[3]) * img_w - h = float(line[4]) * img_h - x0 = center_x - w / 2 - y0 = center_y - h / 2 - x1 = center_x + w / 2 - y1 = center_y + h / 2 - box = [x0, y0, x1, y1] - layout_det_res_dic["boxes"].append( - { - "cls_id": category_id, - "label": label, - "coordinate": box, - "score": 1.0, - } - ) - layout_det_res = DetResult(layout_det_res_dic) + ) + imgs_in_doc = gather_imgs(doc_preprocessor_image, layout_det_res["boxes"]) if model_settings["use_region_detection"]: @@ -1139,7 +1089,13 @@ def predict( ), ) else: - overall_ocr_res = {} + overall_ocr_res = { + "dt_polys": [], + "rec_texts": [], + "rec_scores": [], + "rec_polys": [], + "rec_boxes": np.array([]), + } overall_ocr_res["rec_labels"] = ["text"] * len(overall_ocr_res["rec_texts"]) @@ -1157,9 +1113,14 @@ def predict( table_contents["rec_texts"].append( f"${formula_res['rec_formula']}$" ) - table_contents["rec_boxes"] = np.vstack( - (table_contents["rec_boxes"], [formula_res["dt_polys"]]) - ) + if table_contents["rec_boxes"].size == 0: + table_contents["rec_boxes"] = np.array( + [formula_res["dt_polys"]] + ) + else: + table_contents["rec_boxes"] = np.vstack( + (table_contents["rec_boxes"], [formula_res["dt_polys"]]) + ) table_contents["rec_polys"].append(poly_points) table_contents["rec_scores"].append(1) diff --git a/paddlex/inference/pipelines/layout_parsing/result_v2.py b/paddlex/inference/pipelines/layout_parsing/result_v2.py index c27c47bc87..6c41434e9b 100644 --- a/paddlex/inference/pipelines/layout_parsing/result_v2.py +++ b/paddlex/inference/pipelines/layout_parsing/result_v2.py @@ -20,8 +20,9 @@ from typing import List import numpy as np -from PIL import Image, ImageDraw +from PIL import Image, ImageDraw, ImageFont +from ....utils.fonts import PINGFANG_FONT_FILE_PATH from ...common.result import ( BaseCVResult, HtmlMixin, @@ -29,6 +30,7 @@ MarkdownMixin, XlsxMixin, ) +from .setting import BLOCK_LABEL_MAP class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin): @@ -107,16 +109,23 @@ def _to_img(self) -> dict[str, np.ndarray]: # for layout ordering image image = Image.fromarray(self["doc_preprocessor_res"]["output_img"][:, :, ::-1]) draw = ImageDraw.Draw(image, "RGBA") + font_size = int(0.018 * int(image.width)) + 2 + font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8") parsing_result: List[LayoutParsingBlock] = self["parsing_res_list"] for block in parsing_result: bbox = block.bbox - index = block.index - label = block.label - fill_color = get_show_color(label) + index = block.order_index + label = block.order_label + fill_color = get_show_color(label, True) draw.rectangle(bbox, fill=fill_color) if index is not None: - text_position = (bbox[2] + 2, bbox[1] - 10) - draw.text(text_position, str(index), fill="red") + text_position = (bbox[2] + 2, bbox[1] - font_size // 2) + if int(image.width) - bbox[2] < font_size: + text_position = ( + int(bbox[2] - font_size * 1.1), + bbox[1] - font_size // 2, + ) + draw.text(text_position, str(index), font=font, fill="red") res_img_dict["layout_order_res"] = image @@ -475,8 +484,8 @@ class LayoutParsingBlock: def __init__(self, label, bbox, content="") -> None: self.label = label - self.order_label = "other" - self.bbox = [int(item) for item in bbox] + self.order_label = None + self.bbox = list(map(int, bbox)) self.content = content self.seg_start_coordinate = float("inf") self.seg_end_coordinate = float("-inf") @@ -486,40 +495,42 @@ def __init__(self, label, bbox, content="") -> None: self.num_of_lines = 1 self.image = None self.index = None - self.visual_index = None - self.orientation = self.get_bbox_orientation() + self.order_index = None + self.text_line_width = 1 + self.text_line_height = 1 + self.direction = self.get_bbox_direction() self.child_blocks = [] - self.update_orientation_info() + self.update_direction_info() def __str__(self) -> str: return f"{self.__dict__}" def __repr__(self) -> str: - _str = f"\n\n#################\nlabel:\t{self.label}\nregion_label:\t{self.order_label}\nbbox:\t{self.bbox}\ncontent:\t{self.content}\n#################" + _str = f"\n\n#################\nindex:\t{self.index}\nlabel:\t{self.label}\nregion_label:\t{self.order_label}\nbbox:\t{self.bbox}\ncontent:\t{self.content}\n#################" return _str def to_dict(self) -> dict: return self.__dict__ - def update_orientation_info(self) -> None: + def update_direction_info(self) -> None: if self.order_label == "vision": - self.orientation = "horizontal" - if self.orientation == "horizontal": - self.secondary_orientation = "vertical" + self.direction = "horizontal" + if self.direction == "horizontal": + self.secondary_direction = "vertical" self.short_side_length = self.height self.long_side_length = self.width self.start_coordinate = self.bbox[0] self.end_coordinate = self.bbox[2] - self.secondary_orientation_start_coordinate = self.bbox[1] - self.secondary_orientation_end_coordinate = self.bbox[3] + self.secondary_direction_start_coordinate = self.bbox[1] + self.secondary_direction_end_coordinate = self.bbox[3] else: - self.secondary_orientation = "horizontal" + self.secondary_direction = "horizontal" self.short_side_length = self.width self.long_side_length = self.height self.start_coordinate = self.bbox[1] self.end_coordinate = self.bbox[3] - self.secondary_orientation_start_coordinate = self.bbox[0] - self.secondary_orientation_end_coordinate = self.bbox[2] + self.secondary_direction_start_coordinate = self.bbox[0] + self.secondary_direction_end_coordinate = self.bbox[2] def append_child_block(self, child_block: LayoutParsingBlock) -> None: if not self.child_blocks: @@ -533,7 +544,7 @@ def append_child_block(self, child_block: LayoutParsingBlock) -> None: max(y2, y2_child), ) self.bbox = union_bbox - self.update_orientation_info() + self.update_direction_info() child_blocks = [child_block] if child_block.child_blocks: child_blocks.extend(child_block.get_child_blocks()) @@ -550,107 +561,130 @@ def get_centroid(self) -> tuple: centroid = ((x1 + x2) / 2, (y1 + y2) / 2) return centroid - def get_bbox_orientation(self, orientation_ratio: float = 1.0) -> bool: + def get_bbox_direction(self, direction_ratio: float = 1.0) -> bool: """ Determine if a bounding box is horizontal or vertical. Args: bbox (List[float]): Bounding box [x_min, y_min, x_max, y_max]. - orientation_ratio (float): Ratio for determining orientation. Default is 1.0. + direction_ratio (float): Ratio for determining direction. Default is 1.0. Returns: str: "horizontal" or "vertical". """ return ( - "horizontal" - if self.width * orientation_ratio >= self.height - else "vertical" + "horizontal" if self.width * direction_ratio >= self.height else "vertical" ) class LayoutParsingRegion: - def __init__( - self, region_bbox, blocks: List[LayoutParsingBlock] = [], block_label_mapping={} - ) -> None: - self.region_bbox = region_bbox - self.blocks = blocks + def __init__(self, bbox, blocks: List[LayoutParsingBlock] = []) -> None: + self.bbox = bbox self.block_map = {} - self.update_config(block_label_mapping) - self.orientation = None + self.direction = "horizontal" self.calculate_bbox_metrics() - - def update_config(self, block_label_mapping): - self.block_map = {} - self.config = copy.deepcopy(block_label_mapping) - self.config["region_bbox"] = self.region_bbox - horizontal_text_block_num = 0 - for idx, block in enumerate(self.blocks): - label = block.label - if ( - block.order_label not in ["vision", "vision_title"] - and block.orientation == "horizontal" - ): - horizontal_text_block_num += 1 + self.doc_title_block_idxes = [] + self.paragraph_title_block_idxes = [] + self.vision_block_idxes = [] + self.unordered_block_idxes = [] + self.vision_title_block_idxes = [] + self.normal_text_block_idxes = [] + self.header_block_idxes = [] + self.footer_block_idxes = [] + self.text_line_width = 20 + self.text_line_height = 10 + self.init_region_info_from_layout(blocks) + self.init_direction_info() + + def init_region_info_from_layout(self, blocks: List[LayoutParsingBlock]): + horizontal_normal_text_block_num = 0 + text_line_height_list = [] + text_line_width_list = [] + for idx, block in enumerate(blocks): self.block_map[idx] = block - self.update_layout_order_config_block_index(label, idx) - text_block_num = ( - len(self.blocks) - - len(self.config.get("vision_block_idxes", [])) - - len(self.config.get("vision_title_block_idxes", [])) - ) - self.orientation = ( + block.index = idx + if block.label in BLOCK_LABEL_MAP["header_labels"]: + self.header_block_idxes.append(idx) + elif block.label in BLOCK_LABEL_MAP["doc_title_labels"]: + self.doc_title_block_idxes.append(idx) + elif block.label in BLOCK_LABEL_MAP["paragraph_title_labels"]: + self.paragraph_title_block_idxes.append(idx) + elif block.label in BLOCK_LABEL_MAP["vision_labels"]: + self.vision_block_idxes.append(idx) + elif block.label in BLOCK_LABEL_MAP["vision_title_labels"]: + self.vision_title_block_idxes.append(idx) + elif block.label in BLOCK_LABEL_MAP["footer_labels"]: + self.footer_block_idxes.append(idx) + elif block.label in BLOCK_LABEL_MAP["unordered_labels"]: + self.unordered_block_idxes.append(idx) + else: + self.normal_text_block_idxes.append(idx) + text_line_height_list.append(block.text_line_height) + text_line_width_list.append(block.text_line_width) + if block.direction == "horizontal": + horizontal_normal_text_block_num += 1 + self.direction = ( "horizontal" - if horizontal_text_block_num >= text_block_num * 0.5 + if horizontal_normal_text_block_num + >= len(self.normal_text_block_idxes) * 0.5 else "vertical" ) - self.config["region_orientation"] = self.orientation + self.text_line_width = ( + np.mean(text_line_width_list) if text_line_width_list else 20 + ) + self.text_line_height = ( + np.mean(text_line_height_list) if text_line_height_list else 10 + ) + + def init_direction_info(self): + if self.direction == "horizontal": + self.direction_start_index = 0 + self.direction_end_index = 2 + self.secondary_direction_start_index = 1 + self.secondary_direction_end_index = 3 + self.secondary_direction = "vertical" + else: + self.direction_start_index = 1 + self.direction_end_index = 3 + self.secondary_direction_start_index = 0 + self.secondary_direction_end_index = 2 + self.secondary_direction = "horizontal" + + self.direction_center_coordinate = ( + self.bbox[self.direction_start_index] + self.bbox[self.direction_end_index] + ) / 2 + self.secondary_direction_center_coordinate = ( + self.bbox[self.secondary_direction_start_index] + + self.bbox[self.secondary_direction_end_index] + ) / 2 def calculate_bbox_metrics(self): - x1, y1, x2, y2 = self.region_bbox + x1, y1, x2, y2 = self.bbox x_center, y_center = (x1 + x2) / 2, (y1 + y2) / 2 self.euclidean_distance = math.sqrt(((x1) ** 2 + (y1) ** 2)) self.center_euclidean_distance = math.sqrt(((x_center) ** 2 + (y_center) ** 2)) self.angle_rad = math.atan2(y_center, x_center) + def sort_normal_blocks(self, blocks): + if self.direction == "horizontal": + blocks.sort( + key=lambda x: ( + x.bbox[1] // self.text_line_height, + x.bbox[0] // self.text_line_width, + x.bbox[1] ** 2 + x.bbox[0] ** 2, + ), + ) + else: + blocks.sort( + key=lambda x: ( + -x.bbox[0] // self.text_line_width, + x.bbox[1] // self.text_line_height, + -(x.bbox[2] ** 2 + x.bbox[1] ** 2), + ), + ) + def sort(self): from .xycut_enhanced import xycut_enhanced - return xycut_enhanced(self.blocks, self.config) - - def update_layout_order_config_block_index( - self, block_label: str, block_idx: int - ) -> None: - doc_title_labels = self.config["doc_title_labels"] - paragraph_title_labels = self.config["paragraph_title_labels"] - vision_labels = self.config["vision_labels"] - vision_title_labels = self.config["vision_title_labels"] - header_labels = self.config["header_labels"] - unordered_labels = self.config["unordered_labels"] - footer_labels = self.config["footer_labels"] - text_labels = self.config["text_labels"] - self.config.setdefault("doc_title_block_idxes", []) - self.config.setdefault("paragraph_title_block_idxes", []) - self.config.setdefault("vision_block_idxes", []) - self.config.setdefault("vision_title_block_idxes", []) - self.config.setdefault("unordered_block_idxes", []) - self.config.setdefault("text_block_idxes", []) - self.config.setdefault("header_block_idxes", []) - self.config.setdefault("footer_block_idxes", []) - - if block_label in doc_title_labels: - self.config["doc_title_block_idxes"].append(block_idx) - if block_label in paragraph_title_labels: - self.config["paragraph_title_block_idxes"].append(block_idx) - if block_label in vision_labels: - self.config["vision_block_idxes"].append(block_idx) - if block_label in vision_title_labels: - self.config["vision_title_block_idxes"].append(block_idx) - if block_label in unordered_labels: - self.config["unordered_block_idxes"].append(block_idx) - if block_label in text_labels: - self.config["text_block_idxes"].append(block_idx) - if block_label in header_labels: - self.config["header_block_idxes"].append(block_idx) - if block_label in footer_labels: - self.config["footer_block_idxes"].append(block_idx) + return xycut_enhanced(self) diff --git a/paddlex/inference/pipelines/layout_parsing/setting.py b/paddlex/inference/pipelines/layout_parsing/setting.py index 7affbf3096..82162a95fa 100644 --- a/paddlex/inference/pipelines/layout_parsing/setting.py +++ b/paddlex/inference/pipelines/layout_parsing/setting.py @@ -12,32 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -parameters_config = { - "page": {}, - "region": { - "match_block_overlap_ratio_threshold": 0.8, - "split_block_overlap_ratio_threshold": 0.4, - }, - "block": { - "title_conversion_area_ratio_threshold": 0.3, # update paragraph_title -> doc_title - }, - "line": { - "line_height_iou_threshold": 0.6, # For line segmentation of OCR results - "delimiter_map": { - "doc_title": " ", - "content": "\n", - }, - }, - "word": { - "delimiter": " ", + +XYCUT_SETTINGS = { + "child_block_overlap_ratio_threshold": 0.1, + "edge_distance_compare_tolerance_len": 2, + "distance_weight_map": { + "edge_weight": 10**4, + "up_edge_weight": 1, + "down_edge_weight": 0.0001, }, - "order": { - "block_label_match_iou_threshold": 0.1, - "block_title_match_iou_threshold": 0.1, +} + +REGION_SETTINGS = { + "match_block_overlap_ratio_threshold": 0.6, + "split_block_overlap_ratio_threshold": 0.4, +} + +BLOCK_SETTINGS = { + "title_conversion_area_ratio_threshold": 0.3, # update paragraph_title -> doc_title +} + +LINE_SETTINGS = { + "line_height_iou_threshold": 0.6, # For line segmentation of OCR results + "delimiter_map": { + "doc_title": " ", + "content": "\n", }, } -block_label_mapping = { +BLOCK_LABEL_MAP = { "doc_title_labels": ["doc_title"], # 文档标题 "paragraph_title_labels": [ "paragraph_title", diff --git a/paddlex/inference/pipelines/layout_parsing/utils.py b/paddlex/inference/pipelines/layout_parsing/utils.py index 961c2402bf..904156b932 100644 --- a/paddlex/inference/pipelines/layout_parsing/utils.py +++ b/paddlex/inference/pipelines/layout_parsing/utils.py @@ -27,6 +27,7 @@ from ..components import convert_points_to_boxes from ..ocr.result import OCRResult +from .setting import REGION_SETTINGS def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> List: @@ -173,7 +174,7 @@ def sorted_layout_boxes(res, w): def calculate_projection_overlap_ratio( bbox1: List[float], bbox2: List[float], - orientation: str = "horizontal", + direction: str = "horizontal", mode="union", ) -> float: """ @@ -182,13 +183,13 @@ def calculate_projection_overlap_ratio( Args: bbox1 (List[float]): First bounding box [x_min, y_min, x_max, y_max]. bbox2 (List[float]): Second bounding box [x_min, y_min, x_max, y_max]. - orientation (str): orientation of the projection, "horizontal" or "vertical". + direction (str): direction of the projection, "horizontal" or "vertical". Returns: float: Line overlap ratio. Returns 0 if there is no overlap. """ start_index, end_index = 1, 3 - if orientation == "horizontal": + if direction == "horizontal": start_index, end_index = 0, 2 intersection_start = max(bbox1[start_index], bbox2[start_index]) @@ -241,8 +242,8 @@ def calculate_overlap_ratio( inter_area = inter_width * inter_height - bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) - bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) + bbox1_area = caculate_bbox_area(bbox1) + bbox2_area = caculate_bbox_area(bbox2) if mode == "union": ref_area = bbox1_area + bbox2_area - inter_area @@ -271,7 +272,7 @@ def group_boxes_into_lines(ocr_rec_res, line_height_iou_threshold): ] text_orientation = calculate_text_orientation(text_boxes) - match_orientation = "vertical" if text_orientation == "horizontal" else "horizontal" + match_direction = "vertical" if text_orientation == "horizontal" else "horizontal" spans = list(zip(rec_boxes, rec_texts, rec_labels)) sort_index = 1 @@ -284,14 +285,14 @@ def group_boxes_into_lines(ocr_rec_res, line_height_iou_threshold): lines = [] line = [spans[0]] - line_region_box = spans[0][0][:] + line_region_box = spans[0][0].copy() # merge line for span in spans[1:]: rec_bbox = span[0] if ( calculate_projection_overlap_ratio( - line_region_box, rec_bbox, match_orientation, mode="small" + line_region_box, rec_bbox, match_direction, mode="small" ) >= line_height_iou_threshold ): @@ -301,7 +302,7 @@ def group_boxes_into_lines(ocr_rec_res, line_height_iou_threshold): else: lines.append(line) line = [span] - line_region_box = rec_bbox[:] + line_region_box = rec_bbox.copy() lines.append(line) return lines, text_orientation @@ -365,12 +366,31 @@ def is_english_letter(char): return bool(re.match(r"^[A-Za-z]$", char)) +def is_non_breaking_punctuation(char): + """ + 判断一个字符是否是不需要换行的标点符号,包括全角和半角的符号。 + + :param char: str, 单个字符 + :return: bool, 如果字符是不需要换行的标点符号,返回True,否则返回False + """ + non_breaking_punctuations = { + ",", # 半角逗号 + ",", # 全角逗号 + "、", # 顿号 + ";", # 半角分号 + ";", # 全角分号 + ":", # 半角冒号 + ":", # 全角冒号 + } + + return char in non_breaking_punctuations + + def format_line( line: List[List[Union[List[int], str]]], block_right_coordinate: int, last_line_span_limit: int = 10, block_label: str = "text", - # delimiter_map: Dict = {}, ) -> None: """ Format a line of text spans based on layout constraints. @@ -402,6 +422,7 @@ def format_line( and not line_text.endswith("-") and len(line_text) > 0 and not is_english_letter(line_text[-1]) + and not is_non_breaking_punctuation(line_text[-1]) ): need_new_line = True @@ -415,37 +436,35 @@ def format_line( return line_text, need_new_line -def split_boxes_by_projection(spans: List[List[int]], orientation, offset=1e-5): +def split_boxes_by_projection(spans: List[List[int]], direction, offset=1e-5): """ - Check if there is any complete containment in the x-orientation + Check if there is any complete containment in the x-direction between the bounding boxes and split the containing box accordingly. Args: spans (list of lists): Each element is a list containing an ndarray of length 4, a text string, and a label. - orientation: 'horizontal' or 'vertical', indicating whether the spans are arranged horizontally or vertically. + direction: 'horizontal' or 'vertical', indicating whether the spans are arranged horizontally or vertically. offset (float): A small offset value to ensure that the split boxes are not too close to the original boxes. Returns: A new list of boxes, including split boxes, with the same `rec_text` and `label` attributes. """ def is_projection_contained(box_a, box_b, start_idx, end_idx): - """Check if box_a completely contains box_b in the x-orientation.""" + """Check if box_a completely contains box_b in the x-direction.""" return box_a[start_idx] <= box_b[start_idx] and box_a[end_idx] >= box_b[end_idx] new_boxes = [] - if orientation == "horizontal": + if direction == "horizontal": projection_start_index, projection_end_index = 0, 2 else: projection_start_index, projection_end_index = 1, 3 for i in range(len(spans)): span = spans[i] - box_a, text, label = span is_split = False - for j in range(len(spans)): - if i == j: - continue + for j in range(i, len(spans)): box_b = spans[j][0] + box_a, text, label = span if is_projection_contained( box_a, box_b, projection_start_index, projection_end_index ): @@ -458,12 +477,13 @@ def is_projection_contained(box_a, box_b, start_idx, end_idx): - box_a[projection_start_index] ) if w > 1: - box_a[projection_end_index] = ( + new_bbox = box_a.copy() + new_bbox[projection_end_index] = ( box_b[projection_start_index] - offset ) new_boxes.append( [ - np.array(box_a), + np.array(new_bbox), text, label, ] @@ -562,8 +582,8 @@ def _get_minbox_if_overlap_by_ratio( The selected bounding box or None if the overlap ratio is not exceeded. """ # Calculate the areas of both bounding boxes - area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) - area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) + area1 = caculate_bbox_area(bbox1) + area2 = caculate_bbox_area(bbox2) # Calculate the overlap ratio using a helper function overlap_ratio = calculate_overlap_ratio(bbox1, bbox2, mode="small") # Check if the overlap ratio exceeds the threshold @@ -683,7 +703,6 @@ def shrink_supplement_region_bbox( image_height, block_idxes_set, block_bboxes, - parameters_config, ) -> List: """ Shrink the supplement region bbox according to the reference region bbox and match the block bboxes. @@ -695,7 +714,6 @@ def shrink_supplement_region_bbox( image_height (int): The height of the image. block_idxes_set (set): The indexes of the blocks that intersect with the region bbox. block_bboxes (dict): The dictionary of block bboxes. - parameters_config (dict): The configuration parameters. Returns: list: The new region bbox and the matched block idxes. @@ -723,11 +741,11 @@ def shrink_supplement_region_bbox( overlap_ratio = calculate_overlap_ratio( tmp_region_bbox, block_bboxes[block_idx], mode="small" ) - if overlap_ratio > parameters_config["region"].get( + if overlap_ratio > REGION_SETTINGS.get( "match_block_overlap_ratio_threshold", 0.8 ): iner_block_idxes.append(block_idx) - elif overlap_ratio > parameters_config["region"].get( + elif overlap_ratio > REGION_SETTINGS.get( "split_block_overlap_ratio_threshold", 0.4 ): split_block_idxes.append(block_idx) @@ -755,7 +773,6 @@ def shrink_supplement_region_bbox( image_height, iner_block_idxes, block_bboxes, - parameters_config, ) if len(iner_idxes) == 0: continue @@ -799,50 +816,68 @@ def convert_formula_res_to_ocr_format(formula_res_list: List, ocr_res: dict): ] ocr_res["dt_polys"].append(poly_points) ocr_res["rec_texts"].append(f"{formula_res['rec_formula']}") - ocr_res["rec_boxes"] = np.vstack( - (ocr_res["rec_boxes"], [formula_res["dt_polys"]]) - ) + if ocr_res["rec_boxes"].size == 0: + ocr_res["rec_boxes"] = np.array(formula_res["dt_polys"]) + else: + ocr_res["rec_boxes"] = np.vstack( + (ocr_res["rec_boxes"], [formula_res["dt_polys"]]) + ) ocr_res["rec_labels"].append("formula") ocr_res["rec_polys"].append(poly_points) ocr_res["rec_scores"].append(1) def caculate_bbox_area(bbox): - x1, y1, x2, y2 = bbox + x1, y1, x2, y2 = map(float, bbox) area = abs((x2 - x1) * (y2 - y1)) return area -def get_show_color(label: str) -> Tuple: - label_colors = { - # Medium Blue (from 'titles_list') - "paragraph_title": (102, 102, 255, 100), - "doc_title": (255, 248, 220, 100), # Cornsilk - # Light Yellow (from 'tables_caption_list') - "table_title": (255, 255, 102, 100), - # Sky Blue (from 'imgs_caption_list') - "figure_title": (102, 178, 255, 100), - "chart_title": (221, 160, 221, 100), # Plum - "vision_footnote": (144, 238, 144, 100), # Light Green - # Deep Purple (from 'texts_list') - "text": (153, 0, 76, 100), - # Bright Green (from 'interequations_list') - "formula": (0, 255, 0, 100), - "abstract": (255, 239, 213, 100), # Papaya Whip - # Medium Green (from 'lists_list' and 'indexs_list') - "content": (40, 169, 92, 100), - # Neutral Gray (from 'dropped_bbox_list') - "seal": (158, 158, 158, 100), - # Olive Yellow (from 'tables_body_list') - "table": (204, 204, 0, 100), - # Bright Green (from 'imgs_body_list') - "image": (153, 255, 51, 100), - # Bright Green (from 'imgs_body_list') - "figure": (153, 255, 51, 100), - "chart": (216, 191, 216, 100), # Thistle - # Pale Yellow-Green (from 'tables_footnote_list') - "reference": (229, 255, 204, 100), - "algorithm": (255, 250, 240, 100), # Floral White - } +def get_show_color(label: str, order_label=False) -> Tuple: + if order_label: + label_colors = { + "doc_title": (255, 248, 220, 100), # Cornsilk + "doc_title_text": (255, 239, 213, 100), + "paragraph_title": (102, 102, 255, 100), + "sub_paragraph_title": (102, 178, 255, 100), + "vision": (153, 255, 51, 100), + "vision_title": (144, 238, 144, 100), # Light Green + "vision_footnote": (144, 238, 144, 100), # Light Green + "normal_text": (153, 0, 76, 100), + "cross_layout": (53, 218, 207, 100), # Thistle + "cross_reference": (221, 160, 221, 100), # Floral White + } + else: + label_colors = { + # Medium Blue (from 'titles_list') + "paragraph_title": (102, 102, 255, 100), + "doc_title": (255, 248, 220, 100), # Cornsilk + # Light Yellow (from 'tables_caption_list') + "table_title": (255, 255, 102, 100), + # Sky Blue (from 'imgs_caption_list') + "figure_title": (102, 178, 255, 100), + "chart_title": (221, 160, 221, 100), # Plum + "vision_footnote": (144, 238, 144, 100), # Light Green + # Deep Purple (from 'texts_list') + "text": (153, 0, 76, 100), + # Bright Green (from 'interequations_list') + "formula": (0, 255, 0, 100), + "abstract": (255, 239, 213, 100), # Papaya Whip + # Medium Green (from 'lists_list' and 'indexs_list') + "content": (40, 169, 92, 100), + # Neutral Gray (from 'dropped_bbox_list') + "seal": (158, 158, 158, 100), + # Olive Yellow (from 'tables_body_list') + "table": (204, 204, 0, 100), + # Bright Green (from 'imgs_body_list') + "image": (153, 255, 51, 100), + # Bright Green (from 'imgs_body_list') + "figure": (153, 255, 51, 100), + "chart": (216, 191, 216, 100), # Thistle + # Pale Yellow-Green (from 'tables_footnote_list') + "reference": (229, 255, 204, 100), + # "reference_content": (229, 255, 204, 100), + "algorithm": (255, 250, 240, 100), # Floral White + } default_color = (158, 158, 158, 100) return label_colors.get(label, default_color) diff --git a/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py b/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py index efec88b062..1f333db496 100644 --- a/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py +++ b/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple +from typing import List, Tuple import numpy as np -from ..result_v2 import LayoutParsingBlock +from ..result_v2 import LayoutParsingBlock, LayoutParsingRegion +from ..setting import BLOCK_LABEL_MAP, XYCUT_SETTINGS from ..utils import calculate_projection_overlap_ratio @@ -26,12 +27,12 @@ def get_nearest_edge_distance( weight: List[float] = [1.0, 1.0, 1.0, 1.0], ) -> Tuple[float]: """ - Calculate the nearest edge distance between two bounding boxes, considering orientational weights. + Calculate the nearest edge distance between two bounding boxes, considering directional weights. Args: bbox1 (list): The bounding box coordinates [x1, y1, x2, y2] of the input object. bbox2 (list): The bounding box coordinates [x1', y1', x2', y2'] of the object to match against. - weight (list, optional): orientational weights for the edge distances [left, right, up, down]. Defaults to [1, 1, 1, 1]. + weight (list, optional): directional weights for the edge distances [left, right, up, down]. Defaults to [1, 1, 1, 1]. Returns: float: The calculated minimum edge distance between the bounding boxes. @@ -254,8 +255,7 @@ def recursive_xy_cut( def reference_insert( block: LayoutParsingBlock, sorted_blocks: List[LayoutParsingBlock], - config: Dict, - median_width: float = 0.0, + **kwargs, ): """ Insert reference block into sorted blocks based on the distance between the block and the nearest sorted block. @@ -285,8 +285,7 @@ def reference_insert( def manhattan_insert( block: LayoutParsingBlock, sorted_blocks: List[LayoutParsingBlock], - config: Dict, - median_width: float = 0.0, + **kwargs, ): """ Insert a block into a sorted list of blocks based on the Manhattan distance between the block and the nearest sorted block. @@ -315,8 +314,7 @@ def manhattan_insert( def weighted_distance_insert( block: LayoutParsingBlock, sorted_blocks: List[LayoutParsingBlock], - config: Dict, - median_width: float = 0.0, + region: LayoutParsingRegion, ): """ Insert a block into a sorted list of blocks based on the weighted distance between the block and the nearest sorted block. @@ -330,11 +328,8 @@ def weighted_distance_insert( Returns: sorted_blocks: The updated sorted blocks after insertion. """ - doc_title_labels = config.get("doc_title_labels", []) - paragraph_title_labels = config.get("paragraph_title_labels", []) - vision_labels = config.get("vision_labels", []) - xy_cut_block_labels = config.get("xy_cut_block_labels", []) - tolerance_len = config.get("tolerance_len", 2) + + tolerance_len = XYCUT_SETTINGS["edge_distance_compare_tolerance_len"] x1, y1, x2, y2 = block.bbox min_weighted_distance, min_edge_distance, min_up_edge_distance = ( float("inf"), @@ -347,36 +342,43 @@ def weighted_distance_insert( x1_prime, y1_prime, x2_prime, y2_prime = sorted_block.bbox # Calculate edge distance - weight = _get_weights(block.order_label, block.orientation) + weight = _get_weights(block.order_label, block.direction) edge_distance = get_nearest_edge_distance(block.bbox, sorted_block.bbox, weight) - if block.label in doc_title_labels: - disperse = max(1, median_width) + if block.label in BLOCK_LABEL_MAP["doc_title_labels"]: + disperse = max(1, region.text_line_width) tolerance_len = max(tolerance_len, disperse) if block.label == "abstract": tolerance_len *= 2 edge_distance = max(0.1, edge_distance) * 10 # Calculate up edge distances - up_edge_distance = y1_prime - left_edge_distance = x1_prime + up_edge_distance = y1_prime if region.direction == "horizontal" else -x2_prime + left_edge_distance = x1_prime if region.direction == "horizontal" else y1_prime + is_below_sorted_block = ( + y2_prime < y1 if region.direction == "horizontal" else x1_prime > x2 + ) + if ( - block.label in xy_cut_block_labels - or block.label in doc_title_labels - or block.label in paragraph_title_labels - or block.label in vision_labels - ) and y1 > y2_prime: - up_edge_distance = -y2_prime - left_edge_distance = -x2_prime + block.label not in BLOCK_LABEL_MAP["unordered_labels"] + or block.label in BLOCK_LABEL_MAP["doc_title_labels"] + or block.label in BLOCK_LABEL_MAP["paragraph_title_labels"] + or block.label in BLOCK_LABEL_MAP["vision_labels"] + ) and is_below_sorted_block: + up_edge_distance = -up_edge_distance + left_edge_distance = -left_edge_distance if abs(min_up_edge_distance - up_edge_distance) <= tolerance_len: up_edge_distance = min_up_edge_distance # Calculate weighted distance weighted_distance = ( - +edge_distance * config.get("edge_weight", 10**4) - + up_edge_distance * config.get("up_edge_weight", 1) - + left_edge_distance * config.get("left_edge_weight", 0.0001) + +edge_distance + * XYCUT_SETTINGS["distance_weight_map"].get("edge_weight", 10**4) + + up_edge_distance + * XYCUT_SETTINGS["distance_weight_map"].get("up_edge_weight", 1) + + left_edge_distance + * XYCUT_SETTINGS["distance_weight_map"].get("left_edge_weight", 0.0001) ) min_edge_distance = min(edge_distance, min_edge_distance) @@ -411,7 +413,7 @@ def insert_child_blocks( if block.child_blocks: sub_blocks = block.get_child_blocks() sub_blocks.append(block) - sub_blocks = sort_child_blocks(sub_blocks, block.orientation) + sub_blocks = sort_child_blocks(sub_blocks, block.direction) sorted_blocks[block_idx] = sub_blocks[0] for block in sub_blocks[1:]: block_idx += 1 @@ -419,17 +421,17 @@ def insert_child_blocks( return sorted_blocks -def sort_child_blocks(blocks, orientation="horizontal") -> List[LayoutParsingBlock]: +def sort_child_blocks(blocks, direction="horizontal") -> List[LayoutParsingBlock]: """ Sort child blocks based on their bounding box coordinates. Args: blocks: A list of LayoutParsingBlock objects representing the child blocks. - orientation: Orientation of the blocks ('horizontal' or 'vertical'). Default is 'horizontal'. + direction: direction of the blocks ('horizontal' or 'vertical'). Default is 'horizontal'. Returns: sorted_blocks: A sorted list of LayoutParsingBlock objects. """ - if orientation == "horizontal": + if direction == "horizontal": # from top to bottom blocks.sort( key=lambda x: ( @@ -453,7 +455,7 @@ def sort_child_blocks(blocks, orientation="horizontal") -> List[LayoutParsingBlo def _get_weights(label, dircetion="horizontal"): - """Define weights based on the label and orientation.""" + """Define weights based on the label and direction.""" if label == "doc_title": return ( [1, 0.1, 0.1, 1] if dircetion == "horizontal" else [0.2, 0.1, 1, 1] @@ -518,15 +520,35 @@ def sort_blocks(blocks, median_width=None, reverse=False): return blocks +def sort_normal_blocks(blocks, text_line_height, text_line_width, region_direction): + if region_direction == "horizontal": + blocks.sort( + key=lambda x: ( + x.bbox[1] // text_line_height, + x.bbox[0] // text_line_width, + x.bbox[1] ** 2 + x.bbox[0] ** 2, + ), + ) + else: + blocks.sort( + key=lambda x: ( + -x.bbox[0] // text_line_width, + x.bbox[1] // text_line_height, + -(x.bbox[2] ** 2 + x.bbox[1] ** 2), + ), + ) + return blocks + + def get_cut_blocks( - blocks, cut_orientation, cut_coordinates, overall_region_box, mask_labels=[] + blocks, cut_direction, cut_coordinates, overall_region_box, mask_labels=[] ): """ - Cut blocks based on the given cut orientation and coordinates. + Cut blocks based on the given cut direction and coordinates. Args: blocks (list): list of blocks to be cut. - cut_orientation (str): cut orientation, either "horizontal" or "vertical". + cut_direction (str): cut direction, either "horizontal" or "vertical". cut_coordinates (list): list of cut coordinates. overall_region_box (list): the overall region box that contains all blocks. @@ -537,7 +559,7 @@ def get_cut_blocks( # filter out mask blocks,including header, footer, unordered and child_blocks # 0: horizontal, 1: vertical - cut_aixis = 0 if cut_orientation == "horizontal" else 1 + cut_aixis = 0 if cut_direction == "horizontal" else 1 blocks.sort(key=lambda x: x.bbox[cut_aixis + 2]) cut_coordinates.append(float("inf")) @@ -567,7 +589,7 @@ def add_split_block( ) -> List[LayoutParsingBlock]: block_bboxes = np.array([block.bbox for block in blocks]) discontinuous = calculate_discontinuous_projection( - block_bboxes, orientation="vertical" + block_bboxes, direction="vertical" ) current_interval = discontinuous[0] for interval in discontinuous[1:]: @@ -582,22 +604,62 @@ def add_split_block( current_interval = interval -def get_adjacent_blocks_by_orientation( +def get_nearest_blocks( + block: LayoutParsingBlock, + ref_blocks: List[LayoutParsingBlock], + overlap_threshold, + direction="horizontal", +) -> List: + """ + Get the adjacent blocks with the same direction as the current block. + Args: + block (LayoutParsingBlock): The current block. + blocks (List[LayoutParsingBlock]): A list of all blocks. + ref_block_idxes (List[int]): A list of indices of reference blocks. + iou_threshold (float): The IOU threshold to determine if two blocks are considered adjacent. + Returns: + Int: The index of the previous block with same direction. + Int: The index of the following block with same direction. + """ + prev_blocks: List[LayoutParsingBlock] = [] + post_blocks: List[LayoutParsingBlock] = [] + sort_index = 1 if direction == "horizontal" else 0 + for ref_block in ref_blocks: + if ref_block.index == block.index: + continue + overlap_ratio = calculate_projection_overlap_ratio( + block.bbox, ref_block.bbox, direction, mode="small" + ) + if overlap_ratio > overlap_threshold: + if ref_block.bbox[sort_index] <= block.bbox[sort_index]: + prev_blocks.append(ref_block) + else: + post_blocks.append(ref_block) + + if prev_blocks: + prev_blocks.sort(key=lambda x: x.bbox[sort_index], reverse=True) + if post_blocks: + post_blocks.sort(key=lambda x: x.bbox[sort_index]) + + return prev_blocks, post_blocks + + +def get_adjacent_blocks_by_direction( blocks: List[LayoutParsingBlock], block_idx: int, ref_block_idxes: List[int], iou_threshold, ) -> List: """ - Get the adjacent blocks with the same orientation as the current block. + Get the adjacent blocks with the same direction as the current block. Args: block (LayoutParsingBlock): The current block. blocks (List[LayoutParsingBlock]): A list of all blocks. ref_block_idxes (List[int]): A list of indices of reference blocks. iou_threshold (float): The IOU threshold to determine if two blocks are considered adjacent. Returns: - Int: The index of the previous block with same orientation. - Int: The index of the following block with same orientation. + Int: The index of the previous block with same direction. + Int: The index of the following block with same direction. """ min_prev_block_distance = float("inf") prev_block_index = None @@ -611,16 +673,16 @@ def get_adjacent_blocks_by_orientation( "vision_title", ] - # find the nearest text block with same orientation to the current block + # find the nearest text block with same direction to the current block for ref_block_idx in ref_block_idxes: ref_block = blocks[ref_block_idx] - ref_block_orientation = ref_block.orientation + ref_block_direction = ref_block.direction if ref_block.order_label in child_labels: continue match_block_iou = calculate_projection_overlap_ratio( block.bbox, ref_block.bbox, - ref_block_orientation, + ref_block_direction, ) child_match_distance_tolerance_len = block.short_side_length / 10 @@ -635,38 +697,38 @@ def get_adjacent_blocks_by_orientation( if match_block_iou >= iou_threshold: prev_distance = ( - block.secondary_orientation_start_coordinate - - ref_block.secondary_orientation_end_coordinate + block.secondary_direction_start_coordinate + - ref_block.secondary_direction_end_coordinate + child_match_distance_tolerance_len ) // 5 + ref_block.start_coordinate / 5000 next_distance = ( - ref_block.secondary_orientation_start_coordinate - - block.secondary_orientation_end_coordinate + ref_block.secondary_direction_start_coordinate + - block.secondary_direction_end_coordinate + child_match_distance_tolerance_len ) // 5 + ref_block.start_coordinate / 5000 if ( - ref_block.secondary_orientation_end_coordinate - <= block.secondary_orientation_start_coordinate + ref_block.secondary_direction_end_coordinate + <= block.secondary_direction_start_coordinate + child_match_distance_tolerance_len and prev_distance < min_prev_block_distance ): min_prev_block_distance = prev_distance if ( - block.secondary_orientation_start_coordinate - - ref_block.secondary_orientation_end_coordinate + block.secondary_direction_start_coordinate + - ref_block.secondary_direction_end_coordinate < gap_tolerance_len ): prev_block_index = ref_block_idx elif ( - ref_block.secondary_orientation_start_coordinate - > block.secondary_orientation_end_coordinate + ref_block.secondary_direction_start_coordinate + > block.secondary_direction_end_coordinate - child_match_distance_tolerance_len and next_distance < min_post_block_distance ): min_post_block_distance = next_distance if ( - ref_block.secondary_orientation_start_coordinate - - block.secondary_orientation_end_coordinate + ref_block.secondary_direction_start_coordinate + - block.secondary_direction_end_coordinate < gap_tolerance_len ): post_block_index = ref_block_idx @@ -684,21 +746,19 @@ def get_adjacent_blocks_by_orientation( def update_doc_title_child_blocks( - blocks: List[LayoutParsingBlock], block: LayoutParsingBlock, - prev_idx: int, - post_idx: int, - config: dict, + region: LayoutParsingRegion, ) -> None: """ Update the child blocks of a document title block. The child blocks need to meet the following conditions: 1. They must be adjacent - 2. They must have the same orientation as the parent block. + 2. They must have the same direction as the parent block. 3. Their short side length should be less than 80% of the parent's short side length. 4. Their long side length should be less than 150% of the parent's long side length. 5. The child block must be text block. + 6. The nearest edge distance should be less than 2 times of the text line height. Args: blocks (List[LayoutParsingBlock]): overall blocks. @@ -711,11 +771,23 @@ def update_doc_title_child_blocks( None """ - for idx in [prev_idx, post_idx]: - if idx is None: + ref_blocks = [region.block_map[idx] for idx in region.normal_text_block_idxes] + overlap_threshold = XYCUT_SETTINGS["child_block_overlap_ratio_threshold"] + prev_blocks, post_blocks = get_nearest_blocks( + block, ref_blocks, overlap_threshold, block.direction + ) + prev_block = None + post_block = None + + if prev_blocks: + prev_block = prev_blocks[0] + if post_blocks: + post_block = post_blocks[0] + + for ref_block in [prev_block, post_block]: + if ref_block is None: continue - ref_block = blocks[idx] - with_seem_orientation = ref_block.orientation == block.orientation + with_seem_direction = ref_block.direction == block.direction short_side_length_condition = ( ref_block.short_side_length < block.short_side_length * 0.8 @@ -726,30 +798,31 @@ def update_doc_title_child_blocks( or ref_block.long_side_length > 1.5 * block.long_side_length ) + nearest_edge_distance = get_nearest_edge_distance(block.bbox, ref_block.bbox) + if ( - with_seem_orientation + with_seem_direction + and ref_block.label in BLOCK_LABEL_MAP["text_labels"] and short_side_length_condition and long_side_length_condition and ref_block.num_of_lines < 3 + and nearest_edge_distance < ref_block.text_line_height * 2 ): ref_block.order_label = "doc_title_text" block.append_child_block(ref_block) - config["text_block_idxes"].remove(idx) + region.normal_text_block_idxes.remove(ref_block.index) def update_paragraph_title_child_blocks( - blocks: List[LayoutParsingBlock], block: LayoutParsingBlock, - prev_idx: int, - post_idx: int, - config: dict, + region: LayoutParsingRegion, ) -> None: """ Update the child blocks of a paragraph title block. The child blocks need to meet the following conditions: 1. They must be adjacent - 2. They must have the same orientation as the parent block. + 2. They must have the same direction as the parent block. 3. The child block must be paragraph title block. Args: @@ -763,31 +836,39 @@ def update_paragraph_title_child_blocks( None """ - paragraph_title_labels = config.get("paragraph_title_labels", []) - for idx in [prev_idx, post_idx]: - if idx is None: - continue - ref_block = blocks[idx] - min_height = min(block.height, ref_block.height) - nearest_edge_distance = get_nearest_edge_distance(block.bbox, ref_block.bbox) - with_seem_orientation = ref_block.orientation == block.orientation - if ( - with_seem_orientation - and ref_block.label in paragraph_title_labels - and nearest_edge_distance <= min_height * 2 - ): - ref_block.order_label = "sub_paragraph_title" - block.append_child_block(ref_block) - config["paragraph_title_block_idxes"].remove(idx) + if block.order_label == "sub_paragraph_title": + return + ref_blocks = [ + region.block_map[idx] + for idx in region.paragraph_title_block_idxes + region.normal_text_block_idxes + ] + overlap_threshold = XYCUT_SETTINGS["child_block_overlap_ratio_threshold"] + prev_blocks, post_blocks = get_nearest_blocks( + block, ref_blocks, overlap_threshold, block.direction + ) + for ref_blocks in [prev_blocks, post_blocks]: + for ref_block in ref_blocks: + if ref_block.label not in BLOCK_LABEL_MAP["paragraph_title_labels"]: + break + min_text_line_height = min( + block.text_line_height, ref_block.text_line_height + ) + nearest_edge_distance = get_nearest_edge_distance( + block.bbox, ref_block.bbox + ) + with_seem_direction = ref_block.direction == block.direction + if ( + with_seem_direction + and nearest_edge_distance <= min_text_line_height * 1.5 + ): + ref_block.order_label = "sub_paragraph_title" + block.append_child_block(ref_block) + region.paragraph_title_block_idxes.remove(ref_block.index) def update_vision_child_blocks( - blocks: List[LayoutParsingBlock], block: LayoutParsingBlock, - ref_block_idxes: List[int], - prev_idx: int, - post_idx: int, - config: dict, + region: LayoutParsingRegion, ) -> None: """ Update the child blocks of a paragraph title block. @@ -816,69 +897,122 @@ def update_vision_child_blocks( None """ - vision_title_labels = config.get("vision_title_labels", []) - text_labels = config.get("text_labels", []) - for idx in [prev_idx, post_idx]: - if idx is None: - continue - ref_block = blocks[idx] - nearest_edge_distance = get_nearest_edge_distance(block.bbox, ref_block.bbox) - block_center = block.get_centroid() - ref_block_center = ref_block.get_centroid() - if ref_block.label in vision_title_labels and nearest_edge_distance <= min( - block.height * 0.5, ref_block.height * 2 - ): - ref_block.order_label = "vision_title" - block.append_child_block(ref_block) - config["vision_title_block_idxes"].remove(idx) - elif ( - nearest_edge_distance <= 15 - and ref_block.short_side_length < block.short_side_length - and ref_block.long_side_length < 0.5 * block.long_side_length - and ref_block.orientation == block.orientation - and ( - abs(block_center[0] - ref_block_center[0]) < 10 - or ( - block.bbox[0] - ref_block.bbox[0] < 10 - and ref_block.num_of_lines == 1 - ) - or ( - block.bbox[2] - ref_block.bbox[2] < 10 - and ref_block.num_of_lines == 1 - ) + ref_blocks = [ + region.block_map[idx] + for idx in region.normal_text_block_idxes + region.vision_title_block_idxes + ] + overlap_threshold = XYCUT_SETTINGS["child_block_overlap_ratio_threshold"] + has_vision_footnote = False + has_vision_title = False + for direction in [block.direction, block.secondary_direction]: + prev_blocks, post_blocks = get_nearest_blocks( + block, ref_blocks, overlap_threshold, direction + ) + for ref_block in prev_blocks: + if ( + ref_block.label + not in BLOCK_LABEL_MAP["text_labels"] + + BLOCK_LABEL_MAP["vision_title_labels"] + ): + break + nearest_edge_distance = get_nearest_edge_distance( + block.bbox, ref_block.bbox ) - ): - has_vision_footnote = False - if len(block.child_blocks) > 0: - for child_block in block.child_blocks: - if child_block.label in text_labels: - has_vision_footnote = True - if not has_vision_footnote: - ref_block.order_label = "vision_footnote" + block_center = block.get_centroid() + ref_block_center = ref_block.get_centroid() + if ref_block.label in BLOCK_LABEL_MAP["vision_title_labels"]: + has_vision_title = True + ref_block.order_label = "vision_title" block.append_child_block(ref_block) - config["text_block_idxes"].remove(idx) + region.vision_title_block_idxes.remove(ref_block.index) + if ref_block.label in BLOCK_LABEL_MAP["text_labels"]: + if ( + not has_vision_footnote + and nearest_edge_distance <= block.text_line_height * 2 + and ref_block.short_side_length < block.short_side_length + and ref_block.long_side_length < 0.5 * block.long_side_length + and ref_block.direction == block.direction + and ( + abs(block_center[0] - ref_block_center[0]) < 10 + or ( + block.bbox[0] - ref_block.bbox[0] < 10 + and ref_block.num_of_lines == 1 + ) + or ( + block.bbox[2] - ref_block.bbox[2] < 10 + and ref_block.num_of_lines == 1 + ) + ) + ): + has_vision_footnote = True + ref_block.order_label = "vision_footnote" + block.append_child_block(ref_block) + region.normal_text_block_idxes.remove(ref_block.index) + break + for ref_block in post_blocks: + if ( + has_vision_footnote + and ref_block.label in BLOCK_LABEL_MAP["text_labels"] + ): + break + nearest_edge_distance = get_nearest_edge_distance( + block.bbox, ref_block.bbox + ) + block_center = block.get_centroid() + ref_block_center = ref_block.get_centroid() + if ref_block.label in BLOCK_LABEL_MAP["vision_title_labels"]: + has_vision_title = True + ref_block.order_label = "vision_title" + block.append_child_block(ref_block) + region.vision_title_block_idxes.remove(ref_block.index) + if ref_block.label in BLOCK_LABEL_MAP["text_labels"]: + if ( + not has_vision_footnote + and nearest_edge_distance <= block.text_line_height * 2 + and ref_block.short_side_length < block.short_side_length + and ref_block.long_side_length < 0.5 * block.long_side_length + and ref_block.direction == block.direction + and ( + abs(block_center[0] - ref_block_center[0]) < 10 + or ( + block.bbox[0] - ref_block.bbox[0] < 10 + and ref_block.num_of_lines == 1 + ) + or ( + block.bbox[2] - ref_block.bbox[2] < 10 + and ref_block.num_of_lines == 1 + ) + ) + ): + has_vision_footnote = True + ref_block.order_label = "vision_footnote" + block.append_child_block(ref_block) + region.normal_text_block_idxes.remove(ref_block.index) + break + if has_vision_title: + break def calculate_discontinuous_projection( - boxes, orientation="horizontal", return_num=False + boxes, direction="horizontal", return_num=False ) -> List: """ - Calculate the discontinuous projection of boxes along the specified orientation. + Calculate the discontinuous projection of boxes along the specified direction. Args: boxes (ndarray): Array of bounding boxes represented by [[x_min, y_min, x_max, y_max]]. - orientation (str): orientation along which to perform the projection ('horizontal' or 'vertical'). + direction (str): direction along which to perform the projection ('horizontal' or 'vertical'). Returns: list: List of tuples representing the merged intervals. """ boxes = np.array(boxes) - if orientation == "horizontal": + if direction == "horizontal": intervals = boxes[:, [0, 2]] - elif orientation == "vertical": + elif direction == "vertical": intervals = boxes[:, [1, 3]] else: - raise ValueError("orientation must be 'horizontal' or 'vertical'") + raise ValueError("direction must be 'horizontal' or 'vertical'") intervals = intervals[np.argsort(intervals[:, 0])] @@ -904,15 +1038,53 @@ def calculate_discontinuous_projection( return merged_intervals +def is_projection_consistent(blocks, intervals, direction="horizontal"): + + for interval in intervals: + if direction == "horizontal": + start_index, stop_index = 0, 2 + interval_box = [interval[0], 0, interval[1], 1] + else: + start_index, stop_index = 1, 3 + interval_box = [0, interval[0], 1, interval[1]] + same_interval_bboxes = [] + for block in blocks: + overlap_ratio = calculate_projection_overlap_ratio( + interval_box, block.bbox, direction=direction + ) + if overlap_ratio > 0 and block.label in BLOCK_LABEL_MAP["text_labels"]: + same_interval_bboxes.append(block.bbox) + start_coordinates = [bbox[start_index] for bbox in same_interval_bboxes] + if start_coordinates: + min_start_coordinate = min(start_coordinates) + max_start_coordinate = max(start_coordinates) + is_start_consistent = ( + False + if max_start_coordinate - min_start_coordinate + >= abs(interval[0] - interval[1]) * 0.05 + else True + ) + stop_coordinates = [bbox[stop_index] for bbox in same_interval_bboxes] + min_stop_coordinate = min(stop_coordinates) + max_stop_coordinate = max(stop_coordinates) + if ( + max_stop_coordinate - min_stop_coordinate + >= abs(interval[0] - interval[1]) * 0.05 + and is_start_consistent + ): + return False + return True + + def shrink_overlapping_boxes( - boxes, orientation="horizontal", min_threshold=0, max_threshold=0.1 + boxes, direction="horizontal", min_threshold=0, max_threshold=0.1 ) -> List: """ - Shrink overlapping boxes along the specified orientation. + Shrink overlapping boxes along the specified direction. Args: boxes (ndarray): Array of bounding boxes represented by [[x_min, y_min, x_max, y_max]]. - orientation (str): orientation along which to perform the shrinking ('horizontal' or 'vertical'). + direction (str): direction along which to perform the shrinking ('horizontal' or 'vertical'). min_threshold (float): Minimum threshold for shrinking. Default is 0. max_threshold (float): Maximum threshold for shrinking. Default is 0.2. @@ -924,14 +1096,14 @@ def shrink_overlapping_boxes( x1, y1, x2, y2 = current_block.bbox x1_prime, y1_prime, x2_prime, y2_prime = block.bbox cut_iou = calculate_projection_overlap_ratio( - current_block.bbox, block.bbox, orientation=orientation + current_block.bbox, block.bbox, direction=direction ) match_iou = calculate_projection_overlap_ratio( current_block.bbox, block.bbox, - orientation="horizontal" if orientation == "vertical" else "vertical", + direction="horizontal" if direction == "vertical" else "vertical", ) - if orientation == "vertical": + if direction == "vertical": if ( (match_iou > 0 and cut_iou > min_threshold and cut_iou < max_threshold) or y2 == y1_prime diff --git a/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py b/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py index 087c733b8b..1dec70a522 100644 --- a/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py +++ b/paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py @@ -12,24 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Tuple +from typing import Dict, List, Tuple import numpy as np -from ..result_v2 import LayoutParsingBlock +from ..result_v2 import LayoutParsingBlock, LayoutParsingRegion +from ..setting import BLOCK_LABEL_MAP from ..utils import calculate_overlap_ratio, calculate_projection_overlap_ratio from .utils import ( calculate_discontinuous_projection, - get_adjacent_blocks_by_orientation, get_cut_blocks, get_nearest_edge_distance, insert_child_blocks, + is_projection_consistent, manhattan_insert, recursive_xy_cut, recursive_yx_cut, reference_insert, shrink_overlapping_boxes, - sort_blocks, + sort_normal_blocks, update_doc_title_child_blocks, update_paragraph_title_child_blocks, update_vision_child_blocks, @@ -38,8 +39,7 @@ def pre_process( - blocks: List[LayoutParsingBlock], - config: Dict, + region: LayoutParsingRegion, ) -> List: """ Preprocess the layout for sorting purposes. @@ -49,147 +49,116 @@ def pre_process( 2. Match the blocks with their children. Args: - blocks (List[LayoutParsingBlock]): A list of LayoutParsingBlock objects representing the layout. - config (Dict): Configuration parameters that include settings for pre-cutting and sorting. + region: LayoutParsingRegion, the layout region to be pre-processed. Returns: List: A list of pre-cutted layout blocks list. """ - region_bbox = config.get("region_bbox", None) - region_x_center = (region_bbox[0] + region_bbox[2]) / 2 - region_y_center = (region_bbox[1] + region_bbox[3]) / 2 - - header_block_idxes = config.get("header_block_idxes", []) - header_blocks = [] - for idx in header_block_idxes: - blocks[idx].order_label = "header" - header_blocks.append(blocks[idx]) - - unordered_block_idxes = config.get("unordered_block_idxes", []) - unordered_blocks = [] - for idx in unordered_block_idxes: - blocks[idx].order_label = "unordered" - unordered_blocks.append(blocks[idx]) - - footer_block_idxes = config.get("footer_block_idxes", []) - footer_blocks = [] - for idx in footer_block_idxes: - blocks[idx].order_label = "footer" - footer_blocks.append(blocks[idx]) - - mask_labels = ["header", "unordered", "footer"] - child_labels = [ + mask_labels = [ + "header", + "unordered", + "footer", "vision_footnote", "sub_paragraph_title", "doc_title_text", "vision_title", ] pre_cut_block_idxes = [] - for block_idx, block in enumerate(blocks): - if block.label in mask_labels: - continue - - if block.order_label not in child_labels: - update_region_label(blocks, config, block_idx) - - block_orientation = block.orientation - if block_orientation == "horizontal": - region_bbox_center = region_x_center + block_map = region.block_map + blocks: List[LayoutParsingBlock] = list(block_map.values()) + for block in blocks: + if block.order_label not in mask_labels: + update_region_label(block, region) + + block_direction = block.direction + if block_direction == "horizontal": tolerance_len = block.long_side_length // 5 else: - region_bbox_center = region_y_center tolerance_len = block.short_side_length // 10 block_center = (block.start_coordinate + block.end_coordinate) / 2 - center_offset = abs(block_center - region_bbox_center) + center_offset = abs(block_center - region.direction_center_coordinate) is_centered = center_offset <= tolerance_len if is_centered: - pre_cut_block_idxes.append(block_idx) + pre_cut_block_idxes.append(block.index) pre_cut_list = [] - cut_orientation = "vertical" + cut_direction = region.secondary_direction cut_coordinates = [] discontinuous = [] - mask_labels = child_labels + mask_labels all_boxes = np.array( [block.bbox for block in blocks if block.order_label not in mask_labels] ) if len(all_boxes) == 0: - return header_blocks, pre_cut_list, footer_blocks, unordered_blocks + return pre_cut_list if pre_cut_block_idxes: - horizontal_cut_num = 0 - for block_idx in pre_cut_block_idxes: - block = blocks[block_idx] - horizontal_cut_num += ( - 1 if block.secondary_orientation == "horizontal" else 0 - ) - cut_orientation = ( - "horizontal" - if horizontal_cut_num > len(pre_cut_block_idxes) * 0.5 - else "vertical" - ) discontinuous, num_list = calculate_discontinuous_projection( - all_boxes, orientation=cut_orientation, return_num=True + all_boxes, direction=cut_direction, return_num=True ) for idx in pre_cut_block_idxes: - block = blocks[idx] + block = block_map[idx] if ( block.order_label not in mask_labels - and block.secondary_orientation == cut_orientation + and block.secondary_direction == cut_direction ): if ( - block.secondary_orientation_start_coordinate, - block.secondary_orientation_end_coordinate, + block.secondary_direction_start_coordinate, + block.secondary_direction_end_coordinate, ) in discontinuous: idx = discontinuous.index( ( - block.secondary_orientation_start_coordinate, - block.secondary_orientation_end_coordinate, + block.secondary_direction_start_coordinate, + block.secondary_direction_end_coordinate, ) ) if num_list[idx] == 1: cut_coordinates.append( - block.secondary_orientation_start_coordinate + block.secondary_direction_start_coordinate ) - cut_coordinates.append( - block.secondary_orientation_end_coordinate - ) - if not discontinuous: - discontinuous = calculate_discontinuous_projection( - all_boxes, orientation=cut_orientation - ) - current_interval = discontinuous[0] - for interval in discontinuous[1:]: - gap_len = interval[0] - current_interval[1] - if gap_len >= 60: - cut_coordinates.append(current_interval[1]) - elif gap_len > 40: - x1, _, x2, __ = region_bbox - y1 = current_interval[1] - y2 = interval[0] - bbox = [x1, y1, x2, y2] - ref_interval = interval[0] - current_interval[1] - ref_bboxes = [] - for block in blocks: - if get_nearest_edge_distance(bbox, block.bbox) < ref_interval * 2: - ref_bboxes.append(block.bbox) + cut_coordinates.append(block.secondary_direction_end_coordinate) + secondary_discontinuous = calculate_discontinuous_projection( + all_boxes, direction=region.direction + ) + if len(secondary_discontinuous) == 1: + if not discontinuous: discontinuous = calculate_discontinuous_projection( - ref_bboxes, orientation="horizontal" + all_boxes, direction=cut_direction ) - if len(discontinuous) != 2: + current_interval = discontinuous[0] + for interval in discontinuous[1:]: + gap_len = interval[0] - current_interval[1] + if gap_len >= region.text_line_height * 5: cut_coordinates.append(current_interval[1]) - current_interval = interval + elif gap_len > region.text_line_height * 2: + x1, _, x2, __ = region.bbox + y1 = current_interval[1] + y2 = interval[0] + bbox = [x1, y1, x2, y2] + ref_interval = interval[0] - current_interval[1] + ref_bboxes = [] + for block in blocks: + if get_nearest_edge_distance(bbox, block.bbox) < ref_interval * 2: + ref_bboxes.append(block.bbox) + discontinuous = calculate_discontinuous_projection( + ref_bboxes, direction=region.direction + ) + if len(discontinuous) != 2: + cut_coordinates.append(current_interval[1]) + current_interval = interval cut_list = get_cut_blocks( - blocks, cut_orientation, cut_coordinates, region_bbox, mask_labels + blocks, cut_direction, cut_coordinates, region.bbox, mask_labels ) pre_cut_list.extend(cut_list) + if region.direction == "vertical": + pre_cut_list = pre_cut_list[::-1] - return header_blocks, pre_cut_list, footer_blocks, unordered_blocks + return pre_cut_list def update_region_label( - blocks: List[LayoutParsingBlock], config: Dict[str, Any], block_idx: int + block: LayoutParsingBlock, + region: LayoutParsingRegion, ) -> None: """ Update the region label of a block based on its label and match the block with its children. @@ -202,65 +171,45 @@ def update_region_label( Returns: None """ - - # special title block labels - doc_title_labels = config.get("doc_title_labels", []) - paragraph_title_labels = config.get("paragraph_title_labels", []) - vision_labels = config.get("vision_labels", []) - - block = blocks[block_idx] - if block.label in doc_title_labels: + if block.label in BLOCK_LABEL_MAP["header_labels"]: + block.order_label = "header" + elif block.label in BLOCK_LABEL_MAP["doc_title_labels"]: block.order_label = "doc_title" - # Force the orientation of vision type to be horizontal - if block.label in vision_labels: + elif ( + block.label in BLOCK_LABEL_MAP["paragraph_title_labels"] + and block.order_label is None + ): + block.order_label = "paragraph_title" + elif block.label in BLOCK_LABEL_MAP["vision_labels"]: block.order_label = "vision" block.num_of_lines = 1 - block.update_orientation_info() - # some paragraph title block may be labeled as sub_title, so we need to check if block.order_label is "other"(default). - if block.label in paragraph_title_labels and block.order_label == "other": - block.order_label = "paragraph_title" + block.update_direction_info() + elif block.label in BLOCK_LABEL_MAP["footer_labels"]: + block.order_label = "footer" + elif block.label in BLOCK_LABEL_MAP["unordered_labels"]: + block.order_label = "unordered" + else: + block.order_label = "normal_text" # only vision and doc title block can have child block if block.order_label not in ["vision", "doc_title", "paragraph_title"]: return - iou_threshold = config.get("child_block_match_iou_threshold", 0.1) # match doc title text block if block.order_label == "doc_title": - text_block_idxes = config.get("text_block_idxes", []) - prev_idx, post_idx = get_adjacent_blocks_by_orientation( - blocks, block_idx, text_block_idxes, iou_threshold - ) - update_doc_title_child_blocks(blocks, block, prev_idx, post_idx, config) + update_doc_title_child_blocks(block, region) # match sub title block elif block.order_label == "paragraph_title": - iou_threshold = config.get("sub_title_match_iou_threshold", 0.1) - paragraph_title_block_idxes = config.get("paragraph_title_block_idxes", []) - text_block_idxes = config.get("text_block_idxes", []) - megred_block_idxes = text_block_idxes + paragraph_title_block_idxes - prev_idx, post_idx = get_adjacent_blocks_by_orientation( - blocks, block_idx, megred_block_idxes, iou_threshold - ) - update_paragraph_title_child_blocks(blocks, block, prev_idx, post_idx, config) - # match vision title block + update_paragraph_title_child_blocks(block, region) + # match vision title block and vision footnote block elif block.order_label == "vision": - # for matching vision title block - vision_title_block_idxes = config.get("vision_title_block_idxes", []) - # for matching vision footnote block - text_block_idxes = config.get("text_block_idxes", []) - megred_block_idxes = text_block_idxes + vision_title_block_idxes - # Some vision title block may be matched with multiple vision title block, so we need to try multiple times - for i in range(3): - prev_idx, post_idx = get_adjacent_blocks_by_orientation( - blocks, block_idx, megred_block_idxes, iou_threshold - ) - update_vision_child_blocks( - blocks, block, megred_block_idxes, prev_idx, post_idx, config - ) + update_vision_child_blocks(block, region) def get_layout_structure( blocks: List[LayoutParsingBlock], + region_direction: str, + region_secondary_direction: str, ) -> Tuple[List[Dict[str, any]], bool]: """ Determine the layout cross column of blocks. @@ -276,7 +225,7 @@ def get_layout_structure( key=lambda x: (x.bbox[0], x.width), ) - mask_labels = ["doc_title", "cross_text", "cross_reference"] + mask_labels = ["doc_title", "cross_layout", "cross_reference"] for block_idx, block in enumerate(blocks): if block.order_label in mask_labels: continue @@ -288,16 +237,16 @@ def get_layout_structure( bbox_iou = calculate_overlap_ratio(block.bbox, ref_block.bbox) if bbox_iou > 0: if ref_block.order_label == "vision": - ref_block.order_label = "cross_text" + ref_block.order_label = "cross_layout" break if block.order_label == "vision" or block.area < ref_block.area: - block.order_label = "cross_text" + block.order_label = "cross_layout" break match_projection_iou = calculate_projection_overlap_ratio( block.bbox, ref_block.bbox, - "horizontal", + region_direction, ) if match_projection_iou > 0: for second_ref_idx, second_ref_block in enumerate(blocks): @@ -312,57 +261,59 @@ def get_layout_structure( ) if bbox_iou > 0.1: if second_ref_block.order_label == "vision": - second_ref_block.order_label = "cross_text" + second_ref_block.order_label = "cross_layout" break if ( block.order_label == "vision" or block.area < second_ref_block.area ): - block.order_label = "cross_text" + block.order_label = "cross_layout" break second_match_projection_iou = calculate_projection_overlap_ratio( block.bbox, second_ref_block.bbox, - "horizontal", + region_direction, ) ref_match_projection_iou = calculate_projection_overlap_ratio( ref_block.bbox, second_ref_block.bbox, - "horizontal", + region_direction, ) ref_match_projection_iou_ = calculate_projection_overlap_ratio( ref_block.bbox, second_ref_block.bbox, - "vertical", + region_secondary_direction, ) if ( second_match_projection_iou > 0 and ref_match_projection_iou == 0 and ref_match_projection_iou_ > 0 - and "vision" - not in [ref_block.order_label, second_ref_block.order_label] ): - block.order_label = ( - "cross_reference" - if block.label == "reference" - else "cross_text" - ) + if block.order_label == "vision" or ( + ref_block.order_label == "normal_text" + and second_ref_block.order_label == "normal_text" + ): + block.order_label = ( + "cross_reference" + if block.label == "reference" + else "cross_layout" + ) def sort_by_xycut( block_bboxes: List, - orientation: int = 0, + direction: str = "vertical", min_gap: int = 1, ) -> List[int]: """ - Sort bounding boxes using recursive XY cut method based on the specified orientation. + Sort bounding boxes using recursive XY cut method based on the specified direction. Args: block_bboxes (Union[np.ndarray, List[List[int]]]): An array or list of bounding boxes, where each box is represented as [x_min, y_min, x_max, y_max]. - orientation (int): orientation for the initial cut. Use 1 for Y-axis first and 0 for X-axis first. + direction (int): direction for the initial cut. Use 1 for Y-axis first and 0 for X-axis first. Defaults to 0. min_gap (int): Minimum gap width to consider a separation between segments. Defaults to 1. @@ -371,7 +322,7 @@ def sort_by_xycut( """ block_bboxes = np.asarray(block_bboxes).astype(int) res = [] - if orientation == 1: + if direction == "vertical": recursive_yx_cut( block_bboxes, np.arange(len(block_bboxes)).tolist(), @@ -391,8 +342,7 @@ def sort_by_xycut( def match_unsorted_blocks( sorted_blocks: List[LayoutParsingBlock], unsorted_blocks: List[LayoutParsingBlock], - config: Dict, - median_width: int, + region: LayoutParsingRegion, ) -> List[LayoutParsingBlock]: """ Match special blocks with the sorted blocks based on their region labels. @@ -406,7 +356,7 @@ def match_unsorted_blocks( List[LayoutParsingBlock]: The updated sorted blocks after matching special blocks. """ distance_type_map = { - "cross_text": weighted_distance_insert, + "cross_layout": weighted_distance_insert, "paragraph_title": weighted_distance_insert, "doc_title": weighted_distance_insert, "vision_title": weighted_distance_insert, @@ -416,21 +366,24 @@ def match_unsorted_blocks( "other": manhattan_insert, } - unsorted_blocks = sort_blocks(unsorted_blocks, median_width, reverse=False) + unsorted_blocks = sort_normal_blocks( + unsorted_blocks, + region.text_line_height, + region.text_line_width, + region.direction, + ) for idx, block in enumerate(unsorted_blocks): order_label = block.order_label if idx == 0 and order_label == "doc_title": sorted_blocks.insert(0, block) continue - sorted_blocks = distance_type_map[order_label]( - block, sorted_blocks, config, median_width - ) + sorted_blocks = distance_type_map[order_label](block, sorted_blocks, region) return sorted_blocks def xycut_enhanced( - blocks: List[LayoutParsingBlock], config: Dict -) -> List[LayoutParsingBlock]: + region: LayoutParsingRegion, +) -> LayoutParsingRegion: """ xycut_enhance function performs the following steps: 1. Preprocess the input blocks by extracting headers, footers, and pre-cut blocks. @@ -446,42 +399,51 @@ def xycut_enhanced( Returns: List[LayoutParsingBlock]: Ordered result list after processing. """ - if len(blocks) == 0: - return blocks + if len(region.block_map) == 0: + return [] - text_labels = config.get("text_labels", []) - header_blocks, pre_cut_list, footer_blocks, unordered_blocks = pre_process( - blocks, config - ) + pre_cut_list: List[List[LayoutParsingBlock]] = pre_process(region) final_order_res_list: List[LayoutParsingBlock] = [] - header_blocks = sort_blocks(header_blocks) - footer_blocks = sort_blocks(footer_blocks) - unordered_blocks = sort_blocks(unordered_blocks) + header_blocks: List[LayoutParsingBlock] = [ + region.block_map[idx] for idx in region.header_block_idxes + ] + unordered_blocks: List[LayoutParsingBlock] = [ + region.block_map[idx] for idx in region.unordered_block_idxes + ] + footer_blocks: List[LayoutParsingBlock] = [ + region.block_map[idx] for idx in region.footer_block_idxes + ] + + header_blocks: List[LayoutParsingBlock] = sort_normal_blocks( + header_blocks, region.text_line_height, region.text_line_width, region.direction + ) + footer_blocks: List[LayoutParsingBlock] = sort_normal_blocks( + footer_blocks, region.text_line_height, region.text_line_width, region.direction + ) + unordered_blocks: List[LayoutParsingBlock] = sort_normal_blocks( + unordered_blocks, + region.text_line_height, + region.text_line_width, + region.direction, + ) final_order_res_list.extend(header_blocks) unsorted_blocks: List[LayoutParsingBlock] = [] - sorted_blocks_by_pre_cuts = [] + sorted_blocks_by_pre_cuts: List[LayoutParsingBlock] = [] for pre_cut_blocks in pre_cut_list: sorted_blocks: List[LayoutParsingBlock] = [] doc_title_blocks: List[LayoutParsingBlock] = [] xy_cut_blocks: List[LayoutParsingBlock] = [] - pre_cut_blocks: List[LayoutParsingBlock] - median_width = 1 - text_block_width = [ - block.width for block in pre_cut_blocks if block.label in text_labels - ] - if len(text_block_width) > 0: - median_width = int(np.median(text_block_width)) get_layout_structure( - pre_cut_blocks, + pre_cut_blocks, region.direction, region.secondary_direction ) # Get xy cut blocks and add other blocks in special_block_map for block in pre_cut_blocks: if block.order_label not in [ - "cross_text", + "cross_layout", "cross_reference", "doc_title", "unordered", @@ -496,41 +458,77 @@ def xycut_enhanced( block_bboxes = np.array([block.bbox for block in xy_cut_blocks]) block_text_lines = [block.num_of_lines for block in xy_cut_blocks] discontinuous = calculate_discontinuous_projection( - block_bboxes, orientation="horizontal" + block_bboxes, direction=region.direction ) if len(discontinuous) > 1: xy_cut_blocks = [block for block in xy_cut_blocks] + # if len(discontinuous) == 1 or max(block_text_lines) == 1 or (not is_projection_consistent(xy_cut_blocks, discontinuous, direction=region.direction) and len(discontinuous) > 2 and max(block_text_lines) - min(block_text_lines) < 3): if len(discontinuous) == 1 or max(block_text_lines) == 1: - xy_cut_blocks.sort(key=lambda x: (x.bbox[1] // 5, x.bbox[0])) - xy_cut_blocks = shrink_overlapping_boxes(xy_cut_blocks, "vertical") + xy_cut_blocks.sort( + key=lambda x: ( + x.bbox[region.secondary_direction_start_index] + // (region.text_line_height // 2), + x.bbox[region.direction_start_index], + ) + ) + xy_cut_blocks = shrink_overlapping_boxes( + xy_cut_blocks, region.secondary_direction + ) + if ( + len(discontinuous) == 1 + or max(block_text_lines) == 1 + or ( + not is_projection_consistent( + xy_cut_blocks, discontinuous, direction=region.direction + ) + and len(discontinuous) > 2 + and max(block_text_lines) - min(block_text_lines) < 3 + ) + ): + xy_cut_blocks.sort( + key=lambda x: ( + x.bbox[region.secondary_direction_start_index] + // (region.text_line_height // 2), + x.bbox[region.direction_start_index], + ) + ) + xy_cut_blocks = shrink_overlapping_boxes( + xy_cut_blocks, region.secondary_direction + ) block_bboxes = np.array([block.bbox for block in xy_cut_blocks]) - sorted_indexes = sort_by_xycut(block_bboxes, orientation=1, min_gap=1) + sorted_indexes = sort_by_xycut( + block_bboxes, direction=region.secondary_direction, min_gap=1 + ) else: - xy_cut_blocks.sort(key=lambda x: (x.bbox[0] // 20, x.bbox[1])) - xy_cut_blocks = shrink_overlapping_boxes(xy_cut_blocks, "horizontal") + xy_cut_blocks.sort( + key=lambda x: ( + x.bbox[region.direction_start_index] + // (region.text_line_width // 2), + x.bbox[region.secondary_direction_start_index], + ) + ) + xy_cut_blocks = shrink_overlapping_boxes( + xy_cut_blocks, region.direction + ) block_bboxes = np.array([block.bbox for block in xy_cut_blocks]) - sorted_indexes = sort_by_xycut(block_bboxes, orientation=0, min_gap=20) + sorted_indexes = sort_by_xycut( + block_bboxes, direction=region.direction, min_gap=1 + ) sorted_blocks = [xy_cut_blocks[i] for i in sorted_indexes] sorted_blocks = match_unsorted_blocks( sorted_blocks, doc_title_blocks, - config, - median_width, + region=region, ) sorted_blocks_by_pre_cuts.extend(sorted_blocks) - median_width = 1 - text_block_width = [block.width for block in blocks if block.label in text_labels] - if len(text_block_width) > 0: - median_width = int(np.median(text_block_width)) final_order_res_list = match_unsorted_blocks( sorted_blocks_by_pre_cuts, unsorted_blocks, - config, - median_width, + region=region, ) final_order_res_list.extend(footer_blocks)