Skip to content

Commit eedd5bc

Browse files
sk5268SaedbhatiJINO-ROHIT
authored
Feat: Add Cerebras Platform Models. (#3424)
Co-authored-by: Saed Bhati <[email protected]> Co-authored-by: JINO ROHIT <[email protected]>
1 parent e6f6105 commit eedd5bc

File tree

9 files changed

+304
-0
lines changed

9 files changed

+304
-0
lines changed

.env.example

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
# Models API
1111
#===========================================
1212

13+
# Cerebras API (https://chat.cerebras.ai/)
14+
# CEREBRAS_API_KEY="Fill your API key here"
15+
1316
# OpenAI API (https://platform.openai.com/signup)
1417
# OPENAI_API_KEY="Fill your API key here"
1518

camel/configs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .anthropic_config import ANTHROPIC_API_PARAMS, AnthropicConfig
1818
from .base_config import BaseConfig
1919
from .bedrock_config import BEDROCK_API_PARAMS, BedrockConfig
20+
from .cerebras_config import CEREBRAS_API_PARAMS, CerebrasConfig
2021
from .cohere_config import COHERE_API_PARAMS, CohereConfig
2122
from .cometapi_config import COMETAPI_API_PARAMS, CometAPIConfig
2223
from .crynux_config import CRYNUX_API_PARAMS, CrynuxConfig
@@ -93,6 +94,8 @@
9394
'SAMBA_CLOUD_API_PARAMS',
9495
'TogetherAIConfig',
9596
'TOGETHERAI_API_PARAMS',
97+
'CerebrasConfig',
98+
'CEREBRAS_API_PARAMS',
9699
'CohereConfig',
97100
'COHERE_API_PARAMS',
98101
'CometAPIConfig',

camel/configs/cerebras_config.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14+
from __future__ import annotations
15+
16+
from typing import Dict, Optional, Sequence, Union
17+
18+
from camel.configs.base_config import BaseConfig
19+
20+
21+
class CerebrasConfig(BaseConfig):
22+
r"""Defines the parameters for generating chat completions using Cerebras
23+
compatibility.
24+
25+
Reference: https://inference-docs.cerebras.ai/resources/openai
26+
27+
Args:
28+
temperature (float, optional): Sampling temperature to use, between
29+
:obj:`0` and :obj:`2`. Higher values make the output more random,
30+
while lower values make it more focused and deterministic.
31+
(default: :obj:`None`)
32+
33+
top_p (float, optional): An alternative to sampling with temperature,
34+
called nucleus sampling, where the model considers the results of
35+
the tokens with top_p probability mass. So :obj:`0.1` means only
36+
the tokens comprising the top 10% probability mass are considered.
37+
(default: :obj:`None`)
38+
39+
response_format (object, optional): An object specifying the format
40+
that the model must output.Setting to {"type": "json_object"}
41+
enables JSON mode, which guarantees the message the model
42+
generates is valid JSON. (default: :obj:`None`)
43+
44+
stream (bool, optional): If True, partial message deltas will be sent
45+
as data-only server-sent events as they become available.
46+
(default: :obj:`None`)
47+
48+
stop (str or list, optional): Up to :obj:`4` sequences where the API
49+
will stop generating further tokens. (default: :obj:`None`)
50+
51+
max_tokens (int, optional): The maximum number of tokens to generate
52+
in the chat completion. The total length of input tokens and
53+
generated tokens is limited by the model's context length.
54+
(default: :obj:`None`)
55+
56+
user (str, optional): A unique identifier representing your end-user,
57+
which can help OpenAI to monitor and detect abuse.
58+
(default: :obj:`None`)
59+
60+
tools (list[FunctionTool], optional): A list of tools the model may
61+
call. Currently, only functions are supported as a tool. Use this
62+
to provide a list of functions the model may generate JSON inputs
63+
for. A max of 128 functions are supported.
64+
65+
tool_choice (Union[dict[str, str], str], optional): Controls which (if
66+
any) tool is called by the model. :obj:`"none"` means the model
67+
will not call any tool and instead generates a message.
68+
:obj:`"auto"` means the model can pick between generating a
69+
message or calling one or more tools. :obj:`"required"` means the
70+
model must call one or more tools. Specifying a particular tool
71+
via {"type": "function", "function": {"name": "my_function"}}
72+
forces the model to call that tool. :obj:`"none"` is the default
73+
when no tools are present. :obj:`"auto"` is the default if tools
74+
are present.
75+
76+
reasoning_effort(str, optional): A parameter specifying the level of
77+
reasoning used by certain model types. Valid values are :obj:
78+
`"low"`, :obj:`"medium"`, or :obj:`"high"`. If set, it is only
79+
applied to the model types that support it (e.g., :obj:`o1`,
80+
:obj:`o1mini`, :obj:`o1preview`, :obj:`o3mini`). If not provided
81+
or if the model type does not support it, this parameter is
82+
ignored. (default: :obj:`None`)
83+
"""
84+
85+
temperature: Optional[float] = None
86+
top_p: Optional[float] = None
87+
stream: Optional[bool] = None
88+
stop: Optional[Union[str, Sequence[str]]] = None
89+
max_tokens: Optional[int] = None
90+
response_format: Optional[Dict] = None
91+
user: Optional[str] = None
92+
tool_choice: Optional[Union[Dict[str, str], str]] = None
93+
reasoning_effort: Optional[str] = None
94+
95+
96+
CEREBRAS_API_PARAMS = {param for param in CerebrasConfig.model_fields.keys()}

camel/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .azure_openai_model import AzureOpenAIModel
2020
from .base_audio_model import BaseAudioModel
2121
from .base_model import BaseModelBackend
22+
from .cerebras_model import CerebrasModel
2223
from .cohere_model import CohereModel
2324
from .cometapi_model import CometAPIModel
2425
from .crynux_model import CrynuxModel
@@ -71,6 +72,7 @@
7172
'GroqModel',
7273
'StubModel',
7374
'ZhipuAIModel',
75+
'CerebrasModel',
7476
'CohereModel',
7577
'CometAPIModel',
7678
'ModelFactory',

camel/models/cerebras_model.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14+
import os
15+
from typing import Any, Dict, Optional, Union
16+
17+
from camel.configs import CerebrasConfig
18+
from camel.models.openai_compatible_model import OpenAICompatibleModel
19+
from camel.types import ModelType
20+
from camel.utils import (
21+
BaseTokenCounter,
22+
api_keys_required,
23+
)
24+
25+
26+
class CerebrasModel(OpenAICompatibleModel):
27+
r"""LLM API served by Cerebras in a unified
28+
OpenAICompatibleModel interface.
29+
30+
Args:
31+
model_type (Union[ModelType, str]): Model for which a backend is
32+
created.
33+
model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
34+
that will be fed into:obj:`openai.ChatCompletion.create()`.
35+
If:obj:`None`, :obj:`CerebrasConfig().as_dict()` will be used.
36+
(default: :obj:`None`)
37+
api_key (Optional[str], optional): The API key for authenticating
38+
with the Cerebras service. (default: :obj:`None`).
39+
url (Optional[str], optional): The url to the Cerebras service.
40+
(default: :obj:`None`)
41+
token_counter (Optional[BaseTokenCounter], optional): Token counter to
42+
use for the model. If not provided, :obj:`OpenAITokenCounter(
43+
ModelType.GPT_4O_MINI)` will be used.
44+
(default: :obj:`None`)
45+
timeout (Optional[float], optional): The timeout value in seconds for
46+
API calls. If not provided, will fall back to the MODEL_TIMEOUT
47+
environment variable or default to 180 seconds.
48+
(default: :obj:`None`)
49+
max_retries (int, optional): Maximum number of retries for API calls.
50+
(default: :obj:`3`)
51+
**kwargs (Any): Additional arguments to pass to the client
52+
initialization.
53+
"""
54+
55+
@api_keys_required([("api_key", "CEREBRAS_API_KEY")])
56+
def __init__(
57+
self,
58+
model_type: Union[ModelType, str],
59+
model_config_dict: Optional[Dict[str, Any]] = None,
60+
api_key: Optional[str] = None,
61+
url: Optional[str] = None,
62+
token_counter: Optional[BaseTokenCounter] = None,
63+
timeout: Optional[float] = None,
64+
max_retries: int = 3,
65+
**kwargs: Any,
66+
) -> None:
67+
if model_config_dict is None:
68+
model_config_dict = CerebrasConfig().as_dict()
69+
api_key = api_key or os.environ.get("CEREBRAS_API_KEY")
70+
url = url or os.environ.get(
71+
"CEREBRAS_API_BASE_URL", "https://api.cerebras.ai/v1"
72+
)
73+
timeout = timeout or float(os.environ.get("MODEL_TIMEOUT", 180))
74+
super().__init__(
75+
model_type=model_type,
76+
model_config_dict=model_config_dict,
77+
api_key=api_key,
78+
url=url,
79+
token_counter=token_counter,
80+
timeout=timeout,
81+
max_retries=max_retries,
82+
**kwargs,
83+
)

camel/models/model_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from camel.models.aws_bedrock_model import AWSBedrockModel
2323
from camel.models.azure_openai_model import AzureOpenAIModel
2424
from camel.models.base_model import BaseModelBackend
25+
from camel.models.cerebras_model import CerebrasModel
2526
from camel.models.cohere_model import CohereModel
2627
from camel.models.cometapi_model import CometAPIModel
2728
from camel.models.crynux_model import CrynuxModel
@@ -89,6 +90,7 @@ class ModelFactory:
8990
ModelPlatformType.AZURE: AzureOpenAIModel,
9091
ModelPlatformType.ANTHROPIC: AnthropicModel,
9192
ModelPlatformType.GROQ: GroqModel,
93+
ModelPlatformType.CEREBRAS: CerebrasModel,
9294
ModelPlatformType.COMETAPI: CometAPIModel,
9395
ModelPlatformType.NEBIUS: NebiusModel,
9496
ModelPlatformType.LMSTUDIO: LMStudioModel,

camel/types/enums.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ class ModelType(UnifiedModelType, Enum):
9191
GROQ_MIXTRAL_8_7B = "mixtral-8x7b-32768"
9292
GROQ_GEMMA_2_9B_IT = "gemma2-9b-it"
9393

94+
# Cerebras platform models
95+
CEREBRAS_GPT_OSS_120B = "gpt-oss-120b"
96+
CEREBRAS_LLAMA_3_1_8B = "llama3.1-8b"
97+
CEREBRAS_LLAMA_3_3_70B = "llama3.3-70b"
98+
CEREBRAS_QWEN_3_32B = "qwen-3-32b"
99+
94100
# Nebius AI Studio platform models
95101
NEBIUS_GPT_OSS_120B = "gpt-oss-120b"
96102
NEBIUS_GPT_OSS_20B = "gpt-oss-20b"
@@ -554,6 +560,7 @@ def support_native_tool_calling(self) -> bool:
554560
self.is_together,
555561
self.is_sambanova,
556562
self.is_groq,
563+
self.is_cerebras,
557564
self.is_openrouter,
558565
self.is_lmstudio,
559566
self.is_sglang,
@@ -696,6 +703,16 @@ def is_groq(self) -> bool:
696703
ModelType.GROQ_GEMMA_2_9B_IT,
697704
}
698705

