@@ -2038,6 +2038,9 @@ class WMLInferenceEngineBase(
2038
2038
deployment_id (str, optional):
2039
2039
Deployment ID of a tuned model to be used for
2040
2040
inference. Mutually exclusive with 'model_name'.
2041
+ concurrency_limit (int):
2042
+ Number of concurrent requests sent to a model. Default is 10,
2043
+ which is also the maximum value for the generation.
2041
2044
parameters (Union[WMLInferenceEngineParams, WMLGenerationParamsMixin, WMLChatParamsMixin], optional):
2042
2045
Defines inference parameters and their values. Deprecated attribute, please pass respective
2043
2046
parameters directly to the respective class instead.
@@ -2046,6 +2049,7 @@ class WMLInferenceEngineBase(
2046
2049
credentials : Optional [CredentialsWML ] = None
2047
2050
model_name : Optional [str ] = None
2048
2051
deployment_id : Optional [str ] = None
2052
+ concurrency_limit : int = 10
2049
2053
label : str = "wml"
2050
2054
_requirements_list = {
2051
2055
"ibm_watsonx_ai" : "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. "
@@ -2299,11 +2303,6 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi
2299
2303
2300
2304
If you want to include images in your input, please use 'WMLInferenceEngineChat' instead.
2301
2305
2302
- Args:
2303
- concurrency_limit (int):
2304
- Number of concurrent requests sent to a model. Default is 10,
2305
- which is also the maximum value.
2306
-
2307
2306
Examples:
2308
2307
.. code-block:: python
2309
2308
@@ -2327,8 +2326,6 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi
2327
2326
results = wml_inference.infer(dataset["test"])
2328
2327
"""
2329
2328
2330
- concurrency_limit : int = 10
2331
-
2332
2329
def verify (self ):
2333
2330
super ().verify ()
2334
2331
@@ -2580,6 +2577,32 @@ def to_messages(self, instance: Union[Dict, List]) -> List[List[Dict[str, Any]]]
2580
2577
# images as SDK allows sending only one image per message.
2581
2578
return [messages ]
2582
2579
2580
+ def _handle_async_requests (
2581
+ self ,
2582
+ messages : List [List [Dict [str , Any ]]],
2583
+ params : Dict [str , Any ],
2584
+ ) -> List [Dict [str , Any ]]:
2585
+ async def handle_async_requests (start_idx , end_idx ):
2586
+ coroutines = [
2587
+ self ._model .achat (messages = messages [idx ], params = params )
2588
+ for idx in range (start_idx , end_idx )
2589
+ ]
2590
+ batch_results = await asyncio .gather (* coroutines )
2591
+ return list (batch_results )
2592
+
2593
+ loop = asyncio .get_event_loop ()
2594
+ results = []
2595
+
2596
+ for batch_idx in range (0 , len (messages ), self .concurrency_limit ):
2597
+ batch_results = loop .run_until_complete (
2598
+ handle_async_requests (
2599
+ batch_idx , min (batch_idx + self .concurrency_limit , len (messages ))
2600
+ )
2601
+ )
2602
+ results .extend (batch_results )
2603
+
2604
+ return results
2605
+
2583
2606
def _send_requests (
2584
2607
self ,
2585
2608
dataset : Union [List [Dict [str , Any ]], Dataset ],
@@ -2595,27 +2618,25 @@ def _send_requests(
2595
2618
output_type = "message"
2596
2619
params ["logprobs" ] = False
2597
2620
2598
- final_results = []
2599
-
2600
- for instance in dataset :
2601
- messages = self .to_messages (instance )
2602
-
2603
- for message in messages :
2604
- result = self ._model .chat (
2605
- messages = message ,
2606
- params = params ,
2607
- )
2621
+ indexed_messages = [
2622
+ (i , message )
2623
+ for i in range (len (dataset ))
2624
+ for message in self .to_messages (dataset [i ])
2625
+ ]
2608
2626
2609
- final_results .append (
2610
- self .get_return_object (
2611
- result ["choices" ][0 ][output_type ]["content" ],
2612
- result ,
2613
- instance ["source" ],
2614
- return_meta_data ,
2615
- )
2616
- )
2627
+ results = self ._handle_async_requests (
2628
+ [msg [1 ] for msg in indexed_messages ], params
2629
+ )
2617
2630
2618
- return final_results
2631
+ return [
2632
+ self .get_return_object (
2633
+ result ["choices" ][0 ][output_type ]["content" ],
2634
+ result ,
2635
+ dataset [idx [0 ]]["source" ],
2636
+ return_meta_data ,
2637
+ )
2638
+ for result , idx in zip (results , indexed_messages )
2639
+ ]
2619
2640
2620
2641
def get_return_object (self , predict_result , result , input_text , return_meta_data ):
2621
2642
if return_meta_data :
0 commit comments