Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions litellm/llms/a2a/chat/guardrail_translation/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ async def process_output_response(
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
request_data: Optional[dict] = None,
) -> Any:
"""
Process A2A output response by applying guardrails to text content.
Expand Down Expand Up @@ -166,19 +167,19 @@ async def process_output_response(
return response

# Step 2: Apply guardrail to all texts in batch
# Create a request_data dict with response info and user API key metadata
request_data: dict = {"response": response_dict}
# Create a local_request_data dict with response info and user API key metadata
local_request_data: dict = {**(request_data or {}), "response": response_dict}

# Add user API key metadata with prefixed keys
user_metadata = self.transform_user_api_key_dict_to_metadata(user_api_key_dict)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
local_request_data["litellm_metadata"] = user_metadata

inputs = GenericGuardrailAPIInputs(texts=texts_to_check)

guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
request_data=local_request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
Expand Down Expand Up @@ -213,6 +214,7 @@ async def process_output_streaming_response(
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
request_data: Optional[dict] = None,
) -> List[Any]:
"""
Process A2A streaming output by applying guardrails to accumulated text.
Expand Down Expand Up @@ -258,15 +260,18 @@ async def process_output_streaming_response(
if not combined_text:
return responses_so_far

request_data: dict = {"responses_so_far": responses_so_far}
local_request_data: dict = {
**(request_data or {}),
"responses_so_far": responses_so_far,
}
user_metadata = self.transform_user_api_key_dict_to_metadata(user_api_key_dict)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
local_request_data["litellm_metadata"] = user_metadata

inputs = GenericGuardrailAPIInputs(texts=[combined_text])
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
request_data=local_request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
Expand Down
20 changes: 11 additions & 9 deletions litellm/llms/anthropic/chat/guardrail_translation/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ async def process_input_messages(

texts_to_check: List[str] = []
images_to_check: List[str] = []
tools_to_check: List[
ChatCompletionToolParam
] = chat_completion_compatible_request.get("tools", [])
tools_to_check: List[ChatCompletionToolParam] = (
chat_completion_compatible_request.get("tools", [])
)
task_mappings: List[Tuple[int, Optional[int]]] = []
# Track (message_index, content_index) for each text
# content_index is None for string content, int for list content
Expand Down Expand Up @@ -252,6 +252,7 @@ async def process_output_response(
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
request_data: Optional[dict] = None,
) -> Any:
"""
Process output response by applying guardrails to text content and tool calls.
Expand Down Expand Up @@ -323,15 +324,15 @@ async def process_output_response(

# Step 2: Apply guardrail to all texts in batch
if texts_to_check or tool_calls_to_check:
# Create a request_data dict with response info and user API key metadata
request_data: dict = {"response": response}
# Create a local_request_data dict with response info and user API key metadata
local_request_data: dict = {**(request_data or {}), "response": response}

# Add user API key metadata with prefixed keys
user_metadata = self.transform_user_api_key_dict_to_metadata(
user_api_key_dict
)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
local_request_data["litellm_metadata"] = user_metadata

inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
if images_to_check:
Expand All @@ -349,7 +350,7 @@ async def process_output_response(

guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
request_data=local_request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
Expand All @@ -375,6 +376,7 @@ async def process_output_streaming_response(
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
request_data: Optional[dict] = None,
) -> List[Any]:
"""
Process output streaming response by applying guardrails to text content.
Expand Down Expand Up @@ -413,7 +415,7 @@ async def process_output_streaming_response(

_guardrailed_inputs = await guardrail_to_apply.apply_guardrail( # allow rejecting the response, if invalid
inputs=guardrail_inputs,
request_data={},
request_data=request_data or {},
input_type="response",
logging_obj=litellm_logging_obj,
)
Expand All @@ -426,7 +428,7 @@ async def process_output_streaming_response(
string_so_far = self.get_streaming_string_so_far(responses_so_far)
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail( # allow rejecting the response, if invalid
inputs={"texts": [string_so_far]},
request_data={},
request_data=request_data or {},
input_type="response",
logging_obj=litellm_logging_obj,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ async def process_output_response(
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
request_data: Optional[dict] = None,
) -> Any:
"""
Process output response with guardrails.
Expand All @@ -82,6 +83,7 @@ async def process_output_response(
guardrail_to_apply: The guardrail instance to apply
litellm_logging_obj: Optional logging object
user_api_key_dict: User API key metadata (passed separately since response doesn't contain it)
request_data: Optional original request data dict from the proxy
"""
pass

Expand All @@ -91,11 +93,15 @@ async def process_output_streaming_response(
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
request_data: Optional[dict] = None,
) -> Any:
"""
Process output streaming response with guardrails.

Optional to override in subclasses.

Args:
request_data: Optional original request data dict from the proxy
"""
return responses_so_far

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ async def process_output_response(
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
request_data: Optional[dict] = None,
) -> Any:
"""
Process output response - not applicable for rerank.
Expand Down
3 changes: 2 additions & 1 deletion litellm/llms/mistral/ocr/guardrail_translation/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ async def process_output_response(
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
request_data: Optional[dict] = None,
) -> Any:
"""
Process OCR output by applying guardrails to extracted page text.
Expand Down Expand Up @@ -134,7 +135,7 @@ async def process_output_response(

guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data={},
request_data=request_data or {},
input_type="response",
logging_obj=litellm_logging_obj,
)
Expand Down
28 changes: 17 additions & 11 deletions litellm/llms/openai/chat/guardrail_translation/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ async def process_input_messages(
if tool_calls_to_check:
inputs["tool_calls"] = tool_calls_to_check # type: ignore
if messages:
inputs[
"structured_messages"
] = messages # pass the openai /chat/completions messages to the guardrail, as-is
inputs["structured_messages"] = (
messages # pass the openai /chat/completions messages to the guardrail, as-is
)
# Pass tools (function definitions) to the guardrail
tools = data.get("tools")
if tools:
Expand Down Expand Up @@ -260,6 +260,7 @@ async def process_output_response(
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
request_data: Optional[dict] = None,
) -> Any:
"""
Process output response by applying guardrails to text content.
Expand Down Expand Up @@ -308,15 +309,15 @@ async def process_output_response(

# Step 2: Apply guardrail to all texts and tool calls in batch
if texts_to_check or tool_calls_to_check:
# Create a request_data dict with response info and user API key metadata
request_data: dict = {"response": response}
# Create a local_request_data dict with response info and user API key metadata
local_request_data: dict = {**(request_data or {}), "response": response}

# Add user API key metadata with prefixed keys
user_metadata = self.transform_user_api_key_dict_to_metadata(
user_api_key_dict
)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
local_request_data["litellm_metadata"] = user_metadata

inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
if images_to_check:
Expand All @@ -329,7 +330,7 @@ async def process_output_response(

guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
request_data=local_request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
Expand Down Expand Up @@ -364,6 +365,7 @@ async def process_output_streaming_response(
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
request_data: Optional[dict] = None,
) -> List["ModelResponseStream"]:
"""
Process output streaming responses by applying guardrails to text content.
Expand Down Expand Up @@ -402,6 +404,7 @@ async def process_output_streaming_response(
guardrail_to_apply=guardrail_to_apply,
litellm_logging_obj=litellm_logging_obj,
user_api_key_dict=user_api_key_dict,
request_data=request_data,
)

return responses_so_far
Expand Down Expand Up @@ -436,15 +439,18 @@ async def process_output_streaming_response(

# Step 3: Apply guardrail to all combined texts in batch
if texts_to_check:
# Create a request_data dict with response info and user API key metadata
request_data: dict = {"responses": responses_so_far}
# Create a local_request_data dict with response info and user API key metadata
local_request_data: dict = {
**(request_data or {}),
"responses": responses_so_far,
}

# Add user API key metadata with prefixed keys
user_metadata = self.transform_user_api_key_dict_to_metadata(
user_api_key_dict
)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
local_request_data["litellm_metadata"] = user_metadata

inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
if images_to_check:
Expand All @@ -458,7 +464,7 @@ async def process_output_streaming_response(
inputs["model"] = responses_so_far[0].model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
request_data=local_request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ async def process_output_response(
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
request_data: Optional[dict] = None,
) -> Any:
"""
Process output response by applying guardrails to completion text.
Expand Down Expand Up @@ -155,23 +156,23 @@ async def process_output_response(

# Apply guardrails in batch
if texts_to_check:
# Create a request_data dict with response info and user API key metadata
request_data: dict = {"response": response}
# Create a local_request_data dict with response info and user API key metadata
local_request_data: dict = {**(request_data or {}), "response": response}

# Add user API key metadata with prefixed keys
user_metadata = self.transform_user_api_key_dict_to_metadata(
user_api_key_dict
)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
local_request_data["litellm_metadata"] = user_metadata

inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
# Include model information from the response if available
if hasattr(response, "model") and response.model:
inputs["model"] = response.model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
request_data=local_request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ async def process_output_response(
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
request_data: Optional[dict] = None,
) -> Any:
"""
Process output response - embeddings responses contain vectors, not text.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ async def process_output_response(
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
request_data: Optional[dict] = None,
) -> Any:
"""
Process output response - typically not needed for image generation.
Expand Down
12 changes: 7 additions & 5 deletions litellm/llms/openai/responses/guardrail_translation/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ async def process_output_response(
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
request_data: Optional[dict] = None,
) -> Any:
"""
Process output response by applying guardrails to text content and tool calls.
Expand Down Expand Up @@ -402,15 +403,15 @@ async def process_output_response(

# Step 2: Apply guardrail to all texts in batch
if texts_to_check or tool_calls_to_check:
# Create a request_data dict with response info and user API key metadata
request_data: dict = {"response": response}
# Create a local_request_data dict with response info and user API key metadata
local_request_data: dict = {**(request_data or {}), "response": response}

# Add user API key metadata with prefixed keys
user_metadata = self.transform_user_api_key_dict_to_metadata(
user_api_key_dict
)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
local_request_data["litellm_metadata"] = user_metadata

inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
if images_to_check:
Expand All @@ -428,7 +429,7 @@ async def process_output_response(

guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
request_data=local_request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
Expand All @@ -454,6 +455,7 @@ async def process_output_streaming_response(
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
request_data: Optional[dict] = None,
) -> List[Any]:
"""
Process output streaming response by applying guardrails to text content.
Expand Down Expand Up @@ -481,7 +483,7 @@ async def process_output_streaming_response(
inputs["model"] = model_response_stream.model
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data={},
request_data=request_data or {},
input_type="response",
logging_obj=litellm_logging_obj,
)
Expand Down
Loading
Loading