Skip to content

Commit 40c4443

Browse files
Feat/add allow failed param when pdf split is used (#134)
This PR: - adds `split_pdf_page_allow_failed` parameter
1 parent 7c3f38b commit 40c4443

File tree

11 files changed

+359
-79
lines changed

11 files changed

+359
-79
lines changed

Diff for: Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ DOCKER_IMAGE ?= downloads.unstructured.io/unstructured-io/unstructured-api:lates
99

1010
.PHONY: install-test
1111
install-test:
12-
pip install pytest requests_mock pypdf deepdiff requests-toolbelt
12+
pip install pytest pytest-asyncio pytest-mock requests_mock pypdf deepdiff requests-toolbelt
1313

1414
.PHONY: install-dev
1515
install-dev:

Diff for: README.md

+15
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,21 @@ req = shared.PartitionParameters(
109109
)
110110
```
111111

112+
#### Splitting PDF by pages - strict mode
113+
114+
When `split_pdf_allow_failed=False` (the default), any errors encountered during sending parallel request will break the process and raise an exception.
115+
When `split_pdf_allow_failed=True`, the process will continue even if some requests fail, and the results will be combined at the end (the output from the errored pages will not be included).
116+
117+
Example:
118+
```python
119+
req = shared.PartitionParameters(
120+
files=files,
121+
strategy="fast",
122+
languages=["eng"],
123+
split_pdf_allow_failed=True,
124+
)
125+
```
126+
112127
<!-- Start Retries [retries] -->
113128
## Retries
114129

Diff for: USAGE.md

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ res = s.general.partition(request=operations.PartitionRequest(
1919
1,
2020
10,
2121
],
22+
split_pdf_allow_failed=False,
2223
strategy=shared.Strategy.HI_RES,
2324
),
2425
))

Diff for: _test_unstructured_client/integration/test_decorators.py

+73
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,76 @@ def test_integration_split_pdf_with_page_range(
180180

181181
assert min(page_numbers) == min_page_number, f"Result should start at page {min_page_number}"
182182
assert max(page_numbers) == max_page_number, f"Result should end at page {max_page_number}"
183+
184+
185+
@pytest.mark.parametrize("concurrency_level", [2, 3])
186+
@pytest.mark.parametrize("allow_failed", [True, False])
187+
@pytest.mark.parametrize(
188+
("filename", "expected_ok", "strategy"),
189+
[
190+
("_sample_docs/list-item-example-1.pdf", True, "fast"), # 1 page
191+
("_sample_docs/layout-parser-paper-fast.pdf", True, "fast"), # 2 pages
192+
("_sample_docs/layout-parser-paper.pdf", True, shared.Strategy.HI_RES), # 16 pages
193+
],
194+
)
195+
def test_integration_split_pdf_strict_mode(
196+
concurrency_level: int,
197+
allow_failed: bool,
198+
filename: str,
199+
expected_ok: bool,
200+
strategy: shared.Strategy,
201+
caplog
202+
):
203+
"""Test strict mode (allow failed = False) for split_pdf."""
204+
try:
205+
response = requests.get("http://localhost:8000/general/docs")
206+
assert response.status_code == 200, "The unstructured-api is not running on localhost:8000"
207+
except requests.exceptions.ConnectionError:
208+
assert False, "The unstructured-api is not running on localhost:8000"
209+
210+
client = UnstructuredClient(api_key_auth=FAKE_KEY, server_url="localhost:8000")
211+
212+
with open(filename, "rb") as f:
213+
files = shared.Files(
214+
content=f.read(),
215+
file_name=filename,
216+
)
217+
218+
if not expected_ok:
219+
# This will append .pdf to filename to fool first line of filetype detection, to simulate decoding error
220+
files.file_name += ".pdf"
221+
222+
req = shared.PartitionParameters(
223+
files=files,
224+
strategy=strategy,
225+
languages=["eng"],
226+
split_pdf_page=True,
227+
split_pdf_concurrency_level=concurrency_level,
228+
split_pdf_allow_failed=allow_failed,
229+
)
230+
231+
try:
232+
resp_split = client.general.partition(req)
233+
except (HTTPValidationError, AttributeError) as exc:
234+
if not expected_ok:
235+
assert "The file does not appear to be a valid PDF." in caplog.text
236+
assert "File does not appear to be a valid PDF" in str(exc)
237+
return
238+
else:
239+
assert exc is None
240+
241+
req.split_pdf_page = False
242+
resp_single = client.general.partition(req)
243+
244+
assert len(resp_split.elements) == len(resp_single.elements)
245+
assert resp_split.content_type == resp_single.content_type
246+
assert resp_split.status_code == resp_single.status_code
247+
248+
diff = DeepDiff(
249+
t1=resp_split.elements,
250+
t2=resp_single.elements,
251+
exclude_regex_paths=[
252+
r"root\[\d+\]\['metadata'\]\['parent_id'\]",
253+
],
254+
)
255+
assert len(diff) == 0

Diff for: _test_unstructured_client/unit/test_split_pdf_hook.py

+118-22
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
import asyncio
12
import io
23
import logging
3-
from concurrent.futures import Future
4+
from asyncio import Task
5+
from collections import Counter
6+
from typing import Coroutine
47

58
import pytest
69
import requests
710
from requests_toolbelt import MultipartDecoder, MultipartEncoder
11+
812
from unstructured_client._hooks.custom import form_utils, pdf_utils, request_utils
913
from unstructured_client._hooks.custom.form_utils import (
1014
PARTITION_FORM_CONCURRENCY_LEVEL_KEY,
@@ -18,7 +22,7 @@
1822
MAX_PAGES_PER_SPLIT,
1923
MIN_PAGES_PER_SPLIT,
2024
SplitPdfHook,
21-
get_optimal_split_size,
25+
get_optimal_split_size, run_tasks,
2226
)
2327
from unstructured_client.models import shared
2428

@@ -224,7 +228,6 @@ def test_unit_parse_form_data():
224228
b"--boundary--\r\n"
225229
)
226230

227-
228231
decoded_data = MultipartDecoder(
229232
test_form_data,
230233
"multipart/form-data; boundary=boundary",
@@ -361,22 +364,22 @@ def test_get_optimal_split_size(num_pages, concurrency_level, expected_split_siz
361364
({}, DEFAULT_CONCURRENCY_LEVEL), # no value
362365
({"split_pdf_concurrency_level": 10}, 10), # valid number
363366
(
364-
# exceeds max value
365-
{"split_pdf_concurrency_level": f"{MAX_CONCURRENCY_LEVEL+1}"},
366-
MAX_CONCURRENCY_LEVEL,
367+
# exceeds max value
368+
{"split_pdf_concurrency_level": f"{MAX_CONCURRENCY_LEVEL + 1}"},
369+
MAX_CONCURRENCY_LEVEL,
367370
),
368371
({"split_pdf_concurrency_level": -3}, DEFAULT_CONCURRENCY_LEVEL), # negative value
369372
],
370373
)
371374
def test_unit_get_split_pdf_concurrency_level_returns_valid_number(form_data, expected_result):
372375
assert (
373-
form_utils.get_split_pdf_concurrency_level_param(
374-
form_data,
375-
key=PARTITION_FORM_CONCURRENCY_LEVEL_KEY,
376-
fallback_value=DEFAULT_CONCURRENCY_LEVEL,
377-
max_allowed=MAX_CONCURRENCY_LEVEL,
378-
)
379-
== expected_result
376+
form_utils.get_split_pdf_concurrency_level_param(
377+
form_data,
378+
key=PARTITION_FORM_CONCURRENCY_LEVEL_KEY,
379+
fallback_value=DEFAULT_CONCURRENCY_LEVEL,
380+
max_allowed=MAX_CONCURRENCY_LEVEL,
381+
)
382+
== expected_result
380383
)
381384

382385

@@ -404,16 +407,16 @@ def test_unit_get_starting_page_number(starting_page_number, expected_result):
404407
@pytest.mark.parametrize(
405408
"page_range, expected_result",
406409
[
407-
(["1", "14"], (1, 14)), # Valid range, start on boundary
408-
(["4", "16"], (4, 16)), # Valid range, end on boundary
409-
(None, (1, 20)), # Range not specified, defaults to full range
410+
(["1", "14"], (1, 14)), # Valid range, start on boundary
411+
(["4", "16"], (4, 16)), # Valid range, end on boundary
412+
(None, (1, 20)), # Range not specified, defaults to full range
410413
(["2", "5"], (2, 5)), # Valid range within boundary
411-
(["2", "100"], None), # End page too high
412-
(["50", "100"], None), # Range too high
413-
(["-50", "5"], None), # Start page too low
414-
(["-50", "-2"], None), # Range too low
415-
(["10", "2"], None), # Backwards range
416-
(["foo", "foo"], None), # Parse error
414+
(["2", "100"], None), # End page too high
415+
(["50", "100"], None), # Range too high
416+
(["-50", "5"], None), # Start page too low
417+
(["-50", "-2"], None), # Range too low
418+
(["10", "2"], None), # Backwards range
419+
(["foo", "foo"], None), # Parse error
417420
],
418421
)
419422
def test_unit_get_page_range_returns_valid_range(page_range, expected_result):
@@ -432,3 +435,96 @@ def test_unit_get_page_range_returns_valid_range(page_range, expected_result):
432435
return
433436

434437
assert result == expected_result
438+
439+
440+
async def _request_mock(fails: bool, content: str) -> requests.Response:
441+
response = requests.Response()
442+
response.status_code = 500 if fails else 200
443+
response._content = content.encode()
444+
return response
445+
446+
447+
@pytest.mark.parametrize(
448+
("allow_failed", "tasks", "expected_responses"), [
449+
pytest.param(
450+
True, [
451+
_request_mock(fails=False, content="1"),
452+
_request_mock(fails=False, content="2"),
453+
_request_mock(fails=False, content="3"),
454+
_request_mock(fails=False, content="4"),
455+
],
456+
["1", "2", "3", "4"],
457+
id="no failures, fails allower"
458+
),
459+
pytest.param(
460+
True, [
461+
_request_mock(fails=False, content="1"),
462+
_request_mock(fails=True, content="2"),
463+
_request_mock(fails=False, content="3"),
464+
_request_mock(fails=True, content="4"),
465+
],
466+
["1", "2", "3", "4"],
467+
id="failures, fails allowed"
468+
),
469+
pytest.param(
470+
False, [
471+
_request_mock(fails=True, content="failure"),
472+
_request_mock(fails=False, content="2"),
473+
_request_mock(fails=True, content="failure"),
474+
_request_mock(fails=False, content="4"),
475+
],
476+
["failure"],
477+
id="failures, fails disallowed"
478+
),
479+
pytest.param(
480+
False, [
481+
_request_mock(fails=False, content="1"),
482+
_request_mock(fails=False, content="2"),
483+
_request_mock(fails=False, content="3"),
484+
_request_mock(fails=False, content="4"),
485+
],
486+
["1", "2", "3", "4"],
487+
id="no failures, fails disallowed"
488+
),
489+
]
490+
)
491+
@pytest.mark.asyncio
492+
async def test_unit_disallow_failed_coroutines(
493+
allow_failed: bool,
494+
tasks: list[Task],
495+
expected_responses: list[str],
496+
):
497+
"""Test disallow failed coroutines method properly sets the flag to False."""
498+
responses = await run_tasks(tasks, allow_failed=allow_failed)
499+
response_contents = [response[1].content.decode() for response in responses]
500+
assert response_contents == expected_responses
501+
502+
503+
async def _fetch_canceller_error(fails: bool, content: str, cancelled_counter: Counter):
504+
try:
505+
if not fails:
506+
await asyncio.sleep(0.01)
507+
print("Doesn't fail")
508+
else:
509+
print("Fails")
510+
return await _request_mock(fails=fails, content=content)
511+
except asyncio.CancelledError:
512+
cancelled_counter.update(["cancelled"])
513+
print(cancelled_counter["cancelled"])
514+
print("Cancelled")
515+
516+
517+
@pytest.mark.asyncio
518+
async def test_remaining_tasks_cancelled_when_fails_disallowed():
519+
cancelled_counter = Counter()
520+
tasks = [
521+
_fetch_canceller_error(fails=True, content="1", cancelled_counter=cancelled_counter),
522+
*[_fetch_canceller_error(fails=False, content=f"{i}", cancelled_counter=cancelled_counter)
523+
for i in range(2, 200)],
524+
]
525+
526+
await run_tasks(tasks, allow_failed=False)
527+
# give some time to actually cancel the tasks in background
528+
await asyncio.sleep(1)
529+
print("Cancelled amount: ", cancelled_counter["cancelled"])
530+
assert len(tasks) > cancelled_counter["cancelled"] > 0

0 commit comments

Comments
 (0)