Skip to content

Commit c80b583

Browse files
maxmnemonicMaksym Lysakvagenas
authored
feat: convert regions into TableData (#430)
* Introduction of regions.py - convert regions into TableData Signed-off-by: Maksym Lysak <[email protected]> * Cleaned up bbox helper functions, reusing more of the existing code Signed-off-by: Maksym Lysak <[email protected]> * Small fixes Signed-off-by: Maksym Lysak <[email protected]> * refactored _bbox_intersection from regions.py into a method of BoundingBox.get_intersection_bbox Signed-off-by: Maksym Lysak <[email protected]> * move region-based construction into `TableData` class Signed-off-by: Panos Vagenas <[email protected]> --------- Signed-off-by: Maksym Lysak <[email protected]> Signed-off-by: Panos Vagenas <[email protected]> Co-authored-by: Maksym Lysak <[email protected]> Co-authored-by: Panos Vagenas <[email protected]>
1 parent 7bd274b commit c80b583

File tree

3 files changed

+390
-2
lines changed

3 files changed

+390
-2
lines changed

docling_core/types/doc/base.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Models for the base data types."""
22

33
from enum import Enum
4-
from typing import Any, List, Tuple
4+
from typing import Any, List, Optional, Tuple
55

66
from pydantic import BaseModel, FieldSerializationInfo, field_serializer
77

@@ -231,6 +231,31 @@ def to_bottom_left_origin(self, page_height: float) -> "BoundingBox":
231231
coord_origin=CoordOrigin.BOTTOMLEFT,
232232
)
233233

234+
def get_intersection_bbox(self, other: "BoundingBox") -> Optional["BoundingBox"]:
235+
"""Return the intersection bounding box with another bounding box or ``None`` when disjoint."""
236+
if self.coord_origin != other.coord_origin:
237+
raise ValueError("BoundingBoxes have different CoordOrigin")
238+
239+
left = max(self.l, other.l)
240+
right = min(self.r, other.r)
241+
242+
if self.coord_origin == CoordOrigin.TOPLEFT:
243+
top = max(self.t, other.t)
244+
bottom = min(self.b, other.b)
245+
if right <= left or bottom <= top:
246+
return None
247+
return BoundingBox(
248+
l=left, t=top, r=right, b=bottom, coord_origin=self.coord_origin
249+
)
250+
251+
top = min(self.t, other.t)
252+
bottom = max(self.b, other.b)
253+
if right <= left or top <= bottom:
254+
return None
255+
return BoundingBox(
256+
l=left, t=top, r=right, b=bottom, coord_origin=self.coord_origin
257+
)
258+
234259
def to_top_left_origin(self, page_height: float) -> "BoundingBox":
235260
"""to_top_left_origin.
236261

docling_core/types/doc/document.py

Lines changed: 285 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,18 @@
1414
from enum import Enum
1515
from io import BytesIO
1616
from pathlib import Path
17-
from typing import Any, Dict, Final, List, Literal, Optional, Sequence, Tuple, Union
17+
from typing import (
18+
Any,
19+
Dict,
20+
Final,
21+
List,
22+
Literal,
23+
Optional,
24+
Sequence,
25+
Set,
26+
Tuple,
27+
Union,
28+
)
1829
from urllib.parse import unquote
1930

2031
import pandas as pd
@@ -679,6 +690,279 @@ def get_column_bounding_boxes(self) -> dict[int, BoundingBox]:
679690

680691
return col_bboxes
681692

693+
@classmethod
694+
def _dedupe_bboxes(
695+
cls,
696+
elements: Sequence[BoundingBox],
697+
*,
698+
iou_threshold: float = 0.9,
699+
) -> list[BoundingBox]:
700+
"""Return elements whose bounding boxes are unique within ``iou_threshold``."""
701+
deduped: list[BoundingBox] = []
702+
for element in elements:
703+
if all(
704+
element.intersection_over_union(kept) < iou_threshold
705+
for kept in deduped
706+
):
707+
deduped.append(element)
708+
return deduped
709+
710+
@classmethod
711+
def _process_table_headers(
712+
cls,
713+
bbox: BoundingBox,
714+
row_headers: List[BoundingBox] = [],
715+
col_headers: List[BoundingBox] = [],
716+
row_sections: List[BoundingBox] = [],
717+
) -> Tuple[bool, bool, bool]:
718+
c_column_header = False
719+
c_row_header = False
720+
c_row_section = False
721+
722+
for col_header in col_headers:
723+
if bbox.intersection_over_self(col_header) >= 0.5:
724+
c_column_header = True
725+
for row_header in row_headers:
726+
if bbox.intersection_over_self(row_header) >= 0.5:
727+
c_row_header = True
728+
for row_section in row_sections:
729+
if bbox.intersection_over_self(row_section) >= 0.5:
730+
c_row_section = True
731+
return c_column_header, c_row_header, c_row_section
732+
733+
@classmethod
734+
def _compute_cells(
735+
cls,
736+
rows: List[BoundingBox],
737+
columns: List[BoundingBox],
738+
merges: List[BoundingBox],
739+
row_headers: List[BoundingBox] = [],
740+
col_headers: List[BoundingBox] = [],
741+
row_sections: List[BoundingBox] = [],
742+
row_overlap_threshold: float = 0.5, # how much of a row a merge must cover vertically
743+
col_overlap_threshold: float = 0.5, # how much of a column a merge must cover horizontally
744+
) -> List[TableCell]:
745+
"""Returns TableCell. Merged cells are aligned to grid boundaries.
746+
747+
rows, columns, merges are lists of BoundingBox(l,t,r,b).
748+
"""
749+
rows.sort(key=lambda r: (r.t + r.b) / 2.0)
750+
columns.sort(key=lambda c: (c.l + c.r) / 2.0)
751+
752+
def span_from_merge(
753+
m: BoundingBox, lines: List[BoundingBox], axis: str, frac_threshold: float
754+
) -> Optional[Tuple[int, int]]:
755+
"""Map a merge bbox to an inclusive index span over rows or columns.
756+
757+
axis='row' uses vertical overlap vs row height; axis='col' uses horizontal overlap vs col width.
758+
If nothing meets threshold, pick the single best-overlapping line if overlap>0; else return None.
759+
"""
760+
idxs = []
761+
best_i, best_len = None, 0.0
762+
for i, elem in enumerate(lines):
763+
inter = m.get_intersection_bbox(elem)
764+
if not inter:
765+
continue
766+
if axis == "row":
767+
overlap_len = inter.height
768+
base = max(1e-9, elem.height)
769+
else:
770+
overlap_len = inter.width
771+
base = max(1e-9, elem.width)
772+
773+
frac = overlap_len / base
774+
if frac >= frac_threshold:
775+
idxs.append(i)
776+
777+
if overlap_len > best_len:
778+
best_len, best_i = overlap_len, i
779+
780+
if idxs:
781+
return min(idxs), max(idxs)
782+
if best_i is not None and best_len > 0.0:
783+
return best_i, best_i
784+
return None
785+
786+
cells: List[TableCell] = []
787+
covered: Set[Tuple[int, int]] = set()
788+
seen_merge_rects: Set[Tuple[int, int, int, int]] = set()
789+
790+
# 1) Add merged cells first (and mark their covered simple cells)
791+
for m in merges:
792+
rspan = span_from_merge(
793+
m, rows, axis="row", frac_threshold=row_overlap_threshold
794+
)
795+
cspan = span_from_merge(
796+
m, columns, axis="col", frac_threshold=col_overlap_threshold
797+
)
798+
if rspan is None or cspan is None:
799+
# Can't confidently map this merge to grid -> skip it
800+
continue
801+
802+
sr, er = rspan
803+
sc, ec = cspan
804+
rect_key = (sr, er, sc, ec)
805+
if rect_key in seen_merge_rects:
806+
continue
807+
seen_merge_rects.add(rect_key)
808+
809+
# Grid-aligned bbox for the merged cell
810+
grid_bbox = BoundingBox(
811+
l=columns[sc].l,
812+
t=rows[sr].t,
813+
r=columns[ec].r,
814+
b=rows[er].b,
815+
)
816+
c_column_header, c_row_header, c_row_section = cls._process_table_headers(
817+
grid_bbox, col_headers, row_headers, row_sections
818+
)
819+
820+
cells.append(
821+
TableCell(
822+
text="",
823+
row_span=er - sr + 1,
824+
col_span=ec - sc + 1,
825+
start_row_offset_idx=sr,
826+
end_row_offset_idx=er + 1,
827+
start_col_offset_idx=sc,
828+
end_col_offset_idx=ec + 1,
829+
bbox=grid_bbox,
830+
column_header=c_column_header,
831+
row_header=c_row_header,
832+
row_section=c_row_section,
833+
)
834+
)
835+
for ri in range(sr, er + 1):
836+
for ci in range(sc, ec + 1):
837+
covered.add((ri, ci))
838+
839+
# 2) Add simple (1x1) cells where not covered by merges
840+
for ri, row in enumerate(rows):
841+
for ci, col in enumerate(columns):
842+
if (ri, ci) in covered:
843+
continue
844+
inter = row.get_intersection_bbox(col)
845+
if not inter:
846+
# In degenerate cases (big gaps), there might be no intersection; skip.
847+
continue
848+
c_column_header, c_row_header, c_row_section = (
849+
cls._process_table_headers(
850+
inter, col_headers, row_headers, row_sections
851+
)
852+
)
853+
cells.append(
854+
TableCell(
855+
text="",
856+
row_span=1,
857+
col_span=1,
858+
start_row_offset_idx=ri,
859+
end_row_offset_idx=ri + 1,
860+
start_col_offset_idx=ci,
861+
end_col_offset_idx=ci + 1,
862+
bbox=inter,
863+
column_header=c_column_header,
864+
row_header=c_row_header,
865+
row_section=c_row_section,
866+
)
867+
)
868+
return cells
869+
870+
@classmethod
871+
def from_regions(
872+
cls,
873+
table_bbox: BoundingBox,
874+
rows: List[BoundingBox],
875+
cols: List[BoundingBox],
876+
merges: List[BoundingBox],
877+
row_headers: List[BoundingBox] = [],
878+
col_headers: List[BoundingBox] = [],
879+
row_sections: List[BoundingBox] = [],
880+
) -> Self:
881+
"""Converts regions: rows, columns, merged cells into table_data structure.
882+
883+
Adds semantics for regions of row_headers, col_headers, row_section
884+
"""
885+
default_containment_thresh = 0.5
886+
rows.extend(row_sections) # use row sections to compensate for missing rows
887+
rows = cls._dedupe_bboxes(
888+
[
889+
e
890+
for e in rows
891+
if e.intersection_over_self(table_bbox) >= default_containment_thresh
892+
]
893+
)
894+
cols = cls._dedupe_bboxes(
895+
[
896+
e
897+
for e in cols
898+
if e.intersection_over_self(table_bbox) >= default_containment_thresh
899+
]
900+
)
901+
merges = cls._dedupe_bboxes(
902+
[
903+
e
904+
for e in merges
905+
if e.intersection_over_self(table_bbox) >= default_containment_thresh
906+
]
907+
)
908+
909+
col_headers = cls._dedupe_bboxes(
910+
[
911+
e
912+
for e in col_headers
913+
if e.intersection_over_self(table_bbox) >= default_containment_thresh
914+
]
915+
)
916+
row_headers = cls._dedupe_bboxes(
917+
[
918+
e
919+
for e in row_headers
920+
if e.intersection_over_self(table_bbox) >= default_containment_thresh
921+
]
922+
)
923+
row_sections = cls._dedupe_bboxes(
924+
[
925+
e
926+
for e in row_sections
927+
if e.intersection_over_self(table_bbox) >= default_containment_thresh
928+
]
929+
)
930+
931+
# Compute table cells from CVAT elements: rows, cols, merges
932+
computed_table_cells = cls._compute_cells(
933+
rows,
934+
cols,
935+
merges,
936+
col_headers,
937+
row_headers,
938+
row_sections,
939+
)
940+
941+
# If no table structure found, create single fake cell for content
942+
if not rows or not cols:
943+
computed_table_cells = [
944+
TableCell(
945+
text="",
946+
row_span=1,
947+
col_span=1,
948+
start_row_offset_idx=0,
949+
end_row_offset_idx=1,
950+
start_col_offset_idx=0,
951+
end_col_offset_idx=1,
952+
bbox=table_bbox,
953+
column_header=False,
954+
row_header=False,
955+
row_section=False,
956+
)
957+
]
958+
table_data = cls(num_rows=1, num_cols=1)
959+
else:
960+
table_data = cls(num_rows=len(rows), num_cols=len(cols))
961+
962+
table_data.table_cells = computed_table_cells
963+
964+
return table_data
965+
682966

683967
class PictureTabularChartData(PictureChartData):
684968
"""Base class for picture chart data.

0 commit comments

Comments
 (0)