Skip to content

Commit 4e7464f

Browse files
authored
refactor: partition in elements (#84)
Added the capability to partition granular-scale elements that have been identified (words, characters) by proximity using the word/character height as a reference. In many cases this does a good job of grouping text blocks. Also moved logic for extracting text from a region into the region itself, so in the future different logic can be used for embedded text and images.
1 parent a9f1255 commit 4e7464f

File tree

9 files changed

+491
-213
lines changed

9 files changed

+491
-213
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
## 0.3.3-dev2
1+
## 0.4.0
22

3+
* Added logic to partition granular elements (words, characters) by proximity
4+
* Text extraction is now delegated to text regions rather than being handled centrally
35
* Fixed embedded image coordinates being interpreted differently than embedded text coordinates
46
* Update to how dependencies are being handled
57
* Update detectron2 version

test_unstructured_inference/inference/test_layout.py

Lines changed: 78 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from PIL import Image
99

1010
import unstructured_inference.inference.layout as layout
11+
import unstructured_inference.inference.elements as elements
1112
import unstructured_inference.models.base as models
1213
import unstructured_inference.models.detectron2 as detectron2
1314
import unstructured_inference.models.tesseract as tesseract
@@ -20,9 +21,9 @@ def mock_image():
2021

2122
@pytest.fixture
2223
def mock_page_layout():
23-
text_block = layout.TextRegion(2, 4, 6, 8, text="A very repetitive narrative. " * 10)
24+
text_block = layout.EmbeddedTextRegion(2, 4, 6, 8, text="A very repetitive narrative. " * 10)
2425

25-
title_block = layout.TextRegion(1, 2, 3, 4, text="A Catchy Title")
26+
title_block = layout.EmbeddedTextRegion(1, 2, 3, 4, text="A Catchy Title")
2627

2728
return [text_block, title_block]
2829

@@ -49,7 +50,7 @@ def detect(self, *args):
4950
image = Image.fromarray(np.random.randint(12, 24, (40, 40)), mode="RGB")
5051
text_block = layout.TextRegion(1, 2, 3, 4, text=None)
5152

52-
assert layout.ocr(text_block, image=image) == mock_text
53+
assert elements.ocr(text_block, image=image) == mock_text
5354

5455

5556
class MockLayoutModel:
@@ -69,12 +70,12 @@ def test_get_page_elements(monkeypatch, mock_page_layout):
6970
number=0, image=image, layout=mock_page_layout, model=MockLayoutModel(mock_page_layout)
7071
)
7172

72-
elements = page.get_elements(inplace=False)
73+
elements = page.get_elements_with_model(inplace=False)
7374

7475
assert str(elements[0]) == "A Catchy Title"
7576
assert str(elements[1]).startswith("A very repetitive narrative.")
7677

77-
page.get_elements(inplace=True)
78+
page.get_elements_with_model(inplace=True)
7879
assert elements == page.elements
7980

8081

@@ -95,13 +96,13 @@ def test_get_page_elements_with_ocr(monkeypatch):
9596
doc_layout = [text_block, image_block]
9697

9798
monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True)
98-
monkeypatch.setattr(layout, "ocr", lambda *args: "An Even Catchier Title")
99+
monkeypatch.setattr(elements, "ocr", lambda *args: "An Even Catchier Title")
99100

100101
image = Image.fromarray(np.random.randint(12, 14, size=(40, 10, 3)), mode="RGB")
101102
page = layout.PageLayout(
102103
number=0, image=image, layout=doc_layout, model=MockLayoutModel(doc_layout)
103104
)
104-
page.get_elements()
105+
page.get_elements_with_model()
105106

106107
assert str(page) == "\n\nAn Even Catchier Title"
107108

@@ -174,7 +175,7 @@ def tolist(self):
174175
return [1, 2, 3, 4]
175176

176177

177-
class MockTextRegion(layout.TextRegion):
178+
class MockEmbeddedTextRegion(layout.EmbeddedTextRegion):
178179
def __init__(self, type=None, text=None, ocr_text=None):
179180
self.type = type
180181
self.text = text
@@ -193,7 +194,7 @@ def __init__(self, layout=None, model=None, ocr_strategy="auto", extract_tables=
193194
self.ocr_strategy = ocr_strategy
194195
self.extract_tables = extract_tables
195196

196-
def ocr(self, text_block: MockTextRegion):
197+
def ocr(self, text_block: MockEmbeddedTextRegion):
197198
return text_block.ocr_text
198199

199200

@@ -202,15 +203,15 @@ def ocr(self, text_block: MockTextRegion):
202203
[("base", 0.0), ("", 0.0), ("(cid:2)", 1.0), ("(cid:1)a", 0.5), ("c(cid:1)ab", 0.25)],
203204
)
204205
def test_cid_ratio(text, expected):
205-
assert layout.cid_ratio(text) == expected
206+
assert elements.cid_ratio(text) == expected
206207