706+
@property
707+
def is_cerebras(self) -> bool:
708+
r"""Returns whether this type of models is served by Cerebras."""
709+
return self in {
710+
ModelType.CEREBRAS_GPT_OSS_120B,
711+
ModelType.CEREBRAS_LLAMA_3_1_8B,
712+
ModelType.CEREBRAS_LLAMA_3_3_70B,
713+
ModelType.CEREBRAS_QWEN_3_32B,
714+
}
715+
699716
@property
700717
def is_nebius(self) -> bool:
701718
r"""Returns whether this type of models is served by Nebius AI
@@ -1165,6 +1182,7 @@ def token_limit(self) -> int:
11651182
}:
11661183
return 4_096
11671184
elif self in {
1185+
ModelType.CEREBRAS_LLAMA_3_1_8B,
11681186
ModelType.GPT_4,
11691187
ModelType.GROQ_LLAMA_3_8B,
11701188
ModelType.GROQ_LLAMA_3_70B,
@@ -1312,6 +1330,9 @@ def token_limit(self) -> int:
13121330
return 32_768
13131331
elif self in {
13141332
ModelType.MISTRAL_MIXTRAL_8x22B,
1333+
ModelType.CEREBRAS_GPT_OSS_120B,
1334+
ModelType.CEREBRAS_LLAMA_3_3_70B,
1335+
ModelType.CEREBRAS_QWEN_3_32B,
13151336
ModelType.DEEPSEEK_CHAT,
13161337
ModelType.DEEPSEEK_REASONER,
13171338
ModelType.PPIO_DEEPSEEK_R1_TURBO,
@@ -1801,6 +1822,7 @@ class ModelPlatformType(Enum):
18011822
CRYNUX = "crynux"
18021823
AIHUBMIX = "aihubmix"
18031824
MINIMAX = "minimax"
1825+
CEREBRAS = "cerebras"
18041826

18051827
@classmethod
18061828
def from_name(cls, name):
@@ -1991,6 +2013,11 @@ def is_minimax(self) -> bool:
19912013
r"""Returns whether this platform is Minimax M2."""
19922014
return self is ModelPlatformType.MINIMAX
19932015

2016+
@property
2017+
def is_cerebras(self) -> bool:
2018+
r"""Returns whether this platform is Cerebras."""
2019+
return self is ModelPlatformType.CEREBRAS
2020+
19942021

19952022
class AudioModelType(Enum):
19962023
TTS_1 = "tts-1"
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14+
15+
from camel.agents import ChatAgent
16+
from camel.configs import CerebrasConfig
17+
from camel.models import ModelFactory
18+
from camel.types import ModelPlatformType, ModelType
19+
20+
# Define system message
21+
model = ModelFactory.create(
22+
model_platform=ModelPlatformType.CEREBRAS,
23+
model_type=ModelType.CEREBRAS_LLAMA_3_3_70B,
24+
model_config_dict=CerebrasConfig(temperature=0.2).as_dict(),
25+
)
26+
27+
sys_msg = "You are a helpful assistant."
28+
29+
# Set agent
30+
camel_agent = ChatAgent(system_message=sys_msg, model=model)
31+
32+
user_msg = """Say hi to CAMEL AI, one open-source community
33+
dedicated to the study of autonomous and communicative agents."""
34+
35+
# Get response information
36+
response = camel_agent.step(user_msg)
37+
print(response.msgs[0].content)
38+
39+
'''
40+
===============================================================================
41+
Hello to the CAMEL AI community. It's great to see a group of like-minded
42+
individuals coming together to explore and advance the field of autonomous and
43+
communicative agents. Your open-source approach is truly commendable, as it
44+
fosters collaboration, innovation, and transparency. I'm excited to learn more
45+
about your projects and initiatives, and I'm happy to help in any way I can.
46+
Keep pushing the boundaries of AI research and development!
47+
===============================================================================
48+
'''

0 commit comments

Comments
 (0)