forked from pydantic/pydantic-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcerebras.py
More file actions
140 lines (107 loc) · 5.02 KB
/
cerebras.py
File metadata and controls
140 lines (107 loc) · 5.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""Cerebras model implementation using OpenAI-compatible API."""
from __future__ import annotations as _annotations
from dataclasses import dataclass
from typing import Any, Literal, cast
from typing_extensions import override
from ..profiles import ModelProfileSpec
from ..providers import Provider
from ..settings import ModelSettings
from . import ModelRequestParameters
try:
from openai import AsyncOpenAI
from .openai import OpenAIChatModel, OpenAIChatModelSettings
except ImportError as _import_error:
raise ImportError(
'Please install the `openai` package to use the Cerebras model, '
'you can use the `cerebras` optional group — `pip install "pydantic-ai-slim[cerebras]"'
) from _import_error
__all__ = ('CerebrasModel', 'CerebrasModelName', 'CerebrasModelSettings')
LatestCerebrasModelNames = Literal[
'gpt-oss-120b',
'llama-3.3-70b',
'llama3.1-8b',
'qwen-3-235b-a22b-instruct-2507',
'qwen-3-32b',
'zai-glm-4.6',
]
CerebrasModelName = str | LatestCerebrasModelNames
"""Possible Cerebras model names.
Since Cerebras supports a variety of models and the list changes frequently, we explicitly list known models
but allow any name in the type hints.
See <https://inference-docs.cerebras.ai/models/overview> for an up to date list of models.
"""
class CerebrasModelSettings(ModelSettings, total=False):
"""Settings used for a Cerebras model request.
ALL FIELDS MUST BE `cerebras_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
"""
cerebras_disable_reasoning: bool
"""Disable reasoning for the model.
This setting is only supported on reasoning models: `zai-glm-4.6` and `gpt-oss-120b`.
See [the Cerebras docs](https://inference-docs.cerebras.ai/resources/openai#passing-non-standard-parameters) for more details.
"""
@dataclass(init=False)
class CerebrasModel(OpenAIChatModel):
"""A model that uses Cerebras's OpenAI-compatible API.
Cerebras provides ultra-fast inference powered by the Wafer-Scale Engine (WSE).
Apart from `__init__`, all methods are private or match those of the base class.
"""
def __init__(
self,
model_name: CerebrasModelName,
*,
provider: Literal['cerebras'] | Provider[AsyncOpenAI] = 'cerebras',
profile: ModelProfileSpec | None = None,
settings: CerebrasModelSettings | None = None,
):
"""Initialize a Cerebras model.
Args:
model_name: The name of the Cerebras model to use.
provider: The provider to use. Defaults to 'cerebras'.
profile: The model profile to use. Defaults to a profile based on the model name.
settings: Model-specific settings that will be used as defaults for this model.
"""
super().__init__(model_name, provider=provider, profile=profile, settings=settings)
@override
def _translate_thinking(
self,
model_settings: OpenAIChatModelSettings,
model_request_parameters: ModelRequestParameters,
) -> Any:
"""Cerebras handles reasoning via extra_body['disable_reasoning'], not reasoning_effort."""
from openai import omit
# Only pass through explicit openai_reasoning_effort if set; unified thinking
# is handled in _cerebras_settings_to_openai_settings via disable_reasoning.
if effort := model_settings.get('openai_reasoning_effort'):
return effort
return omit
@override
def prepare_request(
self,
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> tuple[ModelSettings | None, ModelRequestParameters]:
merged_settings, customized_parameters = super().prepare_request(model_settings, model_request_parameters)
new_settings = _cerebras_settings_to_openai_settings(
cast(CerebrasModelSettings, merged_settings or {}), customized_parameters
)
return new_settings, customized_parameters
def _cerebras_settings_to_openai_settings(
model_settings: CerebrasModelSettings, model_request_parameters: ModelRequestParameters
) -> OpenAIChatModelSettings:
"""Transforms a 'CerebrasModelSettings' object into an 'OpenAIChatModelSettings' object.
Args:
model_settings: The 'CerebrasModelSettings' object to transform.
model_request_parameters: The 'ModelRequestParameters' object to use for the transformation.
Returns:
An 'OpenAIChatModelSettings' object with equivalent settings.
"""
extra_body = cast(dict[str, Any], model_settings.get('extra_body', {}))
if (disable_reasoning := model_settings.pop('cerebras_disable_reasoning', None)) is not None:
extra_body['disable_reasoning'] = disable_reasoning
elif model_request_parameters.thinking is False:
extra_body['disable_reasoning'] = True
elif model_request_parameters.thinking:
extra_body['disable_reasoning'] = False
if extra_body:
model_settings['extra_body'] = extra_body
return OpenAIChatModelSettings(**model_settings) # type: ignore[reportCallIssue]