207208

208209
@pytest.mark.parametrize(
209210
"text, expected",
210211
[("base", False), ("(cid:2)", True), ("(cid:1234567890)", True), ("jkl;(cid:12)asdf", True)],
211212
)
212213
def test_is_cid_present(text, expected):
213-
assert layout.is_cid_present(text) == expected
214+
assert elements.is_cid_present(text) == expected
214215

215216

216217
class MockLayout:
@@ -241,7 +242,7 @@ def filter_by(self, *args, **kwargs):
241242
],
242243
)
243244
def test_get_element_from_block(block_text, layout_texts, mock_image, expected_text):
244-
with patch("unstructured_inference.inference.layout.ocr", return_value="ocr"):
245+
with patch("unstructured_inference.inference.elements.ocr", return_value="ocr"):
245246
block = layout.TextRegion(0, 0, 10, 10, text=block_text)
246247
captured_layout = [
247248
layout.TextRegion(i + 1, i + 1, i + 2, i + 2, text=text)
@@ -263,7 +264,7 @@ def test_from_image_file(monkeypatch, mock_page_layout, filetype):
263264
def mock_get_elements(self, *args, **kwargs):
264265
self.elements = [mock_page_layout]
265266

266-
monkeypatch.setattr(layout.PageLayout, "get_elements", mock_get_elements)
267+
monkeypatch.setattr(layout.PageLayout, "get_elements_with_model", mock_get_elements)
267268
elements = (
268269
layout.DocumentLayout.from_image_file(f"sample-docs/loremipsum.{filetype}")
269270
.pages[0]
@@ -301,12 +302,12 @@ def test_get_elements_from_layout(mock_page_layout, idx):
301302
@pytest.mark.parametrize(
302303
"fixed_layouts, called_method, not_called_method",
303304
[
304-
([MockLayout()], "get_elements_from_layout", "get_elements"),
305-
(None, "get_elements", "get_elements_from_layout"),
305+
([MockLayout()], "get_elements_from_layout", "get_elements_with_model"),
306+
(None, "get_elements_with_model", "get_elements_from_layout"),
306307
],
307308
)
308309
def test_from_file_fixed_layout(fixed_layouts, called_method, not_called_method):
309-
with patch.object(layout.PageLayout, "get_elements", return_value=[]), patch.object(
310+
with patch.object(layout.PageLayout, "get_elements_with_model", return_value=[]), patch.object(
310311
layout.PageLayout, "get_elements_from_layout", return_value=[]
311312
):
312313
layout.DocumentLayout.from_file("sample-docs/loremipsum.pdf", fixed_layouts=fixed_layouts)
@@ -323,70 +324,105 @@ def test_invalid_ocr_strategy_raises(mock_image):
323324
("text", "expected"), [("a\ts\x0cd\nfas\fd\rf\b", "asdfasdf"), ("\"'\\", "\"'\\")]
324325
)
325326
def test_remove_control_characters(text, expected):
326-
assert layout.remove_control_characters(text) == expected
327+
assert elements.remove_control_characters(text) == expected
327328

328329

329-
no_text_region = layout.TextRegion(0, 0, 100, 100)
330-
text_region = layout.TextRegion(0, 0, 100, 100, text="test")
331-
cid_text_region = layout.TextRegion(0, 0, 100, 100, text="(cid:1)(cid:2)(cid:3)(cid:4)(cid:5)")
330+
no_text_region = layout.EmbeddedTextRegion(0, 0, 100, 100)
331+
text_region = layout.EmbeddedTextRegion(0, 0, 100, 100, text="test")
332+
cid_text_region = layout.EmbeddedTextRegion(
333+
0, 0, 100, 100, text="(cid:1)(cid:2)(cid:3)(cid:4)(cid:5)"
334+
)
332335
overlapping_rect = layout.ImageTextRegion(50, 50, 150, 150)
333336
nonoverlapping_rect = layout.ImageTextRegion(150, 150, 200, 200)
334-
populated_text_region = layout.TextRegion(50, 50, 60, 60, text="test")
335-
unpopulated_text_region = layout.TextRegion(50, 50, 60, 60, text=None)
337+
populated_text_region = layout.EmbeddedTextRegion(50, 50, 60, 60, text="test")
338+
unpopulated_text_region = layout.EmbeddedTextRegion(50, 50, 60, 60, text=None)
336339

337340

338341
@pytest.mark.parametrize(
339-
("region", "text_objects", "image_objects", "ocr_strategy", "expected"),
342+
("region", "objects", "ocr_strategy", "expected"),
340343
[
341-
(no_text_region, [], [nonoverlapping_rect], "auto", False),
342-
(no_text_region, [], [overlapping_rect], "auto", True),
343-
(no_text_region, [], [], "auto", False),
344-
(no_text_region, [populated_text_region], [nonoverlapping_rect], "auto", False),
345-
(no_text_region, [populated_text_region], [overlapping_rect], "auto", False),
346-
(no_text_region, [populated_text_region], [], "auto", False),
347-
(no_text_region, [unpopulated_text_region], [nonoverlapping_rect], "auto", False),
348-
(no_text_region, [unpopulated_text_region], [overlapping_rect], "auto", True),
349-
(no_text_region, [unpopulated_text_region], [], "auto", False),
344+
(no_text_region, [nonoverlapping_rect], "auto", False),
345+
(no_text_region, [overlapping_rect], "auto", True),
346+
(no_text_region, [], "auto", False),
347+
(no_text_region, [populated_text_region, nonoverlapping_rect], "auto", False),
348+
(no_text_region, [populated_text_region, overlapping_rect], "auto", False),
349+
(no_text_region, [populated_text_region], "auto", False),
350+
(no_text_region, [unpopulated_text_region, nonoverlapping_rect], "auto", False),
351+
(no_text_region, [unpopulated_text_region, overlapping_rect], "auto", True),
352+
(no_text_region, [unpopulated_text_region], "auto", False),
350353
*list(
351354
product(
352355
[text_region],
353-
[[], [populated_text_region], [unpopulated_text_region]],
354-
[[], [nonoverlapping_rect], [overlapping_rect]],
356+
[
357+
[],
358+
[populated_text_region],
359+
[unpopulated_text_region],
360+
[nonoverlapping_rect],
361+
[overlapping_rect],
362+
[populated_text_region, nonoverlapping_rect],
363+
[populated_text_region, overlapping_rect],
364+
[unpopulated_text_region, nonoverlapping_rect],
365+
[unpopulated_text_region, overlapping_rect],
366+
],
355367
["auto"],
356368
[False],
357369
)
358370
),
359371
*list(
360372
product(
361373
[cid_text_region],
362-
[[], [populated_text_region], [unpopulated_text_region]],
363-
[[overlapping_rect]],
374+
[
375+
[],
376+
[populated_text_region],
377+
[unpopulated_text_region],
378+
[overlapping_rect],
379+
[populated_text_region, overlapping_rect],
380+
[unpopulated_text_region, overlapping_rect],
381+
],
364382
["auto"],
365383
[True],
366384
)
367385
),
368386
*list(
369387
product(
370388
[no_text_region, text_region, cid_text_region],
371-
[[], [populated_text_region], [unpopulated_text_region]],
372-
[[], [nonoverlapping_rect], [overlapping_rect]],
389+
[
390+
[],
391+
[populated_text_region],
392+
[unpopulated_text_region],
393+
[nonoverlapping_rect],
394+
[overlapping_rect],
395+
[populated_text_region, nonoverlapping_rect],
396+
[populated_text_region, overlapping_rect],
397+
[unpopulated_text_region, nonoverlapping_rect],
398+
[unpopulated_text_region, overlapping_rect],
399+
],
373400
["force"],
374401
[True],
375402
)
376403
),
377404
*list(
378405
product(
379406
[no_text_region, text_region, cid_text_region],
380-
[[], [populated_text_region], [unpopulated_text_region]],
381-
[[], [nonoverlapping_rect], [overlapping_rect]],
407+
[
408+
[],
409+
[populated_text_region],
410+
[unpopulated_text_region],
411+
[nonoverlapping_rect],
412+
[overlapping_rect],
413+
[populated_text_region, nonoverlapping_rect],
414+
[populated_text_region, overlapping_rect],
415+
[unpopulated_text_region, nonoverlapping_rect],
416+
[unpopulated_text_region, overlapping_rect],
417+
],
382418
["never"],
383419
[False],
384420
)
385421
),
386422
],
387423
)
388-
def test_ocr_image(region, text_objects, image_objects, ocr_strategy, expected):
389-
assert layout.needs_ocr(region, text_objects, image_objects, ocr_strategy) is expected
424+
def test_ocr_image(region, objects, ocr_strategy, expected):
425+
assert elements.needs_ocr(region, objects, ocr_strategy) is expected
390426

391427

392428
def test_load_pdf():
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from random import randint
2+
from unstructured_inference.inference import elements
3+
from unstructured_inference.inference.layout import load_pdf
4+
5+
6+
def intersect_brute(rect1, rect2):
7+
return any(
8+
(rect2.x1 <= x <= rect2.x2) and (rect2.y1 <= y <= rect2.y2)
9+
for x in range(rect1.x1, rect1.x2 + 1)
10+
for y in range(rect1.y1, rect1.y2 + 1)
11+
)
12+
13+
14+
def rand_rect(size=10):
15+
x1 = randint(0, 30 - size)
16+
y1 = randint(0, 30 - size)
17+
return elements.Rectangle(x1, y1, x1 + size, y1 + size)
18+
19+
20+
def test_intersects_overlap():
21+
for _ in range(1000):
22+
rect1 = rand_rect()
23+
rect2 = rand_rect()
24+
assert intersect_brute(rect1, rect2) == rect1.intersects(rect2) == rect2.intersects(rect1)
25+
26+
27+
def test_intersects_subset():
28+
for _ in range(1000):
29+
rect1 = rand_rect()
30+
rect2 = rand_rect(20)
31+
assert intersect_brute(rect1, rect2) == rect1.intersects(rect2) == rect2.intersects(rect1)
32+
33+
34+
def test_intersection_of_lots_of_rects():
35+
for _ in range(1000):
36+
n_rects = 10
37+
rects = [rand_rect(6) for _ in range(n_rects)]
38+
intersection_mtx = elements.intersections(*rects)
39+
for i in range(n_rects):
40+
for j in range(n_rects):
41+
assert (
42+
intersect_brute(rects[i], rects[j])
43+
== intersection_mtx[i, j]
44+
== intersection_mtx[j, i]
45+
)
46+
47+
48+
def test_rectangle_width_height():
49+
for _ in range(1000):
50+
x1 = randint(0, 50)
51+
x2 = randint(x1 + 1, 100)
52+
y1 = randint(0, 50)
53+
y2 = randint(y1 + 1, 100)
54+
rect = elements.Rectangle(x1, y1, x2, y2)
55+
assert rect.width == x2 - x1
56+
assert rect.height == y2 - y1
57+
58+
59+
def test_minimal_containing_rect():
60+
for _ in range(1000):
61+
rect1 = rand_rect()
62+
rect2 = rand_rect()
63+
big_rect = elements.minimal_containing_region(rect1, rect2)
64+
for decrease_attr in ["x1", "y1", "x2", "y2"]:
65+
almost_as_big_rect = rand_rect()
66+
mod = 1 if decrease_attr.endswith("1") else -1
67+
for attr in ["x1", "y1", "x2", "y2"]:
68+
if attr == decrease_attr:
69+
setattr(almost_as_big_rect, attr, getattr(big_rect, attr) + mod)
70+
else:
71+
setattr(almost_as_big_rect, attr, getattr(big_rect, attr))
72+
assert not rect1.is_in(almost_as_big_rect) or not rect2.is_in(almost_as_big_rect)
73+
74+
assert rect1.is_in(big_rect)
75+
assert rect2.is_in(big_rect)
76+
77+
78+
def test_partition_groups_from_regions():
79+
words, _ = load_pdf("sample-docs/layout-parser-paper.pdf")
80+
groups = elements.partition_groups_from_regions(words[0])
81+
assert len(groups) == 9
82+
sorted_groups = sorted(groups, key=lambda group: group[0].y1)
83+
text = "".join([el.text for el in sorted_groups[-1]])
84+
assert text.startswith("Deep")
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3.3-dev2" # pragma: no cover
1+
__version__ = "0.4.0" # pragma: no cover

0 commit comments

Comments
 (0)