Skip to content

Commit f6af949

Browse files
author
Krrish Dholakia
committed
feat: transformation.py
proper responses api tool handling for guardrail translation layer
1 parent 409c3c5 commit f6af949

File tree

4 files changed

+183
-28
lines changed

4 files changed

+183
-28
lines changed

litellm/llms/openai/responses/guardrail_translation/handler.py

Lines changed: 75 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,11 @@ async def process_input_messages(
9696
# Handle simple string input
9797
if isinstance(input_data, str):
9898
inputs = GenericGuardrailAPIInputs(texts=[input_data])
99+
original_tools: List[Dict[str, Any]] = []
99100

100101
# Extract and transform tools if present
101-
102102
if "tools" in data and data["tools"]:
103+
original_tools = list(data["tools"])
103104
self._extract_and_transform_tools(data["tools"], tools_to_check)
104105
if tools_to_check:
105106
inputs["tools"] = tools_to_check
@@ -118,9 +119,9 @@ async def process_input_messages(
118119
)
119120
guardrailed_texts = guardrailed_inputs.get("texts", [])
120121
data["input"] = guardrailed_texts[0] if guardrailed_texts else input_data
121-
guardrailed_tools = guardrailed_inputs.get("tools")
122-
if guardrailed_tools is not None:
123-
data["tools"] = guardrailed_tools
122+
self._apply_guardrailed_tools_to_data(
123+
data, original_tools, guardrailed_inputs.get("tools")
124+
)
124125
verbose_proxy_logger.debug("OpenAI Responses API: Processed string input")
125126
return data
126127

@@ -131,8 +132,7 @@ async def process_input_messages(
131132
texts_to_check: List[str] = []
132133
images_to_check: List[str] = []
133134
task_mappings: List[Tuple[int, Optional[int]]] = []
134-
# Track (message_index, content_index) for each text
135-
# content_index is None for string content, int for list content
135+
original_tools_list: List[Dict[str, Any]] = list(data.get("tools") or [])
136136

137137
# Step 1: Extract all text content, images, and tools
138138
for msg_idx, message in enumerate(input_data):
@@ -169,9 +169,11 @@ async def process_input_messages(
169169
)
170170

171171
guardrailed_texts = guardrailed_inputs.get("texts", [])
172-
guardrailed_tools = guardrailed_inputs.get("tools")
173-
if guardrailed_tools is not None:
174-
data["tools"] = guardrailed_tools
172+
self._apply_guardrailed_tools_to_data(
173+
data,
174+
original_tools_list,
175+
guardrailed_inputs.get("tools"),
176+
)
175177

176178
# Step 3: Map guardrail responses back to original input structure
177179
await self._apply_guardrail_responses_to_input(
@@ -209,6 +211,53 @@ def _extract_and_transform_tools(
209211
cast(List[ChatCompletionToolParam], transformed_tools)
210212
)
211213

214+
def _remap_tools_to_responses_api_format(
215+
self, guardrailed_tools: List[Any]
216+
) -> List[Dict[str, Any]]:
217+
"""
218+
Remap guardrail-returned tools (Chat Completion format) back to
219+
Responses API request tool format.
220+
"""
221+
return LiteLLMCompletionResponsesConfig.transform_chat_completion_tool_params_to_responses_api_tools(
222+
guardrailed_tools # type: ignore
223+
)
224+
225+
def _merge_tools_after_guardrail(
226+
self,
227+
original_tools: List[Dict[str, Any]],
228+
remapped: List[Dict[str, Any]],
229+
) -> List[Dict[str, Any]]:
230+
"""
231+
Merge remapped guardrailed tools with original tools that were not sent
232+
to the guardrail (e.g. web_search, web_search_preview), preserving order.
233+
"""
234+
if not original_tools:
235+
return remapped
236+
result: List[Dict[str, Any]] = []
237+
j = 0
238+
for tool in original_tools:
239+
if isinstance(tool, dict) and tool.get("type") in (
240+
"web_search",
241+
"web_search_preview",
242+
):
243+
result.append(tool)
244+
else:
245+
if j < len(remapped):
246+
result.append(remapped[j])
247+
j += 1
248+
return result
249+
250+
def _apply_guardrailed_tools_to_data(
251+
self,
252+
data: dict,
253+
original_tools: List[Dict[str, Any]],
254+
guardrailed_tools: Optional[List[Any]],
255+
) -> None:
256+
"""Remap guardrailed tools to Responses API format and merge with original, then set data['tools']."""
257+
if guardrailed_tools is not None:
258+
remapped = self._remap_tools_to_responses_api_format(guardrailed_tools)
259+
data["tools"] = self._merge_tools_after_guardrail(original_tools, remapped)
260+
212261
def _extract_input_text_and_images(
213262
self,
214263
message: Any, # Can be Dict[str, Any] or ResponseInputParam
@@ -413,7 +462,10 @@ async def process_output_streaming_response(
413462
List[ChatCompletionToolCallChunk], tool_calls
414463
)
415464
# Include model information if available
416-
if hasattr(model_response_stream, "model") and model_response_stream.model:
465+
if (
466+
hasattr(model_response_stream, "model")
467+
and model_response_stream.model
468+
):
417469
inputs["model"] = model_response_stream.model
418470
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
419471
inputs=inputs,
@@ -454,15 +506,21 @@ async def process_output_streaming_response(
454506
)
455507
return responses_so_far
456508
else:
457-
verbose_proxy_logger.debug("Skipping output guardrail - model response has no choices")
509+
verbose_proxy_logger.debug(
510+
"Skipping output guardrail - model response has no choices"
511+
)
458512
# model_response_stream = OpenAiResponsesToChatCompletionStreamIterator.translate_responses_chunk_to_openai_stream(final_chunk)
459513
# tool_calls = model_response_stream.choices[0].tool_calls
460514
# convert openai response to model response
461515
string_so_far = self.get_streaming_string_so_far(responses_so_far)
462516
inputs = GenericGuardrailAPIInputs(texts=[string_so_far])
463517
# Try to get model from the final chunk if available
464518
if isinstance(final_chunk, dict):
465-
response_model = final_chunk.get("response", {}).get("model") if isinstance(final_chunk.get("response"), dict) else None
519+
response_model = (
520+
final_chunk.get("response", {}).get("model")
521+
if isinstance(final_chunk.get("response"), dict)
522+
else None
523+
)
466524
if response_model:
467525
inputs["model"] = response_model
468526
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
@@ -597,8 +655,8 @@ def _extract_output_text_and_images(
597655
content = generic_response_output_item.content
598656
except Exception:
599657
# Try to extract content directly from output_item if validation fails
600-
if hasattr(output_item, "content") and output_item.content: # type: ignore
601-
content = output_item.content # type: ignore
658+
if hasattr(output_item, "content") and output_item.content: # type: ignore
659+
content = output_item.content # type: ignore
602660
else:
603661
return
604662
elif isinstance(output_item, dict):
@@ -675,10 +733,10 @@ async def _apply_guardrail_responses_to_output(
675733
if isinstance(content_item, OutputText):
676734
content_item.text = guardrail_response
677735
# Update the original response output
678-
if hasattr(output_item, "content") and output_item.content: # type: ignore
679-
original_content = output_item.content[content_idx] # type: ignore
736+
if hasattr(output_item, "content") and output_item.content: # type: ignore
737+
original_content = output_item.content[content_idx] # type: ignore
680738
if hasattr(original_content, "text"):
681-
original_content.text = guardrail_response # type: ignore
739+
original_content.text = guardrail_response # type: ignore
682740
except Exception:
683741
pass
684742
elif isinstance(output_item, dict):

litellm/proxy/_new_secret_config.yaml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,25 @@ model_list:
1313
- model_name: gpt-4.1-mini
1414
litellm_params:
1515
model: openai/gpt-4.1-mini
16+
- model_name: gpt-5-mini
17+
litellm_params:
18+
model: openai/gpt-5-mini
19+
20+
21+
guardrails:
22+
- guardrail_name: mcp-user-permissions
23+
litellm_params:
24+
guardrail: mcp_end_user_permission
25+
mode: pre_call
26+
default_on: true
27+
28+
mcp_servers:
29+
my_http_server:
30+
url: "http://0.0.0.0:8001/mcp"
31+
transport: "http"
32+
description: "My custom MCP server"
33+
available_on_public_internet: true
34+
35+
general_settings:
36+
store_model_in_db: true
37+
store_prompts_in_spend_logs: true

litellm/proxy/guardrails/guardrail_hooks/mcp_end_user_permission/mcp_end_user_permission.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- end_user_id, no mcp_servers → allow all (default)
1010
- end_user_id + mcp_servers → allow only those servers
1111
"""
12+
1213
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Type
1314

1415
from litellm._logging import verbose_proxy_logger
@@ -78,8 +79,10 @@ async def _check_request_tools(
7879
if not tools:
7980
return inputs
8081

81-
allowed_mcp_servers = await self._get_allowed_mcp_servers_from_object_permission(
82-
object_permission
82+
allowed_mcp_servers = (
83+
await self._get_allowed_mcp_servers_from_object_permission(
84+
object_permission
85+
)
8386
)
8487
if allowed_mcp_servers is None:
8588
return inputs # No restrictions → pass through unchanged
@@ -93,7 +96,9 @@ async def _check_request_tools(
9396

9497
for tool in tools:
9598
tool_name = self._get_tool_name_from_definition(tool)
96-
server_name = self._extract_mcp_server_name(tool_name) if tool_name else None
99+
server_name = (
100+
self._extract_mcp_server_name(tool_name) if tool_name else None
101+
)
97102

98103
if server_name is None:
99104
# Not an MCP tool (no prefix) or unrecognised format → keep
@@ -138,7 +143,9 @@ async def _resolve_end_user_object_permission(
138143
end_user_object = await MCPEndUserPermissionGuardrail._fetch_end_user_object(
139144
end_user_id
140145
)
141-
return end_user_object.object_permission if end_user_object is not None else None
146+
return (
147+
end_user_object.object_permission if end_user_object is not None else None
148+
)
142149

143150
@staticmethod
144151
def _get_end_user_id_from_request_data(request_data: dict) -> Optional[str]:

litellm/responses/litellm_completion_transformation/transformation.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, cast
77

88
from openai.types.responses import ResponseFunctionToolCall
9+
from openai.types.responses.response_create_params import ResponseInputParam
910
from openai.types.responses.tool_param import FunctionToolParam
1011
from typing_extensions import TypedDict
1112

@@ -32,7 +33,6 @@
3233
OpenAIWebSearchUserLocation,
3334
OutputTokensDetails,
3435
ResponseAPIUsage,
35-
ResponseInputParam,
3636
ResponsesAPIOptionalRequestParams,
3737
ResponsesAPIResponse,
3838
ResponsesAPIStatus,
@@ -738,9 +738,25 @@ def _add_tool_call_to_assistant(
738738

739739
@staticmethod
740740
def _ensure_tool_results_have_corresponding_tool_calls(
741-
messages: List[Union[AllMessageValues, GenericChatCompletionMessage, ChatCompletionResponseMessage]],
741+
messages: Sequence[
742+
Union[
743+
AllMessageValues,
744+
GenericChatCompletionMessage,
745+
ChatCompletionResponseMessage,
746+
ChatCompletionMessageToolCall,
747+
Message,
748+
]
749+
],
742750
tools: Optional[List[Any]] = None,
743-
) -> List[Union[AllMessageValues, GenericChatCompletionMessage, ChatCompletionResponseMessage]]:
751+
) -> List[
752+
Union[
753+
AllMessageValues,
754+
GenericChatCompletionMessage,
755+
ChatCompletionResponseMessage,
756+
ChatCompletionMessageToolCall,
757+
Message,
758+
]
759+
]:
744760
"""
745761
Ensure that tool_result messages have corresponding tool_calls in the previous assistant message.
746762
@@ -755,11 +771,19 @@ def _ensure_tool_results_have_corresponding_tool_calls(
755771
List of messages with tool_calls added to assistant messages when needed
756772
"""
757773
if not messages:
758-
return messages
759-
760-
# Create a deep copy to avoid modifying the original
774+
return list(messages)
775+
776+
# Create a deep copy to avoid modifying the original (use list() so we can mutate and return List)
761777
import copy
762-
fixed_messages = copy.deepcopy(messages)
778+
fixed_messages: List[
779+
Union[
780+
AllMessageValues,
781+
GenericChatCompletionMessage,
782+
ChatCompletionResponseMessage,
783+
ChatCompletionMessageToolCall,
784+
Message,
785+
]
786+
] = list(copy.deepcopy(messages))
763787
messages_to_remove = []
764788

765789
# Count non-tool messages to avoid removing all messages
@@ -1306,6 +1330,50 @@ def transform_responses_api_tools_to_chat_completion_tools(
13061330
chat_completion_tools.append(cast(Union[ChatCompletionToolParam, OpenAIMcpServerTool], tool))
13071331
return chat_completion_tools, web_search_options
13081332

1333+
@staticmethod
1334+
def transform_chat_completion_tool_params_to_responses_api_tools(
1335+
chat_completion_tools: Optional[
1336+
List[Union[ChatCompletionToolParam, OpenAIMcpServerTool]]
1337+
],
1338+
) -> List[Dict[str, Any]]:
1339+
"""
1340+
Transform Chat Completion tool params (e.g. from guardrail output) back to
1341+
Responses API request tool format. Inverse of
1342+
transform_responses_api_tools_to_chat_completion_tools for the tools list.
1343+
"""
1344+
if chat_completion_tools is None or not chat_completion_tools:
1345+
return []
1346+
result: List[Dict[str, Any]] = []
1347+
for tool in chat_completion_tools:
1348+
if not isinstance(tool, dict):
1349+
result.append(tool) # type: ignore
1350+
continue
1351+
if tool.get("type") == "function":
1352+
fn = tool.get("function") or {}
1353+
parameters = dict(fn.get("parameters", {}) or {})
1354+
if not parameters or "type" not in parameters:
1355+
parameters["type"] = "object"
1356+
responses_tool: Dict[str, Any] = {
1357+
"type": "function",
1358+
"name": fn.get("name") or "",
1359+
"description": fn.get("description") or "",
1360+
"parameters": parameters,
1361+
"strict": fn.get("strict", False) or False,
1362+
}
1363+
if tool.get("cache_control") is not None:
1364+
responses_tool["cache_control"] = tool.get("cache_control")
1365+
if tool.get("defer_loading") is not None:
1366+
responses_tool["defer_loading"] = tool.get("defer_loading")
1367+
if tool.get("allowed_callers") is not None:
1368+
responses_tool["allowed_callers"] = tool.get("allowed_callers")
1369+
if tool.get("input_examples") is not None:
1370+
responses_tool["input_examples"] = tool.get("input_examples")
1371+
result.append(responses_tool)
1372+
else:
1373+
# mcp or other: pass through unchanged
1374+
result.append(dict(tool))
1375+
return result
1376+
13091377
@staticmethod
13101378
def transform_chat_completion_tools_to_responses_tools(
13111379
chat_completion_response: ModelResponse,

0 commit comments

Comments
 (0)