Skip to content

Commit 8759b0a

Browse files
authored
feat: allow passing down of ocr agent and table agent (#3954)
This PR allows passing down both `ocr_agent` and `table_ocr_agent` as parameters to specify the `OCRAgent` class for the page and tables, if any, respectively. Both are default to using `tesseract`, consistent with the present default behavior. We used to rely on env variables to specify the agents but os env can be changed during runtime outside of the caller's control. This method of passing down the variables ensures that specification is independent of env changes. ## testing Using `example-docs/img/layout-parser-paper-with-table.jpg` and run partition with two different settings. Note that this test requires `paddleocr` extra. ```python from unstructured.partition.auto import partition from unstructured.partition.utils.constants import OCR_AGENT_TESSERACT, OCR_AGENT_PADDLE elements = partition(f, strategy="hi_res", skip_infer_table_types=[], ocr_agent=OCR_AGENT_TESSERACT, table_ocr_agent=OCR_AGENT_PADDLE) elements_alt = partition(f, strategy="hi_res", skip_infer_table_types=[], ocr_agent=OCR_AGENT_PADDLE, table_ocr_agent=OCR_AGENT_TESSERACT) ``` we should see both finish and slight differences in the table element's text attribute.
1 parent 0001a33 commit 8759b0a

File tree

8 files changed

+122
-18
lines changed

8 files changed

+122
-18
lines changed

.github/workflows/ci.yml

+1
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ jobs:
218218
sudo apt-get install -y tesseract-ocr tesseract-ocr-kor
219219
tesseract --version
220220
make install-${{ matrix.extra }}
221+
[[ ${{ matrix.extra }} == "pdf-image" ]] && make install-paddleocr
221222
make test-extra-${{ matrix.extra }} CI=true
222223
223224
setup_ingest:

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
## 0.16.26-dev2
1+
## 0.16.26-dev3
22

33
### Enhancements
44

55
- **Add support for images in html partitioner** `<img>` tags will now be parsed as `Image` elements. When `extract_image_block_types` includes `Image` and `extract_image_block_to_payload`=True then the `image_base64` will be included for images that specify the base64 data (rather than url) as the source.
6+
- **Use kwargs instead of env to specify `ocr_agent` and `table_ocr_agent`** for `hi_res` strategy.
67

78
### Features
89

Makefile

+5-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ install-base: install-base-pip-packages install-nltk-models
2222
install: install-base-pip-packages install-dev install-nltk-models install-test install-huggingface install-all-docs
2323

2424
.PHONY: install-ci
25-
install-ci: install-base-pip-packages install-nltk-models install-huggingface install-all-docs install-test install-pandoc
25+
install-ci: install-base-pip-packages install-nltk-models install-huggingface install-all-docs install-test install-pandoc install-paddleocr
2626

2727
.PHONY: install-base-ci
2828
install-base-ci: install-base-pip-packages install-nltk-models install-test install-pandoc
@@ -80,6 +80,10 @@ install-odt:
8080
install-pypandoc:
8181
${PYTHON} -m pip install -r requirements/extra-pandoc.txt
8282

83+
.PHONY: install-paddleocr
84+
install-paddleocr:
85+
${PYTHON} -m pip install -r requirements/extra-paddleocr.txt
86+
8387
.PHONY: install-markdown
8488
install-markdown:
8589
${PYTHON} -m pip install -r requirements/extra-markdown.txt

test_unstructured/partition/pdf_image/test_ocr.py

+56-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import namedtuple
22
from typing import Optional
3-
from unittest.mock import patch
3+
from unittest.mock import MagicMock, patch
44

55
import numpy as np
66
import pandas as pd
@@ -10,7 +10,7 @@
1010
from pdf2image.exceptions import PDFPageCountError
1111
from PIL import Image, UnidentifiedImageError
1212
from unstructured_inference.inference.elements import EmbeddedTextRegion, TextRegion, TextRegions
13-
from unstructured_inference.inference.layout import DocumentLayout
13+
from unstructured_inference.inference.layout import DocumentLayout, PageLayout
1414
from unstructured_inference.inference.layoutelement import (
1515
LayoutElement,
1616
LayoutElements,
@@ -25,6 +25,8 @@
2525
)
2626
from unstructured.partition.utils.config import env_config
2727
from unstructured.partition.utils.constants import (
28+
OCR_AGENT_PADDLE,
29+
OCR_AGENT_TESSERACT,
2830
Source,
2931
)
3032
from unstructured.partition.utils.ocr_models.google_vision_ocr import OCRAgentGoogleVision
@@ -66,12 +68,10 @@ def test_process_file_with_ocr_invalid_filename(is_image):
6668
)
6769

