Skip to content

Commit ecae5f6

Browse files
authored
✨ add support for workflow polling (#323)
1 parent 9e4885a commit ecae5f6

File tree

13 files changed

+225
-58
lines changed

13 files changed

+225
-58
lines changed

mindee/client.py

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from mindee.input import WorkflowOptions
88
from mindee.input.local_response import LocalResponse
99
from mindee.input.page_options import PageOptions
10+
from mindee.input.predict_options import AsyncPredictOptions, PredictOptions
1011
from mindee.input.sources.base_64_input import Base64Input
1112
from mindee.input.sources.bytes_input import BytesInput
1213
from mindee.input.sources.file_input import FileInput
@@ -123,14 +124,13 @@ def parse(
123124
page_options.on_min_pages,
124125
page_options.page_indexes,
125126
)
127+
options = PredictOptions(cropper, full_text, include_words)
126128
return self._make_request(
127129
product_class,
128130
input_source,
129131
endpoint,
130-
include_words,
132+
options,
131133
close_file,
132-
cropper,
133-
full_text,
134134
)
135135

136136
def enqueue(
@@ -143,6 +143,8 @@ def enqueue(
143143
cropper: bool = False,
144144
endpoint: Optional[Endpoint] = None,
145145
full_text: bool = False,
146+
workflow_id: Optional[str] = None,
147+
rag: bool = False,
146148
) -> AsyncPredictResponse:
147149
"""
148150
Enqueues a document to an asynchronous endpoint.
@@ -169,6 +171,11 @@ def enqueue(
169171
:param endpoint: For custom endpoints, an endpoint has to be given.
170172
171173
:param full_text: Whether to include the full OCR text response in compatible APIs.
174+
175+
:param workflow_id: Workflow ID.
176+
177+
:param rag: If set, will enable Retrieval-Augmented Generation.
178+
Only works if a valid ``workflow_id`` is set.
172179
"""
173180
if input_source is None:
174181
raise MindeeClientError("No input document provided.")
@@ -185,14 +192,15 @@ def enqueue(
185192
page_options.on_min_pages,
186193
page_options.page_indexes,
187194
)
195+
options = AsyncPredictOptions(
196+
cropper, full_text, include_words, workflow_id, rag
197+
)
188198
return self._predict_async(
189199
product_class,
190200
input_source,
201+
options,
191202
endpoint,
192-
include_words,
193203
close_file,
194-
cropper,
195-
full_text,
196204
)
197205

198206
def load_prediction(
@@ -246,8 +254,9 @@ def execute_workflow(
246254
:param input_source: The document/source file to use.
247255
Has to be created beforehand.
248256
:param workflow_id: ID of the workflow.
249-
:param page_options: If set, remove pages from the document as specified. This is done before sending the file\
250-
to the server. It is useful to avoid page limitations.
257+
:param page_options: If set, remove pages from the document as specified.
258+
This is done before sending the file to the server.
259+
It is useful to avoid page limitations.
251260
:param options: Options for the workflow.
252261
:return:
253262
"""
@@ -259,13 +268,11 @@ def execute_workflow(
259268
page_options.page_indexes,
260269
)
261270

262-
logger.debug("Sending document to workflow: %s", workflow_id)
263-
264271
if not options:
265272
options = WorkflowOptions(
266273
alias=None, priority=None, full_text=False, public_url=None, rag=False
267274
)
268-
275+
logger.debug("Sending document to workflow: %s", workflow_id)
269276
return self._send_to_workflow(GeneratedV1, input_source, workflow_id, options)
270277

271278
def _validate_async_params(
@@ -285,7 +292,7 @@ def _validate_async_params(
285292
if max_retries < min_retries:
286293
raise MindeeClientError(f"Cannot set retries to less than {min_retries}.")
287294

288-
def enqueue_and_parse(
295+
def enqueue_and_parse( # pylint: disable=too-many-locals
289296
self,
290297
product_class: Type[Inference],
291298
input_source: Union[LocalInputSource, UrlInputSource],
@@ -298,40 +305,51 @@ def enqueue_and_parse(
298305
delay_sec: float = 1.5,
299306
max_retries: int = 80,
300307
full_text: bool = False,
308+
workflow_id: Optional[str] = None,
309+
rag: bool = False,
301310
) -> AsyncPredictResponse:
302311
"""
303312
Enqueues to an asynchronous endpoint and automatically polls for a response.
304313
305-
:param product_class: The document class to use. The response object will be instantiated based on this\
306-
parameter.
314+
:param product_class: The document class to use.
315+
The response object will be instantiated based on this parameter.
307316
308-
:param input_source: The document/source file to use. Has to be created beforehand.
317+
:param input_source: The document/source file to use.
318+
Has to be created beforehand.
309319
310-
:param include_words: Whether to include the full text for each page. This performs a full OCR operation on\
311-
the server and will increase response time.
320+
:param include_words: Whether to include the full text for each page.
321+
This performs a full OCR operation on the server and will increase response time.
312322
313-
:param close_file: Whether to ``close()`` the file after parsing it. Set to ``False`` if you need to access\
314-
the file after this operation.
323+
:param close_file: Whether to ``close()`` the file after parsing it.
324+
Set to ``False`` if you need to access the file after this operation.
315325
316-
:param page_options: If set, remove pages from the document as specified. This is done before sending the file\
317-
to the server. It is useful to avoid page limitations.
326+
:param page_options: If set, remove pages from the document as specified.
327+
This is done before sending the file to the server.
328+
It is useful to avoid page limitations.
318329
319-
:param cropper: Whether to include cropper results for each page. This performs a cropping operation on the\
320-
server and will increase response time.
330+
:param cropper: Whether to include cropper results for each page.
331+
This performs a cropping operation on the server and will increase response time.
321332
322333
:param endpoint: For custom endpoints, an endpoint has to be given.
323334
324-
:param initial_delay_sec: Delay between each polling attempts This should not be shorter than 1 second.
335+
:param initial_delay_sec: Delay between each polling attempts.
336+
This should not be shorter than 1 second.
325337
326-
:param delay_sec: Delay between each polling attempts This should not be shorter than 1 second.
338+
:param delay_sec: Delay between each polling attempts.
339+
This should not be shorter than 1 second.
327340
328341
:param max_retries: Total amount of polling attempts.
329342
330343
:param full_text: Whether to include the full OCR text response in compatible APIs.
344+
345+
:param workflow_id: Workflow ID.
346+
347+
:param rag: If set, will enable Retrieval-Augmented Generation.
348+
Only works if a valid ``workflow_id`` is set.
331349
"""
332350
self._validate_async_params(initial_delay_sec, delay_sec, max_retries)
333351
if not endpoint:
334-
endpoint = self._initialize_ots_endpoint(product_class)
352+
endpoint = self._initialize_ots_endpoint(product_class=product_class)
335353
queue_result = self.enqueue(
336354
product_class,
337355
input_source,
@@ -341,6 +359,8 @@ def enqueue_and_parse(
341359
cropper,
342360
endpoint,
343361
full_text,
362+
workflow_id,
363+
rag,
344364
)
345365
logger.debug(
346366
"Successfully enqueued document with job id: %s", queue_result.job.id
@@ -406,15 +426,16 @@ def _make_request(
406426
product_class: Type[Inference],
407427
input_source: Union[LocalInputSource, UrlInputSource],
408428
endpoint: Endpoint,
409-
include_words: bool,
429+
options: PredictOptions,
410430
close_file: bool,
411-
cropper: bool,
412-
full_text: bool,
413431
) -> PredictResponse:
414432
response = endpoint.predict_req_post(
415-
input_source, include_words, close_file, cropper, full_text
433+
input_source,
434+
options.include_words,
435+
close_file,
436+
options.cropper,
437+
options.full_text,
416438
)
417-
418439
dict_response = response.json()
419440

420441
if not is_valid_sync_response(response):
@@ -423,28 +444,30 @@ def _make_request(
423444
str(product_class.endpoint_name),
424445
clean_response,
425446
)
426-
427447
return PredictResponse(product_class, dict_response)
428448

429449
def _predict_async(
430450
self,
431451
product_class: Type[Inference],
432452
input_source: Union[LocalInputSource, UrlInputSource],
453+
options: AsyncPredictOptions,
433454
endpoint: Optional[Endpoint] = None,
434-
include_words: bool = False,
435455
close_file: bool = True,
436-
cropper: bool = False,
437-
full_text: bool = False,
438456
) -> AsyncPredictResponse:
439457
"""Sends a document to the queue, and sends back an asynchronous predict response."""
440458
if input_source is None:
441459
raise MindeeClientError("No input document provided")
442460
if not endpoint:
443461
endpoint = self._initialize_ots_endpoint(product_class)
444462
response = endpoint.predict_async_req_post(
445-
input_source, include_words, close_file, cropper, full_text
463+
input_source=input_source,
464+
include_words=options.include_words,
465+
close_file=close_file,
466+
cropper=options.cropper,
467+
full_text=options.full_text,
468+
workflow_id=options.workflow_id,
469+
rag=options.rag,
446470
)
447-
448471
dict_response = response.json()
449472

450473
if not is_valid_async_response(response):

mindee/input/predict_options.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import Optional
2+
3+
4+
class PredictOptions:
5+
"""Options to pass to a prediction."""
6+
7+
def __init__(
8+
self,
9+
cropper: bool = False,
10+
full_text: bool = False,
11+
include_words: bool = False,
12+
):
13+
self.cropper = cropper
14+
self.full_text = full_text
15+
self.include_words = include_words
16+
17+
18+
class AsyncPredictOptions(PredictOptions):
19+
"""Options to pass to an asynchronous prediction."""
20+
21+
def __init__(
22+
self,
23+
cropper: bool = False,
24+
full_text: bool = False,
25+
include_words: bool = False,
26+
workflow_id: Optional[str] = None,
27+
rag: bool = False,
28+
):
29+
super().__init__(cropper, full_text, include_words)
30+
self.workflow_id = workflow_id
31+
self.rag = rag

mindee/mindee_http/endpoint.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Union
2+
from typing import Optional, Union
33

44
import requests
55
from requests import Response
@@ -60,6 +60,8 @@ def predict_async_req_post(
6060
close_file: bool = True,
6161
cropper: bool = False,
6262
full_text: bool = False,
63+
workflow_id: Optional[str] = None,
64+
rag: bool = False,
6365
) -> requests.Response:
6466
"""
6567
Make an asynchronous request to POST a document for prediction.
@@ -69,10 +71,19 @@ def predict_async_req_post(
6971
:param close_file: Whether to `close()` the file after parsing it.
7072
:param cropper: Including Mindee cropping results.
7173
:param full_text: Whether to include the full OCR text response in compatible APIs.
74+
:param workflow_id: Workflow ID.
75+
:param rag: If set, will enable Retrieval-Augmented Generation.
7276
:return: requests response
7377
"""
7478
return self._custom_request(
75-
"predict_async", input_source, include_words, close_file, cropper, full_text
79+
"predict_async",
80+
input_source,
81+
include_words,
82+
close_file,
83+
cropper,
84+
full_text,
85+
workflow_id,
86+
rag,
7687
)
7788

7889
def _custom_request(
@@ -83,6 +94,8 @@ def _custom_request(
8394
close_file: bool = True,
8495
cropper: bool = False,
8596
full_text: bool = False,
97+
workflow_id: Optional[str] = None,
98+
rag: bool = False,
8699
):
87100
data = {}
88101
if include_words:
@@ -93,11 +106,18 @@ def _custom_request(
93106
params["full_text_ocr"] = "true"
94107
if cropper:
95108
params["cropper"] = "true"
109+
if rag:
110+
params["rag"] = "true"
111+
112+
if workflow_id:
113+
url = f"{self.settings.base_url}/workflows/{workflow_id}/{route}"
114+
else:
115+
url = f"{self.settings.url_root}/{route}"
96116

97117
if isinstance(input_source, UrlInputSource):
98118
data["document"] = input_source.url
99119
response = requests.post(
100-
f"{self.settings.url_root}/{route}",
120+
url=url,
101121
headers=self.settings.base_headers,
102122
data=data,
103123
params=params,
@@ -106,7 +126,7 @@ def _custom_request(
106126
else:
107127
files = {"document": input_source.read_contents(close_file)}
108128
response = requests.post(
109-
f"{self.settings.url_root}/{route}",
129+
url=url,
110130
files=files,
111131
headers=self.settings.base_headers,
112132
data=data,

mindee/parsing/common/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from mindee.parsing.common.page import Page
1616
from mindee.parsing.common.predict_response import PredictResponse
1717
from mindee.parsing.common.prediction import Prediction
18-
from mindee.parsing.common.string_dict import StringDict
1918
from mindee.parsing.common.summary_helper import (
2019
clean_out_string,
2120
format_for_display,

mindee/parsing/common/async_predict_response.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class AsyncPredictResponse(Generic[TypeInference], ApiResponse):
1515

1616
job: Job
1717
"""Job object link to the prediction. As long as it isn't complete, the prediction doesn't exist."""
18-
document: Optional[Document]
18+
document: Optional[Document] = None
1919

2020
def __init__(
2121
self, inference_type: Type[TypeInference], raw_response: StringDict

mindee/parsing/common/document.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ class Document(Generic[TypePrediction, TypePage]):
2929
"""Result of the base inference"""
3030
id: str
3131
"""Id of the document as sent back by the server"""
32-
extras: Optional[Extras]
32+
extras: Optional[Extras] = None
3333
"""Potential Extras fields sent back along the prediction"""
34-
ocr: Optional[Ocr]
34+
ocr: Optional[Ocr] = None
3535
"""Potential raw text results read by the OCR (limited feature)"""
3636
n_pages: int
3737
"""Amount of pages in the document"""

0 commit comments

Comments
 (0)