Skip to content

Commit 631e6fb

Browse files
authored
Feat/improve chipper bounding boxes (#292)
Chipper bounding boxes largely exceed the area where the annotated text is. This PR intends solving this problem. There are several fixes applied: * Improved cross attention processing. Maps are filtered at the head level, which is normalised [0-1], this improves the bounding box definition. * The correlation between the token index and cross attention map has been resolved. Before, with beam search size = 1 or 3 there were cases in which the cross attention map did not match the token being processed. In addition, the empty areas of the bounding boxes have been cleaned and the overlaps between bounding boxes are resolved by identifying the largest margin that separates both bounding boxes, identified as the largest gap without text either in the horizontal or vertical directions. Note: overlapping bounding boxes for child elements are not resolved for now. In the case of a list with list items, the list element will be affected by the overlapping resolution code. --------- Co-authored-by: Antonio Jimeno Yepes <[email protected]>
1 parent 0f0c2be commit 631e6fb

File tree

4 files changed

+668
-31
lines changed

4 files changed

+668
-31
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.7.16-dev1
2+
3+
* enhancement: improved Chipper bounding boxes
4+
15
## 0.7.16
26

37
* bug: Allow supplied ONNX models to use label_map dictionary from json file

test_unstructured_inference/models/test_chippermodel.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,10 @@ def test_no_repeat_ngram_logits():
197197
@pytest.mark.parametrize(
198198
("decoded_str", "expected_classes"),
199199
[
200-
("<s><s_Misc> 1</s_Misc><s_Text>There is some text here.</s_Text></s>", ["Misc", "Text"]),
200+
(
201+
"<s><s_Misc> 1</s_Misc><s_Text>There is some text here.</s_Text></s>",
202+
["Misc", "Text"],
203+
),
201204
(
202205
"<s><s_List><s_List-item>Text here.</s_List-item><s_List><s_List-item>Final one",
203206
["List", "List-item", "List", "List-item"],
@@ -245,3 +248,119 @@ def test_run_chipper_v2():
245248
tables = [el for el in elements if el.type == "Table"]
246249
assert all(table.text_as_html.startswith("<table>") for table in tables)
247250
assert all("<table>" not in table.text for table in tables)
251+
252+
253+
@pytest.mark.parametrize(
254+
("bbox", "output"),
255+
[
256+
(
257+
[0, 0, 0, 0],
258+
None,
259+
),
260+
(
261+
[0, 1, 1, -1],
262+
None,
263+
),
264+
],
265+
)
266+
def test_largest_margin(bbox, output):
267+
model = get_model("chipper")
268+
img = Image.open("sample-docs/easy_table.jpg")
269+
assert model.largest_margin(img, bbox) is output
270+
271+
272+
@pytest.mark.parametrize(
273+
("bbox", "output"),
274+
[
275+
(
276+
[0, 1, 0, -1],
277+
[0, 1, 0, -1],
278+
),
279+
(
280+
[0, 1, 1, -1],
281+
[0, 1, 1, -1],
282+
),
283+
(
284+
[20, 10, 30, 40],
285+
[20, 10, 30, 40],
286+
),
287+
],
288+
)
289+
def test_reduce_bbox_overlap(bbox, output):
290+
model = get_model("chipper")
291+
img = Image.open("sample-docs/easy_table.jpg")
292+
assert model.reduce_bbox_overlap(img, bbox) == output
293+
294+
295+
@pytest.mark.parametrize(
296+
("bbox", "output"),
297+
[
298+
(
299+
[20, 10, 30, 40],
300+
[20, 10, 30, 40],
301+
),
302+
],
303+
)
304+
def test_reduce_bbox_no_overlap(bbox, output):
305+
model = get_model("chipper")
306+
img = Image.open("sample-docs/easy_table.jpg")
307+
assert model.reduce_bbox_no_overlap(img, bbox) == output
308+
309+
310+
@pytest.mark.parametrize(
311+
("bbox1", "bbox2", "output"),
312+
[
313+
(
314+
[0, 50, 20, 80],
315+
[10, 10, 30, 30],
316+
(
317+
"horizontal",
318+
[10, 10, 30, 30],
319+
[0, 50, 20, 80],
320+
[0, 50, 20, 80],
321+
[10, 10, 30, 30],
322+
None,
323+
),
324+
),
325+
(
326+
[10, 10, 30, 30],
327+
[40, 10, 60, 30],
328+
(
329+
"vertical",
330+
[40, 10, 60, 30],
331+
[10, 10, 30, 30],
332+
[10, 10, 30, 30],
333+
[40, 10, 60, 30],
334+
None,
335+
),
336+
),
337+
(
338+
[10, 80, 30, 100],
339+
[40, 10, 60, 30],
340+
(
341+
"none",
342+
[40, 10, 60, 30],
343+
[10, 80, 30, 100],
344+
[10, 80, 30, 100],
345+
[40, 10, 60, 30],
346+
None,
347+
),
348+
),
349+
(
350+
[40, 10, 60, 30],
351+
[10, 10, 30, 30],
352+
(
353+
"vertical",
354+
[10, 10, 30, 30],
355+
[40, 10, 60, 30],
356+
[10, 10, 30, 30],
357+
[40, 10, 60, 30],
358+
None,
359+
),
360+
),
361+
],
362+
)
363+
def test_check_overlap(bbox1, bbox2, output):
364+
model = get_model("chipper")
365+
366+
assert model.check_overlap(bbox1, bbox2) == output
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.7.16" # pragma: no cover
1+
__version__ = "0.7.16-dev1" # pragma: no cover

0 commit comments

Comments
 (0)