6870

69-
def test_supplement_page_layout_with_ocr_invalid_ocr(monkeypatch):
70-
monkeypatch.setenv("OCR_AGENT", "invalid_ocr")
71+
def test_supplement_page_layout_with_ocr_invalid_ocr():
7172
with pytest.raises(ValueError):
7273
_ = ocr.supplement_page_layout_with_ocr(
73-
page_layout=None,
74-
image=None,
74+
page_layout=None, image=None, ocr_agent="invliad_ocr"
7575
)
7676

7777

@@ -610,3 +610,53 @@ def test_hocr_to_dataframe_when_no_prediction_empty_df():
610610
assert "width" in df.columns
611611
assert "text" in df.columns
612612
assert "text" in df.columns
613+
614+
615+
@pytest.fixture
616+
def mock_page(mock_ocr_layout, mock_layout):
617+
mock_page = MagicMock(PageLayout)
618+
mock_page.elements_array = mock_layout
619+
return mock_page
620+
621+
622+
def test_supplement_layout_with_ocr(mocker, mock_page):
623+
from unstructured.partition.pdf_image.ocr import OCRAgent
624+
625+
mocker.patch.object(OCRAgent, "get_layout_from_image", return_value=mock_ocr_layout)
626+
spy = mocker.spy(OCRAgent, "get_instance")
627+
628+
ocr.supplement_page_layout_with_ocr(
629+
mock_page,
630+
Image.new("RGB", (100, 100)),
631+
infer_table_structure=True,
632+
ocr_agent=OCR_AGENT_TESSERACT,
633+
ocr_languages="eng",
634+
table_ocr_agent=OCR_AGENT_PADDLE,
635+
)
636+
637+
assert spy.call_args_list[0][1] == {"language": "eng", "ocr_agent_module": OCR_AGENT_TESSERACT}
638+
assert spy.call_args_list[1][1] == {"language": "en", "ocr_agent_module": OCR_AGENT_PADDLE}
639+
640+
641+
def test_pass_down_agents(mocker, mock_page):
642+
from unstructured.partition.pdf_image.ocr import OCRAgent, PILImage
643+
644+
mocker.patch.object(OCRAgent, "get_layout_from_image", return_value=mock_ocr_layout)
645+
mocker.patch.object(PILImage, "open", return_value=Image.new("RGB", (100, 100)))
646+
spy = mocker.spy(OCRAgent, "get_instance")
647+
doc = MagicMock(DocumentLayout)
648+
doc.pages = [mock_page]
649+
650+
ocr.process_file_with_ocr(
651+
"foo",
652+
doc,
653+
[],
654+
infer_table_structure=True,
655+
is_image=True,
656+
ocr_agent=OCR_AGENT_PADDLE,
657+
ocr_languages="eng",
658+
table_ocr_agent=OCR_AGENT_TESSERACT,
659+
)
660+
661+
assert spy.call_args_list[0][1] == {"language": "en", "ocr_agent_module": OCR_AGENT_PADDLE}
662+
assert spy.call_args_list[1][1] == {"language": "eng", "ocr_agent_module": OCR_AGENT_TESSERACT}

test_unstructured/partition/pdf_image/test_pdf.py

