Skip to content

Commit b849772

Browse files
authored
Merge pull request #13 from mindsdb/feat/add-timeout-explicitaly
feat: add timeout args
2 parents 38650dc + d27fce6 commit b849772

File tree

2 files changed

+36
-31
lines changed

2 files changed

+36
-31
lines changed

src/aipdf/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .ocr import ocr, ocr_async
22

3-
__version__ = "0.0.6.2"
3+
__version__ = "0.0.6.3"
44

55
__all__ = ["__version__", "ocr", "ocr_async"]

src/aipdf/ocr.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import os
77

88
import fitz
9+
from httpx import Timeout
910
from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI
1011

11-
1212
DEFAULT_PROMPT = """
1313
Extract the full markdown text from the given image, following these guidelines:
1414
- Respond only with markdown, no additional commentary.
@@ -22,14 +22,15 @@
2222
DEFAULT_GAP_THRESHOLD = 10 # 10 points
2323

2424

25-
def get_openai_client(api_key=None, base_url='https://api.openai.com/v1', is_async=False, **kwargs):
25+
def get_openai_client(api_key=None, base_url='https://api.openai.com/v1', is_async=False, timeout=Timeout(None), **kwargs):
2626
"""
2727
Get an OpenAI client instance.
2828
2929
Args:
3030
api_key (str): The OpenAI API key.
3131
base_url (str): The base URL for the OpenAI API.
3232
is_async (bool): Whether to create an asynchronous client.
33+
timeout (Timeout): Timeout for the OpenAI API calls.
3334
**kwargs: Additional keyword arguments.
3435
3536
Returns:
@@ -40,32 +41,32 @@ def get_openai_client(api_key=None, base_url='https://api.openai.com/v1', is_asy
4041

4142
if not api_key:
4243
raise ValueError("API key is required. Please provide it as an argument or set the AIPDF_API_KEY environment variable.")
43-
44+
4445
if base_url and "openai.azure.com" in base_url:
4546
if is_async:
4647
return AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, **kwargs)
4748
else:
4849
return AzureOpenAI(api_key=api_key, azure_endpoint=base_url, **kwargs)
4950

5051
if is_async:
51-
return AsyncOpenAI(api_key=api_key, base_url=base_url, **kwargs)
52+
return AsyncOpenAI(api_key=api_key, base_url=base_url, timeout=timeout, **kwargs)
5253
else:
53-
return OpenAI(api_key=api_key, base_url=base_url, **kwargs)
54+
return OpenAI(api_key=api_key, base_url=base_url, timeout=timeout, **kwargs)
5455

5556

