30
30
from ..base import BasePipeline
31
31
from ..ocr .result import OCRResult
32
32
from .result_v2 import LayoutParsingBlock , LayoutParsingRegion , LayoutParsingResultV2
33
+ from .setting import BLOCK_LABEL_MAP , BLOCK_SETTINGS , LINE_SETTINGS , REGION_SETTINGS
33
34
from .utils import (
34
35
caculate_bbox_area ,
35
36
calculate_minimum_enclosing_bbox ,
@@ -260,8 +261,6 @@ def check_model_settings_valid(self, input_params: dict) -> bool:
260
261
def standardized_data (
261
262
self ,
262
263
image : list ,
263
- parameters_config : dict ,
264
- block_label_mapping : dict ,
265
264
region_det_res : DetResult ,
266
265
layout_det_res : DetResult ,
267
266
overall_ocr_res : OCRResult ,
@@ -360,7 +359,7 @@ def standardized_data(
360
359
paragraph_title_block_area = caculate_bbox_area (
361
360
layout_det_res ["boxes" ][paragraph_title_list [0 ]]["coordinate" ]
362
361
)
363
- title_area_max_block_threshold = parameters_config [ "block" ] .get (
362
+ title_area_max_block_threshold = BLOCK_SETTINGS .get (
364
363
"title_conversion_area_ratio_threshold" , 0.3
365
364
)
366
365
if (
@@ -441,7 +440,7 @@ def standardized_data(
441
440
break
442
441
if not has_text and layout_det_res ["boxes" ][layout_box_idx ][
443
442
"label"
444
- ] not in block_label_mapping .get ("vision_labels" , []):
443
+ ] not in BLOCK_LABEL_MAP .get ("vision_labels" , []):
445
444
crop_box = layout_det_res ["boxes" ][layout_box_idx ]["coordinate" ]
446
445
x1 , y1 , x2 , y2 = [int (i ) for i in crop_box ]
447
446
crop_img = np .array (image )[y1 :y2 , x1 :x2 ]
@@ -506,7 +505,7 @@ def standardized_data(
506
505
overlap_ratio = calculate_overlap_ratio (
507
506
region_bbox , block_bboxes [block_idx ], mode = "small"
508
507
)
509
- if overlap_ratio > parameters_config [ "region" ] .get (
508
+ if overlap_ratio > REGION_SETTINGS .get (
510
509
"match_block_overlap_ratio_threshold" , 0.8
511
510
):
512
511
region_to_block_map [region_idx ].append (block_idx )
@@ -540,7 +539,6 @@ def standardized_data(
540
539
image .shape [0 ],
541
540
block_idxes_set ,
542
541
block_bboxes ,
543
- parameters_config ,
544
542
)
545
543
)
546
544
if len (matched_idxes ) == 0 :
@@ -570,7 +568,7 @@ def sort_line_by_projection(
570
568
input_img : np .ndarray ,
571
569
text_rec_model : Any ,
572
570
text_rec_score_thresh : Union [float , None ] = None ,
573
- orientation : str = "vertical" ,
571
+ direction : str = "vertical" ,
574
572
) -> None :
575
573
"""
576
574
Sort a line of text spans based on their vertical position within the layout bounding box.
@@ -583,8 +581,8 @@ def sort_line_by_projection(
583
581
Returns:
584
582
list: The sorted line of text spans.
585
583
"""
586
- sort_index = 0 if orientation == "horizontal" else 1
587
- splited_boxes = split_boxes_by_projection (line , orientation )
584
+ sort_index = 0 if direction == "horizontal" else 1
585
+ splited_boxes = split_boxes_by_projection (line , direction )
588
586
splited_lines = []
589
587
if len (line ) != len (splited_boxes ):
590
588
splited_boxes .sort (key = lambda span : span [0 ][sort_index ])
@@ -614,7 +612,6 @@ def sort_line_by_projection(
614
612
def get_block_rec_content (
615
613
self ,
616
614
image : list ,
617
- line_parameters_config : dict ,
618
615
ocr_rec_res : dict ,
619
616
block : LayoutParsingBlock ,
620
617
text_rec_model : Any ,
@@ -625,37 +622,49 @@ def get_block_rec_content(
625
622
block .content = ""
626
623
return block
627
624
628
- lines , text_orientation = group_boxes_into_lines (
625
+ lines , text_direction = group_boxes_into_lines (
629
626
ocr_rec_res ,
630
- line_parameters_config .get ("line_height_iou_threshold" , 0.8 ),
627
+ LINE_SETTINGS .get ("line_height_iou_threshold" , 0.8 ),
631
628
)
632
629
633
630
if block .label == "reference" :
634
631
rec_boxes = ocr_rec_res ["boxes" ]
635
632
block_right_coordinate = max ([box [2 ] for box in rec_boxes ])
636
- last_line_span_limit = 20
637
633
else :
638
634
block_right_coordinate = block .bbox [2 ]
639
- last_line_span_limit = 10
640
635
641
636
# format line
642
637
text_lines = []
643
638
need_new_line_num = 0
644
- sort_index = 0 if text_orientation == "horizontal" else 1
639
+ start_index = 0 if text_direction == "horizontal" else 1
640
+ secondary_direction_start_index = 1 if text_direction == "horizontal" else 0
641
+ line_height_list , line_width_list = [], []
645
642
for idx , line in enumerate (lines ):
646
- line .sort (key = lambda span : span [0 ][sort_index ])
643
+ line .sort (key = lambda span : span [0 ][start_index ])
647
644
645
+ text_bboxes_height = [
646
+ span [0 ][secondary_direction_start_index + 2 ]
647
+ - span [0 ][secondary_direction_start_index ]
648
+ for span in line
649
+ ]
650
+ text_bboxes_width = [
651
+ span [0 ][start_index + 2 ] - span [0 ][start_index ] for span in line
652
+ ]
653
+
654
+ line_height = np .mean (text_bboxes_height )
655
+ line_height_list .append (line_height )
656
+ line_width_list .append (np .mean (text_bboxes_width ))
648
657
# merge formula and text
649
658
ocr_labels = [span [2 ] for span in line ]
650
659
if "formula" in ocr_labels :
651
660
line = self .sort_line_by_projection (
652
- line , image , text_rec_model , text_rec_score_thresh , text_orientation
661
+ line , image , text_rec_model , text_rec_score_thresh , text_direction
653
662
)
654
663
655
664
line_text , need_new_line = format_line (
656
665
line ,
657
666
block_right_coordinate ,
658
- last_line_span_limit = last_line_span_limit ,
667
+ last_line_span_limit = line_height * 1.5 ,
659
668
block_label = block .label ,
660
669
)
661
670
if need_new_line :
@@ -668,21 +677,21 @@ def get_block_rec_content(
668
677
block .seg_end_coordinate = line_end_coordinate
669
678
text_lines .append (line_text )
670
679
671
- delim = line_parameters_config ["delimiter_map" ].get (block .label , "" )
680
+ delim = LINE_SETTINGS ["delimiter_map" ].get (block .label , "" )
672
681
if need_new_line_num > len (text_lines ) * 0.5 and delim == "" :
673
682
delim = "\n "
674
683
content = delim .join (text_lines )
675
684
block .content = content
676
685
block .num_of_lines = len (text_lines )
677
- block .orientation = text_orientation
686
+ block .direction = text_direction
687
+ block .text_line_height = np .mean (line_height_list )
688
+ block .text_line_width = np .mean (line_width_list )
678
689
679
690
return block
680
691
681
692
def get_layout_parsing_blocks (
682
693
self ,
683
694
image : list ,
684
- parameters_config : dict ,
685
- block_label_mapping : dict ,
686
695
region_block_ocr_idx_map : dict ,
687
696
region_det_res : DetResult ,
688
697
overall_ocr_res : OCRResult ,
@@ -759,7 +768,6 @@ def get_layout_parsing_blocks(
759
768
block = self .get_block_rec_content (
760
769
image = image ,
761
770
block = block ,
762
- line_parameters_config = parameters_config ["line" ],
763
771
ocr_rec_res = rec_res ,
764
772
text_rec_model = text_rec_model ,
765
773
text_rec_score_thresh = text_rec_score_thresh ,
@@ -781,9 +789,8 @@ def get_layout_parsing_blocks(
781
789
for idx in region_block_ocr_idx_map ["region_to_block_map" ][region_idx ]
782
790
]
783
791
region = LayoutParsingRegion (
784
- region_bbox = region_bbox ,
792
+ bbox = region_bbox ,
785
793
blocks = region_blocks ,
786
- block_label_mapping = block_label_mapping ,
787
794
)
788
795
region_list .append (region )
789
796
@@ -818,14 +825,11 @@ def get_layout_parsing_res(
818
825
Returns:
819
826
list: A list of dictionaries representing the layout parsing result.
820
827
"""
821
- from .setting import block_label_mapping , parameters_config
822
828
823
829
# Standardize data
824
830
region_block_ocr_idx_map , region_det_res , layout_det_res = (
825
831
self .standardized_data (
826
832
image = image ,
827
- parameters_config = parameters_config ,
828
- block_label_mapping = block_label_mapping ,
829
833
region_det_res = region_det_res ,
830
834
layout_det_res = layout_det_res ,
831
835
overall_ocr_res = overall_ocr_res ,
@@ -838,8 +842,6 @@ def get_layout_parsing_res(
838
842
# Format layout parsing block
839
843
region_list = self .get_layout_parsing_blocks (
840
844
image = image ,
841
- parameters_config = parameters_config ,
842
- block_label_mapping = block_label_mapping ,
843
845
region_block_ocr_idx_map = region_block_ocr_idx_map ,
844
846
region_det_res = region_det_res ,
845
847
overall_ocr_res = overall_ocr_res ,
@@ -854,11 +856,10 @@ def get_layout_parsing_res(
854
856
for region in region_list :
855
857
parsing_res_list .extend (region .sort ())
856
858
857
- visualize_index_labels = block_label_mapping ["visualize_index_labels" ]
858
859
index = 1
859
860
for block in parsing_res_list :
860
- if block .label in visualize_index_labels :
861
- block .index = index
861
+ if block .label in BLOCK_LABEL_MAP [ " visualize_index_labels" ] :
862
+ block .order_index = index
862
863
index += 1
863
864
864
865
return parsing_res_list
@@ -956,8 +957,6 @@ def predict(
956
957
use_e2e_wired_table_rec_model : bool = False ,
957
958
use_e2e_wireless_table_rec_model : bool = True ,
958
959
is_pretty_markdown : Union [bool , None ] = None ,
959
- use_layout_gt : bool = False ,
960
- layout_gt_dir : Union [str , None ] = None ,
961
960
** kwargs ,
962
961
) -> LayoutParsingResultV2 :
963
962
"""
@@ -1032,65 +1031,16 @@ def predict(
1032
1031
1033
1032
doc_preprocessor_image = doc_preprocessor_res ["output_img" ]
1034
1033
1035
- use_layout_gt = use_layout_gt
1036
- if not use_layout_gt :
1037
- layout_det_res = next (
1038
- self .layout_det_model (
1039
- doc_preprocessor_image ,
1040
- threshold = layout_threshold ,
1041
- layout_nms = layout_nms ,
1042
- layout_unclip_ratio = layout_unclip_ratio ,
1043
- layout_merge_bboxes_mode = layout_merge_bboxes_mode ,
1044
- )
1034
+ layout_det_res = next (
1035
+ self .layout_det_model (
1036
+ doc_preprocessor_image ,
1037
+ threshold = layout_threshold ,
1038
+ layout_nms = layout_nms ,
1039
+ layout_unclip_ratio = layout_unclip_ratio ,
1040
+ layout_merge_bboxes_mode = layout_merge_bboxes_mode ,
1045
1041
)
1046
- else :
1047
- import json
1048
- import os
1049
-
1050
- from ...models .object_detection .result import DetResult
1051
-
1052
- label_dir = layout_gt_dir
1053
- notes_path = f"{ label_dir } /notes.json"
1054
- labels = f"{ label_dir } /labels"
1055
- gt_file = os .path .basename (input )[:- 4 ] + ".txt"
1056
- gt_path = f"{ labels } /{ gt_file } "
1057
- with open (notes_path , "r" ) as f :
1058
- notes = json .load (f )
1059
- categories_map = {}
1060
- for categories in notes ["categories" ]:
1061
- id = int (categories ["id" ])
1062
- name = categories ["name" ]
1063
- categories_map [id ] = name
1064
- with open (gt_path , "r" ) as f :
1065
- lines = f .readlines ()
1066
- layout_det_res_dic = {
1067
- "input_img" : doc_preprocessor_image ,
1068
- "page_index" : None ,
1069
- "boxes" : [],
1070
- }
1071
- for line in lines :
1072
- line = line .strip ().split (" " )
1073
- category_id = int (line [0 ])
1074
- label = categories_map [category_id ]
1075
- img_h , img_w = doc_preprocessor_image .shape [:2 ]
1076
- center_x = float (line [1 ]) * img_w
1077
- center_y = float (line [2 ]) * img_h
1078
- w = float (line [3 ]) * img_w
1079
- h = float (line [4 ]) * img_h
1080
- x0 = center_x - w / 2
1081
- y0 = center_y - h / 2
1082
- x1 = center_x + w / 2
1083
- y1 = center_y + h / 2
1084
- box = [x0 , y0 , x1 , y1 ]
1085
- layout_det_res_dic ["boxes" ].append (
1086
- {
1087
- "cls_id" : category_id ,
1088
- "label" : label ,
1089
- "coordinate" : box ,
1090
- "score" : 1.0 ,
1091
- }
1092
- )
1093
- layout_det_res = DetResult (layout_det_res_dic )
1042
+ )
1043
+
1094
1044
imgs_in_doc = gather_imgs (doc_preprocessor_image , layout_det_res ["boxes" ])
1095
1045
1096
1046
if model_settings ["use_region_detection" ]:
@@ -1139,7 +1089,13 @@ def predict(
1139
1089
),
1140
1090
)
1141
1091
else :
1142
- overall_ocr_res = {}
1092
+ overall_ocr_res = {
1093
+ "dt_polys" : [],
1094
+ "rec_texts" : [],
1095
+ "rec_scores" : [],
1096
+ "rec_polys" : [],
1097
+ "rec_boxes" : np .array ([]),
1098
+ }
1143
1099
1144
1100
overall_ocr_res ["rec_labels" ] = ["text" ] * len (overall_ocr_res ["rec_texts" ])
1145
1101
@@ -1157,9 +1113,14 @@ def predict(
1157
1113
table_contents ["rec_texts" ].append (
1158
1114
f"${ formula_res ['rec_formula' ]} $"
1159
1115
)
1160
- table_contents ["rec_boxes" ] = np .vstack (
1161
- (table_contents ["rec_boxes" ], [formula_res ["dt_polys" ]])
1162
- )
1116
+ if table_contents ["rec_boxes" ].size == 0 :
1117
+ table_contents ["rec_boxes" ] = np .array (
1118
+ [formula_res ["dt_polys" ]]
1119
+ )
1120
+ else :
1121
+ table_contents ["rec_boxes" ] = np .vstack (
1122
+ (table_contents ["rec_boxes" ], [formula_res ["dt_polys" ]])
1123
+ )
1163
1124
table_contents ["rec_polys" ].append (poly_points )
1164
1125
table_contents ["rec_scores" ].append (1 )
1165
1126
0 commit comments