+19
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from unstructured.partition.pdf_image import ocr, pdfminer_processing
4040
from unstructured.partition.pdf_image.pdfminer_processing import get_uris_from_annots
4141
from unstructured.partition.utils.constants import (
42+
OCR_AGENT_PADDLE,
43+
OCR_AGENT_TESSERACT,
4244
SORT_MODE_BASIC,
4345
SORT_MODE_DONT,
4446
SORT_MODE_XY_CUT,
@@ -1585,3 +1587,20 @@ def _test(result):
15851587
file=spooled_temp_file, strategy=strategy, password="password"
15861588
)
15871589
_test(result)
1590+
1591+
1592+
def test_partition_pdf_with_specified_ocr_agents(mocker):
1593+
from unstructured.partition.pdf_image.ocr import OCRAgent
1594+
1595+
spy = mocker.spy(OCRAgent, "get_instance")
1596+
1597+
pdf.partition_pdf(
1598+
filename=example_doc_path("pdf/layout-parser-paper-with-table.pdf"),
1599+
strategy=PartitionStrategy.HI_RES,
1600+
infer_table_structure=True,
1601+
ocr_agent=OCR_AGENT_TESSERACT,
1602+
table_ocr_agent=OCR_AGENT_PADDLE,
1603+
)
1604+
1605+
assert spy.call_args_list[0][1] == {"language": "eng", "ocr_agent_module": OCR_AGENT_TESSERACT}
1606+
assert spy.call_args_list[1][1] == {"language": "en", "ocr_agent_module": OCR_AGENT_PADDLE}

unstructured/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.16.26-dev2" # pragma: no cover
1+
__version__ = "0.16.26-dev3" # pragma: no cover