5657
def _prepare_image_messages(file_object, prompt):
5758
"""
5859
Helper function to prepare messages for OpenAI API call.
59-
60+
6061
Args:
6162
file_object (io.BytesIO): The image file object.
6263
prompt (str): The prompt to send to the API.
63-
64+
6465
Returns:
6566
list: The messages list for the API call.
6667
"""
6768
base64_image = base64.b64encode(file_object.read()).decode('utf-8')
68-
69+
6970
return [
7071
{
7172
"role": "user",
@@ -88,42 +89,42 @@ def _prepare_image_messages(file_object, prompt):
8889
def _validate_and_extract_content(response):
8990
"""
9091
Helper function to validate OpenAI API response and extract content.
91-
92+
9293
Args:
9394
response: The response object from OpenAI API.
94-
95+
9596
Returns:
9697
str or None: The extracted content, or None if validation fails.
9798
"""
9899
# Validate the response structure before accessing choices
99100
if not response:
100101
logging.error(f"Received empty response from OpenAI API: {response}")
101102
return None
102-
103+
103104
if not hasattr(response, 'choices') or not response.choices:
104105
logging.error(f"Response does not contain choices or choices is empty. Response: {response}")
105106
return None
106-
107+
107108
if len(response.choices) == 0:
108109
logging.error(f"Response choices list is empty. Response: {response}")
109110
return None
110-
111+
111112
first_choice = response.choices[0]
112113
if not hasattr(first_choice, 'message') or not first_choice.message:
113114
logging.error(f"Response choice does not contain message. First choice: {first_choice}")
114115
return None
115-
116+
116117
if not hasattr(first_choice.message, 'content'):
117118
logging.error(f"Response message does not contain content. Message: {first_choice.message}")
118119
return None
119-
120+
120121
markdown_content = first_choice.message.content
121-
122+
122123
# Additional check for empty or None content
123124
if not markdown_content:
124125
logging.warning(f"Response content is empty or None. Content: {repr(markdown_content)}")
125126
return None
126-
127+
127128
return markdown_content
128129

129130

@@ -152,7 +153,7 @@ def image_to_markdown(file_object, client, model="gpt-4o", prompt=DEFAULT_PROMP
152153
)
153154

154155
markdown_content = _validate_and_extract_content(response)
155-
156+
156157
if markdown_content:
157158
logging.debug("Page processed successfully")
158159
return markdown_content
@@ -163,7 +164,7 @@ def image_to_markdown(file_object, client, model="gpt-4o", prompt=DEFAULT_PROMP
163164
except Exception as e:
164165
logging.error(f"An error occurred while processing the image: {e}")
165166
return None
166-
167+
167168

168169
async def image_to_markdown_async(file_object, client, model="gpt-4o", prompt=DEFAULT_PROMPT):
169170
"""
@@ -190,7 +191,7 @@ async def image_to_markdown_async(file_object, client, model="gpt-4o", prompt=DE
190191
)
191192

192193
markdown_content = _validate_and_extract_content(response)
193-
194+
194195
if markdown_content:
195196
logging.debug("Page processed successfully")
196197
return markdown_content
@@ -339,16 +340,17 @@ def process_pages(pdf_file, pages_list=None, use_llm_for_all=False, drawing_area
339340

340341

341342
def ocr(
342-
pdf_file,
343+
pdf_file,
343344
api_key = None,
344-
model="gpt-4o",
345-
base_url='https://api.openai.com/v1',
346-
prompt=DEFAULT_PROMPT,
345+
model="gpt-4o",
346+
base_url='https://api.openai.com/v1',
347+
prompt=DEFAULT_PROMPT,
347348
pages_list=None,
348349
use_llm_for_all=False,
349350
drawing_area_threshold=DEFAULT_DRAWING_AREA_THRESHOLD,
350351
gap_threshold=DEFAULT_GAP_THRESHOLD,
351352
logging_level=logging.INFO,
353+
timeout=Timeout(None),
352354
**kwargs
353355
):
354356
"""
@@ -367,6 +369,7 @@ def ocr(
367369
drawing_area_threshold (float): Minimum fraction of page area that drawings must cover to be visual.
368370
gap_threshold (int): The threshold for vertical gaps between text blocks.
369371
logging_level (int): The logging level. Defaults to logging.INFO.
372+
timeout (Timeout): Timeout for the OpenAI API calls.
370373
**kwargs: Additional keyword arguments.
371374
372375
Returns:
@@ -375,8 +378,8 @@ def ocr(
375378
# Set up logging
376379
logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
377380

378-
client = get_openai_client(api_key=api_key, base_url=base_url, **kwargs)
379-
381+
client = get_openai_client(api_key=api_key, base_url=base_url, timeout=timeout, **kwargs)
382+
380383
# Identify the maximum number of workers for parallel processing
381384
max_workers = os.getenv("AIPDF_MAX_CONCURRENT_REQUESTS", None)
382385
if max_workers:
@@ -400,9 +403,9 @@ def ocr(
400403
# Process each image file in parallel
401404
with executor:
402405
# Submit tasks for each image file
403-
future_to_page = {executor.submit(image_to_markdown, img_file, client, model, prompt): page_num
406+
future_to_page = {executor.submit(image_to_markdown, img_file, client, model, prompt): page_num
404407
for page_num, img_file in image_files.items()}
405-
408+
406409
# Collect results as they complete
407410
for future in concurrent.futures.as_completed(future_to_page):
408411
page_num = future_to_page[future]
@@ -420,7 +423,7 @@ def ocr(
420423

421424

422425
async def ocr_async(
423-
pdf_file,
426+
pdf_file,
424427
api_key = None,
425428
model="gpt-4o",
426429
base_url='https://api.openai.com/v1',
@@ -430,6 +433,7 @@ async def ocr_async(
430433
drawing_area_threshold=DEFAULT_DRAWING_AREA_THRESHOLD,
431434
gap_threshold=DEFAULT_GAP_THRESHOLD,
432435
logging_level=logging.INFO,
436+
timeout=Timeout(None),
433437
**kwargs
434438
):
435439
"""
@@ -448,6 +452,7 @@ async def ocr_async(
448452
drawing_area_threshold (float): Minimum fraction of page area that drawings must cover to be visual.
449453
gap_threshold (int): The threshold for vertical gaps between text blocks.
450454
logging_level (int): The logging level. Defaults to logging.INFO.
455+
timeout (Timeout): Timeout for the OpenAI API calls.
451456
**kwargs: Additional keyword arguments.
452457
453458
Returns:
@@ -456,7 +461,7 @@ async def ocr_async(
456461
# Set up logging
457462
logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
458463

459-
client = get_openai_client(api_key=api_key, base_url=base_url, is_async=True, **kwargs)
464+
client = get_openai_client(api_key=api_key, base_url=base_url, is_async=True, timeout=timeout, **kwargs)
460465

461466
# Set up a semaphore for limiting concurrent requests if specified
462467
semaphore = None

0 commit comments

Comments
 (0)