Skip to content

Commit 605c1f2

Browse files
authored
feat: Refactor for local inference (#5)
Small factorization, primarily moving some of the api processing code out to where it can be reused.
1 parent 1b1507d commit 605c1f2

File tree

7 files changed

+109
-28
lines changed

7 files changed

+109
-28
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
## 0.2.1-dev0
1+
## 0.2.1-dev1
22

3+
* Refactor to facilitate local inference
34
* Removes BasicConfig from logger configuration
45
* Implement auto model downloading
56

test_unstructured_inference/inference/test_layout.py

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import pytest
2-
from unittest.mock import patch
2+
from unittest.mock import patch, mock_open
33

44
import layoutparser as lp
55
from layoutparser.elements import Layout, Rectangle, TextBlock
66
import numpy as np
77
from PIL import Image
88

9-
from unstructured_inference.inference.layout import DocumentLayout, PageLayout
9+
import unstructured_inference.inference.layout as layout
10+
import unstructured_inference.models as models
11+
1012
import unstructured_inference.models.detectron2 as detectron2
1113
import unstructured_inference.models.tesseract as tesseract
1214

@@ -28,7 +30,7 @@ def mock_page_layout():
2830

2931

3032
def test_pdf_page_converts_images_to_array(mock_image):
31-
page = PageLayout(number=0, image=mock_image, layout=Layout())
33+
page = layout.PageLayout(number=0, image=mock_image, layout=Layout())
3234
assert page.image_array is None
3335

3436
image_array = page._get_image_array()
@@ -47,7 +49,7 @@ def detect(self, *args):
4749
monkeypatch.setattr(tesseract, "is_pytesseract_available", lambda *args: True)
4850

4951
image = np.random.randint(12, 24, (40, 40))
50-
page = PageLayout(number=0, image=image, layout=Layout())
52+
page = layout.PageLayout(number=0, image=image, layout=Layout())
5153
rectangle = Rectangle(1, 2, 3, 4)
5254
text_block = TextBlock(rectangle, text=None)
5355

@@ -67,7 +69,7 @@ def test_get_page_elements(monkeypatch, mock_page_layout):
6769
monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True)
6870

6971
image = np.random.randint(12, 24, (40, 40))
70-
page = PageLayout(number=0, image=image, layout=mock_page_layout)
72+
page = layout.PageLayout(number=0, image=image, layout=mock_page_layout)
7173

7274
elements = page.get_elements(inplace=False)
7375

@@ -79,17 +81,17 @@ def test_get_page_elements(monkeypatch, mock_page_layout):
7981

8082

8183
def test_get_page_elements_with_ocr(monkeypatch):
82-
monkeypatch.setattr(PageLayout, "ocr", lambda *args: "An Even Catchier Title")
84+
monkeypatch.setattr(layout.PageLayout, "ocr", lambda *args: "An Even Catchier Title")
8385

8486
rectangle = Rectangle(2, 4, 6, 8)
8587
text_block = TextBlock(rectangle, text=None, type="Title")
86-
layout = Layout([text_block])
88+
doc_layout = Layout([text_block])
8789

88-
monkeypatch.setattr(detectron2, "load_default_model", lambda: MockLayoutModel(layout))
90+
monkeypatch.setattr(detectron2, "load_default_model", lambda: MockLayoutModel(doc_layout))
8991
monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True)
9092

9193
image = np.random.randint(12, 24, (40, 40))
92-
page = PageLayout(number=0, image=image, layout=layout)
94+
page = layout.PageLayout(number=0, image=image, layout=doc_layout)
9395
page.get_elements()
9496

9597
assert str(page) == "An Even Catchier Title"
@@ -105,7 +107,7 @@ def test_read_pdf(monkeypatch, mock_page_layout):
105107
monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True)
106108

107109
with patch.object(lp, "load_pdf", return_value=(layouts, images)):
108-
doc = DocumentLayout.from_file("fake-file.pdf")
110+
doc = layout.DocumentLayout.from_file("fake-file.pdf")
109111

110112
assert str(doc).startswith("A Catchy Title")
111113
assert str(doc).count("A Catchy Title") == 2 # Once for each page
@@ -115,3 +117,62 @@ def test_read_pdf(monkeypatch, mock_page_layout):
115117

