Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ def _prepare_request_params(
"""
# update generation kwargs by merging with the generation kwargs passed to the run method
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
generation_kwargs = self._resolve_flattened_generation_kwargs(generation_kwargs)

disallowed_params = set(generation_kwargs) - set(self.ALLOWED_PARAMS)
if disallowed_params:
logger.warning(
Expand Down Expand Up @@ -275,6 +277,32 @@ def _prepare_request_params(

return system_messages, non_system_messages, generation_kwargs, anthropic_tools

def _resolve_flattened_generation_kwargs(self, generation_kwargs: dict[str, Any]) -> dict[str, Any]:
generation_kwargs = generation_kwargs.copy()
if "disable_parallel_tool_use" in generation_kwargs:
disable_parallel_tool_use = generation_kwargs.pop("disable_parallel_tool_use")
tool_choice = generation_kwargs.setdefault("tool_choice", {})
tool_choice["disable_parallel_tool_use"] = disable_parallel_tool_use

if "parallel_tool_use" in generation_kwargs:
parallel_tool_use = generation_kwargs.pop("parallel_tool_use")
disable_parallel_tool_use = not parallel_tool_use
tool_choice = generation_kwargs.setdefault("tool_choice", {})
tool_choice["disable_parallel_tool_use"] = disable_parallel_tool_use

if "tool_choice_type" in generation_kwargs:
tool_choice_type = generation_kwargs.pop("tool_choice_type")
tool_choice = generation_kwargs.setdefault("tool_choice", {})
tool_choice["type"] = tool_choice_type

if "thinking_budget_tokens" in generation_kwargs:
thinking_budget_tokens = generation_kwargs.pop("thinking_budget_tokens")
thinking = generation_kwargs.setdefault("thinking", {})
thinking["budget_tokens"] = thinking_budget_tokens
thinking["type"] = "enabled"

return generation_kwargs

def _process_response(
self,
response: Message | Stream[RawMessageStreamEvent],
Expand Down
21 changes: 21 additions & 0 deletions integrations/anthropic/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,27 @@ def test_run_with_params(self, chat_messages, mock_anthropic_completion):
assert response["replies"][0].meta["model"] == "claude-sonnet-4-5"
assert response["replies"][0].meta["finish_reason"] == "stop"

def test_run_with_flattened_generation_kwargs(self, chat_messages, mock_anthropic_completion):
"""
Test that the AnthropicChatGenerator component can run with parameters.
"""
component = AnthropicChatGenerator(
api_key=Secret.from_token("test-api-key"),
generation_kwargs={
"max_tokens": 10,
"thinking_budget_tokens": 1024,
"parallel_tool_use": False,
"tool_choice_type": "any",
},
)
component.run(chat_messages)

# Check that the component calls the Anthropic API with the correct parameters
_, kwargs = mock_anthropic_completion.call_args
assert kwargs["max_tokens"] == 10
assert kwargs["thinking"] == {"budget_tokens": 1024, "type": "enabled"}
assert kwargs["tool_choice"] == {"disable_parallel_tool_use": True, "type": "any"}

def test_check_duplicate_tool_names(self, tools):
"""Test that the AnthropicChatGenerator component fails to initialize with duplicate tool names."""
with pytest.raises(ValueError):
Expand Down
Loading