1
+ import asyncio
1
2
import io
2
3
import logging
3
- from concurrent .futures import Future
4
+ from asyncio import Task
5
+ from collections import Counter
6
+ from typing import Coroutine
4
7
5
8
import pytest
6
9
import requests
7
10
from requests_toolbelt import MultipartDecoder , MultipartEncoder
11
+
8
12
from unstructured_client ._hooks .custom import form_utils , pdf_utils , request_utils
9
13
from unstructured_client ._hooks .custom .form_utils import (
10
14
PARTITION_FORM_CONCURRENCY_LEVEL_KEY ,
18
22
MAX_PAGES_PER_SPLIT ,
19
23
MIN_PAGES_PER_SPLIT ,
20
24
SplitPdfHook ,
21
- get_optimal_split_size ,
25
+ get_optimal_split_size , run_tasks ,
22
26
)
23
27
from unstructured_client .models import shared
24
28
@@ -224,7 +228,6 @@ def test_unit_parse_form_data():
224
228
b"--boundary--\r \n "
225
229
)
226
230
227
-
228
231
decoded_data = MultipartDecoder (
229
232
test_form_data ,
230
233
"multipart/form-data; boundary=boundary" ,
@@ -361,22 +364,22 @@ def test_get_optimal_split_size(num_pages, concurrency_level, expected_split_siz
361
364
({}, DEFAULT_CONCURRENCY_LEVEL ), # no value
362
365
({"split_pdf_concurrency_level" : 10 }, 10 ), # valid number
363
366
(
364
- # exceeds max value
365
- {"split_pdf_concurrency_level" : f"{ MAX_CONCURRENCY_LEVEL + 1 } " },
366
- MAX_CONCURRENCY_LEVEL ,
367
+ # exceeds max value
368
+ {"split_pdf_concurrency_level" : f"{ MAX_CONCURRENCY_LEVEL + 1 } " },
369
+ MAX_CONCURRENCY_LEVEL ,
367
370
),
368
371
({"split_pdf_concurrency_level" : - 3 }, DEFAULT_CONCURRENCY_LEVEL ), # negative value
369
372
],
370
373
)
371
374
def test_unit_get_split_pdf_concurrency_level_returns_valid_number (form_data , expected_result ):
372
375
assert (
373
- form_utils .get_split_pdf_concurrency_level_param (
374
- form_data ,
375
- key = PARTITION_FORM_CONCURRENCY_LEVEL_KEY ,
376
- fallback_value = DEFAULT_CONCURRENCY_LEVEL ,
377
- max_allowed = MAX_CONCURRENCY_LEVEL ,
378
- )
379
- == expected_result
376
+ form_utils .get_split_pdf_concurrency_level_param (
377
+ form_data ,
378
+ key = PARTITION_FORM_CONCURRENCY_LEVEL_KEY ,
379
+ fallback_value = DEFAULT_CONCURRENCY_LEVEL ,
380
+ max_allowed = MAX_CONCURRENCY_LEVEL ,
381
+ )
382
+ == expected_result
380
383
)
381
384
382
385
@@ -404,16 +407,16 @@ def test_unit_get_starting_page_number(starting_page_number, expected_result):
404
407
@pytest .mark .parametrize (
405
408
"page_range, expected_result" ,
406
409
[
407
- (["1" , "14" ], (1 , 14 )), # Valid range, start on boundary
408
- (["4" , "16" ], (4 , 16 )), # Valid range, end on boundary
409
- (None , (1 , 20 )), # Range not specified, defaults to full range
410
+ (["1" , "14" ], (1 , 14 )), # Valid range, start on boundary
411
+ (["4" , "16" ], (4 , 16 )), # Valid range, end on boundary
412
+ (None , (1 , 20 )), # Range not specified, defaults to full range
410
413
(["2" , "5" ], (2 , 5 )), # Valid range within boundary
411
- (["2" , "100" ], None ), # End page too high
412
- (["50" , "100" ], None ), # Range too high
413
- (["-50" , "5" ], None ), # Start page too low
414
- (["-50" , "-2" ], None ), # Range too low
415
- (["10" , "2" ], None ), # Backwards range
416
- (["foo" , "foo" ], None ), # Parse error
414
+ (["2" , "100" ], None ), # End page too high
415
+ (["50" , "100" ], None ), # Range too high
416
+ (["-50" , "5" ], None ), # Start page too low
417
+ (["-50" , "-2" ], None ), # Range too low
418
+ (["10" , "2" ], None ), # Backwards range
419
+ (["foo" , "foo" ], None ), # Parse error
417
420
],
418
421
)
419
422
def test_unit_get_page_range_returns_valid_range (page_range , expected_result ):
@@ -432,3 +435,96 @@ def test_unit_get_page_range_returns_valid_range(page_range, expected_result):
432
435
return
433
436
434
437
assert result == expected_result
438
+
439
+
440
+ async def _request_mock (fails : bool , content : str ) -> requests .Response :
441
+ response = requests .Response ()
442
+ response .status_code = 500 if fails else 200
443
+ response ._content = content .encode ()
444
+ return response
445
+
446
+
447
+ @pytest .mark .parametrize (
448
+ ("allow_failed" , "tasks" , "expected_responses" ), [
449
+ pytest .param (
450
+ True , [
451
+ _request_mock (fails = False , content = "1" ),
452
+ _request_mock (fails = False , content = "2" ),
453
+ _request_mock (fails = False , content = "3" ),
454
+ _request_mock (fails = False , content = "4" ),
455
+ ],
456
+ ["1" , "2" , "3" , "4" ],
457
+ id = "no failures, fails allower"
458
+ ),
459
+ pytest .param (
460
+ True , [
461
+ _request_mock (fails = False , content = "1" ),
462
+ _request_mock (fails = True , content = "2" ),
463
+ _request_mock (fails = False , content = "3" ),
464
+ _request_mock (fails = True , content = "4" ),
465
+ ],
466
+ ["1" , "2" , "3" , "4" ],
467
+ id = "failures, fails allowed"
468
+ ),
469
+ pytest .param (
470
+ False , [
471
+ _request_mock (fails = True , content = "failure" ),
472
+ _request_mock (fails = False , content = "2" ),
473
+ _request_mock (fails = True , content = "failure" ),
474
+ _request_mock (fails = False , content = "4" ),
475
+ ],
476
+ ["failure" ],
477
+ id = "failures, fails disallowed"
478
+ ),
479
+ pytest .param (
480
+ False , [
481
+ _request_mock (fails = False , content = "1" ),
482
+ _request_mock (fails = False , content = "2" ),
483
+ _request_mock (fails = False , content = "3" ),
484
+ _request_mock (fails = False , content = "4" ),
485
+ ],
486
+ ["1" , "2" , "3" , "4" ],
487
+ id = "no failures, fails disallowed"
488
+ ),
489
+ ]
490
+ )
491
+ @pytest .mark .asyncio
492
+ async def test_unit_disallow_failed_coroutines (
493
+ allow_failed : bool ,
494
+ tasks : list [Task ],
495
+ expected_responses : list [str ],
496
+ ):
497
+ """Test disallow failed coroutines method properly sets the flag to False."""
498
+ responses = await run_tasks (tasks , allow_failed = allow_failed )
499
+ response_contents = [response [1 ].content .decode () for response in responses ]
500
+ assert response_contents == expected_responses
501
+
502
+
503
+ async def _fetch_canceller_error (fails : bool , content : str , cancelled_counter : Counter ):
504
+ try :
505
+ if not fails :
506
+ await asyncio .sleep (0.01 )
507
+ print ("Doesn't fail" )
508
+ else :
509
+ print ("Fails" )
510
+ return await _request_mock (fails = fails , content = content )
511
+ except asyncio .CancelledError :
512
+ cancelled_counter .update (["cancelled" ])
513
+ print (cancelled_counter ["cancelled" ])
514
+ print ("Cancelled" )
515
+
516
+
517
+ @pytest .mark .asyncio
518
+ async def test_remaining_tasks_cancelled_when_fails_disallowed ():
519
+ cancelled_counter = Counter ()
520
+ tasks = [
521
+ _fetch_canceller_error (fails = True , content = "1" , cancelled_counter = cancelled_counter ),
522
+ * [_fetch_canceller_error (fails = False , content = f"{ i } " , cancelled_counter = cancelled_counter )
523
+ for i in range (2 , 200 )],
524
+ ]
525
+
526
+ await run_tasks (tasks , allow_failed = False )
527
+ # give some time to actually cancel the tasks in background
528
+ await asyncio .sleep (1 )
529
+ print ("Cancelled amount: " , cancelled_counter ["cancelled" ])
530
+ assert len (tasks ) > cancelled_counter ["cancelled" ] > 0
0 commit comments