14
14
from __future__ import annotations
15
15
16
16
import copy
17
+ import math
17
18
import re
18
19
from pathlib import Path
19
20
from typing import List
@@ -73,6 +74,9 @@ def _to_img(self) -> dict[str, np.ndarray]:
73
74
res_img_dict [key ] = value
74
75
res_img_dict ["layout_det_res" ] = self ["layout_det_res" ].img ["res" ]
75
76
77
+ if model_settings ["use_region_detection" ]:
78
+ res_img_dict ["region_det_res" ] = self ["region_det_res" ].img ["res" ]
79
+
76
80
if model_settings ["use_general_ocr" ] or model_settings ["use_table_recognition" ]:
77
81
res_img_dict ["overall_ocr_res" ] = self ["overall_ocr_res" ].img ["ocr_res_img" ]
78
82
@@ -283,22 +287,33 @@ def format_title(title):
283
287
" " ,
284
288
)
285
289
290
+ # def format_centered_text():
291
+ # return (
292
+ # f'<div style="text-align: center;">{block.content}</div>'.replace(
293
+ # "-\n",
294
+ # "",
295
+ # ).replace("\n", " ")
296
+ # + "\n"
297
+ # )
298
+
286
299
def format_centered_text ():
287
- return (
288
- f'<div style="text-align: center;">{ block .content } </div>' .replace (
289
- "-\n " ,
290
- "" ,
291
- ).replace ("\n " , " " )
292
- + "\n "
293
- )
300
+ return block .content
301
+
302
+ # def format_image():
303
+ # img_tags = []
304
+ # image_path = "".join(block.image.keys())
305
+ # img_tags.append(
306
+ # '<div style="text-align: center;"><img src="{}" alt="Image" /></div>'.format(
307
+ # image_path.replace("-\n", "").replace("\n", " "),
308
+ # ),
309
+ # )
310
+ # return "\n".join(img_tags)
294
311
295
312
def format_image ():
296
313
img_tags = []
297
314
image_path = "" .join (block .image .keys ())
298
315
img_tags .append (
299
- '<div style="text-align: center;"><img src="{}" alt="Image" /></div>' .format (
300
- image_path .replace ("-\n " , "" ).replace ("\n " , " " ),
301
- ),
316
+ "" .format (image_path .replace ("-\n " , "" ).replace ("\n " , " " ))
302
317
)
303
318
return "\n " .join (img_tags )
304
319
@@ -332,7 +347,7 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
332
347
num_of_prev_lines = prev_block .num_of_lines
333
348
pre_block_seg_end_coordinate = prev_block .seg_end_coordinate
334
349
prev_end_space_small = (
335
- context_right_coordinate - pre_block_seg_end_coordinate < 10
350
+ abs ( prev_block_bbox [ 2 ] - pre_block_seg_end_coordinate ) < 10
336
351
)
337
352
prev_lines_more_than_one = num_of_prev_lines > 1
338
353
@@ -347,8 +362,12 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
347
362
prev_block_bbox [2 ], context_right_coordinate
348
363
)
349
364
prev_end_space_small = (
350
- prev_block_bbox [2 ] - pre_block_seg_end_coordinate < 10
365
+ abs (context_right_coordinate - pre_block_seg_end_coordinate )
366
+ < 10
351
367
)
368
+ edge_distance = 0
369
+ else :
370
+ edge_distance = abs (block_box [0 ] - prev_block_bbox [2 ])
352
371
353
372
current_start_space_small = (
354
373
seg_start_coordinate - context_left_coordinate < 10
@@ -358,6 +377,7 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
358
377
prev_end_space_small
359
378
and current_start_space_small
360
379
and prev_lines_more_than_one
380
+ and edge_distance < max (prev_block .width , block .width )
361
381
):
362
382
seg_start_flag = False
363
383
else :
@@ -371,14 +391,19 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
371
391
372
392
handlers = {
373
393
"paragraph_title" : lambda : format_title (block .content ),
394
+ "abstract_title" : lambda : format_title (block .content ),
395
+ "reference_title" : lambda : format_title (block .content ),
396
+ "content_title" : lambda : format_title (block .content ),
374
397
"doc_title" : lambda : f"# { block .content } " .replace (
375
398
"-\n " ,
376
399
"" ,
377
400
).replace ("\n " , " " ),
378
401
"table_title" : lambda : format_centered_text (),
379
402
"figure_title" : lambda : format_centered_text (),
380
403
"chart_title" : lambda : format_centered_text (),
381
- "text" : lambda : block .content .replace ("-\n " , " " ).replace ("\n " , " " ),
404
+ "text" : lambda : block .content .replace ("\n \n " , "\n " ).replace (
405
+ "\n " , "\n \n "
406
+ ),
382
407
"abstract" : lambda : format_first_line (
383
408
["摘要" , "abstract" ], lambda l : f"## { l } \n " , " "
384
409
),
@@ -416,24 +441,7 @@ def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
416
441
if handler :
417
442
prev_block = block
418
443
if label == last_label == "text" and seg_start_flag == False :
419
- last_char_of_markdown = (
420
- markdown_content [- 1 ] if markdown_content else ""
421
- )
422
- first_char_of_handler = handler ()[0 ] if handler () else ""
423
- last_is_chinese_char = (
424
- re .match (r"[\u4e00-\u9fff]" , last_char_of_markdown )
425
- if last_char_of_markdown
426
- else False
427
- )
428
- first_is_chinese_char = (
429
- re .match (r"[\u4e00-\u9fff]" , first_char_of_handler )
430
- if first_char_of_handler
431
- else False
432
- )
433
- if not (last_is_chinese_char or first_is_chinese_char ):
434
- markdown_content += " " + handler ()
435
- else :
436
- markdown_content += handler ()
444
+ markdown_content += handler ()
437
445
else :
438
446
markdown_content += (
439
447
"\n \n " + handler () if markdown_content else handler ()
@@ -467,7 +475,7 @@ class LayoutParsingBlock:
467
475
468
476
def __init__ (self , label , bbox , content = "" ) -> None :
469
477
self .label = label
470
- self .region_label = "other"
478
+ self .order_label = "other"
471
479
self .bbox = [int (item ) for item in bbox ]
472
480
self .content = content
473
481
self .seg_start_coordinate = float ("inf" )
@@ -479,39 +487,39 @@ def __init__(self, label, bbox, content="") -> None:
479
487
self .image = None
480
488
self .index = None
481
489
self .visual_index = None
482
- self .direction = self .get_bbox_direction ()
490
+ self .orientation = self .get_bbox_orientation ()
483
491
self .child_blocks = []
484
- self .update_direction_info ()
492
+ self .update_orientation_info ()
485
493
486
494
def __str__ (self ) -> str :
487
495
return f"{ self .__dict__ } "
488
496
489
497
def __repr__ (self ) -> str :
490
- _str = f"\n \n #################\n label:\t { self .label } \n region_label:\t { self .region_label } \n bbox:\t { self .bbox } \n content:\t { self .content } \n #################"
498
+ _str = f"\n \n #################\n label:\t { self .label } \n region_label:\t { self .order_label } \n bbox:\t { self .bbox } \n content:\t { self .content } \n #################"
491
499
return _str
492
500
493
501
def to_dict (self ) -> dict :
494
502
return self .__dict__
495
503
496
- def update_direction_info (self ) -> None :
497
- if self .region_label == "vision" :
498
- self .direction = "horizontal"
499
- if self .direction == "horizontal" :
500
- self .secondary_direction = "vertical"
504
+ def update_orientation_info (self ) -> None :
505
+ if self .order_label == "vision" :
506
+ self .orientation = "horizontal"
507
+ if self .orientation == "horizontal" :
508
+ self .secondary_orientation = "vertical"
501
509
self .short_side_length = self .height
502
510
self .long_side_length = self .width
503
511
self .start_coordinate = self .bbox [0 ]
504
512
self .end_coordinate = self .bbox [2 ]
505
- self .secondary_direction_start_coordinate = self .bbox [1 ]
506
- self .secondary_direction_end_coordinate = self .bbox [3 ]
513
+ self .secondary_orientation_start_coordinate = self .bbox [1 ]
514
+ self .secondary_orientation_end_coordinate = self .bbox [3 ]
507
515
else :
508
- self .secondary_direction = "horizontal"
516
+ self .secondary_orientation = "horizontal"
509
517
self .short_side_length = self .width
510
518
self .long_side_length = self .height
511
519
self .start_coordinate = self .bbox [1 ]
512
520
self .end_coordinate = self .bbox [3 ]
513
- self .secondary_direction_start_coordinate = self .bbox [0 ]
514
- self .secondary_direction_end_coordinate = self .bbox [2 ]
521
+ self .secondary_orientation_start_coordinate = self .bbox [0 ]
522
+ self .secondary_orientation_end_coordinate = self .bbox [2 ]
515
523
516
524
def append_child_block (self , child_block : LayoutParsingBlock ) -> None :
517
525
if not self .child_blocks :
@@ -525,7 +533,7 @@ def append_child_block(self, child_block: LayoutParsingBlock) -> None:
525
533
max (y2 , y2_child ),
526
534
)
527
535
self .bbox = union_bbox
528
- self .update_direction_info ()
536
+ self .update_orientation_info ()
529
537
child_blocks = [child_block ]
530
538
if child_block .child_blocks :
531
539
child_blocks .extend (child_block .get_child_blocks ())
@@ -542,7 +550,7 @@ def get_centroid(self) -> tuple:
542
550
centroid = ((x1 + x2 ) / 2 , (y1 + y2 ) / 2 )
543
551
return centroid
544
552
545
- def get_bbox_direction (self , orientation_ratio : float = 1.0 ) -> bool :
553
+ def get_bbox_orientation (self , orientation_ratio : float = 1.0 ) -> bool :
546
554
"""
547
555
Determine if a bounding box is horizontal or vertical.
548
556
@@ -558,3 +566,91 @@ def get_bbox_direction(self, orientation_ratio: float = 1.0) -> bool:
558
566
if self .width * orientation_ratio >= self .height
559
567
else "vertical"
560
568
)
569
+
570
+
571
+ class LayoutParsingRegion :
572
+
573
+ def __init__ (
574
+ self , region_bbox , blocks : List [LayoutParsingBlock ] = [], block_label_mapping = {}
575
+ ) -> None :
576
+ self .region_bbox = region_bbox
577
+ self .blocks = blocks
578
+ self .block_map = {}
579
+ self .update_config (block_label_mapping )
580
+ self .orientation = None
581
+ self .calculate_bbox_metrics ()
582
+
583
+ def update_config (self , block_label_mapping ):
584
+ self .block_map = {}
585
+ self .config = copy .deepcopy (block_label_mapping )
586
+ self .config ["region_bbox" ] = self .region_bbox
587
+ horizontal_text_block_num = 0
588
+ for idx , block in enumerate (self .blocks ):
589
+ label = block .label
590
+ if (
591
+ block .order_label not in ["vision" , "vision_title" ]
592
+ and block .orientation == "horizontal"
593
+ ):
594
+ horizontal_text_block_num += 1
595
+ self .block_map [idx ] = block
596
+ self .update_layout_order_config_block_index (label , idx )
597
+ text_block_num = (
598
+ len (self .blocks )
599
+ - len (self .config ["vision_block_idxes" ])
600
+ - len (self .config ["vision_title_block_idxes" ])
601
+ )
602
+ self .orientation = (
603
+ "horizontal"
604
+ if horizontal_text_block_num >= text_block_num * 0.5
605
+ else "vertical"
606
+ )
607
+ self .config ["region_orientation" ] = self .orientation
608
+
609
+ def calculate_bbox_metrics (self ):
610
+ x1 , y1 , x2 , y2 = self .region_bbox
611
+ x_center , y_center = (x1 + x2 ) / 2 , (y1 + y2 ) / 2
612
+ self .euclidean_distance = math .sqrt (((x1 ) ** 2 + (y1 ) ** 2 ))
613
+ self .center_euclidean_distance = math .sqrt (((x_center ) ** 2 + (y_center ) ** 2 ))
614
+ self .angle_rad = math .atan2 (y_center , x_center )
615
+
616
+ def sort (self ):
617
+ from .xycut_enhanced import xycut_enhanced
618
+
619
+ return xycut_enhanced (self .blocks , self .config )
620
+
621
+ def update_layout_order_config_block_index (
622
+ self , block_label : str , block_idx : int
623
+ ) -> None :
624
+ doc_title_labels = self .config ["doc_title_labels" ]
625
+ paragraph_title_labels = self .config ["paragraph_title_labels" ]
626
+ vision_labels = self .config ["vision_labels" ]
627
+ vision_title_labels = self .config ["vision_title_labels" ]
628
+ header_labels = self .config ["header_labels" ]
629
+ unordered_labels = self .config ["unordered_labels" ]
630
+ footer_labels = self .config ["footer_labels" ]
631
+ text_labels = self .config ["text_labels" ]
632
+ self .config .setdefault ("doc_title_block_idxes" , [])
633
+ self .config .setdefault ("paragraph_title_block_idxes" , [])
634
+ self .config .setdefault ("vision_block_idxes" , [])
635
+ self .config .setdefault ("vision_title_block_idxes" , [])
636
+ self .config .setdefault ("unordered_block_idxes" , [])
637
+ self .config .setdefault ("text_block_idxes" , [])
638
+ self .config .setdefault ("header_block_idxes" , [])
639
+ self .config .setdefault ("footer_block_idxes" , [])
640
+
641
+ if block_label in doc_title_labels :
642
+ self .config ["doc_title_block_idxes" ].append (block_idx )
643
+ if block_label in paragraph_title_labels :
644
+ self .config ["paragraph_title_block_idxes" ].append (block_idx )
645
+ if block_label in vision_labels :
646
+ self .config ["vision_block_idxes" ].append (block_idx )
647
+ if block_label in vision_title_labels :
648
+ self .config ["vision_title_block_idxes" ].append (block_idx )
649
+ if block_label in unordered_labels :
650
+ self .config ["unordered_block_idxes" ].append (block_idx )
651
+ if block_label in text_labels :
652
+ self .config ["text_block_idxes" ].append (block_idx )
653
+ if block_label in header_labels :
654
+ self .config ["header_block_idxes" ].append (block_idx )
655
+ if block_label in footer_labels :
656
+ self .config ["footer_block_idxes" ].append (block_idx )
0 commit comments