116118
pages = doc.pages
117119
assert str(doc) == "\n\n".join([str(page) for page in pages])
120+
121+
122+
@pytest.mark.parametrize("model_name", [None, "checkbox", "fake"])
123+
def test_process_data_with_model(monkeypatch, mock_page_layout, model_name):
124+
monkeypatch.setattr(models, "get_model", lambda x: MockLayoutModel(mock_page_layout))
125+
monkeypatch.setattr(
126+
layout.DocumentLayout,
127+
"from_file",
128+
lambda *args, **kwargs: layout.DocumentLayout.from_pages([]),
129+
)
130+
monkeypatch.setattr(
131+
models, "load_model", lambda *args, **kwargs: MockLayoutModel(mock_page_layout)
132+
)
133+
monkeypatch.setattr(
134+
models,
135+
"_get_model_loading_info",
136+
lambda *args, **kwargs: (
137+
"fake-binary-path",
138+
"fake-config-path",
139+
{0: "Unchecked", 1: "Checked"},
140+
),
141+
)
142+
with patch("builtins.open", mock_open(read_data=b"000000")):
143+
assert layout.process_data_with_model(open(""), model_name=model_name)
144+
145+
146+
def test_process_data_with_model_raises_on_invalid_model_name():
147+
with patch("builtins.open", mock_open(read_data=b"000000")):
148+
with pytest.raises(models.UnknownModelException):
149+
layout.process_data_with_model(open(""), model_name="fake")
150+
151+
152+
@pytest.mark.parametrize("model_name", [None, "checkbox"])
153+
def test_process_file_with_model(monkeypatch, mock_page_layout, model_name):
154+
monkeypatch.setattr(models, "get_model", lambda x: MockLayoutModel(mock_page_layout))
155+
monkeypatch.setattr(
156+
layout.DocumentLayout,
157+
"from_file",
158+
lambda *args, **kwargs: layout.DocumentLayout.from_pages([]),
159+
)
160+
monkeypatch.setattr(
161+
models, "load_model", lambda *args, **kwargs: MockLayoutModel(mock_page_layout)
162+
)
163+
monkeypatch.setattr(
164+
models,
165+
"_get_model_loading_info",
166+
lambda *args, **kwargs: (
167+
"fake-binary-path",
168+
"fake-config-path",
169+
{0: "Unchecked", 1: "Checked"},
170+
),
171+
)
172+
filename = ""
173+
assert layout.process_file_with_model(filename, model_name=model_name)
174+
175+
176+
def test_process_file_with_model_raises_on_invalid_model_name():
177+
with pytest.raises(models.UnknownModelException):
178+
layout.process_file_with_model("", model_name="fake")

test_unstructured_inference/models/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,5 @@ def test_get_model(monkeypatch):
2424

2525

2626
def test_raises_invalid_model():
27-
with pytest.raises(ValueError):
27+
with pytest.raises(models.UnknownModelException):
2828
models.get_model("fake_model")
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.2.1-dev0" # pragma: no cover
1+
__version__ = "0.2.1-dev1" # pragma: no cover

unstructured_inference/api.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from fastapi import FastAPI, File, status, Request, UploadFile, Form, HTTPException
2-
from unstructured_inference.inference.layout import DocumentLayout
3-
from unstructured_inference.models import get_model
2+
from unstructured_inference.inference.layout import process_data_with_model
3+
from unstructured_inference.models import UnknownModelException
44
from typing import List
5-
import tempfile
65

76
app = FastAPI()
87

@@ -15,16 +14,10 @@ async def layout_parsing_pdf(
1514
include_elems: List[str] = Form(default=ALL_ELEMS),
1615
model: str = Form(default=None),
1716
):
18-
with tempfile.NamedTemporaryFile() as tmp_file:
19-
tmp_file.write(file.file.read())
20-
if model is None:
21-
layout = DocumentLayout.from_file(tmp_file.name)
22-
else:
23-
try:
24-
detector = get_model(model)
25-
except ValueError as e:
26-
raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY, str(e))
27-
layout = DocumentLayout.from_file(tmp_file.name, model=detector)
17+
try:
18+
layout = process_data_with_model(file.file, model)
19+
except UnknownModelException as e:
20+
raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY, str(e))
2821
pages_layout = [
2922
{
3023
"number": page.number,

unstructured_inference/inference/layout.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22
from dataclasses import dataclass
3-
from typing import List, Optional, Tuple, Union
3+
import tempfile
4+
from typing import List, Optional, Tuple, Union, BinaryIO
45

56
import layoutparser as lp
67
from layoutparser.models.detectron2.layoutmodel import Detectron2LayoutModel
@@ -10,6 +11,7 @@
1011
from unstructured_inference.logger import logger
1112
import unstructured_inference.models.tesseract as tesseract
1213
import unstructured_inference.models.detectron2 as detectron2
14+
from unstructured_inference.models import get_model
1315

1416

1517
@dataclass
@@ -136,3 +138,21 @@ def _get_image_array(self) -> Union[np.ndarray, None]:
136138
if self.image_array is None:
137139
self.image_array = np.array(self.image)
138140
return self.image_array
141+
142+
143+
def process_data_with_model(data: BinaryIO, model_name: str) -> DocumentLayout:
144+
"""Processes pdf file in the form of a file handler (supporting a read method) into a
145+
DocumentLayout by using a model identified by model_name."""
146+
with tempfile.NamedTemporaryFile() as tmp_file:
147+
tmp_file.write(data.read())
148+
layout = process_file_with_model(tmp_file.name, model_name)
149+
150+
return layout
151+
152+
153+
def process_file_with_model(filename: str, model_name: str) -> DocumentLayout:
154+
"""Processes pdf file with name filename into a DocumentLayout by using a model identified by
155+
model_name."""
156+
model = None if model_name is None else get_model(model_name)
157+
layout = DocumentLayout.from_file(filename, model=model)
158+
return layout

unstructured_inference/models/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,11 @@ def _get_model_loading_info(model: str) -> Tuple[str, str, Dict[int, str]]:
2424
config_path = hf_hub_download(repo_id, config_fn)
2525
label_map = {0: "Unchecked", 1: "Checked"}
2626
else:
27-
raise ValueError(f"Unknown model type: {model}")
27+
raise UnknownModelException(f"Unknown model type: {model}")
2828
return model_path, config_path, label_map
29+
30+
31+
class UnknownModelException(Exception):
32+
"""Exception for the case where a model is called for with an unrecognized identifier."""
33+
34+
pass

0 commit comments

Comments
 (0)