-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdiscretize_bboxes.py
More file actions
96 lines (79 loc) · 3.26 KB
/
Copy pathdiscretize_bboxes.py
File metadata and controls
96 lines (79 loc) · 3.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import copy
from typing import Any, Union
import numpy as np
from langchain_core.runnables import Runnable
from langchain_core.runnables.config import RunnableConfig
from pydantic import BaseModel
from layout_prompter.models import LayoutData, ProcessedLayoutData
from layout_prompter.utils import decapsulate
class DiscretizeBboxes(BaseModel, Runnable):
name: str = "discretize-bboxes"
def discretize(self, bboxes: np.ndarray, width: int, height: int) -> np.ndarray:
assert bboxes.shape[1] == 4, "bboxes should be of shape (N, 4)"
clipped_bboxes = np.clip(bboxes, a_min=0.0, a_max=1.0)
x1, y1, x2, y2 = decapsulate(clipped_bboxes)
discrete_x1 = np.floor(x1 * width)
discrete_y1 = np.floor(y1 * height)
discrete_x2 = np.floor(x2 * width)
discrete_y2 = np.floor(y2 * height)
discrete_bboxes = np.stack(
[discrete_x1, discrete_y1, discrete_x2, discrete_y2], axis=-1
)
return discrete_bboxes.astype(np.int32)
def continuize(self, bboxes: np.ndarray, width: int, height: int) -> np.ndarray:
x1, y1, x2, y2 = decapsulate(bboxes)
cx1, cx2 = x1 / width, x2 / width
cy1, cy2 = y1 / height, y2 / height
continuize_bboxes = np.stack([cx1, cy1, cx2, cy2], axis=-1)
return continuize_bboxes.astype(np.float32)
def invoke(
self,
input: Union[LayoutData, ProcessedLayoutData],
config: RunnableConfig | None = None,
**kwargs: Any,
) -> ProcessedLayoutData:
assert input.bboxes is not None and input.labels is not None
canvas_size = input.canvas_size.model_dump()
bboxes, labels = copy.deepcopy(input.bboxes), copy.deepcopy(input.labels)
content_bboxes = (
copy.deepcopy(input.content_bboxes) if input.is_content_aware() else None
)
encoded_image = input.encoded_image if isinstance(input, LayoutData) else None
gold_bboxes = (
copy.deepcopy(input.bboxes)
if isinstance(input, LayoutData)
else input.gold_bboxes
)
orig_bboxes = (
copy.deepcopy(gold_bboxes)
if isinstance(input, LayoutData)
else input.orig_bboxes
)
orig_labels = (
copy.deepcopy(input.labels)
if isinstance(input, LayoutData)
else input.orig_labels
)
discrete_bboxes = self.discretize(bboxes, **canvas_size)
discrete_gold_bboxes = self.discretize(gold_bboxes, **canvas_size)
content_bboxes = (
copy.deepcopy(input.content_bboxes) if input.is_content_aware() else None
)
discrete_content_bboxes = (
self.discretize(content_bboxes, **canvas_size)
if content_bboxes is not None
else None
)
return ProcessedLayoutData(
bboxes=bboxes,
labels=labels,
gold_bboxes=gold_bboxes,
encoded_image=encoded_image,
content_bboxes=content_bboxes,
discrete_bboxes=discrete_bboxes,
discrete_gold_bboxes=discrete_gold_bboxes,
discrete_content_bboxes=discrete_content_bboxes,
orig_bboxes=orig_bboxes,
orig_labels=orig_labels,
canvas_size=canvas_size,
)