Skip to content

Commit f7c1c94

Browse files
feat: improved memory management during long split pdf processing (#202)
This PR: - improves the memory management by freeing the memory early when processing larger files, thus holding the data in memory for a long time - instead, it uses tempfiles to keep: - pdf_chunks for split_pdf_page functionality - partial response elements as json files
1 parent 25010eb commit f7c1c94

File tree

10 files changed

+959
-458
lines changed

10 files changed

+959
-458
lines changed

Diff for: _test_unstructured_client/integration/test_decorators.py

+84
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
import tempfile
4+
from pathlib import Path
5+
36
import httpx
47
import json
58
import pytest
@@ -102,6 +105,87 @@ def test_integration_split_pdf_has_same_output_as_non_split(
102105
)
103106
assert len(diff) == 0
104107

108+
@pytest.mark.parametrize( ("filename", "expected_ok", "strategy"), [
109+
("_sample_docs/layout-parser-paper.pdf", True, "hi_res"), # 16
110+
]# pages
111+
)
112+
@pytest.mark.parametrize( ("use_caching", "cache_dir"), [
113+
(True, None), # Use default cache dir
114+
(True, Path(tempfile.gettempdir()) / "test_integration_unstructured_client1"), # Use custom cache dir
115+
(False, None), # Don't use caching
116+
(False, Path(tempfile.gettempdir()) / "test_integration_unstructured_client2"), # Don't use caching, use custom cache dir
117+
])
118+
def test_integration_split_pdf_with_caching(
119+
filename: str, expected_ok: bool, strategy: str, use_caching: bool,
120+
cache_dir: Path | None
121+
):
122+
try:
123+
response = requests.get("http://localhost:8000/general/docs")
124+
assert response.status_code == 200, "The unstructured-api is not running on localhost:8000"
125+
except requests.exceptions.ConnectionError:
126+
assert False, "The unstructured-api is not running on localhost:8000"
127+
128+
client = UnstructuredClient(api_key_auth=FAKE_KEY, server_url="localhost:8000")
129+
130+
with open(filename, "rb") as f:
131+
files = shared.Files(
132+
content=f.read(),
133+
file_name=filename,
134+
)
135+
136+
if not expected_ok:
137+
# This will append .pdf to filename to fool first line of filetype detection, to simulate decoding error
138+
files.file_name += ".pdf"
139+
140+
parameters = shared.PartitionParameters(
141+
files=files,
142+
strategy=strategy,
143+
languages=["eng"],
144+
split_pdf_page=True,
145+
split_pdf_cache_tmp_data=use_caching,
146+
split_pdf_cache_dir=cache_dir,
147+
)
148+
149+
req = operations.PartitionRequest(
150+
partition_parameters=parameters
151+
)
152+
153+
try:
154+
resp_split = client.general.partition(request=req)
155+
except (HTTPValidationError, AttributeError) as exc:
156+
if not expected_ok:
157+
assert "File does not appear to be a valid PDF" in str(exc)
158+
return
159+
else:
160+
assert exc is None
161+
162+
parameters.split_pdf_page = False
163+
164+
req = operations.PartitionRequest(
165+
partition_parameters=parameters
166+
)
167+
168+
resp_single = client.general.partition(request=req)
169+
170+
assert len(resp_split.elements) == len(resp_single.elements)
171+
assert resp_split.content_type == resp_single.content_type
172+
assert resp_split.status_code == resp_single.status_code
173+
174+
diff = DeepDiff(
175+
t1=resp_split.elements,
176+
t2=resp_single.elements,
177+
exclude_regex_paths=[
178+
r"root\[\d+\]\['metadata'\]\['parent_id'\]",
179+
r"root\[\d+\]\['element_id'\]",
180+
],
181+
)
182+
assert len(diff) == 0
183+
184+
# make sure the cache dir was cleaned if passed explicitly
185+
if cache_dir:
186+
assert not Path(cache_dir).exists()
187+
188+
105189

106190
def test_integration_split_pdf_for_file_with_no_name():
107191
"""

Diff for: _test_unstructured_client/unit/test_request_utils.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Get unit tests for request_utils.py module
2+
import httpx
3+
import pytest
4+
5+
from unstructured_client._hooks.custom.request_utils import create_pdf_chunk_request_params, get_multipart_stream_fields
6+
from unstructured_client.models import shared
7+
8+
9+
# make the above test using @pytest.mark.parametrize
10+
@pytest.mark.parametrize(("input_request", "expected"), [
11+
(httpx.Request("POST", "http://localhost:8000", data={}, headers={"Content-Type": "multipart/form-data"}), {}),
12+
(httpx.Request("POST", "http://localhost:8000", data={"hello": "world"}, headers={"Content-Type": "application/json"}), {}),
13+
(httpx.Request(
14+
"POST",
15+
"http://localhost:8000",
16+
data={"hello": "world"},
17+
files={"files": ("hello.pdf", b"hello", "application/pdf")},
18+
headers={"Content-Type": "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"}),
19+
{
20+
"hello": "world",
21+
"files": {
22+
"content_type":"application/pdf",
23+
"filename": "hello.pdf",
24+
"file": b"hello",
25+
}
26+
}
27+
),
28+
])
29+
def test_get_multipart_stream_fields(input_request, expected):
30+
fields = get_multipart_stream_fields(input_request)
31+
assert fields == expected
32+
33+
def test_multipart_stream_fields_raises_value_error_when_filename_is_not_set():
34+
with pytest.raises(ValueError):
35+
get_multipart_stream_fields(httpx.Request(
36+
"POST",
37+
"http://localhost:8000",
38+
data={"hello": "world"},
39+
files={"files": ("", b"hello", "application/pdf")},
40+
headers={"Content-Type": "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"}),
41+
)
42+
43+
@pytest.mark.parametrize(("input_form_data", "page_number", "expected_form_data"), [
44+
(
45+
{"hello": "world"},
46+
2,
47+
{"hello": "world", "split_pdf_page": "false", "starting_page_number": "2"}
48+
),
49+
(
50+
{"hello": "world", "split_pdf_page": "true"},
51+
2,
52+
{"hello": "world", "split_pdf_page": "false", "starting_page_number": "2"}
53+
),
54+
(
55+
{"hello": "world", "split_pdf_page": "true", "files": "dummy_file"},
56+
3,
57+
{"hello": "world", "split_pdf_page": "false", "starting_page_number": "3"}
58+
),
59+
(
60+
{"split_pdf_page_range[]": [1, 3], "hello": "world", "split_pdf_page": "true", "files": "dummy_file"},
61+
3,
62+
{"hello": "world", "split_pdf_page": "false", "starting_page_number": "3"}
63+
),
64+
(
65+
{"split_pdf_page_range": [1, 3], "hello": "world", "split_pdf_page": "true", "files": "dummy_file"},
66+
4,
67+
{"hello": "world", "split_pdf_page": "false", "starting_page_number": "4"}
68+
),
69+
])
70+
def test_create_pdf_chunk_request_params(input_form_data, page_number, expected_form_data):
71+
form_data = create_pdf_chunk_request_params(input_form_data, page_number)
72+
assert form_data == expected_form_data

Diff for: _test_unstructured_client/unit/test_split_pdf_hook.py

+40-85
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
from asyncio import Task
77
from collections import Counter
8+
from functools import partial
89
from typing import Coroutine
910

1011
import httpx
@@ -53,29 +54,6 @@ async def example():
5354
assert hook.api_successful_responses.get(operation_id) is None
5455

5556

56-
def test_unit_prepare_request_payload():
57-
"""Test prepare request payload method properly sets split_pdf_page to 'false'
58-
and removes files key."""
59-
test_form_data = {
60-
"files": ("test_file.pdf", b"test_file_content"),
61-
"split_pdf_page": "true",
62-
"parameter_1": "value_1",
63-
"parameter_2": "value_2",
64-
"parameter_3": "value_3",
65-
}
66-
expected_form_data = {
67-
"split_pdf_page": "false",
68-
"parameter_1": "value_1",
69-
"parameter_2": "value_2",
70-
"parameter_3": "value_3",
71-
}
72-
73-
payload = request_utils.prepare_request_payload(test_form_data)
74-
75-
assert payload != test_form_data
76-
assert payload, expected_form_data
77-
78-
7957
def test_unit_prepare_request_headers():
8058
"""Test prepare request headers method properly removes Content-Type and Content-Length headers."""
8159
test_headers = {
@@ -224,61 +202,31 @@ def test_unit_parse_form_data_none_filename_error():
224202
form_utils.parse_form_data(decoded_data)
225203

226204

227-
def test_unit_is_pdf_valid_pdf():
228-
"""Test is pdf method returns True for valid pdf file with filename."""
205+
def test_unit_is_pdf_valid_pdf_when_passing_file_object():
206+
"""Test is pdf method returns pdf object for valid pdf file with filename."""
229207
filename = "_sample_docs/layout-parser-paper-fast.pdf"
230208

231209
with open(filename, "rb") as f:
232-
file = shared.Files(
233-
content=f.read(),
234-
file_name=filename,
235-
)
236-
237-
result = pdf_utils.is_pdf(file)
210+
result = pdf_utils.read_pdf(f)
238211

239-
assert result is True
212+
assert result is not None
240213

241214

242-
def test_unit_is_pdf_valid_pdf_without_file_extension():
243-
"""Test is pdf method returns True for file with valid pdf content without basing on file extension."""
215+
def test_unit_is_pdf_valid_pdf_when_passing_binary_content():
216+
"""Test is pdf method returns pdf object for file with valid pdf content"""
244217
filename = "_sample_docs/layout-parser-paper-fast.pdf"
245218

246219
with open(filename, "rb") as f:
247-
file = shared.Files(
248-
content=f.read(),
249-
file_name="uuid1234",
250-
)
251-
252-
result = pdf_utils.is_pdf(file)
253-
254-
assert result is True
255-
256-
257-
def test_unit_is_pdf_invalid_extension():
258-
"""Test is pdf method returns False for file with invalid extension."""
259-
file = shared.Files(content=b"txt_content", file_name="test_file.txt")
260-
261-
result = pdf_utils.is_pdf(file)
220+
result = pdf_utils.read_pdf(f.read())
262221

263-
assert result is False
222+
assert result is not None
264223

265224

266225
def test_unit_is_pdf_invalid_pdf():
267-
"""Test is pdf method returns False for file with invalid pdf content."""
268-
file = shared.Files(content=b"invalid_pdf_content", file_name="test_file.pdf")
269-
270-
result = pdf_utils.is_pdf(file)
271-
272-
assert result is False
273-
274-
275-
def test_unit_is_pdf_invalid_pdf_without_file_extension():
276-
"""Test is pdf method returns False for file with invalid pdf content without basing on file extension."""
277-
file = shared.Files(content=b"invalid_pdf_content", file_name="uuid1234")
278-
279-
result = pdf_utils.is_pdf(file)
226+
"""Test is pdf method returns False for file with invalid extension."""
227+
result = pdf_utils.read_pdf(b"txt_content")
280228

281-
assert result is False
229+
assert result is None
282230

283231

284232
def test_unit_get_starting_page_number_missing_key():
@@ -388,7 +336,10 @@ def test_unit_get_page_range_returns_valid_range(page_range, expected_result):
388336
assert result == expected_result
389337

390338

391-
async def _request_mock(fails: bool, content: str) -> requests.Response:
339+
async def _request_mock(
340+
async_client: httpx.AsyncClient, # not used by mock
341+
fails: bool,
342+
content: str) -> requests.Response:
392343
response = requests.Response()
393344
response.status_code = 500 if fails else 200
394345
response._content = content.encode()
@@ -399,40 +350,40 @@ async def _request_mock(fails: bool, content: str) -> requests.Response:
399350
("allow_failed", "tasks", "expected_responses"), [
400351
pytest.param(
401352
True, [
402-
_request_mock(fails=False, content="1"),
403-
_request_mock(fails=False, content="2"),
404-
_request_mock(fails=False, content="3"),
405-
_request_mock(fails=False, content="4"),
353+
partial(_request_mock, fails=False, content="1"),
354+
partial(_request_mock, fails=False, content="2"),
355+
partial(_request_mock, fails=False, content="3"),
356+
partial(_request_mock, fails=False, content="4"),
406357
],
407358
["1", "2", "3", "4"],
408359
id="no failures, fails allower"
409360
),
410361
pytest.param(
411362
True, [
412-
_request_mock(fails=False, content="1"),
413-
_request_mock(fails=True, content="2"),
414-
_request_mock(fails=False, content="3"),
415-
_request_mock(fails=True, content="4"),
363+
partial(_request_mock, fails=False, content="1"),
364+
partial(_request_mock, fails=True, content="2"),
365+
partial(_request_mock, fails=False, content="3"),
366+
partial(_request_mock, fails=True, content="4"),
416367
],
417368
["1", "2", "3", "4"],
418369
id="failures, fails allowed"
419370
),
420371
pytest.param(
421372
False, [
422-
_request_mock(fails=True, content="failure"),
423-
_request_mock(fails=False, content="2"),
424-
_request_mock(fails=True, content="failure"),
425-
_request_mock(fails=False, content="4"),
373+
partial(_request_mock, fails=True, content="failure"),
374+
partial(_request_mock, fails=False, content="2"),
375+
partial(_request_mock, fails=True, content="failure"),
376+
partial(_request_mock, fails=False, content="4"),
426377
],
427378
["failure"],
428379
id="failures, fails disallowed"
429380
),
430381
pytest.param(
431382
False, [
432-
_request_mock(fails=False, content="1"),
433-
_request_mock(fails=False, content="2"),
434-
_request_mock(fails=False, content="3"),
435-
_request_mock(fails=False, content="4"),
383+
partial(_request_mock, fails=False, content="1"),
384+
partial(_request_mock, fails=False, content="2"),
385+
partial(_request_mock, fails=False, content="3"),
386+
partial(_request_mock, fails=False, content="4"),
436387
],
437388
["1", "2", "3", "4"],
438389
id="no failures, fails disallowed"
@@ -451,14 +402,18 @@ async def test_unit_disallow_failed_coroutines(
451402
assert response_contents == expected_responses
452403

453404

454-
async def _fetch_canceller_error(fails: bool, content: str, cancelled_counter: Counter):
405+
async def _fetch_canceller_error(
406+
async_client: httpx.AsyncClient, # not used by mock
407+
fails: bool,
408+
content: str,
409+
cancelled_counter: Counter):
455410
try:
456411
if not fails:
457412
await asyncio.sleep(0.01)
458413
print("Doesn't fail")
459414
else:
460415
print("Fails")
461-
return await _request_mock(fails=fails, content=content)
416+
return await _request_mock(async_client=async_client, fails=fails, content=content)
462417
except asyncio.CancelledError:
463418
cancelled_counter.update(["cancelled"])
464419
print(cancelled_counter["cancelled"])
@@ -469,8 +424,8 @@ async def _fetch_canceller_error(fails: bool, content: str, cancelled_counter: C
469424
async def test_remaining_tasks_cancelled_when_fails_disallowed():
470425
cancelled_counter = Counter()
471426
tasks = [
472-
_fetch_canceller_error(fails=True, content="1", cancelled_counter=cancelled_counter),
473-
*[_fetch_canceller_error(fails=False, content=f"{i}", cancelled_counter=cancelled_counter)
427+
partial(_fetch_canceller_error, fails=True, content="1", cancelled_counter=cancelled_counter),
428+
*[partial(_fetch_canceller_error, fails=False, content=f"{i}", cancelled_counter=cancelled_counter)
474429
for i in range(2, 200)],
475430
]
476431

0 commit comments

Comments
 (0)