|
1 | 1 | """LLM node for ReAct Agent graph.""" |
2 | 2 |
|
3 | | -from typing import Any, Sequence |
| 3 | +from typing import Sequence |
4 | 4 |
|
5 | 5 | from langchain_core.language_models import BaseChatModel |
6 | 6 | from langchain_core.messages import AIMessage, AnyMessage, ToolCall |
7 | 7 | from langchain_core.tools import BaseTool |
8 | 8 | from uipath.runtime.errors import UiPathErrorCategory, UiPathErrorCode |
9 | 9 |
|
| 10 | +from uipath_langchain.llm import get_payload_handler |
| 11 | + |
10 | 12 | from ..exceptions import AgentTerminationException |
11 | 13 | from .constants import ( |
12 | 14 | DEFAULT_MAX_CONSECUTIVE_THINKING_MESSAGES, |
13 | 15 | DEFAULT_MAX_LLM_MESSAGES, |
14 | 16 | ) |
15 | 17 | from .types import FLOW_CONTROL_TOOLS, AgentGraphState |
16 | | -from uipath_langchain.chat.types import APIFlavor |
17 | | - |
18 | | -from .constants import MAX_CONSECUTIVE_THINKING_MESSAGES |
19 | | -from .types import AgentGraphState |
20 | 18 | from .utils import count_consecutive_thinking_messages |
21 | 19 |
|
22 | | -OPENAI_COMPATIBLE_CHAT_MODELS = ( |
23 | | - "UiPathChatOpenAI", |
24 | | - "AzureChatOpenAI", |
25 | | - "ChatOpenAI", |
26 | | - "UiPathChat", |
27 | | - "UiPathAzureChatOpenAI", |
28 | | -) |
29 | | - |
30 | | - |
31 | | -def _get_required_tool_choice_by_model( |
32 | | - model: BaseChatModel, |
33 | | -) -> str | dict[str, Any]: |
34 | | - """Get the appropriate tool_choice value to enforce tool usage based on model type. |
35 | | -
|
36 | | - Returns: |
37 | | - - "required" for OpenAI compatible models |
38 | | - - "any" for Bedrock Converse and Vertex models (string format) |
39 | | - - {"type": "any"} for Bedrock Invoke API (dict format required) |
40 | | - """ |
41 | | - model_class_name = model.__class__.__name__ |
42 | | - if model_class_name in OPENAI_COMPATIBLE_CHAT_MODELS: |
43 | | - return "required" |
44 | | - |
45 | | - api_flavor = getattr(model, "api_flavor", None) |
46 | | - if api_flavor == APIFlavor.AWS_BEDROCK_INVOKE: |
47 | | - return {"type": "any"} |
48 | | - |
49 | | - return "any" |
50 | | - |
51 | 20 |
|
52 | 21 | def _filter_control_flow_tool_calls( |
53 | 22 | tool_calls: list[ToolCall], |
@@ -81,7 +50,8 @@ def create_llm_node( |
81 | 50 | """ |
82 | 51 | bindable_tools = list(tools) if tools else [] |
83 | 52 | base_llm = model.bind_tools(bindable_tools) if bindable_tools else model |
84 | | - tool_choice_required_value = _get_required_tool_choice_by_model(model) |
| 53 | + payload_handler = get_payload_handler(model) |
| 54 | + tool_choice_required_value = payload_handler.get_required_tool_choice() |
85 | 55 |
|
86 | 56 | async def llm_node(state: AgentGraphState): |
87 | 57 | messages: list[AnyMessage] = state.messages |
|
0 commit comments