Skip to content

Commit c8beba7

Browse files
[RemoteModels] HuggingFace batch Support (mlrun#9206)
1 parent 75f73c2 commit c8beba7

File tree

6 files changed

+352
-111
lines changed

6 files changed

+352
-111
lines changed

mlrun/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@
202202
"openai_default_model": "gpt-4o",
203203
"openai_batch_max_concurrent": 10,
204204
"huggingface_default_model": "microsoft/Phi-3-mini-4k-instruct",
205+
"huggingface_default_batch_size": 8,
205206
},
206207
# default node selector to be applied to all functions - json string base64 encoded format
207208
"default_function_node_selector": "e30=",

mlrun/datastore/model_provider/huggingface_provider.py

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def _response_handler(
179179
return str_response
180180
if invoke_response_format == InvokeResponseFormat.USAGE:
181181
tokenizer = self.client.tokenizer
182+
# Messages already be a formatted prompt string
182183
if not isinstance(messages, str):
183184
try:
184185
messages = tokenizer.apply_chat_template(
@@ -292,23 +293,61 @@ def custom_invoke(
292293
else:
293294
return self.client(**invoke_kwargs)
294295

296+
def _batch_invoke(
297+
self,
298+
messages_list: list[list[dict]],
299+
invoke_response_format: InvokeResponseFormat = InvokeResponseFormat.FULL,
300+
**invoke_kwargs,
301+
) -> list[Union[str, dict, list]]:
302+
"""
303+
Internal batch processing for multiple message lists.
304+
305+
:param messages_list: List of message lists to process in batch.
306+
:param invoke_response_format: Response format (STRING, USAGE, or FULL).
307+
:param invoke_kwargs: Additional kwargs for the pipeline.
308+
309+
:return: List of processed responses.
310+
"""
311+
if "batch_size" not in invoke_kwargs:
312+
invoke_kwargs["batch_size"] = (
313+
mlrun.mlconf.model_providers.huggingface_default_batch_size
314+
)
315+
316+
batch_response = self.custom_invoke(text_inputs=messages_list, **invoke_kwargs)
317+
318+
results = []
319+
for i, single_response in enumerate(batch_response):
320+
processed = self._response_handler(
321+
messages=messages_list[i],
322+
response=single_response,
323+
invoke_response_format=invoke_response_format,
324+
)
325+
results.append(processed)
326+
327+
return results
328+
295329
def invoke(
296330
self,
297-
messages: Union[str, list[str], "ChatType", list["ChatType"]],
331+
messages: Union["ChatType", list["ChatType"]],
298332
invoke_response_format: InvokeResponseFormat = InvokeResponseFormat.FULL,
299333
**invoke_kwargs,
300334
) -> Union[str, list, dict[str, Any]]:
301335
"""
302336
HuggingFace-specific implementation of model invocation using the synchronous pipeline client.
303337
Invokes a HuggingFace model operation for text generation tasks.
304338
339+
Supports both single and batch invocations:
340+
- Single invocation: Pass a single ChatType (string or chat format messages)
341+
- Batch invocation: Pass a list of ChatType objects for batch processing
342+
305343
Note: Ensure your environment has sufficient computational resources (CPU/GPU and memory) to run the model.
306344
307345
:param messages:
308346
Input for the text generation model. Can be provided in multiple formats:
309347
348+
**Single invocation:**
349+
310350
- A single string: Direct text input for generation
311-
- A list of strings: Multiple text inputs for batch processing
312351
- Chat format: A list of dictionaries with "role" and "content" keys:
313352
314353
.. code-block:: json
@@ -318,11 +357,27 @@ def invoke(
318357
{"role": "user", "content": "What is the capital of France?"}
319358
]
320359
360+
**Batch invocation:**
361+
362+
- List of chat format messages: Multiple chat conversations for batch processing:
363+
364+
.. code-block:: json
365+
366+
[
367+
[
368+
{"role": "user", "content": "What is the capital of France?"}
369+
],
370+
[
371+
{"role": "user", "content": "What is the capital of Germany?"}
372+
]
373+
]
374+
321375
:param invoke_response_format: InvokeResponseFormat
322376
Specifies the format of the returned response. Options:
323377
324-
- "string": Returns only the generated text content, extracted from a single response.
325-
- "usage": Combines the generated text with metadata (e.g., token usage), returning a dictionary:
378+
- "string": Returns only the generated text content. For batch invocations, returns a list of strings.
379+
- "usage": Combines the generated text with metadata (e.g., token usage). For batch invocations,
380+
returns a list of dictionaries:
326381
327382
.. code-block:: json
328383
{
@@ -342,9 +397,12 @@ def invoke(
342397
343398
:param invoke_kwargs:
344399
Additional keyword arguments passed to the HuggingFace pipeline.
400+
For batch invocations, you can specify 'batch_size' to control the batch processing size.
401+
If not provided, defaults to mlrun.mlconf.model_providers.huggingface_default_batch_size.
345402
346403
:return:
347-
A string, dictionary, or list of model outputs, depending on `invoke_response_format`.
404+
- Single invocation: A string, dictionary, or list depending on `invoke_response_format`.
405+
- Batch invocation: A list of strings, dictionaries, or lists depending on `invoke_response_format`.
348406
349407
:raises MLRunInvalidArgumentError:
350408
If the pipeline task is not "text-generation" or if the response contains multiple outputs when extracting
@@ -356,8 +414,19 @@ def invoke(
356414
raise mlrun.errors.MLRunInvalidArgumentError(
357415
"HuggingFaceProvider.invoke supports text-generation task only"
358416
)
417+
359418
if InvokeResponseFormat.is_str_response(invoke_response_format.value):
360419
invoke_kwargs["return_full_text"] = False
420+
421+
is_batch = self._validate_and_detect_batch_invocation(messages)
422+
423+
if is_batch:
424+
return self._batch_invoke(
425+
messages_list=messages,
426+
invoke_response_format=invoke_response_format,
427+
**invoke_kwargs,
428+
)
429+
361430
response = self.custom_invoke(text_inputs=messages, **invoke_kwargs)
362431
response = self._response_handler(
363432
messages=messages,

mlrun/datastore/model_provider/model_provider.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,43 @@ def __init__(
8282
self._client = None
8383
self._async_client = None
8484

85+
@staticmethod
86+
def _validate_and_detect_batch_invocation(
87+
messages: Union[list[dict], list[list[dict]]],
88+
) -> bool:
89+
"""
90+
Validate messages format and detect if this is a batch invocation.
91+
92+
:param messages: Either a list of message dicts (single) or list of message lists (batch)
93+
:return: True if batch invocation, False if single invocation
94+
:raises MLRunInvalidArgumentError: If messages format is invalid (mixed types or strings)
95+
"""
96+
if not messages or not isinstance(messages, list):
97+
raise mlrun.errors.MLRunInvalidArgumentError(
98+
"Messages must be a non-empty list of dictionaries or list of lists of dictionaries."
99+
)
100+
101+
# Check if user mistakenly passed a list of strings
102+
has_str = any(isinstance(item, str) for item in messages)
103+
if has_str:
104+
raise mlrun.errors.MLRunInvalidArgumentError(
105+
"Invalid messages format: list of strings is not supported. "
106+
"Messages must be a list of dicts (single invocation) or list of lists of dicts (batch invocation)."
107+
)
108+
109+
has_list = any(isinstance(item, list) for item in messages)
110+
has_dict = any(isinstance(item, dict) for item in messages)
111+
112+
if has_list and has_dict:
113+
raise mlrun.errors.MLRunInvalidArgumentError(
114+
"Invalid messages format: cannot mix list and dict items. "
115+
"Use either all lists for batch invocation or all dicts for single invocation."
116+
)
117+
118+
if has_list:
119+
return True
120+
return False
121+
85122
@staticmethod
86123
def _extract_string_output(response: Any) -> str:
87124
"""

mlrun/datastore/model_provider/openai_provider.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -375,42 +375,6 @@ async def _async_single_invoke(
375375
response=response,
376376
)
377377

378-
def _validate_and_detect_batch_invocation(
379-
self, messages: Union[list[dict], list[list[dict]]]
380-
) -> bool:
381-
"""
382-
Validate messages format and detect if this is a batch invocation.
383-
384-
:param messages: Either a list of message dicts (single) or list of message lists (batch)
385-
:return: True if batch invocation, False if single invocation
386-
:raises MLRunInvalidArgumentError: If messages format is invalid (mixed types or strings)
387-
"""
388-
if not messages or not isinstance(messages, list):
389-
raise mlrun.errors.MLRunInvalidArgumentError(
390-
"Messages must be a non-empty list of dictionaries or list of lists of dictionaries."
391-
)
392-
393-
# Check if user mistakenly passed a list of strings
394-
has_str = any(isinstance(item, str) for item in messages)
395-
if has_str:
396-
raise mlrun.errors.MLRunInvalidArgumentError(
397-
"Invalid messages format: list of strings is not supported. "
398-
"Messages must be a list of dicts (single invocation) or list of lists of dicts (batch invocation)."
399-
)
400-
401-
has_list = any(isinstance(item, list) for item in messages)
402-
has_dict = any(isinstance(item, dict) for item in messages)
403-
404-
if has_list and has_dict:
405-
raise mlrun.errors.MLRunInvalidArgumentError(
406-
"Invalid messages format: cannot mix list and dict items. "
407-
"Use either all lists for batch invocation or all dicts for single invocation."
408-
)
409-
410-
if has_list:
411-
return True
412-
return False
413-
414378
def invoke(
415379
self,
416380
messages: Union[list[dict], list[list[dict]]],

tests/datastore/remote_model/remote_model_utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,38 +35,43 @@
3535
}
3636
INPUT_DATA = [
3737
{
38-
"question": "What is the capital of France? Answer with one word first, then provide a historical overview.",
38+
"question": "What is the capital of France? Answer with one word first, then provide a historical overview."
39+
" Answer in detail with at least 200 words.",
3940
"depth_level": "detailed",
4041
"persona": "teacher",
4142
"tone": "casual",
4243
},
4344
{
44-
"question": "What is 2 + 2? Answer shortly and then explain with details.",
45+
"question": "What is the largest planet in our solar system? First give a one-word answer, "
46+
"then provide a detailed explanation in at least 200 words.",
4547
"depth_level": "basic",
46-
"persona": "math teacher",
48+
"persona": "astronomy teacher",
4749
"tone": "simple",
4850
},
4951
{
50-
"question": "Who wrote Hamlet? Answer shortly and then explain with details.",
52+
"question": "Who wrote Hamlet? Answer shortly and then explain with details. "
53+
"Answer in detail with at least 200 words.",
5154
"depth_level": "basic",
5255
"persona": "literature professor",
5356
"tone": "formal",
5457
},
5558
{
56-
"question": "What color is the sky on a clear day? Answer shortly and then explain with details.",
59+
"question": "What color is the sky on a clear day? Answer shortly and then "
60+
"Answer in detail with at least 200 words.",
5761
"depth_level": "basic",
5862
"persona": "child",
5963
"tone": "fun",
6064
},
6165
{
62-
"question": "What planet do we live on? Answer shortly and then explain with details.",
66+
"question": "What planet do we live on? Answer shortly and then explain with details. "
67+
"Answer in detail with at least 200 words.",
6368
"depth_level": "basic",
6469
"persona": "astronaut",
6570
"tone": "educational",
6671
},
6772
]
6873

69-
EXPECTED_RESULTS = ["paris", "4", "shakespeare", "blue", "earth"]
74+
EXPECTED_RESULTS = ["paris", "jupiter", "shakespeare", "blue", "earth"]
7075

7176
PROMPT_TEMPLATE = [
7277
{

0 commit comments

Comments
 (0)