7
7
from mindee .input import WorkflowOptions
8
8
from mindee .input .local_response import LocalResponse
9
9
from mindee .input .page_options import PageOptions
10
+ from mindee .input .predict_options import AsyncPredictOptions , PredictOptions
10
11
from mindee .input .sources .base_64_input import Base64Input
11
12
from mindee .input .sources .bytes_input import BytesInput
12
13
from mindee .input .sources .file_input import FileInput
@@ -123,14 +124,13 @@ def parse(
123
124
page_options .on_min_pages ,
124
125
page_options .page_indexes ,
125
126
)
127
+ options = PredictOptions (cropper , full_text , include_words )
126
128
return self ._make_request (
127
129
product_class ,
128
130
input_source ,
129
131
endpoint ,
130
- include_words ,
132
+ options ,
131
133
close_file ,
132
- cropper ,
133
- full_text ,
134
134
)
135
135
136
136
def enqueue (
@@ -143,6 +143,8 @@ def enqueue(
143
143
cropper : bool = False ,
144
144
endpoint : Optional [Endpoint ] = None ,
145
145
full_text : bool = False ,
146
+ workflow_id : Optional [str ] = None ,
147
+ rag : bool = False ,
146
148
) -> AsyncPredictResponse :
147
149
"""
148
150
Enqueues a document to an asynchronous endpoint.
@@ -169,6 +171,11 @@ def enqueue(
169
171
:param endpoint: For custom endpoints, an endpoint has to be given.
170
172
171
173
:param full_text: Whether to include the full OCR text response in compatible APIs.
174
+
175
+ :param workflow_id: Workflow ID.
176
+
177
+ :param rag: If set, will enable Retrieval-Augmented Generation.
178
+ Only works if a valid ``workflow_id`` is set.
172
179
"""
173
180
if input_source is None :
174
181
raise MindeeClientError ("No input document provided." )
@@ -185,14 +192,15 @@ def enqueue(
185
192
page_options .on_min_pages ,
186
193
page_options .page_indexes ,
187
194
)
195
+ options = AsyncPredictOptions (
196
+ cropper , full_text , include_words , workflow_id , rag
197
+ )
188
198
return self ._predict_async (
189
199
product_class ,
190
200
input_source ,
201
+ options ,
191
202
endpoint ,
192
- include_words ,
193
203
close_file ,
194
- cropper ,
195
- full_text ,
196
204
)
197
205
198
206
def load_prediction (
@@ -246,8 +254,9 @@ def execute_workflow(
246
254
:param input_source: The document/source file to use.
247
255
Has to be created beforehand.
248
256
:param workflow_id: ID of the workflow.
249
- :param page_options: If set, remove pages from the document as specified. This is done before sending the file\
250
- to the server. It is useful to avoid page limitations.
257
+ :param page_options: If set, remove pages from the document as specified.
258
+ This is done before sending the file to the server.
259
+ It is useful to avoid page limitations.
251
260
:param options: Options for the workflow.
252
261
:return:
253
262
"""
@@ -259,13 +268,11 @@ def execute_workflow(
259
268
page_options .page_indexes ,
260
269
)
261
270
262
- logger .debug ("Sending document to workflow: %s" , workflow_id )
263
-
264
271
if not options :
265
272
options = WorkflowOptions (
266
273
alias = None , priority = None , full_text = False , public_url = None , rag = False
267
274
)
268
-
275
+ logger . debug ( "Sending document to workflow: %s" , workflow_id )
269
276
return self ._send_to_workflow (GeneratedV1 , input_source , workflow_id , options )
270
277
271
278
def _validate_async_params (
@@ -285,7 +292,7 @@ def _validate_async_params(
285
292
if max_retries < min_retries :
286
293
raise MindeeClientError (f"Cannot set retries to less than { min_retries } ." )
287
294
288
- def enqueue_and_parse (
295
+ def enqueue_and_parse ( # pylint: disable=too-many-locals
289
296
self ,
290
297
product_class : Type [Inference ],
291
298
input_source : Union [LocalInputSource , UrlInputSource ],
@@ -298,40 +305,51 @@ def enqueue_and_parse(
298
305
delay_sec : float = 1.5 ,
299
306
max_retries : int = 80 ,
300
307
full_text : bool = False ,
308
+ workflow_id : Optional [str ] = None ,
309
+ rag : bool = False ,
301
310
) -> AsyncPredictResponse :
302
311
"""
303
312
Enqueues to an asynchronous endpoint and automatically polls for a response.
304
313
305
- :param product_class: The document class to use. The response object will be instantiated based on this \
306
- parameter.
314
+ :param product_class: The document class to use.
315
+ The response object will be instantiated based on this parameter.
307
316
308
- :param input_source: The document/source file to use. Has to be created beforehand.
317
+ :param input_source: The document/source file to use.
318
+ Has to be created beforehand.
309
319
310
- :param include_words: Whether to include the full text for each page. This performs a full OCR operation on \
311
- the server and will increase response time.
320
+ :param include_words: Whether to include the full text for each page.
321
+ This performs a full OCR operation on the server and will increase response time.
312
322
313
- :param close_file: Whether to ``close()`` the file after parsing it. Set to ``False`` if you need to access \
314
- the file after this operation.
323
+ :param close_file: Whether to ``close()`` the file after parsing it.
324
+ Set to ``False`` if you need to access the file after this operation.
315
325
316
- :param page_options: If set, remove pages from the document as specified. This is done before sending the file\
317
- to the server. It is useful to avoid page limitations.
326
+ :param page_options: If set, remove pages from the document as specified.
327
+ This is done before sending the file to the server.
328
+ It is useful to avoid page limitations.
318
329
319
- :param cropper: Whether to include cropper results for each page. This performs a cropping operation on the \
320
- server and will increase response time.
330
+ :param cropper: Whether to include cropper results for each page.
331
+ This performs a cropping operation on the server and will increase response time.
321
332
322
333
:param endpoint: For custom endpoints, an endpoint has to be given.
323
334
324
- :param initial_delay_sec: Delay between each polling attempts This should not be shorter than 1 second.
335
+ :param initial_delay_sec: Delay between each polling attempts.
336
+ This should not be shorter than 1 second.
325
337
326
- :param delay_sec: Delay between each polling attempts This should not be shorter than 1 second.
338
+ :param delay_sec: Delay between each polling attempts.
339
+ This should not be shorter than 1 second.
327
340
328
341
:param max_retries: Total amount of polling attempts.
329
342
330
343
:param full_text: Whether to include the full OCR text response in compatible APIs.
344
+
345
+ :param workflow_id: Workflow ID.
346
+
347
+ :param rag: If set, will enable Retrieval-Augmented Generation.
348
+ Only works if a valid ``workflow_id`` is set.
331
349
"""
332
350
self ._validate_async_params (initial_delay_sec , delay_sec , max_retries )
333
351
if not endpoint :
334
- endpoint = self ._initialize_ots_endpoint (product_class )
352
+ endpoint = self ._initialize_ots_endpoint (product_class = product_class )
335
353
queue_result = self .enqueue (
336
354
product_class ,
337
355
input_source ,
@@ -341,6 +359,8 @@ def enqueue_and_parse(
341
359
cropper ,
342
360
endpoint ,
343
361
full_text ,
362
+ workflow_id ,
363
+ rag ,
344
364
)
345
365
logger .debug (
346
366
"Successfully enqueued document with job id: %s" , queue_result .job .id
@@ -406,15 +426,16 @@ def _make_request(
406
426
product_class : Type [Inference ],
407
427
input_source : Union [LocalInputSource , UrlInputSource ],
408
428
endpoint : Endpoint ,
409
- include_words : bool ,
429
+ options : PredictOptions ,
410
430
close_file : bool ,
411
- cropper : bool ,
412
- full_text : bool ,
413
431
) -> PredictResponse :
414
432
response = endpoint .predict_req_post (
415
- input_source , include_words , close_file , cropper , full_text
433
+ input_source ,
434
+ options .include_words ,
435
+ close_file ,
436
+ options .cropper ,
437
+ options .full_text ,
416
438
)
417
-
418
439
dict_response = response .json ()
419
440
420
441
if not is_valid_sync_response (response ):
@@ -423,28 +444,30 @@ def _make_request(
423
444
str (product_class .endpoint_name ),
424
445
clean_response ,
425
446
)
426
-
427
447
return PredictResponse (product_class , dict_response )
428
448
429
449
def _predict_async (
430
450
self ,
431
451
product_class : Type [Inference ],
432
452
input_source : Union [LocalInputSource , UrlInputSource ],
453
+ options : AsyncPredictOptions ,
433
454
endpoint : Optional [Endpoint ] = None ,
434
- include_words : bool = False ,
435
455
close_file : bool = True ,
436
- cropper : bool = False ,
437
- full_text : bool = False ,
438
456
) -> AsyncPredictResponse :
439
457
"""Sends a document to the queue, and sends back an asynchronous predict response."""
440
458
if input_source is None :
441
459
raise MindeeClientError ("No input document provided" )
442
460
if not endpoint :
443
461
endpoint = self ._initialize_ots_endpoint (product_class )
444
462
response = endpoint .predict_async_req_post (
445
- input_source , include_words , close_file , cropper , full_text
463
+ input_source = input_source ,
464
+ include_words = options .include_words ,
465
+ close_file = close_file ,
466
+ cropper = options .cropper ,
467
+ full_text = options .full_text ,
468
+ workflow_id = options .workflow_id ,
469
+ rag = options .rag ,
446
470
)
447
-
448
471
dict_response = response .json ()
449
472
450
473
if not is_valid_async_response (response ):
0 commit comments