Skip to content

Commit f4bc8e1

Browse files
feat: address PR comments
1 parent e625887 commit f4bc8e1

22 files changed

Lines changed: 230 additions & 102 deletions

src/uipath_langchain/agent/react/llm_node.py

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,22 @@
11
"""LLM node for ReAct Agent graph."""
22

3-
from typing import Any, Sequence
3+
from typing import Sequence
44

55
from langchain_core.language_models import BaseChatModel
66
from langchain_core.messages import AIMessage, AnyMessage, ToolCall
77
from langchain_core.tools import BaseTool
88
from uipath.runtime.errors import UiPathErrorCategory, UiPathErrorCode
99

10+
from uipath_langchain.llm import get_payload_handler
11+
1012
from ..exceptions import AgentTerminationException
1113
from .constants import (
1214
DEFAULT_MAX_CONSECUTIVE_THINKING_MESSAGES,
1315
DEFAULT_MAX_LLM_MESSAGES,
1416
)
1517
from .types import FLOW_CONTROL_TOOLS, AgentGraphState
16-
from uipath_langchain.chat.types import APIFlavor
17-
18-
from .constants import MAX_CONSECUTIVE_THINKING_MESSAGES
19-
from .types import AgentGraphState
2018
from .utils import count_consecutive_thinking_messages
2119

22-
OPENAI_COMPATIBLE_CHAT_MODELS = (
23-
"UiPathChatOpenAI",
24-
"AzureChatOpenAI",
25-
"ChatOpenAI",
26-
"UiPathChat",
27-
"UiPathAzureChatOpenAI",
28-
)
29-
30-
31-
def _get_required_tool_choice_by_model(
32-
model: BaseChatModel,
33-
) -> str | dict[str, Any]:
34-
"""Get the appropriate tool_choice value to enforce tool usage based on model type.
35-
36-
Returns:
37-
- "required" for OpenAI compatible models
38-
- "any" for Bedrock Converse and Vertex models (string format)
39-
- {"type": "any"} for Bedrock Invoke API (dict format required)
40-
"""
41-
model_class_name = model.__class__.__name__
42-
if model_class_name in OPENAI_COMPATIBLE_CHAT_MODELS:
43-
return "required"
44-
45-
api_flavor = getattr(model, "api_flavor", None)
46-
if api_flavor == APIFlavor.AWS_BEDROCK_INVOKE:
47-
return {"type": "any"}
48-
49-
return "any"
50-
5120

