forked from pydantic/pydantic-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgroq.py
More file actions
782 lines (695 loc) · 33.3 KB
/
groq.py
File metadata and controls
782 lines (695 loc) · 33.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
from __future__ import annotations as _annotations
from collections.abc import AsyncIterable, AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Literal, cast, overload
from pydantic import BaseModel, ValidationError
from pydantic_core import from_json
from typing_extensions import assert_never
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
from .._run_context import RunContext
from .._thinking_part import split_content_into_text_and_thinking
from .._utils import generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id, number_to_datetime
from ..builtin_tools import AbstractBuiltinTool, WebSearchTool
from ..exceptions import ModelAPIError, UserError
from ..messages import (
AudioUrl,
BinaryContent,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CachePoint,
DocumentUrl,
FilePart,
FinishReason,
ImageUrl,
ModelMessage,
ModelRequest,
ModelResponse,
ModelResponsePart,
ModelResponseStreamEvent,
RetryPromptPart,
SystemPromptPart,
TextPart,
ThinkingPart,
ToolCallPart,
ToolReturnPart,
UploadedFile,
UserContent,
UserPromptPart,
VideoUrl,
)
from ..profiles import ModelProfile, ModelProfileSpec
from ..profiles.groq import GroqModelProfile
from ..providers import Provider, infer_provider
from ..settings import ModelSettings
from ..tools import ToolDefinition
from . import (
Model,
ModelRequestParameters,
StreamedResponse,
check_allow_model_requests,
download_item,
get_user_agent,
)
try:
from groq import NOT_GIVEN, APIConnectionError, APIError, APIStatusError, AsyncGroq, AsyncStream, NotGiven
from groq.types import chat
from groq.types.chat.chat_completion_content_part_image_param import ImageURL
from groq.types.chat.chat_completion_message import ExecutedTool
except ImportError as _import_error:
raise ImportError(
'Please install `groq` to use the Groq model, '
'you can use the `groq` optional group — `pip install "pydantic-ai-slim[groq]"`'
) from _import_error
ProductionGroqModelNames = Literal[
'llama-3.1-8b-instant',
'llama-3.3-70b-versatile',
'meta-llama/llama-guard-4-12b',
'openai/gpt-oss-120b',
'openai/gpt-oss-20b',
'whisper-large-v3',
'whisper-large-v3-turbo',
]
"""Production Groq models from <https://console.groq.com/docs/models#production-models>."""
PreviewGroqModelNames = Literal[
'meta-llama/llama-4-maverick-17b-128e-instruct',
'meta-llama/llama-4-scout-17b-16e-instruct',
'meta-llama/llama-prompt-guard-2-22m',
'meta-llama/llama-prompt-guard-2-86m',
'moonshotai/kimi-k2-instruct-0905',
'openai/gpt-oss-safeguard-20b',
'playai-tts',
'playai-tts-arabic',
'qwen/qwen-3-32b',
]
"""Preview Groq models from <https://console.groq.com/docs/models#preview-models>."""
GroqModelName = str | ProductionGroqModelNames | PreviewGroqModelNames
"""Possible Groq model names.
Since Groq supports a variety of models and the list changes frequencly, we explicitly list the named models as of 2025-03-31
but allow any name in the type hints.
See <https://console.groq.com/docs/models> for an up to date date list of models and more details.
"""
_FINISH_REASON_MAP: dict[Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'], FinishReason] = {
'stop': 'stop',
'length': 'length',
'tool_calls': 'tool_call',
'content_filter': 'content_filter',
'function_call': 'tool_call',
}
class GroqModelSettings(ModelSettings, total=False):
"""Settings used for a Groq model request."""
# ALL FIELDS MUST BE `groq_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
groq_reasoning_format: Literal['hidden', 'raw', 'parsed']
"""The format of the reasoning output.
See [the Groq docs](https://console.groq.com/docs/reasoning#reasoning-format) for more details.
"""
@dataclass(init=False)
class GroqModel(Model):
"""A model that uses the Groq API.
Internally, this uses the [Groq Python client](https://github.com/groq/groq-python) to interact with the API.
Apart from `__init__`, all methods are private or match those of the base class.
"""
client: AsyncGroq = field(repr=False)
_model_name: GroqModelName = field(repr=False)
_provider: Provider[AsyncGroq] = field(repr=False)
def __init__(
self,
model_name: GroqModelName,
*,
provider: Literal['groq', 'gateway'] | Provider[AsyncGroq] = 'groq',
profile: ModelProfileSpec | None = None,
settings: ModelSettings | None = None,
):
"""Initialize a Groq model.
Args:
model_name: The name of the Groq model to use. List of model names available
[here](https://console.groq.com/docs/models).
provider: The provider to use for authentication and API access. Can be either the string
'groq' or an instance of `Provider[AsyncGroq]`. If not provided, a new provider will be
created using the other parameters.
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
settings: Model-specific settings that will be used as defaults for this model.
"""
self._model_name = model_name
if isinstance(provider, str):
provider = infer_provider('gateway/groq' if provider == 'gateway' else provider)
self._provider = provider
self.client = provider.client
super().__init__(settings=settings, profile=profile or provider.model_profile)
@property
def base_url(self) -> str:
return str(self.client.base_url)
@property
def model_name(self) -> GroqModelName:
"""The model name."""
return self._model_name
@property
def system(self) -> str:
"""The model provider."""
return self._provider.name
@classmethod
def supported_builtin_tools(cls) -> frozenset[type[AbstractBuiltinTool]]:
"""Return the set of builtin tool types this model can handle."""
return frozenset({WebSearchTool})
async def request(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
try:
response = await self._completions_create(
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
)
except ModelHTTPError as e:
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
# but we'd rather handle it ourselves so we can tell the model to retry the tool call.
if (failed_generation := _parse_tool_use_failed_error(e.body)) is not None:
if isinstance(failed_generation, _GroqToolUseFailedGeneration):
part = ToolCallPart(
tool_name=failed_generation.name,
args=failed_generation.arguments,
)
elif failed_generation:
part = TextPart(content=failed_generation)
else: # pragma: no cover
part = None
return ModelResponse(
parts=[part] if part else [],
model_name=e.model_name,
provider_name=self._provider.name,
provider_url=self.base_url,
finish_reason='error',
)
raise
model_response = self._process_response(response)
return model_response
@asynccontextmanager
async def request_stream(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
run_context: RunContext[Any] | None = None,
) -> AsyncIterator[StreamedResponse]:
check_allow_model_requests()
model_settings, model_request_parameters = self.prepare_request(
model_settings,
model_request_parameters,
)
response = await self._completions_create(
messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters
)
async with response:
yield await self._process_streamed_response(response, model_request_parameters)
def _translate_thinking(
self,
model_settings: GroqModelSettings,
model_request_parameters: ModelRequestParameters,
) -> Literal['hidden', 'raw', 'parsed'] | NotGiven:
"""Get reasoning format, falling back to unified thinking when provider-specific setting is not set."""
if fmt := model_settings.get('groq_reasoning_format'):
return fmt
thinking = model_request_parameters.thinking
if thinking is False:
# Groq has no true disable; 'hidden' suppresses reasoning output
return 'hidden'
if thinking is not None:
return 'parsed'
return NOT_GIVEN
@overload
async def _completions_create(
self,
messages: list[ModelMessage],
stream: Literal[True],
model_settings: GroqModelSettings,
model_request_parameters: ModelRequestParameters,
) -> AsyncStream[chat.ChatCompletionChunk]:
pass
@overload
async def _completions_create(
self,
messages: list[ModelMessage],
stream: Literal[False],
model_settings: GroqModelSettings,
model_request_parameters: ModelRequestParameters,
) -> chat.ChatCompletion:
pass
async def _completions_create(
self,
messages: list[ModelMessage],
stream: bool,
model_settings: GroqModelSettings,
model_request_parameters: ModelRequestParameters,
) -> chat.ChatCompletion | AsyncStream[chat.ChatCompletionChunk]:
tools = self._get_tools(model_request_parameters)
tools += self._get_builtin_tools(model_request_parameters)
if not tools:
tool_choice: Literal['none', 'required', 'auto'] | None = None
elif not model_request_parameters.allow_text_output:
tool_choice = 'required'
else:
tool_choice = 'auto'
groq_messages = await self._map_messages(messages, model_request_parameters)
response_format: chat.completion_create_params.ResponseFormat | None = None
if model_request_parameters.output_mode == 'native':
output_object = model_request_parameters.output_object
assert output_object is not None
response_format = self._map_json_schema(output_object)
elif (
model_request_parameters.output_mode == 'prompted'
and not tools
and self.profile.supports_json_object_output
): # pragma: no branch
response_format = {'type': 'json_object'}
try:
extra_headers = model_settings.get('extra_headers', {})
extra_headers.setdefault('User-Agent', get_user_agent())
return await self.client.chat.completions.create(
model=self._model_name,
messages=groq_messages,
n=1,
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
tools=tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
stop=model_settings.get('stop_sequences', NOT_GIVEN),
stream=stream,
response_format=response_format or NOT_GIVEN,
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
temperature=model_settings.get('temperature', NOT_GIVEN),
top_p=model_settings.get('top_p', NOT_GIVEN),
timeout=model_settings.get('timeout', NOT_GIVEN),
seed=model_settings.get('seed', NOT_GIVEN),
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
reasoning_format=self._translate_thinking(model_settings, model_request_parameters),
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
extra_headers=extra_headers,
extra_body=model_settings.get('extra_body'),
)
except APIStatusError as e:
if (status_code := e.status_code) >= 400:
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
raise ModelAPIError(model_name=self.model_name, message=e.message) from e # pragma: no cover
except APIConnectionError as e:
raise ModelAPIError(model_name=self.model_name, message=e.message) from e
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
"""Process a non-streamed response, and prepare a message to return."""
choice = response.choices[0]
items: list[ModelResponsePart] = []
if choice.message.reasoning is not None:
# NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
items.append(ThinkingPart(content=choice.message.reasoning))
if choice.message.executed_tools:
for tool in choice.message.executed_tools:
call_part, return_part = _map_executed_tool(tool, self.system)
if call_part and return_part: # pragma: no branch
items.append(call_part)
items.append(return_part)
if choice.message.content:
# NOTE: The `<think>` tag is only present if `groq_reasoning_format` is set to `raw`.
items.extend(split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags))
if choice.message.tool_calls is not None:
for c in choice.message.tool_calls:
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
raw_finish_reason = choice.finish_reason
provider_details: dict[str, Any] = {'finish_reason': raw_finish_reason}
if response.created: # pragma: no branch
provider_details['timestamp'] = number_to_datetime(response.created)
finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
return ModelResponse(
parts=items,
usage=_map_usage(response),
model_name=response.model,
provider_response_id=response.id,
provider_name=self._provider.name,
provider_url=self.base_url,
finish_reason=finish_reason,
provider_details=provider_details,
)
async def _process_streamed_response(
self, response: AsyncStream[chat.ChatCompletionChunk], model_request_parameters: ModelRequestParameters
) -> GroqStreamedResponse:
"""Process a streamed response, and prepare a streaming response to return."""
peekable_response = _utils.PeekableAsyncStream(response)
first_chunk = await peekable_response.peek()
if isinstance(first_chunk, _utils.Unset):
raise UnexpectedModelBehavior( # pragma: no cover
'Streamed response ended without content or tool calls'
)
return GroqStreamedResponse(
model_request_parameters=model_request_parameters,
_response=peekable_response,
_model_name=first_chunk.model,
_model_profile=self.profile,
_provider_name=self._provider.name,
_provider_url=self.base_url,
_provider_timestamp=number_to_datetime(first_chunk.created),
)
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
def _get_builtin_tools(
self, model_request_parameters: ModelRequestParameters
) -> list[chat.ChatCompletionToolParam]:
tools: list[chat.ChatCompletionToolParam] = []
for tool in model_request_parameters.builtin_tools:
if isinstance(tool, WebSearchTool):
if not GroqModelProfile.from_profile(self.profile).groq_always_has_web_search_builtin_tool:
raise UserError('`WebSearchTool` is not supported by Groq') # pragma: no cover
else: # pragma: no cover
raise UserError(
f'`{tool.__class__.__name__}` is not supported by `GroqModel`. If it should be, please file an issue.'
)
return tools
async def _map_messages(
self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters
) -> list[chat.ChatCompletionMessageParam]:
"""Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
groq_messages: list[chat.ChatCompletionMessageParam] = []
for message in messages:
if isinstance(message, ModelRequest):
async for item in self._map_user_message(message):
groq_messages.append(item)
elif isinstance(message, ModelResponse):
texts: list[str] = []
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
for item in message.parts:
if isinstance(item, TextPart):
texts.append(item.content)
elif isinstance(item, ToolCallPart):
tool_calls.append(self._map_tool_call(item))
elif isinstance(item, ThinkingPart):
start_tag, end_tag = self.profile.thinking_tags
texts.append('\n'.join([start_tag, item.content, end_tag]))
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
# These are not currently sent back
pass
elif isinstance(item, FilePart): # pragma: no cover
# Files generated by models are not sent back to models that don't themselves generate files.
pass
else:
assert_never(item)
message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
if texts:
# Note: model responses from this model should only have one text item, so the following
# shouldn't merge multiple texts into one unless you switch models between runs:
message_param['content'] = '\n\n'.join(texts)
if tool_calls:
message_param['tool_calls'] = tool_calls
groq_messages.append(message_param)
else:
assert_never(message)
if instructions := self._get_instructions(messages, model_request_parameters):
system_prompt_count = sum(1 for m in groq_messages if m.get('role') == 'system')
groq_messages.insert(
system_prompt_count, chat.ChatCompletionSystemMessageParam(role='system', content=instructions)
)
return groq_messages
@staticmethod
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
return chat.ChatCompletionMessageToolCallParam(
id=_guard_tool_call_id(t=t),
type='function',
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
)
@staticmethod
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
return {
'type': 'function',
'function': {
'name': f.name,
'description': f.description or '',
'parameters': f.parameters_json_schema,
},
}
def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
response_format_param: chat.completion_create_params.ResponseFormatResponseFormatJsonSchema = {
'type': 'json_schema',
'json_schema': {
'name': o.name or DEFAULT_OUTPUT_TOOL_NAME,
'schema': o.json_schema,
'strict': o.strict,
},
}
if o.description: # pragma: no branch
response_format_param['json_schema']['description'] = o.description
return response_format_param
async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]:
file_content: list[UserContent] = []
for part in message.parts:
if isinstance(part, SystemPromptPart):
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
elif isinstance(part, UserPromptPart):
yield await self._map_user_prompt(part)
elif isinstance(part, ToolReturnPart):
tool_text, tool_file_content = part.model_response_str_and_user_content()
file_content.extend(tool_file_content)
yield chat.ChatCompletionToolMessageParam(
role='tool',
tool_call_id=_guard_tool_call_id(t=part),
content=tool_text,
)
elif isinstance(part, RetryPromptPart): # pragma: no branch
if part.tool_name is None:
yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
else:
yield chat.ChatCompletionToolMessageParam(
role='tool',
tool_call_id=_guard_tool_call_id(t=part),
content=part.model_response(),
)
if file_content:
yield await self._map_user_prompt(UserPromptPart(content=file_content))
async def _map_user_prompt(self, part: UserPromptPart) -> chat.ChatCompletionUserMessageParam:
content: str | list[chat.ChatCompletionContentPartParam]
if isinstance(part.content, str):
content = part.content
else:
content = []
for item in part.content:
if isinstance(item, str):
content.append(chat.ChatCompletionContentPartTextParam(text=item, type='text'))
elif isinstance(item, ImageUrl):
image_url_str = item.url
if item.force_download:
downloaded = await download_item(item, data_format='base64_uri')
image_url_str = downloaded['data']
image_url = ImageURL(url=image_url_str)
content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
elif isinstance(item, BinaryContent):
if item.is_image:
image_url = ImageURL(url=item.data_uri)
content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
else:
raise NotImplementedError('Only images are supported for BinaryContent in Groq user prompts')
elif isinstance(item, DocumentUrl):
raise NotImplementedError('DocumentUrl is not supported in Groq user prompts')
elif isinstance(item, AudioUrl):
raise NotImplementedError('AudioUrl is not supported in Groq user prompts')
elif isinstance(item, VideoUrl):
raise NotImplementedError('VideoUrl is not supported in Groq user prompts')
elif isinstance(item, UploadedFile):
raise NotImplementedError('UploadedFile is not supported in Groq user prompts')
elif isinstance(item, CachePoint):
pass
else:
assert_never(item)
return chat.ChatCompletionUserMessageParam(role='user', content=content)
@dataclass
class GroqStreamedResponse(StreamedResponse):
"""Implementation of `StreamedResponse` for Groq models."""
_model_name: GroqModelName
_model_profile: ModelProfile
_response: AsyncIterable[chat.ChatCompletionChunk]
_provider_name: str
_provider_url: str
_provider_timestamp: datetime | None = None
_timestamp: datetime = field(default_factory=_utils.now_utc)
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
try:
executed_tool_call_id: str | None = None
reasoning_index = 0
reasoning = False
if self._provider_timestamp is not None: # pragma: no branch
self.provider_details = {'timestamp': self._provider_timestamp}
async for chunk in self._response:
self._usage += _map_usage(chunk)
if chunk.id: # pragma: no branch
self.provider_response_id = chunk.id
try:
choice = chunk.choices[0]
except IndexError:
continue
if raw_finish_reason := choice.finish_reason:
self.provider_details = {**(self.provider_details or {}), 'finish_reason': raw_finish_reason}
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
if choice.delta.reasoning is not None:
if not reasoning:
reasoning_index += 1
reasoning = True
# NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
for event in self._parts_manager.handle_thinking_delta(
vendor_part_id=f'reasoning-{reasoning_index}', content=choice.delta.reasoning
):
yield event
else:
reasoning = False
if choice.delta.executed_tools:
for tool in choice.delta.executed_tools:
call_part, return_part = _map_executed_tool(
tool, self.provider_name, streaming=True, tool_call_id=executed_tool_call_id
)
if call_part:
executed_tool_call_id = call_part.tool_call_id
yield self._parts_manager.handle_part(
vendor_part_id=f'executed_tools-{tool.index}-call', part=call_part
)
if return_part:
executed_tool_call_id = None
yield self._parts_manager.handle_part(
vendor_part_id=f'executed_tools-{tool.index}-return', part=return_part
)
# Handle the text part of the response
content = choice.delta.content
if content:
for event in self._parts_manager.handle_text_delta(
vendor_part_id='content',
content=content,
thinking_tags=self._model_profile.thinking_tags,
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
):
yield event
# Handle the tool calls
for dtc in choice.delta.tool_calls or []:
maybe_event = self._parts_manager.handle_tool_call_delta(
vendor_part_id=dtc.index,
tool_name=dtc.function and dtc.function.name,
args=dtc.function and dtc.function.arguments,
tool_call_id=dtc.id,
)
if maybe_event is not None:
yield maybe_event
except APIError as e:
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
# but we'd rather handle it ourselves so we can tell the model to retry the tool call
if (failed_generation := _parse_tool_use_failed_error(e.body)) is not None:
if isinstance(failed_generation, _GroqToolUseFailedGeneration):
yield self._parts_manager.handle_tool_call_part(
vendor_part_id='tool_use_failed',
tool_name=failed_generation.name,
args=failed_generation.arguments,
)
elif failed_generation: # pragma: no cover
# This branch is not covered because when streaming, the non-tool call text would already
# have streamed before the `tool_use_failed` error which comes with `failed_generation=''`,
# but we keep this here for (hypothetical?) cases where that field would not be empty.
for event in self._parts_manager.handle_text_delta(
vendor_part_id='tool_use_failed',
content=failed_generation,
thinking_tags=self._model_profile.thinking_tags,
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
):
yield event
return
raise # pragma: no cover
@property
def model_name(self) -> GroqModelName:
"""Get the model name of the response."""
return self._model_name
@property
def provider_name(self) -> str:
"""Get the provider name."""
return self._provider_name
@property
def provider_url(self) -> str:
"""Get the provider base URL."""
return self._provider_url
@property
def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
return self._timestamp
def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.RequestUsage:
response_usage = None
if isinstance(completion, chat.ChatCompletion):
response_usage = completion.usage
elif completion.x_groq is not None:
response_usage = completion.x_groq.usage
if response_usage is None:
return usage.RequestUsage()
return usage.RequestUsage(
input_tokens=response_usage.prompt_tokens,
output_tokens=response_usage.completion_tokens,
)
class _GroqToolUseFailedGeneration(BaseModel):
name: str
arguments: dict[str, Any]
class _GroqToolUseFailedInnerError(BaseModel):
message: str
type: Literal['invalid_request_error']
code: Literal['tool_use_failed']
failed_generation: str
class _GroqToolUseFailedError(BaseModel):
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
# but we'd rather handle it ourselves so we can tell the model to retry the tool call.
# Example payload from `exception.body`:
# {
# 'error': {
# 'message': "Tool call validation failed: tool call validation failed: parameters for tool get_something_by_name did not match schema: errors: [missing properties: 'name', additionalProperties 'foo' not allowed]",
# 'type': 'invalid_request_error',
# 'code': 'tool_use_failed',
# 'failed_generation': '{"name": "get_something_by_name", "arguments": {\n "foo": "bar"\n}}',
# }
# }
error: _GroqToolUseFailedInnerError
def _parse_tool_use_failed_error(body: Any) -> _GroqToolUseFailedGeneration | str | None:
if not isinstance(body, dict):
return None # pragma: no cover
try:
error = _GroqToolUseFailedError.model_validate(body)
error = error.error
except ValidationError:
try:
error = _GroqToolUseFailedInnerError.model_validate(body)
except ValidationError:
return None
try:
return _GroqToolUseFailedGeneration.model_validate_json(error.failed_generation)
except ValidationError:
return error.failed_generation
def _map_executed_tool(
tool: ExecutedTool, provider_name: str, streaming: bool = False, tool_call_id: str | None = None
) -> tuple[BuiltinToolCallPart | None, BuiltinToolReturnPart | None]:
if tool.type == 'search':
if tool.search_results and (tool.search_results.images or tool.search_results.results):
results = tool.search_results.model_dump(mode='json')
else:
results = tool.output
tool_call_id = tool_call_id or generate_tool_call_id()
call_part = BuiltinToolCallPart(
tool_name=WebSearchTool.kind,
args=from_json(tool.arguments),
provider_name=provider_name,
tool_call_id=tool_call_id,
)
return_part = BuiltinToolReturnPart(
tool_name=WebSearchTool.kind,
content=results,
provider_name=provider_name,
tool_call_id=tool_call_id,
)
if streaming:
if results:
return None, return_part
else:
return call_part, None
else:
return call_part, return_part
else: # pragma: no cover
return None, None