Skip to content

Commit a8aa3db

Browse files
feat: add file support for more LLM API Flavors
1 parent a39cee7 commit a8aa3db

8 files changed

Lines changed: 260 additions & 65 deletions

File tree

src/uipath_langchain/agent/react/file_type_handler.py

Lines changed: 200 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import base64
22
import re
3-
from enum import StrEnum
43
from typing import Any
54

65
import httpx
6+
from langchain_core.language_models import BaseChatModel
77
from uipath._utils._ssl_context import get_httpx_client_kwargs
88

9+
from uipath_langchain.chat.types import APIFlavor, LLMProvider
10+
911
IMAGE_MIME_TYPES: set[str] = {
1012
"image/png",
1113
"image/jpeg",
@@ -14,13 +16,6 @@
1416
}
1517

1618

17-
class LlmProvider(StrEnum):
18-
OPENAI = "openai"
19-
BEDROCK = "bedrock"
20-
VERTEX = "vertex"
21-
UNKNOWN = "unknown"
22-
23-
2419
def is_pdf(mime_type: str) -> bool:
2520
"""Check if the MIME type represents a PDF document."""
2621
return mime_type.lower() == "application/pdf"
@@ -31,23 +26,28 @@ def is_image(mime_type: str) -> bool:
3126
return mime_type.lower() in IMAGE_MIME_TYPES
3227

3328

34-
def detect_provider(model_name: str) -> LlmProvider:
35-
"""Detect the LLM provider (Bedrock, OpenAI, or Vertex) based on the model name."""
36-
if not model_name:
37-
raise ValueError(f"Unsupported model: {model_name}")
29+
def get_llm_provider(model: BaseChatModel) -> LLMProvider:
30+
"""Get the LLM provider from a model instance.
3831
39-
model_lower = model_name.lower()
32+
Checks for the llm_provider attribute on UiPath LLM classes first,
33+
then falls back to model name string matching for other models.
34+
"""
35+
if hasattr(model, "llm_provider") and isinstance(model.llm_provider, LLMProvider):
36+
return model.llm_provider
4037

41-
if "anthropic" in model_lower or "claude" in model_lower:
42-
return LlmProvider.BEDROCK
38+
raise ValueError(
39+
f"Can't determine llm_provider in file_type_handler for model={model}"
40+
)
4341

44-
if "gpt" in model_lower:
45-
return LlmProvider.OPENAI
4642

47-
if "gemini" in model_lower:
48-
return LlmProvider.VERTEX
43+
def get_api_flavor(model: BaseChatModel) -> APIFlavor:
44+
"""Get the API flavor from a model instance."""
45+
if hasattr(model, "api_flavor") and isinstance(model.api_flavor, APIFlavor):
46+
return model.api_flavor
4947

50-
raise ValueError(f"Unsupported model: {model_name}")
48+
raise ValueError(
49+
f"Can't determine api_flavor in file_type_handler for model={model}"
50+
)
5151

5252

5353
def sanitize_filename_for_anthropic(filename: str) -> str:
@@ -85,24 +85,30 @@ async def build_message_content_part_from_data(
8585
url: str,
8686
filename: str,
8787
mime_type: str,
88-
model: str,
88+
model: BaseChatModel,
8989
) -> dict[str, Any]:
9090
"""Download a file and build a provider-specific message content part.
9191
9292
The format varies based on the detected provider (Bedrock, OpenAI, or Vertex).
93+
Uses model.llm_provider and model.api_flavor attributes.
9394
"""
94-
provider = detect_provider(model)
95+
provider = get_llm_provider(model)
96+
api_flavor = get_api_flavor(model)
9597

96-
if provider == LlmProvider.BEDROCK:
97-
return await _build_bedrock_content_part_from_data(url, mime_type, filename)
98+
if provider == LLMProvider.BEDROCK:
99+
return await _build_bedrock_content_part_from_data(
100+
url, mime_type, filename, api_flavor
101+
)
98102

99-
if provider == LlmProvider.OPENAI:
103+
if provider == LLMProvider.OPENAI:
100104
return await _build_openai_content_part_from_data(
101-
url, mime_type, filename, False
105+
url, mime_type, filename, True, api_flavor
102106
)
103107

104-
if provider == LlmProvider.VERTEX:
105-
return await _build_vertex_content_part_from_data(url, mime_type, False)
108+
if provider == LLMProvider.VERTEX:
109+
return await _build_vertex_content_part_from_data(
110+
url, mime_type, True, api_flavor
111+
)
106112

107113
raise ValueError(f"Unsupported provider: {provider}")
108114

