Skip to content

Commit 23e1a87

Browse files
authored
Add support for claude thinking models (#103)
- New class added for claude reasoning model - Bump anthropic pip requirement to latest to support thinking requests Note: conde environment files remain unchanged. Warning: Current claude API returns several responses (thinking, text, ...) instead of just text. This will likely cause compatibility issues with multi-turn workflows.
1 parent ede9ee1 commit 23e1a87

File tree

5 files changed

+61
-2
lines changed

5 files changed

+61
-2
lines changed

eureka_ml_insights/configs/model_configs.py

+14
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from eureka_ml_insights.models import (
66
AzureOpenAIOModel,
77
ClaudeModel,
8+
ClaudeReasoningModel,
89
DirectOpenAIModel,
910
DirectOpenAIOModel,
1011
GeminiModel,
@@ -201,6 +202,19 @@
201202
},
202203
)
203204

205+
CLAUDE_3_7_SONNET_THINKING_CONFIG = ModelConfig(
206+
ClaudeReasoningModel,
207+
{
208+
"secret_key_params": CLAUDE_SECRET_KEY_PARAMS,
209+
"model_name": "claude-3-7-sonnet-20250219",
210+
"thinking_enabled": True,
211+
"thinking_budget": 16000,
212+
"max_tokens": 20000, # This number should always be higher than the thinking budget
213+
"temperature": 1.0, # As of 03/08/2025, thinking only works with temperature 1.0
214+
"timeout": 600, # We set a timeout of 10 minutes for thinking
215+
},
216+
)
217+
204218
CLAUDE_3_5_SONNET_20241022_CONFIG = ModelConfig(
205219
ClaudeModel,
206220
{

eureka_ml_insights/data_utils/transform.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from eureka_ml_insights.models import (
1616
ClaudeModel,
17+
ClaudeReasoningModel,
1718
GeminiModel,
1819
LlamaServerlessAzureRestEndpointModel,
1920
MistralServerlessAzureRestEndpointModel,
@@ -452,7 +453,8 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
452453
usage_completion_read_col = None
453454
if (self.model_config.class_name is GeminiModel):
454455
usage_completion_read_col = "candidates_token_count"
455-
elif (self.model_config.class_name is ClaudeModel):
456+
elif (self.model_config.class_name is ClaudeModel
457+
or self.model_config.class_name is ClaudeReasoningModel):
456458
usage_completion_read_col = "output_tokens"
457459
elif (self.model_config.class_name is AzureOpenAIOModel
458460
or self.model_config.class_name is AzureOpenAIModel
@@ -463,6 +465,8 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
463465
or self.model_config.class_name is DirectOpenAIOModel
464466
or self.model_config.class_name is TogetherModel):
465467
usage_completion_read_col = "completion_tokens"
468+
else:
469+
logging.warn(f"Model {self.model_config.class_name} is not recognized for extracting completion token usage.")
466470
# if the model is one for which the usage of completion tokens is known, use that corresponding column for the model
467471
# otherwise, use the default "n_output_tokens" which is computed with a universal tokenizer as shown in TokenCounterTransform()
468472
if usage_completion_read_col:

eureka_ml_insights/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
AzureOpenAIModel,
33
AzureOpenAIOModel,
44
ClaudeModel,
5+
ClaudeReasoningModel,
56
DirectOpenAIModel,
67
DirectOpenAIOModel,
78
GeminiModel,
@@ -32,6 +33,7 @@
3233
AzureOpenAIModel,
3334
GeminiModel,
3435
ClaudeModel,
36+
ClaudeReasoningModel,
3537
MistralServerlessAzureRestEndpointModel,
3638
LlamaServerlessAzureRestEndpointModel,
3739
DeepseekR1ServerlessAzureRestEndpointModel,

eureka_ml_insights/models/models.py

+39
Original file line numberDiff line numberDiff line change
@@ -1269,6 +1269,45 @@ def get_response(self, request):
12691269
def handle_request_error(self, e):
12701270
return False
12711271

1272+
@dataclass
1273+
class ClaudeReasoningModel(ClaudeModel):
1274+
"""This class is used to interact with Claude reasoning models through the python api."""
1275+
1276+
model_name: str = None
1277+
temperature: float = 1.
1278+
max_tokens: int = 20000
1279+
timeout: int = 600
1280+
thinking_enabled: bool = True
1281+
thinking_budget: int = 16000
1282+
top_p: float = None
1283+
1284+
def get_response(self, request):
1285+
if self.top_p is not None:
1286+
logging.warning("top_p is not supported for claude reasoning models as of 03/08/2025. It will be ignored.")
1287+
1288+
start_time = time.time()
1289+
thinking = {"type": "enabled", "budget_tokens": self.thinking_budget} if self.thinking_enabled else None
1290+
completion = self.client.messages.create(
1291+
model=self.model_name,
1292+
**request,
1293+
temperature=self.temperature,
1294+
thinking=thinking,
1295+
max_tokens=self.max_tokens,
1296+
)
1297+
end_time = time.time()
1298+
1299+
# Loop through completion.content to find the text output
1300+
for content in completion.content:
1301+
if content.type == 'text':
1302+
self.model_output = content.text
1303+
elif content.type == 'thinking':
1304+
self.thinking_output = content.thinking
1305+
elif content.type == 'redacted_thinking':
1306+
self.redacted_thinking_output = content.data
1307+
1308+
self.response_time = end_time - start_time
1309+
if hasattr(completion, "usage"):
1310+
return {"usage": completion.usage.to_dict()}
12721311

12731312
@dataclass
12741313
class TestModel(Model):

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
packages=find_packages(),
1414
include_package_data=True,
1515
install_requires=[
16-
'anthropic>=0.30.0',
16+
'anthropic>=0.49.0',
1717
'azure-ai-textanalytics>=5.3.0',
1818
'azure-core>=1.29.5',
1919
'azure-keyvault-secrets>=4.8.0',

0 commit comments

Comments
 (0)