Skip to content

Commit 260cc14

Browse files
fix(drivers-prompt-anthropic): make top_p and top_k optional to avoid API conflict
Newer Anthropic models (Claude 4.6+) reject requests that specify both `temperature` and `top_p`. Change `top_p` and `top_k` to Optional fields defaulting to None, and only include them in the API params when explicitly set. This follows Anthropic's recommendation to use `temperature` alone for most use cases. Made-with: Cursor
1 parent c3b8f01 commit 260cc14

File tree

3 files changed

+38
-10
lines changed

3 files changed

+38
-10
lines changed

griptape/drivers/prompt/anthropic_prompt_driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ class AnthropicPromptDriver(BasePromptDriver):
6565
default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True),
6666
kw_only=True,
6767
)
68-
top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True})
69-
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
68+
top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
69+
top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
7070
tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
7171
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
7272
structured_output_strategy: StructuredOutputStrategy = field(
@@ -125,10 +125,10 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
125125
"model": self.model,
126126
"temperature": self.temperature,
127127
"stop_sequences": self.tokenizer.stop_sequences,
128-
"top_p": self.top_p,
129-
"top_k": self.top_k,
130128
"max_tokens": self.max_tokens,
131129
"messages": messages,
130+
**({"top_p": self.top_p} if self.top_p is not None else {}),
131+
**({"top_k": self.top_k} if self.top_k is not None else {}),
132132
**({"system": system_message} if system_message else {}),
133133
**self.extra_params,
134134
}

tests/unit/configs/drivers/test_anthropic_drivers_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def test_to_dict(self, config):
2222
"max_tokens": 1000,
2323
"stream": False,
2424
"model": "claude-3-7-sonnet-latest",
25-
"top_p": 0.999,
26-
"top_k": 250,
25+
"top_p": None,
26+
"top_k": None,
2727
"use_native_tools": True,
2828
"structured_output_strategy": "tool",
2929
"extra_params": {},

tests/unit/drivers/prompt/test_anthropic_prompt_driver.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,6 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, st
382382
model=driver.model,
383383
max_tokens=1000,
384384
temperature=0.1,
385-
top_p=0.999,
386-
top_k=250,
387385
**{"system": "system-input"} if prompt_stack.system_messages else {},
388386
**{
389387
"tools": self.ANTHROPIC_TOOLS if use_native_tools else {},
@@ -430,8 +428,6 @@ def test_try_stream_run(
430428
max_tokens=1000,
431429
temperature=0.1,
432430
stream=True,
433-
top_p=0.999,
434-
top_k=250,
435431
**{"system": "system-input"} if prompt_stack.system_messages else {},
436432
**{
437433
"tools": self.ANTHROPIC_TOOLS if use_native_tools else {},
@@ -464,6 +460,38 @@ def test_try_stream_run(
464460
event = next(stream)
465461
assert event.usage.output_tokens == 10
466462

463+
def test_try_run_with_top_p_and_top_k(self, mock_client, prompt_stack, messages):
464+
# Given
465+
driver = AnthropicPromptDriver(
466+
model="claude-3-haiku",
467+
api_key="api-key",
468+
top_p=0.9,
469+
top_k=100,
470+
)
471+
472+
# When
473+
driver.try_run(prompt_stack)
474+
475+
# Then
476+
call_kwargs = mock_client.return_value.messages.create.call_args
477+
assert call_kwargs.kwargs["top_p"] == 0.9
478+
assert call_kwargs.kwargs["top_k"] == 100
479+
480+
def test_try_run_without_top_p_and_top_k(self, mock_client, prompt_stack, messages):
481+
# Given
482+
driver = AnthropicPromptDriver(
483+
model="claude-3-haiku",
484+
api_key="api-key",
485+
)
486+
487+
# When
488+
driver.try_run(prompt_stack)
489+
490+
# Then
491+
call_kwargs = mock_client.return_value.messages.create.call_args
492+
assert "top_p" not in call_kwargs.kwargs
493+
assert "top_k" not in call_kwargs.kwargs
494+
467495
def test_verify_structured_output_strategy(self):
468496
assert AnthropicPromptDriver(model="foo", structured_output_strategy="tool")
469497

0 commit comments

Comments
 (0)