Skip to content

Commit 4e744a8

Browse files
fix(drivers-prompt-anthropic): make top_p and top_k optional to avoid API conflict (#2070)
* fix(drivers-prompt-anthropic): make top_p and top_k optional to avoid API conflict Newer Anthropic models (Claude 4.6+) reject requests that specify both `temperature` and `top_p`. Change `top_p` and `top_k` to Optional fields defaulting to None, and only include them in the API params when explicitly set. This follows Anthropic's recommendation to use `temperature` alone for most use cases. Made-with: Cursor * fix: exclude temperature from API params when top_p is explicitly set When top_p is provided, temperature is now omitted from the Anthropic API request to avoid the mutual exclusivity conflict on newer models. This ensures callers who prefer top_p-based sampling don't hit a 400. Made-with: Cursor
1 parent c3b8f01 commit 4e744a8

File tree

3 files changed

+40
-11
lines changed

3 files changed

+40
-11
lines changed

griptape/drivers/prompt/anthropic_prompt_driver.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ class AnthropicPromptDriver(BasePromptDriver):
6565
default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True),
6666
kw_only=True,
6767
)
68-
top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True})
69-
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
68+
top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
69+
top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
7070
tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
7171
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
7272
structured_output_strategy: StructuredOutputStrategy = field(
@@ -123,12 +123,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
123123

124124
params = {
125125
"model": self.model,
126-
"temperature": self.temperature,
127126
"stop_sequences": self.tokenizer.stop_sequences,
128-
"top_p": self.top_p,
129-
"top_k": self.top_k,
130127
"max_tokens": self.max_tokens,
131128
"messages": messages,
129+
**({"top_p": self.top_p} if self.top_p is not None else {"temperature": self.temperature}),
130+
**({"top_k": self.top_k} if self.top_k is not None else {}),
132131
**({"system": system_message} if system_message else {}),
133132
**self.extra_params,
134133
}

tests/unit/configs/drivers/test_anthropic_drivers_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def test_to_dict(self, config):
2222
"max_tokens": 1000,
2323
"stream": False,
2424
"model": "claude-3-7-sonnet-latest",
25-
"top_p": 0.999,
26-
"top_k": 250,
25+
"top_p": None,
26+
"top_k": None,
2727
"use_native_tools": True,
2828
"structured_output_strategy": "tool",
2929
"extra_params": {},

tests/unit/drivers/prompt/test_anthropic_prompt_driver.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,6 @@ def test_try_run(self, mock_client, prompt_stack, messages, use_native_tools, st
382382
model=driver.model,
383383
max_tokens=1000,
384384
temperature=0.1,
385-
top_p=0.999,
386-
top_k=250,
387385
**{"system": "system-input"} if prompt_stack.system_messages else {},
388386
**{
389387
"tools": self.ANTHROPIC_TOOLS if use_native_tools else {},
@@ -430,8 +428,6 @@ def test_try_stream_run(
430428
max_tokens=1000,
431429
temperature=0.1,
432430
stream=True,
433-
top_p=0.999,
434-
top_k=250,
435431
**{"system": "system-input"} if prompt_stack.system_messages else {},
436432
**{
437433
"tools": self.ANTHROPIC_TOOLS if use_native_tools else {},
@@ -464,6 +460,40 @@ def test_try_stream_run(
464460
event = next(stream)
465461
assert event.usage.output_tokens == 10
466462

463+
def test_try_run_with_top_p_and_top_k(self, mock_client, prompt_stack, messages):
464+
# Given
465+
driver = AnthropicPromptDriver(
466+
model="claude-3-haiku",
467+
api_key="api-key",
468+
top_p=0.9,
469+
top_k=100,
470+
)
471+
472+
# When
473+
driver.try_run(prompt_stack)
474+
475+
# Then
476+
call_kwargs = mock_client.return_value.messages.create.call_args
477+
assert call_kwargs.kwargs["top_p"] == 0.9
478+
assert call_kwargs.kwargs["top_k"] == 100
479+
assert "temperature" not in call_kwargs.kwargs
480+
481+
def test_try_run_without_top_p_and_top_k(self, mock_client, prompt_stack, messages):
482+
# Given
483+
driver = AnthropicPromptDriver(
484+
model="claude-3-haiku",
485+
api_key="api-key",
486+
)
487+
488+
# When
489+
driver.try_run(prompt_stack)
490+
491+
# Then
492+
call_kwargs = mock_client.return_value.messages.create.call_args
493+
assert "top_p" not in call_kwargs.kwargs
494+
assert "top_k" not in call_kwargs.kwargs
495+
assert call_kwargs.kwargs["temperature"] == 0.1
496+
467497
def test_verify_structured_output_strategy(self):
468498
assert AnthropicPromptDriver(model="foo", structured_output_strategy="tool")
469499

0 commit comments

Comments
 (0)