@@ -111,8 +117,26 @@ async def _build_bedrock_content_part_from_data(
111117
url: str,
112118
mime_type: str,
113119
filename: str,
120+
api_flavor: APIFlavor,
114121
) -> dict[str, Any]:
115-
"""Build a content part for AWS Bedrock (Anthropic Claude models)."""
122+
"""Build a content part for AWS Bedrock (Anthropic Claude models).
123+
124+
Converse API uses raw bytes, Invoke API uses base64-encoded content.
125+
"""
126+
if api_flavor == APIFlavor.AWS_BEDROCK_CONVERSE:
127+
return await _build_bedrock_converse_content_part(url, mime_type, filename)
128+
elif api_flavor == APIFlavor.AWS_BEDROCK_INVOKE:
129+
return await _build_bedrock_invoke_content_part(url, mime_type, filename)
130+
else:
131+
raise ValueError(f"Unsupported Bedrock api_flavor: {api_flavor}")
132+
133+
134+
async def _build_bedrock_converse_content_part(
135+
url: str,
136+
mime_type: str,
137+
filename: str,
138+
) -> dict[str, Any]:
139+
"""Build content part for Bedrock Converse API (uses raw bytes)."""
116140
if is_pdf(mime_type):
117141
file_bytes = await _download_file_bytes(url)
118142
name = filename.rsplit(".", 1)[0] if "." in filename else filename
@@ -143,39 +167,109 @@ async def _build_bedrock_content_part_from_data(
143167
raise ValueError(f"Unsupported mime_type: {mime_type}")
144168

145169

170+
async def _build_bedrock_invoke_content_part(
171+
url: str,
172+
mime_type: str,
173+
filename: str,
174+
) -> dict[str, Any]:
175+
"""Build content part for Bedrock Invoke API (uses base64-encoded content)."""
176+
base64_content = await _download_file(url)
177+
178+
if is_pdf(mime_type):
179+
return {
180+
"type": "document",
181+
"source": {
182+
"type": "base64",
183+
"media_type": mime_type,
184+
"data": base64_content,
185+
},
186+
}
187+
188+
if is_image(mime_type):
189+
return {
190+
"type": "image",
191+
"source": {
192+
"type": "base64",
193+
"media_type": mime_type,
194+
"data": base64_content,
195+
},
196+
}
197+
198+
raise ValueError(f"Unsupported mime_type: {mime_type}")
199+
200+
146201
async def _build_openai_content_part_from_data(
147202
url: str,
148203
mime_type: str,
149204
filename: str,
150-
download_image: bool,
205+
download_file: bool,
206+
api_flavor: APIFlavor,
151207
) -> dict[str, Any]:
152-
"""Build a content part for OpenAI models (base64-encoded or URL reference)."""
153-
if download_image:
154-
base64_content = await _download_file(url)
155-
if is_image(mime_type):
156-
data_url = f"data:{mime_type};base64,{base64_content}"
157-
return {
158-
"type": "input_image",
159-
"image_url": data_url,
160-
}
208+
"""Build a content part for OpenAI models"""
209+
if api_flavor == APIFlavor.OPENAI_RESPONSES:
210+
return await _build_openai_responses_content_part(
211+
url, mime_type, filename, download_file
212+
)
213+
elif api_flavor == APIFlavor.OPENAI_COMPLETIONS:
214+
return await _build_openai_completions_content_part(
215+
url, mime_type, filename, download_file
216+
)
217+
else:
218+
raise ValueError(f"Unsupported OpenAI api_flavor: {api_flavor}")
219+
220+
221+
async def _build_openai_responses_content_part(
222+
url: str,
223+
mime_type: str,
224+
filename: str,
225+
download_file: bool,
226+
) -> dict[str, Any]:
227+
"""Build content part for OpenAI Responses API."""
228+
if download_file:
229+
return await _build_openai_responses_downloaded(url, mime_type, filename)
230+
return _build_openai_responses_from_url(url, mime_type)
161231

162-
if is_pdf(mime_type):
163-
data = f"data:application/pdf;base64,{base64_content}"
164-
return {
165-
"type": "file",
166-
"file": {
167-
"filename": filename,
168-
"file_data": data,
169-
},
170-
}
171232

172-
elif is_image(mime_type):
233+
async def _build_openai_responses_downloaded(
234+
url: str,
235+
mime_type: str,
236+
filename: str,
237+
) -> dict[str, Any]:
238+
"""Build content part for OpenAI Responses API with downloaded file."""
239+
base64_content = await _download_file(url)
240+
241+
if is_image(mime_type):
242+
data_url = f"data:{mime_type};base64,{base64_content}"
243+
return {
244+
"type": "input_image",
245+
"image_url": data_url,
246+
}
247+
248+
if is_pdf(mime_type):
249+
data = f"data:application/pdf;base64,{base64_content}"
250+
return {
251+
"type": "file",
252+
"file": {
253+
"filename": filename,
254+
"file_data": data,
255+
},
256+
}
257+
258+
raise ValueError(f"Unsupported mime_type: {mime_type}")
259+
260+
261+
def _build_openai_responses_from_url(
262+
url: str,
263+
mime_type: str,
264+
) -> dict[str, Any]:
265+
"""Build content part for OpenAI Responses API with URL reference."""
266+
if is_image(mime_type):
173267
return {
174268
"type": "input_image",
175269
"image_url": url,
176270
}
177271

178-
elif is_pdf(mime_type):
272+
if is_pdf(mime_type):
179273
return {
180274
"type": "input_file",
181275
"file_url": url,
@@ -184,12 +278,68 @@ async def _build_openai_content_part_from_data(
184278
raise ValueError(f"Unsupported mime_type: {mime_type}")
185279

186280

281+
async def _build_openai_completions_content_part(
282+
url: str,
283+
mime_type: str,
284+
filename: str,
285+
download_file: bool,
286+
) -> dict[str, Any]:
287+
"""Build content part for OpenAI Completions API."""
288+
if download_file:
289+
return await _build_openai_completions_downloaded(url, mime_type, filename)
290+
return await _build_openai_completions_from_url(url, mime_type, filename)
291+
292+
293+
async def _build_openai_completions_downloaded(
294+
url: str,
295+
mime_type: str,
296+
filename: str,
297+
) -> dict[str, Any]:
298+
"""Build content part for OpenAI Completions API with downloaded file."""
299+
base64_content = await _download_file(url)
300+
301+
if is_image(mime_type):
302+
data_url = f"data:{mime_type};base64,{base64_content}"
303+
return {
304+
"type": "image_url",
305+
"image_url": {"url": data_url},
306+
}
307+
308+
if is_pdf(mime_type):
309+
raise ValueError("PDFs are not supported when using the OpenAi completions API")
310+
311+
raise ValueError(f"Unsupported mime_type: {mime_type}")
312+
313+
314+
async def _build_openai_completions_from_url(
315+
url: str,
316+
mime_type: str,
317+
filename: str,
318+
) -> dict[str, Any]:
319+
"""Build content part for OpenAI Completions API with URL reference."""
320+
if is_image(mime_type):
321+
return {
322+
"type": "image_url",
323+
"image_url": {"url": url},
324+
}
325+
326+
if is_pdf(mime_type):
327+
raise ValueError("PDFs are not supported when using the OpenAi completions API")
328+
329+
raise ValueError(f"Unsupported mime_type: {mime_type}")
330+
331+
187332
async def _build_vertex_content_part_from_data(
188333
url: str,
189334
mime_type: str,
190335
download_file: bool,
336+
api_flavor: APIFlavor,
191337
) -> dict[str, Any]:
192338
"""Build a content part for Google Vertex AI / Gemini models."""
339+
340+
if api_flavor != APIFlavor.VERTEX_GEMINI_GENERATE_CONTENT:
341+
raise ValueError(f"Unsupported api_flavor={api_flavor} for building file content parts")
342+
193343
if download_file:
194344
base64_content = await _download_file(url)
195345
if is_image(mime_type) or is_pdf(mime_type):

src/uipath_langchain/agent/react/llm_node.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""LLM node for ReAct Agent graph."""
22

3-
from typing import Literal, Sequence
3+
from typing import Any, Sequence
44

55
from langchain_core.language_models import BaseChatModel
66
from langchain_core.messages import AIMessage, AnyMessage
77
from langchain_core.tools import BaseTool
88

9+
from uipath_langchain.chat.types import APIFlavor
10+
911
from .constants import MAX_CONSECUTIVE_THINKING_MESSAGES
1012
from .types import AgentGraphState
1113
from .utils import count_consecutive_thinking_messages
@@ -21,15 +23,22 @@
2123

2224
def _get_required_tool_choice_by_model(
2325
model: BaseChatModel,
24-
) -> Literal["required", "any"]:
26+
) -> str | dict[str, Any]:
2527
"""Get the appropriate tool_choice value to enforce tool usage based on model type.
2628
27-
"required" - OpenAI compatible required tool_choice value
28-
"any" - Vertex and Bedrock parameter for required tool_choice value
29+
Returns:
30+
- "required" for OpenAI compatible models
31+
- "any" for Bedrock Converse and Vertex models (string format)
32+
- {"type": "any"} for Bedrock Invoke API (dict format required)
2933
"""
3034
model_class_name = model.__class__.__name__
3135
if model_class_name in OPENAI_COMPATIBLE_CHAT_MODELS:
3236
return "required"
37+
38+
api_flavor = getattr(model, "api_flavor", None)
39+
if api_flavor == APIFlavor.AWS_BEDROCK_INVOKE:
40+
return {"type": "any"}
41+
3342
return "any"
3443

3544

0 commit comments

Comments
 (0)