5221
def _filter_control_flow_tool_calls(
5322
tool_calls: list[ToolCall],
@@ -81,7 +50,8 @@ def create_llm_node(
8150
"""
8251
bindable_tools = list(tools) if tools else []
8352
base_llm = model.bind_tools(bindable_tools) if bindable_tools else model
84-
tool_choice_required_value = _get_required_tool_choice_by_model(model)
53+
payload_handler = get_payload_handler(model)
54+
tool_choice_required_value = payload_handler.get_required_tool_choice()
8555

8656
async def llm_node(state: AgentGraphState):
8757
messages: list[AnyMessage] = state.messages

src/uipath_langchain/chat/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __getattr__(name):
2929
from . import supported_models
3030

3131
return getattr(supported_models, name)
32-
if name in ("LLMProvider", "APIFlavor"):
32+
if name in ("LLMProvider", "APIFlavor", "UiPathChatModel"):
3333
from . import types
3434

3535
return getattr(types, name)
@@ -40,6 +40,7 @@ def __getattr__(name):
4040
"UiPathChat",
4141
"UiPathAzureChatOpenAI",
4242
"UiPathChatOpenAI",
43+
"UiPathChatModel",
4344
"OpenAIModels",
4445
"BedrockModels",
4546
"GeminiModels",

src/uipath_langchain/chat/models.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@
2929
class UiPathAzureChatOpenAI(UiPathRequestMixin, AzureChatOpenAI):
3030
"""Custom LLM connector for LangChain integration with UiPath."""
3131

32-
llm_provider: LLMProvider = LLMProvider.OPENAI
33-
api_flavor: APIFlavor = APIFlavor.OPENAI_COMPLETIONS
34-
3532
def _generate(
3633
self,
3734
messages: list[BaseMessage],
@@ -174,9 +171,6 @@ def endpoint(self) -> str:
174171
class UiPathChat(UiPathRequestMixin, AzureChatOpenAI):
175172
"""Custom LLM connector for LangChain integration with UiPath Normalized."""
176173

177-
llm_provider: LLMProvider = LLMProvider.OPENAI
178-
api_flavor: APIFlavor = APIFlavor.OPENAI_COMPLETIONS
179-
180174
def _create_chat_result(
181175
self,
182176
response: Union[dict[str, Any], BaseModel],

src/uipath_langchain/chat/types.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from enum import StrEnum
2+
from typing import Protocol, runtime_checkable
23

34

45
class LLMProvider(StrEnum):
@@ -18,3 +19,23 @@ class APIFlavor(StrEnum):
1819
AWS_BEDROCK_INVOKE = "AwsBedrockInvoke"
1920
VERTEX_GEMINI_GENERATE_CONTENT = "GeminiGenerateContent"
2021
VERTEX_ANTHROPIC_CLAUDE = "AnthropicClaude"
22+
23+
24+
@runtime_checkable
25+
class UiPathChatModel(Protocol):
26+
"""Protocol for UiPath chat models with provider and flavor information.
27+
28+
All UiPath chat model classes (UiPathChatOpenAI, UiPathChatBedrock,
29+
UiPathChatBedrockConverse, UiPathChatVertex, UiPathChat, UiPathAzureChatOpenAI)
30+
implement this protocol.
31+
"""
32+
33+
@property
34+
def llm_provider(self) -> LLMProvider:
35+
"""The LLM provider for this model."""
36+
...
37+
38+
@property
39+
def api_flavor(self) -> APIFlavor:
40+
"""The API flavor for this model."""
41+
...
Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
"""Common LLM provider utilities."""
22

3+
from uipath_langchain.chat.types import UiPathChatModel
4+
35
from .builders import MessageContentBuilder
46
from .content_builder import get_content_builder
5-
from .provider import get_api_flavor, get_llm_provider
7+
from .handlers import ModelPayloadHandler
8+
from .payload_handler import get_payload_handler
9+
from .utils import get_api_flavor, get_llm_provider
610

711
__all__ = [
812
"MessageContentBuilder",
13+
"ModelPayloadHandler",
14+
"UiPathChatModel",
915
"get_api_flavor",
1016
"get_content_builder",
1117
"get_llm_provider",
18+
"get_payload_handler",
1219
]

src/uipath_langchain/llm/builders/base.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
"""Abstract base class for message content builders."""
22

3-
import base64
43
import re
54
from abc import ABC, abstractmethod
65
from typing import Any
76

8-
import httpx
9-
from uipath._utils._ssl_context import get_httpx_client_kwargs
10-
117
IMAGE_MIME_TYPES: set[str] = {
128
"image/png",
139
"image/jpeg",
@@ -38,20 +34,6 @@ def sanitize_filename_for_anthropic(filename: str) -> str:
3834
return sanitized if sanitized else "document"
3935

4036

41-
async def download_file_bytes(url: str) -> bytes:
42-
"""Download a file from a URL and return its content bytes."""
43-
async with httpx.AsyncClient(**get_httpx_client_kwargs()) as client:
44-
response = await client.get(url)
45-
response.raise_for_status()
46-
return response.content
47-
48-
49-
async def download_file_base64(url: str) -> str:
50-
"""Download a file from a URL and return its content as a base64 string."""
51-
file_content = await download_file_bytes(url)
52-
return base64.b64encode(file_content).decode("utf-8")
53-
54-
5537
class MessageContentBuilder(ABC):
5638
"""Abstract base class for building provider-specific message content parts."""
5739

src/uipath_langchain/llm/builders/bedrock_converse.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44

55
from .base import (
66
MessageContentBuilder,
7-
download_file_base64,
8-
download_file_bytes,
97
is_image,
108
is_pdf,
119
sanitize_filename_for_anthropic,
1210
)
11+
from .download import download_file_base64, download_file_bytes
1312

1413

1514
class BedrockConverseBuilder(MessageContentBuilder):

src/uipath_langchain/llm/builders/bedrock_invoke.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,8 @@
22

33
from typing import Any
44

5-
from .base import (
6-
MessageContentBuilder,
7-
download_file_base64,
8-
is_image,
9-
is_pdf,
10-
)
5+
from .base import MessageContentBuilder, is_image, is_pdf
6+
from .download import download_file_base64
117

128

139
class BedrockInvokeBuilder(MessageContentBuilder):
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Utilities for downloading files from URLs."""
2+
3+
import base64
4+
5+
import httpx
6+
from uipath._utils._ssl_context import get_httpx_client_kwargs
7+
8+
9+
async def download_file_bytes(url: str) -> bytes:
10+
"""Download a file from a URL and return its content bytes."""
11+
async with httpx.AsyncClient(**get_httpx_client_kwargs()) as client:
12+
response = await client.get(url)
13+
response.raise_for_status()
14+
return response.content
15+
16+
17+
async def download_file_base64(url: str) -> str:
18+
"""Download a file from a URL and return its content as a base64 string."""
19+
file_content = await download_file_bytes(url)
20+
return base64.b64encode(file_content).decode("utf-8")

src/uipath_langchain/llm/builders/openai_completions.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,8 @@
22

33
from typing import Any
44

5-
from .base import (
6-
MessageContentBuilder,
7-
download_file_base64,
8-
is_image,
9-
is_pdf,
10-
)
5+
from .base import MessageContentBuilder, is_image, is_pdf
6+
from .download import download_file_base64
117

128

139
class OpenAICompletionsBuilder(MessageContentBuilder):

0 commit comments

Comments
 (0)