@@ -96,10 +96,11 @@ async def process_input_messages(
9696 # Handle simple string input
9797 if isinstance (input_data , str ):
9898 inputs = GenericGuardrailAPIInputs (texts = [input_data ])
99+ original_tools : List [Dict [str , Any ]] = []
99100
100101 # Extract and transform tools if present
101-
102102 if "tools" in data and data ["tools" ]:
103+ original_tools = list (data ["tools" ])
103104 self ._extract_and_transform_tools (data ["tools" ], tools_to_check )
104105 if tools_to_check :
105106 inputs ["tools" ] = tools_to_check
@@ -118,9 +119,9 @@ async def process_input_messages(
118119 )
119120 guardrailed_texts = guardrailed_inputs .get ("texts" , [])
120121 data ["input" ] = guardrailed_texts [0 ] if guardrailed_texts else input_data
121- guardrailed_tools = guardrailed_inputs . get ( "tools" )
122- if guardrailed_tools is not None :
123- data [ "tools" ] = guardrailed_tools
122+ self . _apply_guardrailed_tools_to_data (
123+ data , original_tools , guardrailed_inputs . get ( "tools" )
124+ )
124125 verbose_proxy_logger .debug ("OpenAI Responses API: Processed string input" )
125126 return data
126127
@@ -131,8 +132,7 @@ async def process_input_messages(
131132 texts_to_check : List [str ] = []
132133 images_to_check : List [str ] = []
133134 task_mappings : List [Tuple [int , Optional [int ]]] = []
134- # Track (message_index, content_index) for each text
135- # content_index is None for string content, int for list content
135+ original_tools_list : List [Dict [str , Any ]] = list (data .get ("tools" ) or [])
136136
137137 # Step 1: Extract all text content, images, and tools
138138 for msg_idx , message in enumerate (input_data ):
@@ -169,9 +169,11 @@ async def process_input_messages(
169169 )
170170
171171 guardrailed_texts = guardrailed_inputs .get ("texts" , [])
172- guardrailed_tools = guardrailed_inputs .get ("tools" )
173- if guardrailed_tools is not None :
174- data ["tools" ] = guardrailed_tools
172+ self ._apply_guardrailed_tools_to_data (
173+ data ,
174+ original_tools_list ,
175+ guardrailed_inputs .get ("tools" ),
176+ )
175177
176178 # Step 3: Map guardrail responses back to original input structure
177179 await self ._apply_guardrail_responses_to_input (
@@ -209,6 +211,53 @@ def _extract_and_transform_tools(
209211 cast (List [ChatCompletionToolParam ], transformed_tools )
210212 )
211213
214+ def _remap_tools_to_responses_api_format (
215+ self , guardrailed_tools : List [Any ]
216+ ) -> List [Dict [str , Any ]]:
217+ """
218+ Remap guardrail-returned tools (Chat Completion format) back to
219+ Responses API request tool format.
220+ """
221+ return LiteLLMCompletionResponsesConfig .transform_chat_completion_tool_params_to_responses_api_tools (
222+ guardrailed_tools # type: ignore
223+ )
224+
225+ def _merge_tools_after_guardrail (
226+ self ,
227+ original_tools : List [Dict [str , Any ]],
228+ remapped : List [Dict [str , Any ]],
229+ ) -> List [Dict [str , Any ]]:
230+ """
231+ Merge remapped guardrailed tools with original tools that were not sent
232+ to the guardrail (e.g. web_search, web_search_preview), preserving order.
233+ """
234+ if not original_tools :
235+ return remapped
236+ result : List [Dict [str , Any ]] = []
237+ j = 0
238+ for tool in original_tools :
239+ if isinstance (tool , dict ) and tool .get ("type" ) in (
240+ "web_search" ,
241+ "web_search_preview" ,
242+ ):
243+ result .append (tool )
244+ else :
245+ if j < len (remapped ):
246+ result .append (remapped [j ])
247+ j += 1
248+ return result
249+
250+ def _apply_guardrailed_tools_to_data (
251+ self ,
252+ data : dict ,
253+ original_tools : List [Dict [str , Any ]],
254+ guardrailed_tools : Optional [List [Any ]],
255+ ) -> None :
256+ """Remap guardrailed tools to Responses API format and merge with original, then set data['tools']."""
257+ if guardrailed_tools is not None :
258+ remapped = self ._remap_tools_to_responses_api_format (guardrailed_tools )
259+ data ["tools" ] = self ._merge_tools_after_guardrail (original_tools , remapped )
260+
212261 def _extract_input_text_and_images (
213262 self ,
214263 message : Any , # Can be Dict[str, Any] or ResponseInputParam
@@ -413,7 +462,10 @@ async def process_output_streaming_response(
413462 List [ChatCompletionToolCallChunk ], tool_calls
414463 )
415464 # Include model information if available
416- if hasattr (model_response_stream , "model" ) and model_response_stream .model :
465+ if (
466+ hasattr (model_response_stream , "model" )
467+ and model_response_stream .model
468+ ):
417469 inputs ["model" ] = model_response_stream .model
418470 _guardrailed_inputs = await guardrail_to_apply .apply_guardrail (
419471 inputs = inputs ,
@@ -454,15 +506,21 @@ async def process_output_streaming_response(
454506 )
455507 return responses_so_far
456508 else :
457- verbose_proxy_logger .debug ("Skipping output guardrail - model response has no choices" )
509+ verbose_proxy_logger .debug (
510+ "Skipping output guardrail - model response has no choices"
511+ )
458512 # model_response_stream = OpenAiResponsesToChatCompletionStreamIterator.translate_responses_chunk_to_openai_stream(final_chunk)
459513 # tool_calls = model_response_stream.choices[0].tool_calls
460514 # convert openai response to model response
461515 string_so_far = self .get_streaming_string_so_far (responses_so_far )
462516 inputs = GenericGuardrailAPIInputs (texts = [string_so_far ])
463517 # Try to get model from the final chunk if available
464518 if isinstance (final_chunk , dict ):
465- response_model = final_chunk .get ("response" , {}).get ("model" ) if isinstance (final_chunk .get ("response" ), dict ) else None
519+ response_model = (
520+ final_chunk .get ("response" , {}).get ("model" )
521+ if isinstance (final_chunk .get ("response" ), dict )
522+ else None
523+ )
466524 if response_model :
467525 inputs ["model" ] = response_model
468526 _guardrailed_inputs = await guardrail_to_apply .apply_guardrail (
@@ -597,8 +655,8 @@ def _extract_output_text_and_images(
597655 content = generic_response_output_item .content
598656 except Exception :
599657 # Try to extract content directly from output_item if validation fails
600- if hasattr (output_item , "content" ) and output_item .content : # type: ignore
601- content = output_item .content # type: ignore
658+ if hasattr (output_item , "content" ) and output_item .content : # type: ignore
659+ content = output_item .content # type: ignore
602660 else :
603661 return
604662 elif isinstance (output_item , dict ):
@@ -675,10 +733,10 @@ async def _apply_guardrail_responses_to_output(
675733 if isinstance (content_item , OutputText ):
676734 content_item .text = guardrail_response
677735 # Update the original response output
678- if hasattr (output_item , "content" ) and output_item .content : # type: ignore
679- original_content = output_item .content [content_idx ] # type: ignore
736+ if hasattr (output_item , "content" ) and output_item .content : # type: ignore
737+ original_content = output_item .content [content_idx ] # type: ignore
680738 if hasattr (original_content , "text" ):
681- original_content .text = guardrail_response # type: ignore
739+ original_content .text = guardrail_response # type: ignore
682740 except Exception :
683741 pass
684742 elif isinstance (output_item , dict ):
0 commit comments