-
Notifications
You must be signed in to change notification settings - Fork 904
/
Copy pathtest_ocr_interface.py
111 lines (91 loc) · 4.19 KB
/
test_ocr_interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# pyright: reportPrivateUsage=false
"""Unit-test suite for the `unstructured.partition.utils.ocr_models.ocr_interface` module."""
from __future__ import annotations
from unittest.mock import patch
import pytest
from test_unstructured.unit_utils import (
FixtureRequest,
LogCaptureFixture,
Mock,
instance_mock,
method_mock,
property_mock,
)
from unstructured.partition.utils.config import ENVConfig
from unstructured.partition.utils.constants import (
OCR_AGENT_PADDLE,
OCR_AGENT_PADDLE_OLD,
OCR_AGENT_TESSERACT,
OCR_AGENT_TESSERACT_OLD,
)
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
class DescribeOCRAgent:
"""Unit-test suite for `unstructured.partition.utils...ocr_interface.OCRAgent` class."""
def it_provides_access_to_the_configured_OCR_agent(
self, _get_ocr_agent_cls_qname_: Mock, get_instance_: Mock, ocr_agent_: Mock
):
_get_ocr_agent_cls_qname_.return_value = OCR_AGENT_TESSERACT
get_instance_.return_value = ocr_agent_
ocr_agent = OCRAgent.get_agent(language="eng")
_get_ocr_agent_cls_qname_.assert_called_once_with()
get_instance_.assert_called_once_with(OCR_AGENT_TESSERACT, "eng")
assert ocr_agent is ocr_agent_
def but_it_raises_when_the_requested_agent_is_not_whitelisted(
self, _get_ocr_agent_cls_qname_: Mock
):
_get_ocr_agent_cls_qname_.return_value = "Invalid.Ocr.Agent.Qname"
with pytest.raises(ValueError, match="must be set to a whitelisted module"):
OCRAgent.get_agent(language="eng")
@pytest.mark.parametrize("exception_cls", [ImportError, AttributeError])
def and_it_raises_when_the_requested_agent_cannot_be_loaded(
self, _get_ocr_agent_cls_qname_: Mock, exception_cls: type[Exception], _clear_cache
):
_get_ocr_agent_cls_qname_.return_value = OCR_AGENT_TESSERACT
with patch(
"unstructured.partition.utils.ocr_models.ocr_interface.importlib.import_module",
side_effect=exception_cls,
), pytest.raises(RuntimeError, match="Could not get the OCRAgent instance"):
OCRAgent.get_agent(language="eng")
@pytest.mark.parametrize(
("OCR_AGENT", "expected_value"),
[
(OCR_AGENT_PADDLE, OCR_AGENT_PADDLE),
(OCR_AGENT_PADDLE_OLD, OCR_AGENT_PADDLE),
(OCR_AGENT_TESSERACT, OCR_AGENT_TESSERACT),
(OCR_AGENT_TESSERACT_OLD, OCR_AGENT_TESSERACT),
],
)
def it_computes_the_OCR_agent_qualified_module_name(
self, OCR_AGENT: str, expected_value: str, OCR_AGENT_prop_: Mock
):
OCR_AGENT_prop_.return_value = OCR_AGENT
assert OCRAgent._get_ocr_agent_cls_qname() == expected_value
@pytest.mark.parametrize("OCR_AGENT", [OCR_AGENT_PADDLE_OLD, OCR_AGENT_TESSERACT_OLD])
def and_it_logs_a_warning_when_the_OCR_AGENT_module_name_is_obsolete(
self, caplog: LogCaptureFixture, OCR_AGENT: str, OCR_AGENT_prop_: Mock
):
OCR_AGENT_prop_.return_value = OCR_AGENT
OCRAgent._get_ocr_agent_cls_qname()
assert f"OCR agent name {OCR_AGENT} is outdated " in caplog.text
# -- fixtures --------------------------------------------------------------------------------
@pytest.fixture()
def _clear_cache(self):
# Clear the cache created by @functools.lru_cache(maxsize=None) on OCRAgent.get_instance()
# before each test
OCRAgent.get_instance.cache_clear()
yield
# Clear the cache created by @functools.lru_cache(maxsize=None) on OCRAgent.get_instance()
# after each test (just in case)
OCRAgent.get_instance.cache_clear()
@pytest.fixture()
def get_instance_(self, request: FixtureRequest):
return method_mock(request, OCRAgent, "get_instance")
@pytest.fixture()
def _get_ocr_agent_cls_qname_(self, request: FixtureRequest): # noqa: PT005
return method_mock(request, OCRAgent, "_get_ocr_agent_cls_qname")
@pytest.fixture()
def ocr_agent_(self, request: FixtureRequest):
return instance_mock(request, OCRAgent)
@pytest.fixture()
def OCR_AGENT_prop_(self, request: FixtureRequest):
return property_mock(request, ENVConfig, "OCR_AGENT")