unstructured/partition/pdf.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
from unstructured.partition.common.lang import (
5555
check_language_args,
5656
prepare_languages_for_tesseract,
57-
tesseract_to_paddle_language,
5857
)
5958
from unstructured.partition.common.metadata import get_last_modified_date
6059
from unstructured.partition.pdf_image.analysis.layout_dump import (
@@ -88,7 +87,7 @@
8887
from unstructured.partition.text import element_from_text
8988
from unstructured.partition.utils.config import env_config
9089
from unstructured.partition.utils.constants import (
91-
OCR_AGENT_PADDLE,
90+
OCR_AGENT_TESSERACT,
9291
SORT_MODE_BASIC,
9392
SORT_MODE_DONT,
9493
SORT_MODE_XY_CUT,
@@ -273,6 +272,8 @@ def partition_pdf_or_image(
273272
pdfminer_char_margin: Optional[float] = None,
274273
pdfminer_line_overlap: Optional[float] = None,
275274
pdfminer_word_margin: Optional[float] = 0.185,
275+
ocr_agent: str = OCR_AGENT_TESSERACT,
276+
table_ocr_agent: str = OCR_AGENT_TESSERACT,
276277
**kwargs: Any,
277278
) -> list[Element]:
278279
"""Parses a pdf or image document into a list of interpreted elements."""
@@ -332,8 +333,6 @@ def partition_pdf_or_image(
332333
file.seek(0)
333334

334335
ocr_languages = prepare_languages_for_tesseract(languages)
335-
if env_config.OCR_AGENT == OCR_AGENT_PADDLE:
336-
ocr_languages = tesseract_to_paddle_language(ocr_languages)
337336

338337
if strategy == PartitionStrategy.HI_RES:
339338
# NOTE(robinson): Catches a UserWarning that occurs when detection is called
@@ -359,6 +358,8 @@ def partition_pdf_or_image(
359358
form_extraction_skip_tables=form_extraction_skip_tables,
360359
password=password,
361360
pdfminer_config=pdfminer_config,
361+
ocr_agent=ocr_agent,
362+
table_ocr_agent=table_ocr_agent,
362363
**kwargs,
363364
)
364365
out_elements = _process_uncategorized_text_elements(elements)
@@ -609,6 +610,8 @@ def _partition_pdf_or_image_local(
609610
pdf_hi_res_max_pages: Optional[int] = None,
610611
password: Optional[str] = None,
611612
pdfminer_config: Optional[PDFMinerConfig] = None,
613+
ocr_agent: str = OCR_AGENT_TESSERACT,
614+
table_ocr_agent: str = OCR_AGENT_TESSERACT,
612615
**kwargs: Any,
613616
) -> list[Element]:
614617
"""Partition using package installed locally"""
@@ -690,11 +693,13 @@ def _partition_pdf_or_image_local(
690693
extracted_layout=extracted_layout,
691694
is_image=is_image,
692695
infer_table_structure=infer_table_structure,
696+
ocr_agent=ocr_agent,
693697
ocr_languages=ocr_languages,
694698
ocr_mode=ocr_mode,
695699
pdf_image_dpi=pdf_image_dpi,
696700
ocr_layout_dumper=ocr_layout_dumper,
697701
password=password,
702+
table_ocr_agent=table_ocr_agent,
698703
)
699704
else:
700705
inferred_document_layout = process_data_with_model(
@@ -749,11 +754,13 @@ def _partition_pdf_or_image_local(
749754
extracted_layout=extracted_layout,
750755
is_image=is_image,
751756
infer_table_structure=infer_table_structure,
757+
ocr_agent=ocr_agent,
752758
ocr_languages=ocr_languages,
753759
ocr_mode=ocr_mode,
754760
pdf_image_dpi=pdf_image_dpi,
755761
ocr_layout_dumper=ocr_layout_dumper,
756762
password=password,
763+
table_ocr_agent=table_ocr_agent,
757764
)
758765

759766
# vectorization of the data structure ends here

unstructured/partition/pdf_image/ocr.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414

1515
from unstructured.documents.elements import ElementType
1616
from unstructured.metrics.table.table_formats import SimpleTableCell
17+
from unstructured.partition.common.lang import tesseract_to_paddle_language
1718
from unstructured.partition.pdf_image.analysis.layout_dump import OCRLayoutDumper
1819
from unstructured.partition.pdf_image.pdf_image_utils import valid_text
1920
from unstructured.partition.pdf_image.pdfminer_processing import (
2021
aggregate_embedded_text_by_block,
2122
bboxes1_is_almost_subregion_of_bboxes2,
2223
)
2324
from unstructured.partition.utils.config import env_config
24-
from unstructured.partition.utils.constants import OCRMode
25+
from unstructured.partition.utils.constants import OCR_AGENT_PADDLE, OCR_AGENT_TESSERACT, OCRMode
2526
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
2627
from unstructured.utils import requires_dependencies
2728

@@ -38,11 +39,13 @@ def process_data_with_ocr(
3839
extracted_layout: List[List["TextRegion"]],
3940
is_image: bool = False,
4041
infer_table_structure: bool = False,
42+
ocr_agent: str = OCR_AGENT_TESSERACT,
4143
ocr_languages: str = "eng",
4244
ocr_mode: str = OCRMode.FULL_PAGE.value,
4345
pdf_image_dpi: int = 200,
4446
ocr_layout_dumper: Optional[OCRLayoutDumper] = None,
4547
password: Optional[str] = None,
48+
table_ocr_agent: str = OCR_AGENT_TESSERACT,
4649
) -> "DocumentLayout":
4750
"""
4851
Process OCR data from a given data and supplement the output DocumentLayout
@@ -86,11 +89,13 @@ def process_data_with_ocr(
8689
extracted_layout=extracted_layout,
8790
is_image=is_image,
8891
infer_table_structure=infer_table_structure,
92+
ocr_agent=ocr_agent,
8993
ocr_languages=ocr_languages,
9094
ocr_mode=ocr_mode,
9195
pdf_image_dpi=pdf_image_dpi,
9296
ocr_layout_dumper=ocr_layout_dumper,
9397
password=password,
98+
table_ocr_agent=table_ocr_agent,
9499
)
95100

96101
return merged_layouts
@@ -103,11 +108,13 @@ def process_file_with_ocr(
103108
extracted_layout: List[TextRegions],
104109
is_image: bool = False,
105110
infer_table_structure: bool = False,
111+
ocr_agent: str = OCR_AGENT_TESSERACT,
106112
ocr_languages: str = "eng",
107113
ocr_mode: str = OCRMode.FULL_PAGE.value,
108114
pdf_image_dpi: int = 200,
109115
ocr_layout_dumper: Optional[OCRLayoutDumper] = None,
110116
password: Optional[str] = None,
117+
table_ocr_agent: str = OCR_AGENT_TESSERACT,
111118
) -> "DocumentLayout":
112119
"""
113120
Process OCR data from a given file and supplement the output DocumentLayout
@@ -154,10 +161,12 @@ def process_file_with_ocr(
154161
page_layout=out_layout.pages[i],
155162
image=image,
156163
infer_table_structure=infer_table_structure,
164+
ocr_agent=ocr_agent,
157165
ocr_languages=ocr_languages,
158166
ocr_mode=ocr_mode,
159167
extracted_regions=extracted_regions,
160168
ocr_layout_dumper=ocr_layout_dumper,
169+
table_ocr_agent=table_ocr_agent,
161170
)
162171
merged_page_layouts.append(merged_page_layout)
163172
return DocumentLayout.from_pages(merged_page_layouts)
@@ -178,10 +187,12 @@ def process_file_with_ocr(
178187
page_layout=out_layout.pages[i],
179188
image=image,
180189
infer_table_structure=infer_table_structure,
190+
ocr_agent=ocr_agent,
181191
ocr_languages=ocr_languages,
182192
ocr_mode=ocr_mode,
183193
extracted_regions=extracted_regions,
184194
ocr_layout_dumper=ocr_layout_dumper,
195+
table_ocr_agent=table_ocr_agent,
185196
)
186197
merged_page_layouts.append(merged_page_layout)
187198
return DocumentLayout.from_pages(merged_page_layouts)
@@ -197,10 +208,12 @@ def supplement_page_layout_with_ocr(
197208
page_layout: "PageLayout",
198209
image: PILImage.Image,
199210
infer_table_structure: bool = False,
211+
ocr_agent: str = OCR_AGENT_TESSERACT,
200212
ocr_languages: str = "eng",
201213
ocr_mode: str = OCRMode.FULL_PAGE.value,
202214
extracted_regions: Optional[TextRegions] = None,
203215
ocr_layout_dumper: Optional[OCRLayoutDumper] = None,
216+
table_ocr_agent: str = OCR_AGENT_TESSERACT,
204217
) -> "PageLayout":
205218
"""
206219
Supplement an PageLayout with OCR results depending on OCR mode.
@@ -210,9 +223,12 @@ def supplement_page_layout_with_ocr(
210223
with no text and add text from OCR to each element.
211224
"""
212225

213-
ocr_agent = OCRAgent.get_agent(language=ocr_languages)
226+
language = ocr_languages
227+
if ocr_agent == OCR_AGENT_PADDLE:
228+
language = tesseract_to_paddle_language(ocr_languages)
229+
_ocr_agent = OCRAgent.get_instance(ocr_agent_module=ocr_agent, language=language)
214230
if ocr_mode == OCRMode.FULL_PAGE.value:
215-
ocr_layout = ocr_agent.get_layout_from_image(image)
231+
ocr_layout = _ocr_agent.get_layout_from_image(image)
216232
if ocr_layout_dumper:
217233
ocr_layout_dumper.add_ocred_page(ocr_layout.as_list())
218234
page_layout.elements_array = merge_out_layout_with_ocr_layout(
@@ -236,7 +252,7 @@ def supplement_page_layout_with_ocr(
236252
)
237253
# Note(yuming): instead of getting OCR layout, we just need
238254
# the text extraced from OCR for individual elements
239-
text_from_ocr = ocr_agent.get_text_from_image(cropped_image)
255+
text_from_ocr = _ocr_agent.get_text_from_image(cropped_image)
240256
page_layout.elements_array.texts[i] = text_from_ocr
241257
else:
242258
raise ValueError(
@@ -246,6 +262,12 @@ def supplement_page_layout_with_ocr(
246262

247263
# Note(yuming): use the OCR data from entire page OCR for table extraction
248264
if infer_table_structure:
265+
language = ocr_languages
266+
if table_ocr_agent == OCR_AGENT_PADDLE:
267+
language = tesseract_to_paddle_language(ocr_languages)
268+
_table_ocr_agent = OCRAgent.get_instance(
269+
ocr_agent_module=table_ocr_agent, language=language
270+
)
249271
from unstructured_inference.models import tables
250272

251273
tables.load_agent()
@@ -256,7 +278,7 @@ def supplement_page_layout_with_ocr(
256278
elements=page_layout.elements_array,
257279
image=image,
258280
tables_agent=tables.tables_agent,
259-
ocr_agent=ocr_agent,
281+
ocr_agent=_table_ocr_agent,
260282
extracted_regions=extracted_regions,
261283
)
262284
page_layout.elements = page_layout.elements_array.as_list()

0 commit comments

Comments
 (0)