Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions docling/models/layout_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import logging
import warnings
from collections.abc import Iterable
from copy import deepcopy
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional, Union, cast

import numpy as np
from docling_core.types.doc import DocItemLabel
Expand All @@ -19,6 +20,7 @@
from docling.models.utils.hf_model_download import download_hf_model
from docling.utils.accelerator_utils import decide_device
from docling.utils.layout_postprocessor import LayoutPostprocessor
from docling.utils.orientation import detect_orientation, rotate_bounding_box
from docling.utils.profiling import TimeRecorder
from docling.utils.visualization import draw_clusters

Expand Down Expand Up @@ -102,17 +104,22 @@ def download_models(
)

def draw_clusters_and_cells_side_by_side(
self, conv_res, page, clusters, mode_prefix: str, show: bool = False
self,
conv_res,
page,
clusters,
mode_prefix: str,
show: bool = False,
):
"""
Draws a page image side by side with clusters filtered into two categories:
- Left: Clusters excluding FORM, KEY_VALUE_REGION, and PICTURE.
- Right: Clusters including FORM, KEY_VALUE_REGION, and PICTURE.
Includes label names and confidence scores for each cluster.
"""
scale_x = page.image.width / page.size.width
scale_y = page.image.height / page.size.height

page_image = deepcopy(page.image)
scale_x = page_image.width / page.size.width
scale_y = page_image.height / page.size.height
# Filter clusters for left and right images
exclude_labels = {
DocItemLabel.FORM,
Expand Down Expand Up @@ -152,8 +159,8 @@ def __call__(
pages = list(page_batch)

# Separate valid and invalid pages
valid_pages = []
valid_page_images: List[Union[Image.Image, np.ndarray]] = []
valid_page_orientations: List[int] = []

for page in pages:
assert page._backend is not None
Expand All @@ -164,8 +171,12 @@ def __call__(
page_image = page.get_image(scale=1.0)
assert page_image is not None

valid_pages.append(page)
page_orientation = detect_orientation(page.cells)
if page_orientation:
page_image = page_image.rotate(-page_orientation, expand=True)

valid_page_images.append(page_image)
valid_page_orientations.append(page_orientation)

# Process all valid pages with batch prediction
batch_predictions = []
Expand All @@ -184,25 +195,37 @@ def __call__(
continue

page_predictions = batch_predictions[valid_page_idx]
page_image = valid_page_images[valid_page_idx] # type: ignore[assignment]
page_orientation = valid_page_orientations[valid_page_idx]
valid_page_idx += 1

clusters = []
for ix, pred_item in enumerate(page_predictions):
label = DocItemLabel(
pred_item["label"].lower().replace(" ", "_").replace("-", "_")
) # Temporary, until docling-ibm-model uses docling-core types
bbox = BoundingBox.model_validate(pred_item)
if page_orientation:
bbox = rotate_bounding_box(
bbox,
page_orientation,
page_image.size, # type: ignore[union-attr]
).to_bounding_box()
cluster = Cluster(
id=ix,
label=label,
confidence=pred_item["confidence"],
bbox=BoundingBox.model_validate(pred_item),
bbox=bbox,
cells=[],
)
clusters.append(cluster)

if settings.debug.visualize_raw_layout:
self.draw_clusters_and_cells_side_by_side(
conv_res, page, clusters, mode_prefix="raw"
conv_res,
page,
clusters,
mode_prefix="raw",
)

# Apply postprocessing
Expand Down Expand Up @@ -231,7 +254,10 @@ def __call__(

if settings.debug.visualize_layout:
self.draw_clusters_and_cells_side_by_side(
conv_res, page, processed_clusters, mode_prefix="postprocessed"
conv_res,
page,
processed_clusters,
mode_prefix="postprocessed",
)

yield page
8 changes: 4 additions & 4 deletions docling/models/ocr_mac_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def __call__(
x2 = x1 + w * im_width
y1 = y2 - h * im_height

left = x1 / self.scale
top = y1 / self.scale
right = x2 / self.scale
bottom = y2 / self.scale
left = x1 / self.scale + ocr_rect.l
top = y1 / self.scale + ocr_rect.t
right = x2 / self.scale + ocr_rect.l
bottom = y2 / self.scale + ocr_rect.t

cells.append(
TextCell(
Expand Down
Loading
Loading