Skip to content

Commit cdf5b82

Browse files
pawelknesyoavkatz
andauthored
Support for asynchronous requests for watsonx.ai chat (#1666)
* support for asynchronous requests in wml chat Signed-off-by: Paweł Knes <[email protected]> * update ibm-watsonx-ai version Signed-off-by: Paweł Knes <[email protected]> --------- Signed-off-by: Paweł Knes <[email protected]> Co-authored-by: Yoav Katz <[email protected]>
1 parent 30a5d19 commit cdf5b82

File tree

2 files changed

+48
-27
lines changed

2 files changed

+48
-27
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ ui = [
112112
"transformers"
113113
]
114114
watsonx = [
115-
"ibm-watsonx-ai==1.1.14"
115+
"ibm-watsonx-ai==1.2.10"
116116
]
117117
inference-tests = [
118118
"litellm>=1.52.9",

src/unitxt/inference.py

+47-26
Original file line numberDiff line numberDiff line change
@@ -2038,6 +2038,9 @@ class WMLInferenceEngineBase(
20382038
deployment_id (str, optional):
20392039
Deployment ID of a tuned model to be used for
20402040
inference. Mutually exclusive with 'model_name'.
2041+
concurrency_limit (int):
2042+
Number of concurrent requests sent to a model. Default is 10,
2043+
which is also the maximum value for the generation.
20412044
parameters (Union[WMLInferenceEngineParams, WMLGenerationParamsMixin, WMLChatParamsMixin], optional):
20422045
Defines inference parameters and their values. Deprecated attribute, please pass respective
20432046
parameters directly to the respective class instead.
@@ -2046,6 +2049,7 @@ class WMLInferenceEngineBase(
20462049
credentials: Optional[CredentialsWML] = None
20472050
model_name: Optional[str] = None
20482051
deployment_id: Optional[str] = None
2052+
concurrency_limit: int = 10
20492053
label: str = "wml"
20502054
_requirements_list = {
20512055
"ibm_watsonx_ai": "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. "
@@ -2299,11 +2303,6 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi
22992303
23002304
If you want to include images in your input, please use 'WMLInferenceEngineChat' instead.
23012305
2302-
Args:
2303-
concurrency_limit (int):
2304-
Number of concurrent requests sent to a model. Default is 10,
2305-
which is also the maximum value.
2306-
23072306
Examples:
23082307
.. code-block:: python
23092308
@@ -2327,8 +2326,6 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi
23272326
results = wml_inference.infer(dataset["test"])
23282327
"""
23292328

2330-
concurrency_limit: int = 10
2331-
23322329
def verify(self):
23332330
super().verify()
23342331

@@ -2580,6 +2577,32 @@ def to_messages(self, instance: Union[Dict, List]) -> List[List[Dict[str, Any]]]
25802577
# images as SDK allows sending only one image per message.
25812578
return [messages]
25822579

2580+
def _handle_async_requests(
2581+
self,
2582+
messages: List[List[Dict[str, Any]]],
2583+
params: Dict[str, Any],
2584+
) -> List[Dict[str, Any]]:
2585+
async def handle_async_requests(start_idx, end_idx):
2586+
coroutines = [
2587+
self._model.achat(messages=messages[idx], params=params)
2588+
for idx in range(start_idx, end_idx)
2589+
]
2590+
batch_results = await asyncio.gather(*coroutines)
2591+
return list(batch_results)
2592+
2593+
loop = asyncio.get_event_loop()
2594+
results = []
2595+
2596+
for batch_idx in range(0, len(messages), self.concurrency_limit):
2597+
batch_results = loop.run_until_complete(
2598+
handle_async_requests(
2599+
batch_idx, min(batch_idx + self.concurrency_limit, len(messages))
2600+
)
2601+
)
2602+
results.extend(batch_results)
2603+
2604+
return results
2605+
25832606
def _send_requests(
25842607
self,
25852608
dataset: Union[List[Dict[str, Any]], Dataset],
@@ -2595,27 +2618,25 @@ def _send_requests(
25952618
output_type = "message"
25962619
params["logprobs"] = False
25972620

2598-
final_results = []
2599-
2600-
for instance in dataset:
2601-
messages = self.to_messages(instance)
2602-
2603-
for message in messages:
2604-
result = self._model.chat(
2605-
messages=message,
2606-
params=params,
2607-
)
2621+
indexed_messages = [
2622+
(i, message)
2623+
for i in range(len(dataset))
2624+
for message in self.to_messages(dataset[i])
2625+
]
26082626

2609-
final_results.append(
2610-
self.get_return_object(
2611-
result["choices"][0][output_type]["content"],
2612-
result,
2613-
instance["source"],
2614-
return_meta_data,
2615-
)
2616-
)
2627+
results = self._handle_async_requests(
2628+
[msg[1] for msg in indexed_messages], params
2629+
)
26172630

2618-
return final_results
2631+
return [
2632+
self.get_return_object(
2633+
result["choices"][0][output_type]["content"],
2634+
result,
2635+
dataset[idx[0]]["source"],
2636+
return_meta_data,
2637+
)
2638+
for result, idx in zip(results, indexed_messages)
2639+
]
26192640

26202641
def get_return_object(self, predict_result, result, input_text, return_meta_data):
26212642
if return_meta_data:

0 commit comments

Comments
 (0)