From 34fb9010258e1c2259404f646e862821ad2cf000 Mon Sep 17 00:00:00 2001 From: Audrey Kadjar Date: Sat, 21 Mar 2026 16:39:38 +0100 Subject: [PATCH 01/13] fix: handle Pydantic MockValSer bug in streaming responses (#18801) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Problem TypeError: 'MockValSer' object cannot be converted to 'SchemaSerializer' when handling streaming responses with SAP AI Core and other providers. Pydantic 2.11+ has a bug where the internal MockValSer sentinel is not properly converted to a real SchemaSerializer in certain streaming scenarios. When LiteLLM tries to serialize chunks using model_dump(), it hits this corrupted serializer state. ## Solution Added try-catch fallback that uses __dict__ extraction when model_dump() fails with TypeError. This bypasses Pydantic's serialization entirely while maintaining functionality. ## Changes - litellm/litellm_core_utils/streaming_handler.py: Added fallback in 2 locations - litellm/litellm_core_utils/core_helpers.py: Added fallback in preserve_upstream_non_openai_attributes - tests/test_litellm/litellm_core_utils/test_streaming_handler.py: Added regression test ## Testing ✅ All 49 streaming handler tests pass ✅ Regression test verifies fallback behavior Related: https://github.com/BerriAI/litellm/issues/18801 Related: https://github.com/pydantic/pydantic/issues/7713 --- litellm/litellm_core_utils/core_helpers.py | 7 ++- .../litellm_core_utils/streaming_handler.py | 12 ++++- .../test_streaming_handler.py | 49 +++++++++++++++++++ 3 files changed, 65 insertions(+), 3 deletions(-) diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index 256b16ff312..2314f8bcb04 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -270,7 +270,12 @@ def preserve_upstream_non_openai_attributes( """ # Access model_fields on the class, not the instance, to avoid Pydantic 2.11+ deprecation warnings expected_keys = set(type(model_response).model_fields.keys()).union({"usage"}) - for key, value in original_chunk.model_dump().items(): + try: + obj_dict = original_chunk.model_dump() + except TypeError: + # Fallback for Pydantic MockValSer bug (issue #18801) + obj_dict = dict(original_chunk.__dict__) if hasattr(original_chunk, '__dict__') else {} + for key, value in obj_dict.items(): if key not in expected_keys: setattr(model_response, key, value) diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 96e70845b28..210ed8eec37 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -1859,7 +1859,11 @@ def __next__(self) -> "ModelResponseStream": # noqa: PLR0915 response, "usage" ): # remove usage from chunk, only send on final chunk # Convert the object to a dictionary - obj_dict = response.model_dump() + try: + obj_dict = response.model_dump() + except TypeError as e: + # Fallback: manually extract dict from __dict__ to bypass Pydantic serializer + obj_dict = dict(response.__dict__) if hasattr(response, '__dict__') else {} # Remove an attribute (e.g., 'attr2') if "usage" in obj_dict: @@ -2047,7 +2051,11 @@ async def __anext__(self) -> "ModelResponseStream": # noqa: PLR0915 # Strip usage from the outgoing chunk so it's not sent twice # (once in the chunk, once in _hidden_params). - obj_dict = processed_chunk.model_dump() + try: + obj_dict = processed_chunk.model_dump() + except TypeError as e: + # Fallback: manually extract dict from __dict__ to bypass Pydantic serializer + obj_dict = dict(processed_chunk.__dict__) if hasattr(processed_chunk, '__dict__') else {} if "usage" in obj_dict: del obj_dict["usage"] processed_chunk = self.model_response_creator( diff --git a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py index e0862629947..2ef60ac84a7 100644 --- a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py +++ b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py @@ -1679,3 +1679,52 @@ def test_tool_use_not_dropped_when_finish_reason_already_set( ) assert tool_calls[0].id == "call_1" assert tool_calls[0].function.name == "get_weather" + + +def test_model_dump_fallback_handles_pydantic_serializer_bug( + initialized_custom_stream_wrapper: CustomStreamWrapper, +): + """ + Regression test for #18801: MockValSer TypeError in streaming responses. + + Pydantic 2.11+ has a bug where MockValSer sentinel is not converted to + SchemaSerializer in certain scenarios. The fix catches TypeError and falls + back to __dict__ extraction. + """ + # Create a chunk with usage that will be stripped + chunk_with_usage = ModelResponseStream( + id="test-chunk", + created=1742056047, + model="sap-ai-core/test-model", + object="chat.completion.chunk", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta(content="test content", role="assistant"), + ) + ], + usage=Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + + # Mock model_dump to raise TypeError (simulating MockValSer bug) + original_model_dump = chunk_with_usage.model_dump + + def mock_model_dump(*args, **kwargs): + raise TypeError("'MockValSer' object cannot be converted to 'SchemaSerializer'") + + chunk_with_usage.model_dump = mock_model_dump + + # The code should gracefully fall back to __dict__ and not crash + initialized_custom_stream_wrapper.chunks.append(chunk_with_usage) + + # Process the chunk through return_processed_chunk_logic which calls model_dump + result = initialized_custom_stream_wrapper.return_processed_chunk_logic( + completion_obj={"content": "test content"}, + response_obj={"original_chunk": chunk_with_usage}, + model_response=chunk_with_usage, + ) + + # Should not raise TypeError and should successfully process the chunk + assert result is not None + assert result.choices[0].delta.content == "test content" From 14fdbf8054ceb3519b99247c575f5aa6649ae2c9 Mon Sep 17 00:00:00 2001 From: Audrey Kadjar Date: Sun, 22 Mar 2026 12:11:32 +0100 Subject: [PATCH 02/13] Address PR review comments - narrow TypeError handling - Check for 'MockValSer' in error message before applying fallback - Re-raise non-MockValSer TypeErrors to avoid masking real bugs - Add try-finally block in test for proper cleanup - Addresses review feedback from greptile-apps bot --- litellm/litellm_core_utils/core_helpers.py | 4 ++- .../litellm_core_utils/streaming_handler.py | 8 ++++-- .../test_streaming_handler.py | 28 +++++++++++-------- 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index 2314f8bcb04..bf92fad415e 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -272,7 +272,9 @@ def preserve_upstream_non_openai_attributes( expected_keys = set(type(model_response).model_fields.keys()).union({"usage"}) try: obj_dict = original_chunk.model_dump() - except TypeError: + except TypeError as e: + if "MockValSer" not in str(e): + raise # Fallback for Pydantic MockValSer bug (issue #18801) obj_dict = dict(original_chunk.__dict__) if hasattr(original_chunk, '__dict__') else {} for key, value in obj_dict.items(): diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 210ed8eec37..a8a4a593b59 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -1862,7 +1862,9 @@ def __next__(self) -> "ModelResponseStream": # noqa: PLR0915 try: obj_dict = response.model_dump() except TypeError as e: - # Fallback: manually extract dict from __dict__ to bypass Pydantic serializer + if "MockValSer" not in str(e): + raise + # Fallback for Pydantic MockValSer bug (issue #18801) obj_dict = dict(response.__dict__) if hasattr(response, '__dict__') else {} # Remove an attribute (e.g., 'attr2') @@ -2054,7 +2056,9 @@ async def __anext__(self) -> "ModelResponseStream": # noqa: PLR0915 try: obj_dict = processed_chunk.model_dump() except TypeError as e: - # Fallback: manually extract dict from __dict__ to bypass Pydantic serializer + if "MockValSer" not in str(e): + raise + # Fallback for Pydantic MockValSer bug (issue #18801) obj_dict = dict(processed_chunk.__dict__) if hasattr(processed_chunk, '__dict__') else {} if "usage" in obj_dict: del obj_dict["usage"] diff --git a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py index 2ef60ac84a7..5d11ada96aa 100644 --- a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py +++ b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py @@ -1715,16 +1715,20 @@ def mock_model_dump(*args, **kwargs): chunk_with_usage.model_dump = mock_model_dump - # The code should gracefully fall back to __dict__ and not crash - initialized_custom_stream_wrapper.chunks.append(chunk_with_usage) - - # Process the chunk through return_processed_chunk_logic which calls model_dump - result = initialized_custom_stream_wrapper.return_processed_chunk_logic( - completion_obj={"content": "test content"}, - response_obj={"original_chunk": chunk_with_usage}, - model_response=chunk_with_usage, - ) + try: + # The code should gracefully fall back to __dict__ and not crash + initialized_custom_stream_wrapper.chunks.append(chunk_with_usage) + + # Process the chunk through return_processed_chunk_logic which calls model_dump + result = initialized_custom_stream_wrapper.return_processed_chunk_logic( + completion_obj={"content": "test content"}, + response_obj={"original_chunk": chunk_with_usage}, + model_response=chunk_with_usage, + ) - # Should not raise TypeError and should successfully process the chunk - assert result is not None - assert result.choices[0].delta.content == "test content" + # Should not raise TypeError and should successfully process the chunk + assert result is not None + assert result.choices[0].delta.content == "test content" + finally: + # Restore original method + chunk_with_usage.model_dump = original_model_dump From e03f86e62593d717853c5a1f4866a85328084d54 Mon Sep 17 00:00:00 2001 From: Audrey Kadjar Date: Sun, 22 Mar 2026 12:32:46 +0100 Subject: [PATCH 03/13] Add logging and preserve __pydantic_extra__ in fallback - Add warning logs when MockValSer fallback is triggered for observability - Merge __pydantic_extra__ with __dict__ to preserve dynamic provider fields - Update test to verify extra attributes survive the fallback - Addresses all P1 and P2 review feedback from greptile-apps bot --- litellm/litellm_core_utils/core_helpers.py | 11 +++++++++- .../litellm_core_utils/streaming_handler.py | 22 +++++++++++++++++-- .../test_streaming_handler.py | 6 +++++ 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index bf92fad415e..3f9c6a9c594 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -276,7 +276,16 @@ def preserve_upstream_non_openai_attributes( if "MockValSer" not in str(e): raise # Fallback for Pydantic MockValSer bug (issue #18801) - obj_dict = dict(original_chunk.__dict__) if hasattr(original_chunk, '__dict__') else {} + import logging + logging.getLogger("LiteLLM").warning( + "Pydantic MockValSer bug detected (issue #18801); falling back to __dict__ extraction. " + "Upgrade/downgrade pydantic if this persists. Error: %s", e + ) + # Merge __dict__ with __pydantic_extra__ to preserve dynamically-added provider fields + obj_dict = { + **dict(original_chunk.__dict__), + **(getattr(original_chunk, '__pydantic_extra__', None) or {}), + } for key, value in obj_dict.items(): if key not in expected_keys: setattr(model_response, key, value) diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index a8a4a593b59..848f52355b0 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -1865,7 +1865,16 @@ def __next__(self) -> "ModelResponseStream": # noqa: PLR0915 if "MockValSer" not in str(e): raise # Fallback for Pydantic MockValSer bug (issue #18801) - obj_dict = dict(response.__dict__) if hasattr(response, '__dict__') else {} + import logging + logging.getLogger("LiteLLM").warning( + "Pydantic MockValSer bug detected (issue #18801); falling back to __dict__ extraction. " + "Upgrade/downgrade pydantic if this persists. Error: %s", e + ) + # Merge __dict__ with __pydantic_extra__ to preserve dynamically-added provider fields + obj_dict = { + **dict(response.__dict__), + **(getattr(response, '__pydantic_extra__', None) or {}), + } # Remove an attribute (e.g., 'attr2') if "usage" in obj_dict: @@ -2059,7 +2068,16 @@ async def __anext__(self) -> "ModelResponseStream": # noqa: PLR0915 if "MockValSer" not in str(e): raise # Fallback for Pydantic MockValSer bug (issue #18801) - obj_dict = dict(processed_chunk.__dict__) if hasattr(processed_chunk, '__dict__') else {} + import logging + logging.getLogger("LiteLLM").warning( + "Pydantic MockValSer bug detected (issue #18801); falling back to __dict__ extraction. " + "Upgrade/downgrade pydantic if this persists. Error: %s", e + ) + # Merge __dict__ with __pydantic_extra__ to preserve dynamically-added provider fields + obj_dict = { + **dict(processed_chunk.__dict__), + **(getattr(processed_chunk, '__pydantic_extra__', None) or {}), + } if "usage" in obj_dict: del obj_dict["usage"] processed_chunk = self.model_response_creator( diff --git a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py index 5d11ada96aa..3c47e45be10 100644 --- a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py +++ b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py @@ -1707,6 +1707,9 @@ def test_model_dump_fallback_handles_pydantic_serializer_bug( usage=Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15), ) + # Add a provider-specific field to test __pydantic_extra__ preservation + chunk_with_usage.sap_extra_field = "sap-value" + # Mock model_dump to raise TypeError (simulating MockValSer bug) original_model_dump = chunk_with_usage.model_dump @@ -1729,6 +1732,9 @@ def mock_model_dump(*args, **kwargs): # Should not raise TypeError and should successfully process the chunk assert result is not None assert result.choices[0].delta.content == "test content" + + # Verify that extra/provider attributes survive the fallback + assert getattr(result, "sap_extra_field", None) == "sap-value" finally: # Restore original method chunk_with_usage.model_dump = original_model_dump From 168a41c1cb6ee3c3f4e46e97a873574f0a117150 Mon Sep 17 00:00:00 2001 From: Audrey Kadjar Date: Sun, 22 Mar 2026 19:19:02 +0100 Subject: [PATCH 04/13] Move logging imports to top and improve warning message - Add logging import at module level in core_helpers.py - Remove inline 'import logging' statements from exception handlers - Update warning message to only suggest upgrading pydantic - Reference pydantic issue #7713 instead of litellm issue #18801 - Addresses style feedback from greptile-apps bot --- litellm/litellm_core_utils/core_helpers.py | 8 ++++---- litellm/litellm_core_utils/streaming_handler.py | 14 ++++++-------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index 3f9c6a9c594..fff11a42682 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -1,5 +1,6 @@ # What is this? ## Helper utilities +import logging from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Union import httpx @@ -275,11 +276,10 @@ def preserve_upstream_non_openai_attributes( except TypeError as e: if "MockValSer" not in str(e): raise - # Fallback for Pydantic MockValSer bug (issue #18801) - import logging + # Fallback for Pydantic MockValSer bug (pydantic issue #7713) logging.getLogger("LiteLLM").warning( - "Pydantic MockValSer bug detected (issue #18801); falling back to __dict__ extraction. " - "Upgrade/downgrade pydantic if this persists. Error: %s", e + "Pydantic MockValSer bug detected (pydantic issue #7713); falling back to __dict__ extraction. " + "Upgrading pydantic may resolve this. Error: %s", e ) # Merge __dict__ with __pydantic_extra__ to preserve dynamically-added provider fields obj_dict = { diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 848f52355b0..1832926a651 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -1864,11 +1864,10 @@ def __next__(self) -> "ModelResponseStream": # noqa: PLR0915 except TypeError as e: if "MockValSer" not in str(e): raise - # Fallback for Pydantic MockValSer bug (issue #18801) - import logging + # Fallback for Pydantic MockValSer bug (pydantic issue #7713) logging.getLogger("LiteLLM").warning( - "Pydantic MockValSer bug detected (issue #18801); falling back to __dict__ extraction. " - "Upgrade/downgrade pydantic if this persists. Error: %s", e + "Pydantic MockValSer bug detected (pydantic issue #7713); falling back to __dict__ extraction. " + "Upgrading pydantic may resolve this. Error: %s", e ) # Merge __dict__ with __pydantic_extra__ to preserve dynamically-added provider fields obj_dict = { @@ -2067,11 +2066,10 @@ async def __anext__(self) -> "ModelResponseStream": # noqa: PLR0915 except TypeError as e: if "MockValSer" not in str(e): raise - # Fallback for Pydantic MockValSer bug (issue #18801) - import logging + # Fallback for Pydantic MockValSer bug (pydantic issue #7713) logging.getLogger("LiteLLM").warning( - "Pydantic MockValSer bug detected (issue #18801); falling back to __dict__ extraction. " - "Upgrade/downgrade pydantic if this persists. Error: %s", e + "Pydantic MockValSer bug detected (pydantic issue #7713); falling back to __dict__ extraction. " + "Upgrading pydantic may resolve this. Error: %s", e ) # Merge __dict__ with __pydantic_extra__ to preserve dynamically-added provider fields obj_dict = { From db995f4618a63fbd22ec4e87d0970fef8451a466 Mon Sep 17 00:00:00 2001 From: Audrey Kadjar Date: Sun, 22 Mar 2026 19:51:32 +0100 Subject: [PATCH 05/13] Use verbose_logger and fix test to verify attribute copying - Replace logging.getLogger('LiteLLM') with verbose_logger for consistency - Remove logging import from core_helpers.py (no longer needed) - Fix test to use distinct objects for model_response and original_chunk - Now properly verifies __pydantic_extra__ attributes are COPIED not just preserved - Addresses P1 and P2 feedback from greptile-apps bot --- litellm/litellm_core_utils/core_helpers.py | 3 +-- .../litellm_core_utils/streaming_handler.py | 4 ++-- .../test_streaming_handler.py | 21 ++++++++++++++++--- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index fff11a42682..8d19a55bc8e 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -1,6 +1,5 @@ # What is this? ## Helper utilities -import logging from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Union import httpx @@ -277,7 +276,7 @@ def preserve_upstream_non_openai_attributes( if "MockValSer" not in str(e): raise # Fallback for Pydantic MockValSer bug (pydantic issue #7713) - logging.getLogger("LiteLLM").warning( + verbose_logger.warning( "Pydantic MockValSer bug detected (pydantic issue #7713); falling back to __dict__ extraction. " "Upgrading pydantic may resolve this. Error: %s", e ) diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 1832926a651..716b1cca3eb 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -1865,7 +1865,7 @@ def __next__(self) -> "ModelResponseStream": # noqa: PLR0915 if "MockValSer" not in str(e): raise # Fallback for Pydantic MockValSer bug (pydantic issue #7713) - logging.getLogger("LiteLLM").warning( + verbose_logger.warning( "Pydantic MockValSer bug detected (pydantic issue #7713); falling back to __dict__ extraction. " "Upgrading pydantic may resolve this. Error: %s", e ) @@ -2067,7 +2067,7 @@ async def __anext__(self) -> "ModelResponseStream": # noqa: PLR0915 if "MockValSer" not in str(e): raise # Fallback for Pydantic MockValSer bug (pydantic issue #7713) - logging.getLogger("LiteLLM").warning( + verbose_logger.warning( "Pydantic MockValSer bug detected (pydantic issue #7713); falling back to __dict__ extraction. " "Upgrading pydantic may resolve this. Error: %s", e ) diff --git a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py index 3c47e45be10..9655d37da5c 100644 --- a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py +++ b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py @@ -1722,18 +1722,33 @@ def mock_model_dump(*args, **kwargs): # The code should gracefully fall back to __dict__ and not crash initialized_custom_stream_wrapper.chunks.append(chunk_with_usage) + # Use a DIFFERENT object as the target model_response + fresh_model_response = ModelResponseStream( + id="fresh", + created=1742056047, + model="sap-ai-core/test-model", + object="chat.completion.chunk", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta(content="test content", role="assistant"), + ) + ], + ) + # Process the chunk through return_processed_chunk_logic which calls model_dump result = initialized_custom_stream_wrapper.return_processed_chunk_logic( completion_obj={"content": "test content"}, - response_obj={"original_chunk": chunk_with_usage}, - model_response=chunk_with_usage, + response_obj={"original_chunk": chunk_with_usage}, # source of extra attrs + model_response=fresh_model_response, # target (different object) ) # Should not raise TypeError and should successfully process the chunk assert result is not None assert result.choices[0].delta.content == "test content" - # Verify that extra/provider attributes survive the fallback + # Now this assertion is meaningful: it verifies the attribute was actually COPIED assert getattr(result, "sap_extra_field", None) == "sap-value" finally: # Restore original method From aa0e3561d5adf5d6258e5d278e93e4e5c315f43e Mon Sep 17 00:00:00 2001 From: Audrey Kadjar Date: Sun, 22 Mar 2026 20:06:54 +0100 Subject: [PATCH 06/13] Filter private attributes and fix test mock pollution - Filter underscore-prefixed keys from fallback dict to match model_dump() behavior - Prevents leaking private attributes like _hidden_params to target models - Use unittest.mock.patch.object for class-level mocking in test - Prevents mock from appearing in __dict__ or __pydantic_extra__ - Add assertion that result.model_dump() still works after processing - Addresses P1 feedback from greptile-apps bot --- litellm/litellm_core_utils/core_helpers.py | 9 ++++++-- .../litellm_core_utils/streaming_handler.py | 18 ++++++++++++---- .../test_streaming_handler.py | 21 ++++++++++--------- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index 8d19a55bc8e..e85679f3890 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -281,9 +281,14 @@ def preserve_upstream_non_openai_attributes( "Upgrading pydantic may resolve this. Error: %s", e ) # Merge __dict__ with __pydantic_extra__ to preserve dynamically-added provider fields + # Filter out underscore-prefixed private attributes to match model_dump() behavior obj_dict = { - **dict(original_chunk.__dict__), - **(getattr(original_chunk, '__pydantic_extra__', None) or {}), + k: v + for k, v in { + **dict(original_chunk.__dict__), + **(getattr(original_chunk, '__pydantic_extra__', None) or {}), + }.items() + if not k.startswith('_') } for key, value in obj_dict.items(): if key not in expected_keys: diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 716b1cca3eb..be696b89212 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -1870,9 +1870,14 @@ def __next__(self) -> "ModelResponseStream": # noqa: PLR0915 "Upgrading pydantic may resolve this. Error: %s", e ) # Merge __dict__ with __pydantic_extra__ to preserve dynamically-added provider fields + # Filter out underscore-prefixed private attributes to match model_dump() behavior obj_dict = { - **dict(response.__dict__), - **(getattr(response, '__pydantic_extra__', None) or {}), + k: v + for k, v in { + **dict(response.__dict__), + **(getattr(response, '__pydantic_extra__', None) or {}), + }.items() + if not k.startswith('_') } # Remove an attribute (e.g., 'attr2') @@ -2072,9 +2077,14 @@ async def __anext__(self) -> "ModelResponseStream": # noqa: PLR0915 "Upgrading pydantic may resolve this. Error: %s", e ) # Merge __dict__ with __pydantic_extra__ to preserve dynamically-added provider fields + # Filter out underscore-prefixed private attributes to match model_dump() behavior obj_dict = { - **dict(processed_chunk.__dict__), - **(getattr(processed_chunk, '__pydantic_extra__', None) or {}), + k: v + for k, v in { + **dict(processed_chunk.__dict__), + **(getattr(processed_chunk, '__pydantic_extra__', None) or {}), + }.items() + if not k.startswith('_') } if "usage" in obj_dict: del obj_dict["usage"] diff --git a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py index 9655d37da5c..a6ccaedf662 100644 --- a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py +++ b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py @@ -1691,6 +1691,8 @@ def test_model_dump_fallback_handles_pydantic_serializer_bug( SchemaSerializer in certain scenarios. The fix catches TypeError and falls back to __dict__ extraction. """ + from unittest.mock import patch + # Create a chunk with usage that will be stripped chunk_with_usage = ModelResponseStream( id="test-chunk", @@ -1710,15 +1712,12 @@ def test_model_dump_fallback_handles_pydantic_serializer_bug( # Add a provider-specific field to test __pydantic_extra__ preservation chunk_with_usage.sap_extra_field = "sap-value" - # Mock model_dump to raise TypeError (simulating MockValSer bug) - original_model_dump = chunk_with_usage.model_dump - - def mock_model_dump(*args, **kwargs): + # Mock model_dump at class level to avoid polluting instance attributes + def mock_model_dump(self, *args, **kwargs): raise TypeError("'MockValSer' object cannot be converted to 'SchemaSerializer'") - chunk_with_usage.model_dump = mock_model_dump - - try: + # Use class-level patching to avoid the mock appearing in __dict__ or __pydantic_extra__ + with patch.object(type(chunk_with_usage), 'model_dump', mock_model_dump): # The code should gracefully fall back to __dict__ and not crash initialized_custom_stream_wrapper.chunks.append(chunk_with_usage) @@ -1750,6 +1749,8 @@ def mock_model_dump(*args, **kwargs): # Now this assertion is meaningful: it verifies the attribute was actually COPIED assert getattr(result, "sap_extra_field", None) == "sap-value" - finally: - # Restore original method - chunk_with_usage.model_dump = original_model_dump + + # Verify that result can still be serialized (model_dump method is not corrupted) + # This call is outside the patch context so it should work normally + result_dict = result.model_dump() + assert isinstance(result_dict, dict) From a7dbef1c522b72b49e397b9c933ad540cdb034f3 Mon Sep 17 00:00:00 2001 From: Audrey Kadjar Date: Mon, 23 Mar 2026 08:45:40 +0100 Subject: [PATCH 07/13] Strengthen serialization test to verify field preservation - Add assertion that sap_extra_field appears in model_dump() output - Verifies __pydantic_extra__ fields survive serialization - Guards against silent field dropping in model_dump() - Addresses P2 feedback from greptile-apps bot --- tests/test_litellm/litellm_core_utils/test_streaming_handler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py index a6ccaedf662..de41efd9786 100644 --- a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py +++ b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py @@ -1754,3 +1754,5 @@ def mock_model_dump(self, *args, **kwargs): # This call is outside the patch context so it should work normally result_dict = result.model_dump() assert isinstance(result_dict, dict) + # Extra provider field should survive into the serialized output + assert result_dict.get("sap_extra_field") == "sap-value" From 7dffee2ab8e55a8b0d81819ca9b2f11d3e5de356 Mon Sep 17 00:00:00 2001 From: Audrey Kadjar Date: Mon, 23 Mar 2026 09:09:25 +0100 Subject: [PATCH 08/13] Extract safe_model_dump helper and remove dead test code - Create safe_model_dump() helper in core_helpers.py to eliminate duplication - Replace ~20-line try-except blocks in 3 locations with single function call - Remove dead chunks.append() from test (doesn't affect test outcome) - Future fallback fixes now only need to be made in one place - Addresses P2 feedback from greptile-apps bot Changes: - Added safe_model_dump() to core_helpers.py - Updated streaming_handler.py to import and use safe_model_dump - Simplified preserve_upstream_non_openai_attributes to use helper - Removed misleading chunks.append from test --- litellm/litellm_core_utils/core_helpers.py | 36 ++++++++------- .../litellm_core_utils/streaming_handler.py | 44 ++----------------- .../test_streaming_handler.py | 3 -- 3 files changed, 24 insertions(+), 59 deletions(-) diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index 95640bc5e94..a75b1e28d87 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -263,34 +263,40 @@ def process_response_headers(response_headers: Union[httpx.Headers, dict]) -> di return additional_headers -def preserve_upstream_non_openai_attributes( - model_response: "ModelResponseStream", original_chunk: "ModelResponseStream" -): +def safe_model_dump(obj: "ModelResponseStream") -> dict: """ - Preserve non-OpenAI attributes from the original chunk. + Safely call model_dump(), falling back to __dict__ if the Pydantic + MockValSer bug (pydantic issue #7713) is encountered. """ - # Access model_fields on the class, not the instance, to avoid Pydantic 2.11+ deprecation warnings - expected_keys = set(type(model_response).model_fields.keys()).union({"usage"}) try: - obj_dict = original_chunk.model_dump() + return obj.model_dump() except TypeError as e: if "MockValSer" not in str(e): raise - # Fallback for Pydantic MockValSer bug (pydantic issue #7713) verbose_logger.warning( - "Pydantic MockValSer bug detected (pydantic issue #7713); falling back to __dict__ extraction. " + "Pydantic MockValSer bug detected (pydantic issue #7713); " + "falling back to __dict__ extraction. " "Upgrading pydantic may resolve this. Error: %s", e ) - # Merge __dict__ with __pydantic_extra__ to preserve dynamically-added provider fields - # Filter out underscore-prefixed private attributes to match model_dump() behavior - obj_dict = { + return { k: v for k, v in { - **dict(original_chunk.__dict__), - **(getattr(original_chunk, '__pydantic_extra__', None) or {}), + **dict(obj.__dict__), + **(getattr(obj, "__pydantic_extra__", None) or {}), }.items() - if not k.startswith('_') + if not k.startswith("_") } + + +def preserve_upstream_non_openai_attributes( + model_response: "ModelResponseStream", original_chunk: "ModelResponseStream" +): + """ + Preserve non-OpenAI attributes from the original chunk. + """ + # Access model_fields on the class, not the instance, to avoid Pydantic 2.11+ deprecation warnings + expected_keys = set(type(model_response).model_fields.keys()).union({"usage"}) + obj_dict = safe_model_dump(original_chunk) for key, value in obj_dict.items(): if key not in expected_keys: setattr(model_response, key, value) diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 83659916134..16f1b7f6ef3 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -46,7 +46,7 @@ ) from ..exceptions import OpenAIError -from .core_helpers import map_finish_reason, process_response_headers +from .core_helpers import map_finish_reason, process_response_headers, safe_model_dump from .exception_mapping_utils import exception_type from .llm_response_utils.get_api_base import get_api_base from .rules import Rules @@ -1875,26 +1875,7 @@ def __next__(self) -> "ModelResponseStream": # noqa: PLR0915 response, "usage" ): # remove usage from chunk, only send on final chunk # Convert the object to a dictionary - try: - obj_dict = response.model_dump() - except TypeError as e: - if "MockValSer" not in str(e): - raise - # Fallback for Pydantic MockValSer bug (pydantic issue #7713) - verbose_logger.warning( - "Pydantic MockValSer bug detected (pydantic issue #7713); falling back to __dict__ extraction. " - "Upgrading pydantic may resolve this. Error: %s", e - ) - # Merge __dict__ with __pydantic_extra__ to preserve dynamically-added provider fields - # Filter out underscore-prefixed private attributes to match model_dump() behavior - obj_dict = { - k: v - for k, v in { - **dict(response.__dict__), - **(getattr(response, '__pydantic_extra__', None) or {}), - }.items() - if not k.startswith('_') - } + obj_dict = safe_model_dump(response) # Remove an attribute (e.g., 'attr2') if "usage" in obj_dict: @@ -2082,26 +2063,7 @@ async def __anext__(self) -> "ModelResponseStream": # noqa: PLR0915 # Strip usage from the outgoing chunk so it's not sent twice # (once in the chunk, once in _hidden_params). - try: - obj_dict = processed_chunk.model_dump() - except TypeError as e: - if "MockValSer" not in str(e): - raise - # Fallback for Pydantic MockValSer bug (pydantic issue #7713) - verbose_logger.warning( - "Pydantic MockValSer bug detected (pydantic issue #7713); falling back to __dict__ extraction. " - "Upgrading pydantic may resolve this. Error: %s", e - ) - # Merge __dict__ with __pydantic_extra__ to preserve dynamically-added provider fields - # Filter out underscore-prefixed private attributes to match model_dump() behavior - obj_dict = { - k: v - for k, v in { - **dict(processed_chunk.__dict__), - **(getattr(processed_chunk, '__pydantic_extra__', None) or {}), - }.items() - if not k.startswith('_') - } + obj_dict = safe_model_dump(processed_chunk) if "usage" in obj_dict: del obj_dict["usage"] processed_chunk = self.model_response_creator( diff --git a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py index a9bee0ae3ea..737961aab10 100644 --- a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py +++ b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py @@ -1718,9 +1718,6 @@ def mock_model_dump(self, *args, **kwargs): # Use class-level patching to avoid the mock appearing in __dict__ or __pydantic_extra__ with patch.object(type(chunk_with_usage), 'model_dump', mock_model_dump): - # The code should gracefully fall back to __dict__ and not crash - initialized_custom_stream_wrapper.chunks.append(chunk_with_usage) - # Use a DIFFERENT object as the target model_response fresh_model_response = ModelResponseStream( id="fresh", From a3f5d035fb0ee7c36da4885a1ce4a070adcfa826 Mon Sep 17 00:00:00 2001 From: Audrey Kadjar Date: Mon, 23 Mar 2026 09:36:01 +0100 Subject: [PATCH 09/13] Run Black formatter to fix linting - Format core_helpers.py (split long logging line) - Format audit_logs.py (from upstream merge) - Fixes CI linting errors --- litellm/litellm_core_utils/core_helpers.py | 3 ++- litellm/proxy/management_helpers/audit_logs.py | 14 +++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index a75b1e28d87..c33ab2c7177 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -276,7 +276,8 @@ def safe_model_dump(obj: "ModelResponseStream") -> dict: verbose_logger.warning( "Pydantic MockValSer bug detected (pydantic issue #7713); " "falling back to __dict__ extraction. " - "Upgrading pydantic may resolve this. Error: %s", e + "Upgrading pydantic may resolve this. Error: %s", + e, ) return { k: v diff --git a/litellm/proxy/management_helpers/audit_logs.py b/litellm/proxy/management_helpers/audit_logs.py index b9020222f1f..7599e11bdef 100644 --- a/litellm/proxy/management_helpers/audit_logs.py +++ b/litellm/proxy/management_helpers/audit_logs.py @@ -51,7 +51,11 @@ def _build_audit_log_payload( if request_data.updated_at is not None: updated_at = request_data.updated_at.isoformat() - table_name_str: str = request_data.table_name.value if isinstance(request_data.table_name, LitellmTableNames) else str(request_data.table_name) + table_name_str: str = ( + request_data.table_name.value + if isinstance(request_data.table_name, LitellmTableNames) + else str(request_data.table_name) + ) return StandardAuditLogPayload( id=request_data.id, @@ -89,7 +93,9 @@ async def _dispatch_audit_log_to_callbacks( for callback in litellm.audit_log_callbacks: try: - resolved: Optional[CustomLogger] = callback if isinstance(callback, CustomLogger) else None + resolved: Optional[CustomLogger] = ( + callback if isinstance(callback, CustomLogger) else None + ) if isinstance(callback, str): resolved = _resolve_audit_log_callback(callback) if resolved is None: @@ -138,9 +144,7 @@ async def create_object_audit_log( return _changed_by = ( - litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name + litellm_changed_by or user_api_key_dict.user_id or litellm_proxy_admin_name ) await create_audit_log_for_update( From 17e4e6a448d18e31d832333f64447a0bcc60c34c Mon Sep 17 00:00:00 2001 From: Audrey Kadjar Date: Mon, 23 Mar 2026 09:57:52 +0100 Subject: [PATCH 10/13] Address P1 PR review comments: Use subclass mocking and document type divergence - Replace instance-level model_dump mocking with dedicated subclass approach to isolate the bug without polluting shared ModelResponseStream state - Add detailed docstring to safe_model_dump() explaining the type divergence between normal path (recursive dict conversion) and fallback path (nested Pydantic instances) - Verify fresh_model_response.model_dump() still works correctly during test - Apply Black formatting to all modified files Addresses PR review comments from greptile-apps bot: - P1: Class-level patch silently breaks fresh_model_response.model_dump() - P1: Document that the fallback returns nested Pydantic instances Test verification: - All 51 tests passing - Black formatting satisfied --- litellm/litellm_core_utils/core_helpers.py | 11 ++ .../litellm_core_utils/streaming_handler.py | 12 +- .../test_streaming_handler.py | 109 ++++++++++++------ 3 files changed, 92 insertions(+), 40 deletions(-) diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index c33ab2c7177..8d3a7d2c4af 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -267,6 +267,17 @@ def safe_model_dump(obj: "ModelResponseStream") -> dict: """ Safely call model_dump(), falling back to __dict__ if the Pydantic MockValSer bug (pydantic issue #7713) is encountered. + + Returns: + A dictionary representation of the model. Note that the normal path + (model_dump()) recursively converts nested Pydantic models to plain + dicts/primitives, while the fallback path (__dict__) returns nested + Pydantic model instances (e.g., StreamingChoices, Delta) as-is. + + Pydantic v2 initialization handles both representations when passed to + ModelResponseStream(**dict), so this divergence is typically transparent. + However, code that type-checks or iterates dict values may observe + different types on the fallback path. """ try: return obj.model_dump() diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 16f1b7f6ef3..561b7e51e79 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -127,9 +127,9 @@ def __init__( self.system_fingerprint: Optional[str] = None self.received_finish_reason: Optional[str] = None - self.intermittent_finish_reason: Optional[ - str - ] = None # finish reasons that show up mid-stream + self.intermittent_finish_reason: Optional[str] = ( + None # finish reasons that show up mid-stream + ) self.special_tokens = [ "<|assistant|>", "<|system|>", @@ -1516,9 +1516,9 @@ def chunk_creator(self, chunk: Any): # type: ignore # noqa: PLR0915 t.function.arguments = "" _json_delta = delta.model_dump() if "role" not in _json_delta or _json_delta["role"] is None: - _json_delta[ - "role" - ] = "assistant" # mistral's api returns role as None + _json_delta["role"] = ( + "assistant" # mistral's api returns role as None + ) if "tool_calls" in _json_delta and isinstance( _json_delta["tool_calls"], list ): diff --git a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py index 737961aab10..97643bfd082 100644 --- a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py +++ b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py @@ -770,7 +770,9 @@ async def test_vertex_streaming_bad_request_not_midstream(logging_obj: Logging): from litellm.llms.vertex_ai.common_utils import VertexAIError async def _raise_bad_request(**kwargs): - raise VertexAIError(status_code=400, message="invalid maxOutputTokens", headers=None) + raise VertexAIError( + status_code=400, message="invalid maxOutputTokens", headers=None + ) response = CustomStreamWrapper( completion_stream=None, @@ -788,7 +790,9 @@ async def _raise_bad_request(**kwargs): @pytest.mark.asyncio -async def test_vertex_streaming_rate_limit_triggers_midstream_fallback(logging_obj: Logging): +async def test_vertex_streaming_rate_limit_triggers_midstream_fallback( + logging_obj: Logging, +): """Ensure Vertex 429 rate-limit errors raise MidStreamFallbackError, not RateLimitError. Regression test for https://github.com/BerriAI/litellm/issues/20870 @@ -797,7 +801,9 @@ async def test_vertex_streaming_rate_limit_triggers_midstream_fallback(logging_o from litellm.llms.vertex_ai.common_utils import VertexAIError async def _raise_rate_limit(**kwargs): - raise VertexAIError(status_code=429, message="Resource exhausted.", headers=None) + raise VertexAIError( + status_code=429, message="Resource exhausted.", headers=None + ) response = CustomStreamWrapper( completion_stream=None, @@ -825,7 +831,9 @@ def test_sync_streaming_rate_limit_triggers_midstream_fallback(logging_obj: Logg from litellm.llms.vertex_ai.common_utils import VertexAIError def _raise_rate_limit(**kwargs): - raise VertexAIError(status_code=429, message="Resource exhausted.", headers=None) + raise VertexAIError( + status_code=429, message="Resource exhausted.", headers=None + ) response = CustomStreamWrapper( completion_stream=None, @@ -850,7 +858,9 @@ def test_sync_streaming_bad_request_not_midstream(logging_obj: Logging): from litellm.llms.vertex_ai.common_utils import VertexAIError def _raise_bad_request(**kwargs): - raise VertexAIError(status_code=400, message="invalid maxOutputTokens", headers=None) + raise VertexAIError( + status_code=400, message="invalid maxOutputTokens", headers=None + ) response = CustomStreamWrapper( completion_stream=None, @@ -1363,6 +1373,7 @@ def _build_chunks(pattern: list[str], N: int) -> list[ModelResponseStream]: chunks.append(_make_chunk(p)) return chunks + _REPETITION_TEST_CASES = [ # Basic cases pytest.param( @@ -1419,7 +1430,14 @@ def _build_chunks(pattern: list[str], N: int) -> list[ModelResponseStream]: id="last_chunk_different_no_raise", ), pytest.param( - ["same"] * (litellm.REPEATED_STREAMING_CHUNK_LIMIT // 2 + 1) + ["different_mid"] + ["same"] * (litellm.REPEATED_STREAMING_CHUNK_LIMIT - litellm.REPEATED_STREAMING_CHUNK_LIMIT // 2 + 1), + ["same"] * (litellm.REPEATED_STREAMING_CHUNK_LIMIT // 2 + 1) + + ["different_mid"] + + ["same"] + * ( + litellm.REPEATED_STREAMING_CHUNK_LIMIT + - litellm.REPEATED_STREAMING_CHUNK_LIMIT // 2 + + 1 + ), False, id="middle_chunk_different_no_raise", ), @@ -1429,7 +1447,9 @@ def _build_chunks(pattern: list[str], N: int) -> list[ModelResponseStream]: id="last_two_different_no_raise", ), pytest.param( - ["diff"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT + ["same"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT + ["diff"], + ["diff"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT + + ["same"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT + + ["diff"], True, id="in_between_same_and_diff_raise", ), @@ -1455,6 +1475,8 @@ def test_raise_on_model_repetition( for chunk in chunks: wrapper.chunks.append(chunk) wrapper.raise_on_model_repetition() + + def test_usage_chunk_after_finish_reason_updates_hidden_params(logging_obj): """ Test that provider-reported usage from a post-finish_reason chunk @@ -1536,12 +1558,13 @@ def test_usage_chunk_after_finish_reason_updates_hidden_params(logging_obj): last_chunk = collected[-1] hidden_usage = last_chunk._hidden_params.get("usage") assert hidden_usage is not None, "Expected usage in _hidden_params" - assert hidden_usage.prompt_tokens == 20, ( - f"Expected prompt_tokens=20 from provider, got {hidden_usage.prompt_tokens}" - ) - assert hidden_usage.completion_tokens == 135, ( - f"Expected completion_tokens=135 from provider, got {hidden_usage.completion_tokens}" - ) + assert ( + hidden_usage.prompt_tokens == 20 + ), f"Expected prompt_tokens=20 from provider, got {hidden_usage.prompt_tokens}" + assert ( + hidden_usage.completion_tokens == 135 + ), f"Expected completion_tokens=135 from provider, got {hidden_usage.completion_tokens}" + @pytest.mark.asyncio async def test_custom_stream_wrapper_aclose(): @@ -1615,9 +1638,9 @@ def test_content_not_dropped_when_finish_reason_already_set( result = initialized_custom_stream_wrapper.chunk_creator(chunk=content_chunk) - assert result is not None, ( - "chunk_creator() returned None — content was dropped (issue #22098)" - ) + assert ( + result is not None + ), "chunk_creator() returned None — content was dropped (issue #22098)" assert result.choices[0].delta.content == "world!" @@ -1669,14 +1692,14 @@ def test_tool_use_not_dropped_when_finish_reason_already_set( result = initialized_custom_stream_wrapper.chunk_creator(chunk=tool_chunk) - assert result is not None, ( - "chunk_creator() returned None — tool_use data was dropped" - ) + assert ( + result is not None + ), "chunk_creator() returned None — tool_use data was dropped" tool_calls = result.choices[0].delta.tool_calls - assert tool_calls is not None and len(tool_calls) > 0, ( - "tool_calls should contain at least one tool call" - ) + assert ( + tool_calls is not None and len(tool_calls) > 0 + ), "tool_calls should contain at least one tool call" assert tool_calls[0].id == "call_1" assert tool_calls[0].function.name == "get_weather" @@ -1691,10 +1714,17 @@ def test_model_dump_fallback_handles_pydantic_serializer_bug( SchemaSerializer in certain scenarios. The fix catches TypeError and falls back to __dict__ extraction. """ - from unittest.mock import patch # Create a chunk with usage that will be stripped - chunk_with_usage = ModelResponseStream( + # Create a subclass with corrupted model_dump to isolate the bug to a single instance + class CorruptedModelResponseStream(ModelResponseStream): + def model_dump(self, **kwargs): + raise TypeError( + "'MockValSer' object cannot be converted to 'SchemaSerializer'" + ) + + # Create an instance of the corrupted subclass + corrupted_chunk = CorruptedModelResponseStream( id="test-chunk", created=1742056047, model="sap-ai-core/test-model", @@ -1710,14 +1740,9 @@ def test_model_dump_fallback_handles_pydantic_serializer_bug( ) # Add a provider-specific field to test __pydantic_extra__ preservation - chunk_with_usage.sap_extra_field = "sap-value" - - # Mock model_dump at class level to avoid polluting instance attributes - def mock_model_dump(self, *args, **kwargs): - raise TypeError("'MockValSer' object cannot be converted to 'SchemaSerializer'") + corrupted_chunk.sap_extra_field = "sap-value" - # Use class-level patching to avoid the mock appearing in __dict__ or __pydantic_extra__ - with patch.object(type(chunk_with_usage), 'model_dump', mock_model_dump): + try: # Use a DIFFERENT object as the target model_response fresh_model_response = ModelResponseStream( id="fresh", @@ -1733,11 +1758,17 @@ def mock_model_dump(self, *args, **kwargs): ], ) + # Verify fresh_model_response.model_dump() still works normally + # (proves instance-level mock doesn't affect other instances) + assert callable(fresh_model_response.model_dump) + test_dump = fresh_model_response.model_dump() + assert isinstance(test_dump, dict) + # Process the chunk through return_processed_chunk_logic which calls model_dump result = initialized_custom_stream_wrapper.return_processed_chunk_logic( completion_obj={"content": "test content"}, - response_obj={"original_chunk": chunk_with_usage}, # source of extra attrs - model_response=fresh_model_response, # target (different object) + response_obj={"original_chunk": corrupted_chunk}, # source of extra attrs + model_response=fresh_model_response, # target (different object) ) # Should not raise TypeError and should successfully process the chunk @@ -1747,13 +1778,23 @@ def mock_model_dump(self, *args, **kwargs): # Now this assertion is meaningful: it verifies the attribute was actually COPIED assert getattr(result, "sap_extra_field", None) == "sap-value" + finally: + # No cleanup needed - subclass approach doesn't mutate shared state + pass + # Verify that result can still be serialized (model_dump method is not corrupted) - # This call is outside the patch context so it should work normally result_dict = result.model_dump() assert isinstance(result_dict, dict) # Extra provider field should survive into the serialized output assert result_dict.get("sap_extra_field") == "sap-value" + # Verify nested structures are accessible (validates fallback behavior) + # The fallback returns nested Pydantic instances instead of plain dicts, + # but they should still be accessible and functional + assert result.choices is not None + assert len(result.choices) > 0 + assert result.choices[0].delta.content == "test content" + @pytest.mark.asyncio async def test_custom_stream_wrapper_anext_does_not_block_event_loop_for_sync_iterators( From 580a0b3793c7ca1ab5d3d5b7c5f917b84a09b4f6 Mon Sep 17 00:00:00 2001 From: Audrey Kadjar Date: Mon, 23 Mar 2026 10:20:13 +0100 Subject: [PATCH 11/13] Fix Black formatting: Split long import into multiple lines --- litellm/litellm_core_utils/streaming_handler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 561b7e51e79..8ff5b44898d 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -46,7 +46,11 @@ ) from ..exceptions import OpenAIError -from .core_helpers import map_finish_reason, process_response_headers, safe_model_dump +from .core_helpers import ( + map_finish_reason, + process_response_headers, + safe_model_dump, +) from .exception_mapping_utils import exception_type from .llm_response_utils.get_api_base import get_api_base from .rules import Rules From 4c51af0f14ae77495529f0f0c733ceb63eb4ba7a Mon Sep 17 00:00:00 2001 From: Audrey Kadjar Date: Mon, 23 Mar 2026 10:24:27 +0100 Subject: [PATCH 12/13] Address P2 comment: Remove unnecessary try/finally block Since the subclass approach doesn't require any cleanup, the try/finally block was a no-op. Removing it makes the test structure cleaner and avoids potential UnboundLocalError if an exception occurs before result is assigned. --- .../test_streaming_handler.py | 65 +++++++++---------- 1 file changed, 30 insertions(+), 35 deletions(-) diff --git a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py index 97643bfd082..2f9cd7ae43d 100644 --- a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py +++ b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py @@ -1742,45 +1742,40 @@ def model_dump(self, **kwargs): # Add a provider-specific field to test __pydantic_extra__ preservation corrupted_chunk.sap_extra_field = "sap-value" - try: - # Use a DIFFERENT object as the target model_response - fresh_model_response = ModelResponseStream( - id="fresh", - created=1742056047, - model="sap-ai-core/test-model", - object="chat.completion.chunk", - choices=[ - StreamingChoices( - finish_reason=None, - index=0, - delta=Delta(content="test content", role="assistant"), - ) - ], - ) + # Use a DIFFERENT object as the target model_response + fresh_model_response = ModelResponseStream( + id="fresh", + created=1742056047, + model="sap-ai-core/test-model", + object="chat.completion.chunk", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta(content="test content", role="assistant"), + ) + ], + ) - # Verify fresh_model_response.model_dump() still works normally - # (proves instance-level mock doesn't affect other instances) - assert callable(fresh_model_response.model_dump) - test_dump = fresh_model_response.model_dump() - assert isinstance(test_dump, dict) - - # Process the chunk through return_processed_chunk_logic which calls model_dump - result = initialized_custom_stream_wrapper.return_processed_chunk_logic( - completion_obj={"content": "test content"}, - response_obj={"original_chunk": corrupted_chunk}, # source of extra attrs - model_response=fresh_model_response, # target (different object) - ) + # Verify fresh_model_response.model_dump() still works normally + # (proves subclass approach doesn't affect other instances) + assert callable(fresh_model_response.model_dump) + test_dump = fresh_model_response.model_dump() + assert isinstance(test_dump, dict) - # Should not raise TypeError and should successfully process the chunk - assert result is not None - assert result.choices[0].delta.content == "test content" + # Process the chunk through return_processed_chunk_logic which calls model_dump + result = initialized_custom_stream_wrapper.return_processed_chunk_logic( + completion_obj={"content": "test content"}, + response_obj={"original_chunk": corrupted_chunk}, # source of extra attrs + model_response=fresh_model_response, # target (different object) + ) - # Now this assertion is meaningful: it verifies the attribute was actually COPIED - assert getattr(result, "sap_extra_field", None) == "sap-value" + # Should not raise TypeError and should successfully process the chunk + assert result is not None + assert result.choices[0].delta.content == "test content" - finally: - # No cleanup needed - subclass approach doesn't mutate shared state - pass + # Now this assertion is meaningful: it verifies the attribute was actually COPIED + assert getattr(result, "sap_extra_field", None) == "sap-value" # Verify that result can still be serialized (model_dump method is not corrupted) result_dict = result.model_dump() From 417458b4e3416272c4c9b1a13de70ecd4e42355a Mon Sep 17 00:00:00 2001 From: Audrey Kadjar Date: Mon, 23 Mar 2026 10:33:01 +0100 Subject: [PATCH 13/13] Run Black formatter on all litellm files to fix CI linting Applied Black formatting to all Python files in the litellm directory that were merged from upstream and not properly formatted. --- litellm/__init__.py | 120 ++++---- litellm/_lazy_imports.py | 1 + litellm/_uuid.py | 1 - litellm/a2a_protocol/main.py | 6 +- litellm/a2a_protocol/streaming_iterator.py | 6 +- litellm/anthropic_interface/__init__.py | 1 + .../exceptions/exception_mapping_utils.py | 1 - .../exceptions/exceptions.py | 1 - litellm/caching/caching_handler.py | 10 +- litellm/caching/gcs_cache.py | 1 + .../handler.py | 4 +- .../transformation.py | 20 +- litellm/containers/endpoint_factory.py | 8 +- litellm/containers/main.py | 78 +++-- litellm/cost_calculator.py | 6 +- litellm/evals/main.py | 88 +++--- litellm/exceptions.py | 2 +- litellm/google_genai/adapters/__init__.py | 4 +- litellm/images/main.py | 39 +-- .../SlackAlerting/batching_handler.py | 6 +- .../SlackAlerting/hanging_request_check.py | 8 +- .../SlackAlerting/slack_alerting.py | 18 +- litellm/integrations/SlackAlerting/utils.py | 2 +- .../integrations/additional_logging_utils.py | 2 +- litellm/integrations/agentops/agentops.py | 1 + .../anthropic_cache_control_hook.py | 8 +- litellm/integrations/arize/arize_phoenix.py | 6 +- .../azure_storage/azure_storage.py | 12 +- litellm/integrations/braintrust_logging.py | 6 +- litellm/integrations/cloudzero/cloudzero.py | 8 +- litellm/integrations/cloudzero/transform.py | 12 +- litellm/integrations/custom_batch_logger.py | 2 +- litellm/integrations/custom_logger.py | 6 +- .../integrations/datadog/datadog_llm_obs.py | 24 +- litellm/integrations/focus/transformer.py | 1 - .../gcs_bucket/gcs_bucket_base.py | 6 +- litellm/integrations/humanloop.py | 6 +- litellm/integrations/langfuse/langfuse.py | 6 +- .../integrations/langfuse/langfuse_handler.py | 4 +- .../langfuse/langfuse_prompt_management.py | 6 +- litellm/integrations/langsmith.py | 18 +- litellm/integrations/mock_client_factory.py | 6 +- litellm/integrations/opentelemetry.py | 6 +- litellm/integrations/opik/utils.py | 2 +- litellm/integrations/posthog.py | 6 +- litellm/integrations/prometheus.py | 34 ++- litellm/integrations/s3_v2.py | 12 +- .../integrations/vantage/vantage_logger.py | 16 +- .../vector_store_pre_call_hook.py | 36 +-- .../websearch_interception/transformation.py | 9 +- litellm/integrations/weights_biases.py | 3 +- litellm/interactions/__init__.py | 12 +- .../transformation.py | 8 +- litellm/interactions/main.py | 10 +- .../litellm_core_utils/default_encoding.py | 6 +- litellm/litellm_core_utils/litellm_logging.py | 284 +++++++++--------- .../convert_dict_to_response.py | 12 +- .../litellm_core_utils/model_param_helper.py | 6 +- .../prompt_templates/factory.py | 52 ++-- .../litellm_core_utils/realtime_streaming.py | 12 +- litellm/litellm_core_utils/redact_messages.py | 6 +- litellm/litellm_core_utils/safe_json_loads.py | 1 + .../specialty_caches/dynamic_logging_cache.py | 5 +- .../streaming_chunk_builder_utils.py | 18 +- litellm/llms/a2a/__init__.py | 1 + litellm/llms/a2a/chat/__init__.py | 1 + litellm/llms/a2a/chat/streaming_iterator.py | 1 + litellm/llms/a2a/chat/transformation.py | 1 + litellm/llms/a2a/common_utils.py | 1 + .../llms/amazon_nova/chat/transformation.py | 1 + .../llms/anthropic/batches/transformation.py | 12 +- .../chat/guardrail_translation/handler.py | 6 +- litellm/llms/anthropic/chat/handler.py | 40 ++- litellm/llms/anthropic/chat/transformation.py | 38 ++- litellm/llms/anthropic/common_utils.py | 6 +- .../anthropic/completion/transformation.py | 6 +- .../adapters/streaming_iterator.py | 12 +- .../adapters/transformation.py | 32 +- .../messages/transformation.py | 6 +- .../responses_adapters/streaming_iterator.py | 6 +- .../responses_adapters/transformation.py | 10 +- .../llms/anthropic/files/transformation.py | 6 +- .../azure/chat/o_series_transformation.py | 8 +- litellm/llms/azure/fine_tuning/handler.py | 9 +- litellm/llms/azure_ai/anthropic/__init__.py | 1 + litellm/llms/azure_ai/anthropic/handler.py | 1 + .../anthropic/messages_transformation.py | 1 + .../llms/azure_ai/anthropic/transformation.py | 1 + .../azure_ai/azure_model_router/__init__.py | 1 + .../azure_model_router/transformation.py | 1 + .../azure_ai/embed/cohere_transformation.py | 2 +- litellm/llms/azure_ai/ocr/__init__.py | 1 + .../ocr/document_intelligence/__init__.py | 1 + .../document_intelligence/transformation.py | 1 + litellm/llms/azure_ai/ocr/transformation.py | 1 + .../llms/azure_ai/rerank/transformation.py | 2 +- litellm/llms/base_llm/ocr/__init__.py | 1 + litellm/llms/base_llm/ocr/transformation.py | 1 + litellm/llms/base_llm/search/__init__.py | 1 + .../llms/base_llm/search/transformation.py | 1 + .../vector_store_files/transformation.py | 42 +-- litellm/llms/bedrock/batches/handler.py | 8 +- .../bedrock/chat/agentcore/transformation.py | 6 +- litellm/llms/bedrock/chat/converse_handler.py | 6 +- .../bedrock/chat/converse_transformation.py | 34 +-- litellm/llms/bedrock/chat/invoke_handler.py | 36 ++- .../base_invoke_transformation.py | 6 +- litellm/llms/bedrock/common_utils.py | 8 +- .../embed/amazon_titan_g1_transformation.py | 2 +- .../amazon_titan_multimodal_transformation.py | 6 +- .../bedrock/embed/cohere_transformation.py | 2 +- .../amazon_nova_canvas_transformation.py | 6 +- .../anthropic_claude3_transformation.py | 6 +- litellm/llms/bedrock/realtime/handler.py | 4 +- .../llms/bedrock/realtime/transformation.py | 6 +- .../bedrock_mantle/chat/transformation.py | 1 - litellm/llms/chatgpt/chat/streaming_utils.py | 6 +- litellm/llms/chatgpt/common_utils.py | 1 + .../llms/chatgpt/responses/transformation.py | 6 +- litellm/llms/cohere/embed/handler.py | 2 +- .../llms/custom_httpx/async_client_cleanup.py | 1 + litellm/llms/custom_httpx/http_handler.py | 6 +- litellm/llms/custom_httpx/llm_http_handler.py | 27 +- litellm/llms/custom_httpx/mock_transport.py | 1 - litellm/llms/dashscope/chat/transformation.py | 6 +- litellm/llms/dashscope/cost_calculator.py | 2 +- .../llms/databricks/chat/transformation.py | 6 +- litellm/llms/databricks/common_utils.py | 6 +- .../llms/databricks/embed/transformation.py | 6 +- .../llms/dataforseo/search/transformation.py | 1 + litellm/llms/datarobot/chat/transformation.py | 2 +- litellm/llms/deepinfra/chat/transformation.py | 6 +- .../llms/deepinfra/rerank/transformation.py | 2 +- litellm/llms/deepseek/chat/transformation.py | 6 +- litellm/llms/deepseek/cost_calculator.py | 2 +- .../llms/deprecated_providers/aleph_alpha.py | 6 +- .../chat/transformation.py | 6 +- litellm/llms/duckduckgo/search/__init__.py | 1 + .../llms/duckduckgo/search/transformation.py | 1 + .../text_to_speech/transformation.py | 1 - litellm/llms/exa_ai/search/__init__.py | 1 + litellm/llms/exa_ai/search/transformation.py | 1 + litellm/llms/firecrawl/__init__.py | 1 + litellm/llms/firecrawl/search/__init__.py | 1 + .../llms/firecrawl/search/transformation.py | 1 + .../llms/fireworks_ai/chat/transformation.py | 10 +- litellm/llms/gemini/files/transformation.py | 9 +- .../llms/gemini/image_edit/transformation.py | 6 +- .../gemini/image_generation/transformation.py | 10 +- .../llms/gemini/realtime/transformation.py | 22 +- litellm/llms/github_copilot/common_utils.py | 1 + .../embedding/transformation.py | 1 + .../responses/transformation.py | 1 + litellm/llms/google_pse/search/__init__.py | 1 + .../llms/google_pse/search/transformation.py | 9 +- litellm/llms/groq/chat/transformation.py | 15 +- litellm/llms/heroku/chat/transformation.py | 7 +- .../llms/hosted_vllm/chat/transformation.py | 17 +- .../huggingface/embedding/transformation.py | 24 +- .../llms/infinity/rerank/transformation.py | 2 +- litellm/llms/jina_ai/rerank/transformation.py | 2 +- litellm/llms/lemonade/chat/transformation.py | 1 + litellm/llms/lemonade/cost_calculator.py | 1 + litellm/llms/linkup/__init__.py | 1 + litellm/llms/linkup/search/__init__.py | 1 + litellm/llms/linkup/search/transformation.py | 1 + .../llms/lm_studio/embed/transformation.py | 2 +- litellm/llms/minimax/chat/transformation.py | 1 + .../llms/minimax/messages/transformation.py | 1 + litellm/llms/mistral/chat/transformation.py | 6 +- litellm/llms/mistral/ocr/transformation.py | 1 + litellm/llms/moonshot/chat/transformation.py | 6 +- litellm/llms/novita/chat/transformation.py | 2 +- .../llms/nvidia_nim/chat/transformation.py | 5 +- litellm/llms/nvidia_nim/embed.py | 2 +- litellm/llms/oci/chat/transformation.py | 4 +- .../llms/ollama/completion/transformation.py | 6 +- .../llms/openai/chat/gpt_transformation.py | 8 +- .../chat/guardrail_translation/handler.py | 6 +- .../openai/chat/o_series_transformation.py | 18 +- litellm/llms/openai/common_utils.py | 8 +- .../llms/openai/completion/transformation.py | 6 +- litellm/llms/openai/fine_tuning/handler.py | 9 +- litellm/llms/openai/openai.py | 6 +- .../transcriptions/whisper_transformation.py | 6 +- litellm/llms/openai_like/dynamic_config.py | 6 +- .../llms/openrouter/chat/transformation.py | 6 +- .../openrouter/embedding/transformation.py | 1 + .../openrouter/image_edit/transformation.py | 6 +- .../image_generation/transformation.py | 1 - litellm/llms/ovhcloud/chat/transformation.py | 1 + .../llms/ovhcloud/embedding/transformation.py | 1 + litellm/llms/parallel_ai/search/__init__.py | 1 + .../llms/parallel_ai/search/transformation.py | 1 + .../llms/perplexity/search/transformation.py | 1 + .../llms/petals/completion/transformation.py | 6 +- litellm/llms/predibase/chat/transformation.py | 12 +- .../llms/runwayml/text_to_speech/__init__.py | 1 + .../runwayml/text_to_speech/transformation.py | 1 + .../sagemaker/completion/transformation.py | 8 +- .../sagemaker/embedding/transformation.py | 2 +- litellm/llms/sambanova/chat.py | 6 +- .../sambanova/embedding/transformation.py | 1 + litellm/llms/sap/chat/transformation.py | 1 + litellm/llms/searchapi/search/__init__.py | 1 + .../llms/searchapi/search/transformation.py | 1 + litellm/llms/searxng/__init__.py | 1 + litellm/llms/searxng/search/__init__.py | 1 + litellm/llms/searxng/search/transformation.py | 1 + litellm/llms/serper/search/__init__.py | 1 + litellm/llms/serper/search/transformation.py | 1 + litellm/llms/snowflake/chat/transformation.py | 1 - .../image_generation/transformation.py | 6 +- litellm/llms/tavily/search/__init__.py | 1 + litellm/llms/tavily/search/transformation.py | 5 +- litellm/llms/together_ai/chat.py | 2 +- litellm/llms/together_ai/embed.py | 2 +- .../llms/together_ai/rerank/transformation.py | 2 +- .../vercel_ai_gateway/chat/transformation.py | 6 +- litellm/llms/vertex_ai/batches/handler.py | 6 +- .../context_caching/transformation.py | 4 +- .../vertex_ai_context_caching.py | 6 +- litellm/llms/vertex_ai/fine_tuning/handler.py | 12 +- .../llms/vertex_ai/gemini/transformation.py | 7 +- .../vertex_and_google_ai_studio_gemini.py | 76 ++--- .../batch_embed_content_transformation.py | 2 +- .../vertex_gemini_transformation.py | 10 +- litellm/llms/vertex_ai/ocr/__init__.py | 1 + .../vertex_ai/ocr/deepseek_transformation.py | 9 +- litellm/llms/vertex_ai/ocr/transformation.py | 1 + .../text_to_speech/text_to_speech_handler.py | 2 +- .../count_tokens/handler.py | 1 + .../vertex_embeddings/embedding_handler.py | 12 +- .../llms/vllm/completion/transformation.py | 2 +- .../embedding/transformation_contextual.py | 5 +- litellm/main.py | 48 +-- litellm/ocr/__init__.py | 1 + litellm/ocr/main.py | 11 +- litellm/proxy/_experimental/mcp_server/db.py | 24 +- .../mcp_server/discoverable_endpoints.py | 20 +- .../mcp_server/mcp_server_manager.py | 12 +- .../mcp_server/semantic_tool_filter.py | 1 + .../proxy/_experimental/mcp_server/server.py | 33 +- .../proxy/_experimental/mcp_server/utils.py | 1 + litellm/proxy/_types.py | 96 +++--- .../proxy/agent_endpoints/agent_registry.py | 42 +-- litellm/proxy/agent_endpoints/endpoints.py | 19 +- .../agent_endpoints/model_list_helpers.py | 1 + litellm/proxy/auth/auth_checks.py | 7 +- litellm/proxy/auth/user_api_key_auth.py | 12 +- litellm/proxy/client/cli/commands/auth.py | 2 +- litellm/proxy/client/cli/commands/models.py | 8 +- litellm/proxy/client/cli/main.py | 16 +- litellm/proxy/common_request_processing.py | 24 +- .../proxy/common_utils/cache_coordinator.py | 6 +- litellm/proxy/common_utils/callback_utils.py | 18 +- .../proxy/common_utils/custom_openapi_spec.py | 6 +- litellm/proxy/common_utils/debug_utils.py | 44 +-- .../proxy/common_utils/http_parsing_utils.py | 10 +- .../common_utils/openai_endpoint_utils.py | 2 +- .../pass_through_endpoints.py | 2 +- litellm/proxy/db/create_views.py | 6 +- litellm/proxy/db/db_spend_update_writer.py | 14 +- .../db_transaction_queue/base_update_queue.py | 1 + .../daily_spend_update_queue.py | 12 +- .../redis_update_buffer.py | 6 +- .../spend_update_queue.py | 6 +- .../proxy/fine_tuning_endpoints/endpoints.py | 12 +- .../proxy/guardrails/guardrail_endpoints.py | 24 +- .../guardrail_hooks/akto/__init__.py | 1 - .../guardrails/guardrail_hooks/akto/akto.py | 6 +- .../guardrail_hooks/bedrock_guardrails.py | 12 +- .../block_code_execution.py | 6 +- .../generic_guardrail_api.py | 6 +- .../guardrails/guardrail_hooks/lakera_ai.py | 6 +- .../guardrail_hooks/lakera_ai_v2.py | 6 +- .../competitor_intent/airline.py | 6 +- .../litellm_content_filter/content_filter.py | 12 +- .../model_armor/model_armor.py | 4 +- .../panw_prisma_airs/panw_prisma_airs.py | 32 +- .../guardrails/guardrail_hooks/presidio.py | 12 +- .../proxy/guardrails/tool_name_extraction.py | 12 +- litellm/proxy/hooks/batch_rate_limiter.py | 4 +- litellm/proxy/hooks/dynamic_rate_limiter.py | 26 +- .../proxy/hooks/dynamic_rate_limiter_v3.py | 12 +- .../proxy/hooks/key_management_event_hooks.py | 8 +- .../proxy/hooks/litellm_skills/__init__.py | 2 +- litellm/proxy/hooks/litellm_skills/main.py | 8 +- .../hooks/mcp_semantic_filter/__init__.py | 1 + .../proxy/hooks/mcp_semantic_filter/hook.py | 1 + .../proxy/hooks/parallel_request_limiter.py | 32 +- .../hooks/parallel_request_limiter_v3.py | 6 +- .../proxy/hooks/proxy_track_cost_callback.py | 10 +- litellm/proxy/litellm_pre_call_utils.py | 28 +- .../budget_management_endpoints.py | 4 +- .../callback_management_endpoints.py | 1 + .../common_daily_activity.py | 120 ++++---- .../cost_tracking_settings.py | 16 +- .../customer_endpoints.py | 10 +- .../fallback_management_endpoints.py | 1 + .../internal_user_endpoints.py | 26 +- .../key_management_endpoints.py | 38 +-- .../mcp_management_endpoints.py | 18 +- .../organization_endpoints.py | 46 +-- .../management_endpoints/project_endpoints.py | 14 +- .../management_endpoints/scim/scim_v2.py | 14 +- .../sso/custom_microsoft_sso.py | 2 +- .../management_endpoints/team_endpoints.py | 56 ++-- litellm/proxy/management_endpoints/ui_sso.py | 36 +-- .../user_agent_analytics_endpoints.py | 8 +- .../object_permission_utils.py | 8 +- .../middleware/prometheus_auth_middleware.py | 1 + .../openai_files_endpoints/files_endpoints.py | 8 +- .../anthropic_passthrough_logging_handler.py | 6 +- .../cursor_passthrough_logging_handler.py | 1 - .../openai_passthrough_logging_handler.py | 24 +- .../pass_through_endpoints.py | 26 +- .../passthrough_endpoint_router.py | 14 +- .../pass_through_endpoints/success_handler.py | 18 +- .../policy_engine/attachment_registry.py | 14 +- litellm/proxy/policy_engine/init_policies.py | 6 +- litellm/proxy/prompts/prompt_registry.py | 6 +- litellm/proxy/proxy_cli.py | 11 +- litellm/proxy/proxy_server.py | 62 ++-- litellm/proxy/rag_endpoints/endpoints.py | 6 +- litellm/proxy/response_polling/__init__.py | 1 + .../response_polling/background_streaming.py | 13 +- .../proxy/response_polling/polling_handler.py | 1 + .../search_tool_management.py | 7 +- .../search_endpoints/search_tool_registry.py | 1 + .../spend_tracking/cold_storage_handler.py | 15 +- .../spend_management_endpoints.py | 6 +- .../spend_tracking/spend_tracking_utils.py | 6 +- litellm/proxy/utils.py | 32 +- .../proxy/vector_store_endpoints/endpoints.py | 8 +- .../vertex_ai_endpoints/langfuse_endpoints.py | 10 +- .../handler.py | 38 ++- .../session_handler.py | 8 +- .../transformation.py | 6 +- litellm/responses/main.py | 50 +-- .../mcp/litellm_proxy_mcp_handler.py | 12 +- .../responses/mcp/mcp_streaming_iterator.py | 18 +- litellm/responses/streaming_iterator.py | 10 +- litellm/responses/utils.py | 8 +- litellm/router.py | 97 +++--- .../auto_router/auto_router.py | 1 + .../router_strategy/base_routing_strategy.py | 6 +- litellm/router_strategy/budget_limiter.py | 14 +- .../complexity_router/complexity_router.py | 1 + .../evals/eval_complexity_router.py | 17 +- litellm/router_utils/common_utils.py | 6 +- litellm/router_utils/cooldown_callbacks.py | 6 +- litellm/router_utils/get_retry_from_policy.py | 2 +- .../router_utils/pattern_match_deployments.py | 14 +- .../track_deployment_metrics.py | 2 +- litellm/router_utils/search_api_router.py | 6 +- litellm/search/__init__.py | 1 + litellm/search/cost_calculator.py | 1 + litellm/search/main.py | 9 +- litellm/secret_managers/aws_secret_manager.py | 2 +- .../secret_managers/aws_secret_manager_v2.py | 2 +- .../secret_managers/secret_manager_handler.py | 9 +- litellm/setup_wizard.py | 6 +- litellm/skills/main.py | 32 +- litellm/types/integrations/datadog_llm_obs.py | 1 + litellm/types/llms/bedrock.py | 4 +- litellm/types/llms/openai.py | 36 +-- .../cache_settings_endpoints.py | 12 +- .../router_settings_endpoints.py | 6 +- .../types/mcp_server/mcp_server_manager.py | 18 +- .../guardrail_hooks/generic_guardrail_api.py | 6 +- .../openai/openai_moderation.py | 10 +- .../guardrails/guardrail_hooks/pillar.py | 1 + .../internal_user_endpoints.py | 12 +- .../model_management_endpoints.py | 24 +- litellm/types/rerank.py | 6 +- litellm/types/router.py | 22 +- litellm/types/search.py | 1 + litellm/types/videos/utils.py | 1 + litellm/utils.py | 44 +-- litellm/vector_store_files/utils.py | 4 +- .../vector_stores/vector_store_registry.py | 6 +- litellm/videos/main.py | 115 ++++--- 383 files changed, 2283 insertions(+), 2095 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index e45d926e8db..009d7ffdd25 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -167,12 +167,12 @@ require_auth_for_metrics_endpoint: Optional[bool] = False argilla_batch_size: Optional[int] = None datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload. -gcs_pub_sub_use_v1: Optional[ - bool -] = False # if you want to use v1 gcs pubsub logged payload -generic_api_use_v1: Optional[ - bool -] = False # if you want to use v1 generic api logged payload +gcs_pub_sub_use_v1: Optional[bool] = ( + False # if you want to use v1 gcs pubsub logged payload +) +generic_api_use_v1: Optional[bool] = ( + False # if you want to use v1 generic api logged payload +) argilla_transformation_object: Optional[Dict[str, Any]] = None _async_input_callback: List[ Union[str, Callable, "CustomLogger"] @@ -192,25 +192,25 @@ pre_call_rules: List[Callable] = [] post_call_rules: List[Callable] = [] turn_off_message_logging: Optional[bool] = False -standard_logging_payload_excluded_fields: Optional[ - List[str] -] = None # Fields to exclude from StandardLoggingPayload before callbacks receive it +standard_logging_payload_excluded_fields: Optional[List[str]] = ( + None # Fields to exclude from StandardLoggingPayload before callbacks receive it +) log_raw_request_response: bool = False redact_messages_in_exceptions: Optional[bool] = False redact_user_api_key_info: Optional[bool] = False filter_invalid_headers: Optional[bool] = False -add_user_information_to_llm_headers: Optional[ - bool -] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers +add_user_information_to_llm_headers: Optional[bool] = ( + None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers +) store_audit_logs = False # Enterprise feature, allow users to see audit logs ### end of callbacks ############# -email: Optional[ - str -] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -token: Optional[ - str -] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +email: Optional[str] = ( + None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) +token: Optional[str] = ( + None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) telemetry = True max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False)) @@ -272,9 +272,9 @@ ssl_verify: Union[str, bool] = True ssl_security_level: Optional[str] = None ssl_certificate: Optional[str] = None -ssl_ecdh_curve: Optional[ - str -] = None # Set to 'X25519' to disable PQC and improve performance +ssl_ecdh_curve: Optional[str] = ( + None # Set to 'X25519' to disable PQC and improve performance +) disable_streaming_logging: bool = False disable_token_counter: bool = False disable_add_transform_inline_image_block: bool = False @@ -327,20 +327,24 @@ enable_caching_on_provider_specific_optional_params: bool = ( False # feature-flag for caching on optional params - e.g. 'top_k' ) -caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -cache: Optional[ - "Cache" -] = None # cache object <- use this - https://docs.litellm.ai/docs/caching +caching: bool = ( + False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) +caching_with_models: bool = ( + False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) +cache: Optional["Cache"] = ( + None # cache object <- use this - https://docs.litellm.ai/docs/caching +) default_in_memory_ttl: Optional[float] = None default_redis_ttl: Optional[float] = None default_redis_batch_cache_expiry: Optional[float] = None model_alias_map: Dict[str, str] = {} model_group_settings: Optional["ModelGroupSettings"] = None max_budget: float = 0.0 # set the max budget across all providers -budget_duration: Optional[ - str -] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). +budget_duration: Optional[str] = ( + None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). +) default_soft_budget: float = ( DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0 ) @@ -349,7 +353,9 @@ _current_cost = 0.0 # private variable, used if max budget is set error_logs: Dict = {} -add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt +add_function_to_prompt: bool = ( + False # if function calling not supported by api, append function call details to system prompt +) client_session: Optional[httpx.Client] = None aclient_session: Optional[httpx.AsyncClient] = None model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks' @@ -396,7 +402,9 @@ disable_add_prefix_to_prompt: bool = ( False # used by anthropic, to disable adding prefix to prompt ) -disable_copilot_system_to_assistant: bool = False # If false (default), converts all 'system' role messages to 'assistant' for GitHub Copilot compatibility. Set to true to disable this behavior. +disable_copilot_system_to_assistant: bool = ( + False # If false (default), converts all 'system' role messages to 'assistant' for GitHub Copilot compatibility. Set to true to disable this behavior. +) public_mcp_servers: Optional[List[str]] = None public_model_groups: Optional[List[str]] = None public_agent_groups: Optional[List[str]] = None @@ -405,9 +413,9 @@ # Old format: { "displayName": "url" } (for backward compatibility) public_model_groups_links: Dict[str, Union[str, Dict[str, Any]]] = {} #### REQUEST PRIORITIZATION ####### -priority_reservation: Optional[ - Dict[str, Union[float, "PriorityReservationDict"]] -] = None +priority_reservation: Optional[Dict[str, Union[float, "PriorityReservationDict"]]] = ( + None +) # priority_reservation_settings is lazy-loaded via __getattr__ # Only declare for type checking - at runtime __getattr__ handles it if TYPE_CHECKING: @@ -415,13 +423,17 @@ ######## Networking Settings ######## -use_aiohttp_transport: bool = True # Older variable, aiohttp is now the default. use disable_aiohttp_transport instead. +use_aiohttp_transport: bool = ( + True # Older variable, aiohttp is now the default. use disable_aiohttp_transport instead. +) aiohttp_trust_env: bool = False # set to true to use HTTP_ Proxy settings disable_aiohttp_transport: bool = False # Set this to true to use httpx instead disable_aiohttp_trust_env: bool = ( False # When False, aiohttp will respect HTTP(S)_PROXY env vars ) -force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6. +force_ipv4: bool = ( + False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6. +) network_mock: bool = False # When True, use mock transport — no real network calls ####### STOP SEQUENCE LIMIT ####### @@ -436,13 +448,13 @@ content_policy_fallbacks: Optional[List] = None allowed_fails: int = 3 allow_dynamic_callback_disabling: bool = True -num_retries_per_request: Optional[ - int -] = None # for the request overall (incl. fallbacks + model retries) +num_retries_per_request: Optional[int] = ( + None # for the request overall (incl. fallbacks + model retries) +) ####### SECRET MANAGERS ##################### -secret_manager_client: Optional[ - Any -] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. +secret_manager_client: Optional[Any] = ( + None # list of instantiated key management clients - e.g. azure kv, infisical, etc. +) _google_kms_resource_name: Optional[str] = None _key_management_system: Optional["KeyManagementSystem"] = None # Note: KeyManagementSettings must be eagerly imported because _key_management_settings @@ -455,12 +467,12 @@ from litellm.litellm_core_utils.get_model_cost_map import get_model_cost_map model_cost = get_model_cost_map(url=model_cost_map_url) -cost_discount_config: Dict[ - str, float -] = {} # Provider-specific cost discounts {"vertex_ai": 0.05} = 5% discount -cost_margin_config: Dict[ - str, Union[float, Dict[str, float]] -] = {} # Provider-specific or global cost margins. Examples: +cost_discount_config: Dict[str, float] = ( + {} +) # Provider-specific cost discounts {"vertex_ai": 0.05} = 5% discount +cost_margin_config: Dict[str, Union[float, Dict[str, float]]] = ( + {} +) # Provider-specific or global cost margins. Examples: # Percentage: {"openai": 0.10} = 10% margin # Fixed: {"openai": {"fixed_amount": 0.001}} = $0.001 per request # Global: {"global": 0.05} = 5% global margin on all providers @@ -1309,12 +1321,12 @@ def add_known_models(model_cost_map: Optional[Dict] = None): from .types.llms.custom_llm import CustomLLMItem custom_provider_map: List[CustomLLMItem] = [] -_custom_providers: List[ - str -] = [] # internal helper util, used to track names of custom providers -disable_hf_tokenizer_download: Optional[ - bool -] = None # disable huggingface tokenizer download. Defaults to openai clk100 +_custom_providers: List[str] = ( + [] +) # internal helper util, used to track names of custom providers +disable_hf_tokenizer_download: Optional[bool] = ( + None # disable huggingface tokenizer download. Defaults to openai clk100 +) global_disable_no_log_param: bool = False ### CLI UTILITIES ### diff --git a/litellm/_lazy_imports.py b/litellm/_lazy_imports.py index 3604506d406..4d811c3d7d9 100644 --- a/litellm/_lazy_imports.py +++ b/litellm/_lazy_imports.py @@ -14,6 +14,7 @@ This makes importing litellm much faster because we don't load heavy dependencies until they're actually needed. """ + import importlib import sys from typing import Any, Optional, cast, Callable diff --git a/litellm/_uuid.py b/litellm/_uuid.py index 52acf647dd8..2b7c3b82d35 100644 --- a/litellm/_uuid.py +++ b/litellm/_uuid.py @@ -6,7 +6,6 @@ import fastuuid as _uuid # type: ignore - # Expose a module-like alias so callers can use: uuid.uuid4() uuid = _uuid diff --git a/litellm/a2a_protocol/main.py b/litellm/a2a_protocol/main.py index c86549da77a..96a592f4314 100644 --- a/litellm/a2a_protocol/main.py +++ b/litellm/a2a_protocol/main.py @@ -120,9 +120,9 @@ def _get_a2a_model_info(a2a_client: Any, kwargs: Dict[str, Any]) -> str: litellm_logging_obj.model = model litellm_logging_obj.custom_llm_provider = custom_llm_provider litellm_logging_obj.model_call_details["model"] = model - litellm_logging_obj.model_call_details[ - "custom_llm_provider" - ] = custom_llm_provider + litellm_logging_obj.model_call_details["custom_llm_provider"] = ( + custom_llm_provider + ) return agent_name diff --git a/litellm/a2a_protocol/streaming_iterator.py b/litellm/a2a_protocol/streaming_iterator.py index 98d45cf2ac1..c5ae9bcdc3c 100644 --- a/litellm/a2a_protocol/streaming_iterator.py +++ b/litellm/a2a_protocol/streaming_iterator.py @@ -168,9 +168,9 @@ def _build_logging_result(self, usage: litellm.Usage) -> Dict[str, Any]: result: Dict[str, Any] = { "id": getattr(self.request, "id", "unknown"), "jsonrpc": "2.0", - "usage": usage.model_dump() - if hasattr(usage, "model_dump") - else dict(usage), + "usage": ( + usage.model_dump() if hasattr(usage, "model_dump") else dict(usage) + ), } # Add final chunk result if available diff --git a/litellm/anthropic_interface/__init__.py b/litellm/anthropic_interface/__init__.py index 9902fdc553b..280d70142b1 100644 --- a/litellm/anthropic_interface/__init__.py +++ b/litellm/anthropic_interface/__init__.py @@ -1,6 +1,7 @@ """ Anthropic module for LiteLLM """ + from .messages import acreate, create __all__ = ["acreate", "create"] diff --git a/litellm/anthropic_interface/exceptions/exception_mapping_utils.py b/litellm/anthropic_interface/exceptions/exception_mapping_utils.py index 28020e763f4..4548185bbdc 100644 --- a/litellm/anthropic_interface/exceptions/exception_mapping_utils.py +++ b/litellm/anthropic_interface/exceptions/exception_mapping_utils.py @@ -9,7 +9,6 @@ from .exceptions import AnthropicErrorResponse, AnthropicErrorType - # HTTP status code -> Anthropic error type # Source: https://docs.anthropic.com/en/api/errors ANTHROPIC_ERROR_TYPE_MAP: Dict[int, AnthropicErrorType] = { diff --git a/litellm/anthropic_interface/exceptions/exceptions.py b/litellm/anthropic_interface/exceptions/exceptions.py index 984390fa702..b289e493e6b 100644 --- a/litellm/anthropic_interface/exceptions/exceptions.py +++ b/litellm/anthropic_interface/exceptions/exceptions.py @@ -2,7 +2,6 @@ from typing_extensions import Literal, Required, TypedDict - # Known Anthropic error types # Source: https://docs.anthropic.com/en/api/errors AnthropicErrorType = Literal[ diff --git a/litellm/caching/caching_handler.py b/litellm/caching/caching_handler.py index 7cdbd3fc03d..2bec705946c 100644 --- a/litellm/caching/caching_handler.py +++ b/litellm/caching/caching_handler.py @@ -78,7 +78,9 @@ class CachingHandlerResponse(BaseModel): cached_result: Optional[Any] = None final_embedding_cached_response: Optional[EmbeddingResponse] = None - embedding_all_elements_cache_hit: bool = False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call + embedding_all_elements_cache_hit: bool = ( + False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call + ) in_memory_cache_obj = InMemoryCache() @@ -1014,9 +1016,9 @@ def _update_litellm_logging_obj_environment( } if litellm.cache is not None: - litellm_params[ - "preset_cache_key" - ] = litellm.cache._get_preset_cache_key_from_kwargs(**kwargs) + litellm_params["preset_cache_key"] = ( + litellm.cache._get_preset_cache_key_from_kwargs(**kwargs) + ) else: litellm_params["preset_cache_key"] = None diff --git a/litellm/caching/gcs_cache.py b/litellm/caching/gcs_cache.py index a5bd092f154..3327e094bc2 100644 --- a/litellm/caching/gcs_cache.py +++ b/litellm/caching/gcs_cache.py @@ -1,6 +1,7 @@ """GCS Cache implementation Supports syncing responses to Google Cloud Storage Buckets using HTTP requests. """ + import json import asyncio from typing import Optional diff --git a/litellm/completion_extras/litellm_responses_transformation/handler.py b/litellm/completion_extras/litellm_responses_transformation/handler.py index 2164a2c0f01..ce398ee8288 100644 --- a/litellm/completion_extras/litellm_responses_transformation/handler.py +++ b/litellm/completion_extras/litellm_responses_transformation/handler.py @@ -142,9 +142,7 @@ def validate_input_kwargs( custom_llm_provider=custom_llm_provider, ) - def completion( - self, *args, **kwargs - ) -> Union[ + def completion(self, *args, **kwargs) -> Union[ Coroutine[Any, Any, Union["ModelResponse", "CustomStreamWrapper"]], "ModelResponse", "CustomStreamWrapper", diff --git a/litellm/completion_extras/litellm_responses_transformation/transformation.py b/litellm/completion_extras/litellm_responses_transformation/transformation.py index ee4cdbcdf36..f5856ab1f45 100644 --- a/litellm/completion_extras/litellm_responses_transformation/transformation.py +++ b/litellm/completion_extras/litellm_responses_transformation/transformation.py @@ -240,10 +240,10 @@ def _map_optional_params_to_responses_api_request( if key in ("max_tokens", "max_completion_tokens"): responses_api_request["max_output_tokens"] = value elif key == "tools" and value is not None: - responses_api_request[ - "tools" - ] = self._convert_tools_to_responses_format( - cast(List[Dict[str, Any]], value) + responses_api_request["tools"] = ( + self._convert_tools_to_responses_format( + cast(List[Dict[str, Any]], value) + ) ) elif key == "response_format": text_format = self._transform_response_format_to_text_format(value) @@ -1072,9 +1072,9 @@ def translate_responses_chunk_to_openai_stream( # noqa: PLR0915 ) if provider_specific_fields: - function_chunk[ - "provider_specific_fields" - ] = provider_specific_fields + function_chunk["provider_specific_fields"] = ( + provider_specific_fields + ) tool_call_index = parsed_chunk.get("output_index", 0) tool_call_chunk = ChatCompletionToolCallChunk( @@ -1147,9 +1147,9 @@ def translate_responses_chunk_to_openai_stream( # noqa: PLR0915 # Add provider_specific_fields to function if present if provider_specific_fields: - function_chunk[ - "provider_specific_fields" - ] = provider_specific_fields + function_chunk["provider_specific_fields"] = ( + provider_specific_fields + ) tool_call_index = parsed_chunk.get("output_index", 0) tool_call_chunk = ChatCompletionToolCallChunk( diff --git a/litellm/containers/endpoint_factory.py b/litellm/containers/endpoint_factory.py index 1d8e50856fe..a3624a90674 100644 --- a/litellm/containers/endpoint_factory.py +++ b/litellm/containers/endpoint_factory.py @@ -76,10 +76,10 @@ def endpoint_func( # Get provider config litellm_params = GenericLiteLLMParams(**kwargs) - container_provider_config: Optional[ - BaseContainerConfig - ] = ProviderConfigManager.get_provider_container_config( - provider=litellm.LlmProviders(custom_llm_provider), + container_provider_config: Optional[BaseContainerConfig] = ( + ProviderConfigManager.get_provider_container_config( + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if container_provider_config is None: diff --git a/litellm/containers/main.py b/litellm/containers/main.py index 916fc26351b..a7b37d3f469 100644 --- a/litellm/containers/main.py +++ b/litellm/containers/main.py @@ -165,7 +165,10 @@ def create_container( extra_query: Optional[Dict[str, Any]] = None, extra_body: Optional[Dict[str, Any]] = None, **kwargs, -) -> Union[ContainerObject, Coroutine[Any, Any, ContainerObject],]: +) -> Union[ + ContainerObject, + Coroutine[Any, Any, ContainerObject], +]: """Create a container using the OpenAI Container API. Currently supports OpenAI @@ -205,10 +208,10 @@ def create_container( **kwargs, ) # get provider config - container_provider_config: Optional[ - BaseContainerConfig - ] = ProviderConfigManager.get_provider_container_config( - provider=litellm.LlmProviders(custom_llm_provider), + container_provider_config: Optional[BaseContainerConfig] = ( + ProviderConfigManager.get_provider_container_config( + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if container_provider_config is None: @@ -391,7 +394,10 @@ def list_containers( extra_query: Optional[Dict[str, Any]] = None, extra_body: Optional[Dict[str, Any]] = None, **kwargs, -) -> Union[ContainerListResponse, Coroutine[Any, Any, ContainerListResponse],]: +) -> Union[ + ContainerListResponse, + Coroutine[Any, Any, ContainerListResponse], +]: """List containers using the OpenAI Container API. Currently supports OpenAI @@ -420,10 +426,10 @@ def list_containers( **kwargs, ) # get provider config - container_provider_config: Optional[ - BaseContainerConfig - ] = ProviderConfigManager.get_provider_container_config( - provider=litellm.LlmProviders(custom_llm_provider), + container_provider_config: Optional[BaseContainerConfig] = ( + ProviderConfigManager.get_provider_container_config( + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if container_provider_config is None: @@ -587,7 +593,10 @@ def retrieve_container( extra_query: Optional[Dict[str, Any]] = None, extra_body: Optional[Dict[str, Any]] = None, **kwargs, -) -> Union[ContainerObject, Coroutine[Any, Any, ContainerObject],]: +) -> Union[ + ContainerObject, + Coroutine[Any, Any, ContainerObject], +]: """Retrieve a container using the OpenAI Container API. Currently supports OpenAI @@ -616,10 +625,10 @@ def retrieve_container( **kwargs, ) # get provider config - container_provider_config: Optional[ - BaseContainerConfig - ] = ProviderConfigManager.get_provider_container_config( - provider=litellm.LlmProviders(custom_llm_provider), + container_provider_config: Optional[BaseContainerConfig] = ( + ProviderConfigManager.get_provider_container_config( + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if container_provider_config is None: @@ -773,7 +782,10 @@ def delete_container( extra_query: Optional[Dict[str, Any]] = None, extra_body: Optional[Dict[str, Any]] = None, **kwargs, -) -> Union[DeleteContainerResult, Coroutine[Any, Any, DeleteContainerResult],]: +) -> Union[ + DeleteContainerResult, + Coroutine[Any, Any, DeleteContainerResult], +]: """Delete a container using the OpenAI Container API. Currently supports OpenAI @@ -802,10 +814,10 @@ def delete_container( **kwargs, ) # get provider config - container_provider_config: Optional[ - BaseContainerConfig - ] = ProviderConfigManager.get_provider_container_config( - provider=litellm.LlmProviders(custom_llm_provider), + container_provider_config: Optional[BaseContainerConfig] = ( + ProviderConfigManager.get_provider_container_config( + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if container_provider_config is None: @@ -973,7 +985,10 @@ def list_container_files( extra_query: Optional[Dict[str, Any]] = None, extra_body: Optional[Dict[str, Any]] = None, **kwargs, -) -> Union[ContainerFileListResponse, Coroutine[Any, Any, ContainerFileListResponse],]: +) -> Union[ + ContainerFileListResponse, + Coroutine[Any, Any, ContainerFileListResponse], +]: """List files in a container using the OpenAI Container API. Currently supports OpenAI @@ -1002,10 +1017,10 @@ def list_container_files( **kwargs, ) # get provider config - container_provider_config: Optional[ - BaseContainerConfig - ] = ProviderConfigManager.get_provider_container_config( - provider=litellm.LlmProviders(custom_llm_provider), + container_provider_config: Optional[BaseContainerConfig] = ( + ProviderConfigManager.get_provider_container_config( + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if container_provider_config is None: @@ -1190,7 +1205,10 @@ def upload_container_file( extra_query: Optional[Dict[str, Any]] = None, extra_body: Optional[Dict[str, Any]] = None, **kwargs, -) -> Union[ContainerFileObject, Coroutine[Any, Any, ContainerFileObject],]: +) -> Union[ + ContainerFileObject, + Coroutine[Any, Any, ContainerFileObject], +]: """Upload a file to a container using the OpenAI Container API. This endpoint allows uploading files directly to a container session, @@ -1248,10 +1266,10 @@ def upload_container_file( **kwargs, ) # get provider config - container_provider_config: Optional[ - BaseContainerConfig - ] = ProviderConfigManager.get_provider_container_config( - provider=litellm.LlmProviders(custom_llm_provider), + container_provider_config: Optional[BaseContainerConfig] = ( + ProviderConfigManager.get_provider_container_config( + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if container_provider_config is None: diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 29d28b8c896..46eb1f01d61 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -1136,9 +1136,9 @@ def completion_cost( # noqa: PLR0915 or isinstance(completion_response, dict) ): # tts returns a custom class if isinstance(completion_response, dict): - usage_obj: Optional[ - Union[dict, Usage] - ] = completion_response.get("usage", {}) + usage_obj: Optional[Union[dict, Usage]] = ( + completion_response.get("usage", {}) + ) else: usage_obj = getattr(completion_response, "usage", {}) if isinstance(usage_obj, BaseModel) and not _is_known_usage_objects( diff --git a/litellm/evals/main.py b/litellm/evals/main.py index eab909a6b11..df6d3accb82 100644 --- a/litellm/evals/main.py +++ b/litellm/evals/main.py @@ -152,10 +152,10 @@ def create_eval( custom_llm_provider = "openai" # Get provider config - evals_api_provider_config: Optional[ - BaseEvalsAPIConfig - ] = ProviderConfigManager.get_provider_evals_api_config( # type: ignore - provider=litellm.LlmProviders(custom_llm_provider), + evals_api_provider_config: Optional[BaseEvalsAPIConfig] = ( + ProviderConfigManager.get_provider_evals_api_config( # type: ignore + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if evals_api_provider_config is None: @@ -343,10 +343,10 @@ def list_evals( custom_llm_provider = "openai" # Get provider config - evals_api_provider_config: Optional[ - BaseEvalsAPIConfig - ] = ProviderConfigManager.get_provider_evals_api_config( # type: ignore - provider=litellm.LlmProviders(custom_llm_provider), + evals_api_provider_config: Optional[BaseEvalsAPIConfig] = ( + ProviderConfigManager.get_provider_evals_api_config( # type: ignore + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if evals_api_provider_config is None: @@ -513,10 +513,10 @@ def get_eval( custom_llm_provider = "openai" # Get provider config - evals_api_provider_config: Optional[ - BaseEvalsAPIConfig - ] = ProviderConfigManager.get_provider_evals_api_config( # type: ignore - provider=litellm.LlmProviders(custom_llm_provider), + evals_api_provider_config: Optional[BaseEvalsAPIConfig] = ( + ProviderConfigManager.get_provider_evals_api_config( # type: ignore + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if evals_api_provider_config is None: @@ -682,10 +682,10 @@ def update_eval( custom_llm_provider = "openai" # Get provider config - evals_api_provider_config: Optional[ - BaseEvalsAPIConfig - ] = ProviderConfigManager.get_provider_evals_api_config( # type: ignore - provider=litellm.LlmProviders(custom_llm_provider), + evals_api_provider_config: Optional[BaseEvalsAPIConfig] = ( + ProviderConfigManager.get_provider_evals_api_config( # type: ignore + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if evals_api_provider_config is None: @@ -893,10 +893,10 @@ def delete_eval( custom_llm_provider = "openai" # Get provider config - evals_api_provider_config: Optional[ - BaseEvalsAPIConfig - ] = ProviderConfigManager.get_provider_evals_api_config( # type: ignore - provider=litellm.LlmProviders(custom_llm_provider), + evals_api_provider_config: Optional[BaseEvalsAPIConfig] = ( + ProviderConfigManager.get_provider_evals_api_config( # type: ignore + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if evals_api_provider_config is None: @@ -1047,10 +1047,10 @@ def cancel_eval( custom_llm_provider = "openai" # Get provider config - evals_api_provider_config: Optional[ - BaseEvalsAPIConfig - ] = ProviderConfigManager.get_provider_evals_api_config( # type: ignore - provider=litellm.LlmProviders(custom_llm_provider), + evals_api_provider_config: Optional[BaseEvalsAPIConfig] = ( + ProviderConfigManager.get_provider_evals_api_config( # type: ignore + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if evals_api_provider_config is None: @@ -1230,10 +1230,10 @@ def create_run( custom_llm_provider = "openai" # Get provider config - evals_api_provider_config: Optional[ - BaseEvalsAPIConfig - ] = ProviderConfigManager.get_provider_evals_api_config( # type: ignore - provider=litellm.LlmProviders(custom_llm_provider), + evals_api_provider_config: Optional[BaseEvalsAPIConfig] = ( + ProviderConfigManager.get_provider_evals_api_config( # type: ignore + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if evals_api_provider_config is None: @@ -1418,10 +1418,10 @@ def list_runs( custom_llm_provider = "openai" # Get provider config - evals_api_provider_config: Optional[ - BaseEvalsAPIConfig - ] = ProviderConfigManager.get_provider_evals_api_config( # type: ignore - provider=litellm.LlmProviders(custom_llm_provider), + evals_api_provider_config: Optional[BaseEvalsAPIConfig] = ( + ProviderConfigManager.get_provider_evals_api_config( # type: ignore + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if evals_api_provider_config is None: @@ -1592,10 +1592,10 @@ def get_run( custom_llm_provider = "openai" # Get provider config - evals_api_provider_config: Optional[ - BaseEvalsAPIConfig - ] = ProviderConfigManager.get_provider_evals_api_config( # type: ignore - provider=litellm.LlmProviders(custom_llm_provider), + evals_api_provider_config: Optional[BaseEvalsAPIConfig] = ( + ProviderConfigManager.get_provider_evals_api_config( # type: ignore + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if evals_api_provider_config is None: @@ -1752,10 +1752,10 @@ def cancel_run( custom_llm_provider = "openai" # Get provider config - evals_api_provider_config: Optional[ - BaseEvalsAPIConfig - ] = ProviderConfigManager.get_provider_evals_api_config( # type: ignore - provider=litellm.LlmProviders(custom_llm_provider), + evals_api_provider_config: Optional[BaseEvalsAPIConfig] = ( + ProviderConfigManager.get_provider_evals_api_config( # type: ignore + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if evals_api_provider_config is None: @@ -1921,10 +1921,10 @@ def delete_run( custom_llm_provider = "openai" # Get provider config - evals_api_provider_config: Optional[ - BaseEvalsAPIConfig - ] = ProviderConfigManager.get_provider_evals_api_config( # type: ignore - provider=litellm.LlmProviders(custom_llm_provider), + evals_api_provider_config: Optional[BaseEvalsAPIConfig] = ( + ProviderConfigManager.get_provider_evals_api_config( # type: ignore + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if evals_api_provider_config is None: diff --git a/litellm/exceptions.py b/litellm/exceptions.py index abdba09dd8d..3677509f8e4 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -281,7 +281,7 @@ def __repr__(self): return _message -class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore +class PermissionDeniedError(openai.PermissionDeniedError): # type: ignore def __init__( self, message, diff --git a/litellm/google_genai/adapters/__init__.py b/litellm/google_genai/adapters/__init__.py index bfa9e712678..6fbe7d95a55 100644 --- a/litellm/google_genai/adapters/__init__.py +++ b/litellm/google_genai/adapters/__init__.py @@ -1,10 +1,10 @@ """ Google GenAI Adapters for LiteLLM -This module provides adapters for transforming Google GenAI generate_content requests +This module provides adapters for transforming Google GenAI generate_content requests to/from LiteLLM completion format with full support for: - Text content transformation -- Tool calling (function declarations, function calls, function responses) +- Tool calling (function declarations, function calls, function responses) - Streaming (both regular and tool calling) - Mixed content (text + tool calls) """ diff --git a/litellm/images/main.py b/litellm/images/main.py index a5ae154190a..0d3b2e97294 100644 --- a/litellm/images/main.py +++ b/litellm/images/main.py @@ -210,7 +210,10 @@ def image_generation( # noqa: PLR0915 api_version: Optional[str] = None, custom_llm_provider=None, **kwargs, -) -> Union[ImageResponse, Coroutine[Any, Any, ImageResponse],]: +) -> Union[ + ImageResponse, + Coroutine[Any, Any, ImageResponse], +]: """ Maps the https://api.openai.com/v1/images/generations endpoint. @@ -864,11 +867,11 @@ def image_edit( # noqa: PLR0915 ) # get provider config - image_edit_provider_config: Optional[ - BaseImageEditConfig - ] = ProviderConfigManager.get_provider_image_edit_config( - model=model, - provider=litellm.LlmProviders(custom_llm_provider), + image_edit_provider_config: Optional[BaseImageEditConfig] = ( + ProviderConfigManager.get_provider_image_edit_config( + model=model, + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if image_edit_provider_config is None: @@ -876,20 +879,20 @@ def image_edit( # noqa: PLR0915 local_vars.update(kwargs) # Get ImageEditOptionalRequestParams with only valid parameters - image_edit_optional_params: ImageEditOptionalRequestParams = ( - _get_ImageEditRequestUtils().get_requested_image_edit_optional_param( - local_vars - ) + image_edit_optional_params: ( + ImageEditOptionalRequestParams + ) = _get_ImageEditRequestUtils().get_requested_image_edit_optional_param( + local_vars ) # Get optional parameters for the responses API - image_edit_request_params: Dict = ( - _get_ImageEditRequestUtils().get_optional_params_image_edit( - model=model, - image_edit_provider_config=image_edit_provider_config, - image_edit_optional_params=image_edit_optional_params, - drop_params=kwargs.get("drop_params"), - additional_drop_params=kwargs.get("additional_drop_params"), - ) + image_edit_request_params: ( + Dict + ) = _get_ImageEditRequestUtils().get_optional_params_image_edit( + model=model, + image_edit_provider_config=image_edit_provider_config, + image_edit_optional_params=image_edit_optional_params, + drop_params=kwargs.get("drop_params"), + additional_drop_params=kwargs.get("additional_drop_params"), ) # Pre Call logging diff --git a/litellm/integrations/SlackAlerting/batching_handler.py b/litellm/integrations/SlackAlerting/batching_handler.py index fdce2e04793..828f3eb4175 100644 --- a/litellm/integrations/SlackAlerting/batching_handler.py +++ b/litellm/integrations/SlackAlerting/batching_handler.py @@ -1,9 +1,9 @@ """ -Handles Batching + sending Httpx Post requests to slack +Handles Batching + sending Httpx Post requests to slack -Slack alerts are sent every 10s or when events are greater than X events +Slack alerts are sent every 10s or when events are greater than X events -see custom_batch_logger.py for more details / defaults +see custom_batch_logger.py for more details / defaults """ from typing import TYPE_CHECKING, Any diff --git a/litellm/integrations/SlackAlerting/hanging_request_check.py b/litellm/integrations/SlackAlerting/hanging_request_check.py index b9c485dce82..d2f70c9caf1 100644 --- a/litellm/integrations/SlackAlerting/hanging_request_check.py +++ b/litellm/integrations/SlackAlerting/hanging_request_check.py @@ -102,10 +102,10 @@ async def send_alerts_for_hanging_requests(self): ) for request_id in hanging_requests: - hanging_request_data: Optional[ - HangingRequestData - ] = await self.hanging_request_cache.async_get_cache( - key=request_id, + hanging_request_data: Optional[HangingRequestData] = ( + await self.hanging_request_cache.async_get_cache( + key=request_id, + ) ) if hanging_request_data is None: diff --git a/litellm/integrations/SlackAlerting/slack_alerting.py b/litellm/integrations/SlackAlerting/slack_alerting.py index 013cef74805..0ec17bbea5d 100644 --- a/litellm/integrations/SlackAlerting/slack_alerting.py +++ b/litellm/integrations/SlackAlerting/slack_alerting.py @@ -852,9 +852,9 @@ async def region_outage_alerts( ### UNIQUE CACHE KEY ### cache_key = provider + region_name - outage_value: Optional[ - ProviderRegionOutageModel - ] = await self.internal_usage_cache.async_get_cache(key=cache_key) + outage_value: Optional[ProviderRegionOutageModel] = ( + await self.internal_usage_cache.async_get_cache(key=cache_key) + ) # Convert deployment_ids back to set if it was stored as a list if outage_value is not None: @@ -1443,9 +1443,9 @@ async def send_alert( # noqa: PLR0915 self.alert_to_webhook_url is not None and alert_type in self.alert_to_webhook_url ): - _digest_webhook: Optional[ - Union[str, List[str]] - ] = self.alert_to_webhook_url[alert_type] + _digest_webhook: Optional[Union[str, List[str]]] = ( + self.alert_to_webhook_url[alert_type] + ) elif self.default_webhook_url is not None: _digest_webhook = self.default_webhook_url else: @@ -1499,9 +1499,9 @@ async def send_alert( # noqa: PLR0915 self.alert_to_webhook_url is not None and alert_type in self.alert_to_webhook_url ): - slack_webhook_url: Optional[ - Union[str, List[str]] - ] = self.alert_to_webhook_url[alert_type] + slack_webhook_url: Optional[Union[str, List[str]]] = ( + self.alert_to_webhook_url[alert_type] + ) elif self.default_webhook_url is not None: slack_webhook_url = self.default_webhook_url else: diff --git a/litellm/integrations/SlackAlerting/utils.py b/litellm/integrations/SlackAlerting/utils.py index e695266c88b..e2580768178 100644 --- a/litellm/integrations/SlackAlerting/utils.py +++ b/litellm/integrations/SlackAlerting/utils.py @@ -18,7 +18,7 @@ def process_slack_alerting_variables( - alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]] + alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]], ) -> Optional[Dict[AlertType, Union[List[str], str]]]: """ process alert_to_webhook_url diff --git a/litellm/integrations/additional_logging_utils.py b/litellm/integrations/additional_logging_utils.py index 795afd81d41..59319140a18 100644 --- a/litellm/integrations/additional_logging_utils.py +++ b/litellm/integrations/additional_logging_utils.py @@ -1,5 +1,5 @@ """ -Base class for Additional Logging Utils for CustomLoggers +Base class for Additional Logging Utils for CustomLoggers - Health Check for the logging util - Get Request / Response Payload for the logging util diff --git a/litellm/integrations/agentops/agentops.py b/litellm/integrations/agentops/agentops.py index 38b91c06587..4f17806a6b7 100644 --- a/litellm/integrations/agentops/agentops.py +++ b/litellm/integrations/agentops/agentops.py @@ -1,6 +1,7 @@ """ AgentOps integration for LiteLLM - Provides OpenTelemetry tracing for LLM calls """ + import os from dataclasses import dataclass from typing import Optional, Dict, Any diff --git a/litellm/integrations/anthropic_cache_control_hook.py b/litellm/integrations/anthropic_cache_control_hook.py index 0e99537d5db..213622cb43a 100644 --- a/litellm/integrations/anthropic_cache_control_hook.py +++ b/litellm/integrations/anthropic_cache_control_hook.py @@ -106,10 +106,10 @@ def _process_message_injection( targetted_index += len(messages) if 0 <= targetted_index < len(messages): - messages[ - targetted_index - ] = AnthropicCacheControlHook._safe_insert_cache_control_in_message( - messages[targetted_index], control + messages[targetted_index] = ( + AnthropicCacheControlHook._safe_insert_cache_control_in_message( + messages[targetted_index], control + ) ) else: verbose_logger.warning( diff --git a/litellm/integrations/arize/arize_phoenix.py b/litellm/integrations/arize/arize_phoenix.py index 00bc24d4188..b8cd04836c3 100644 --- a/litellm/integrations/arize/arize_phoenix.py +++ b/litellm/integrations/arize/arize_phoenix.py @@ -178,9 +178,9 @@ def _get_phoenix_context(self, kwargs): start_time_val = kwargs.get("start_time", kwargs.get("api_call_start_time")) parent_span = self.tracer.start_span( name="litellm_proxy_request", - start_time=self._to_ns(start_time_val) - if start_time_val is not None - else None, + start_time=( + self._to_ns(start_time_val) if start_time_val is not None else None + ), context=traceparent_ctx, kind=self.span_kind.SERVER, ) diff --git a/litellm/integrations/azure_storage/azure_storage.py b/litellm/integrations/azure_storage/azure_storage.py index 6fc7b9c1048..85f91199c1c 100644 --- a/litellm/integrations/azure_storage/azure_storage.py +++ b/litellm/integrations/azure_storage/azure_storage.py @@ -54,12 +54,12 @@ def __init__( self._service_client_timeout: Optional[float] = None # Internal variables used for Token based authentication - self.azure_auth_token: Optional[ - str - ] = None # the Azure AD token to use for Azure Storage API requests - self.token_expiry: Optional[ - datetime - ] = None # the expiry time of the currentAzure AD token + self.azure_auth_token: Optional[str] = ( + None # the Azure AD token to use for Azure Storage API requests + ) + self.token_expiry: Optional[datetime] = ( + None # the expiry time of the currentAzure AD token + ) asyncio.create_task(self.periodic_flush()) self.flush_lock = asyncio.Lock() diff --git a/litellm/integrations/braintrust_logging.py b/litellm/integrations/braintrust_logging.py index cb1b2bc5531..9b1c5077882 100644 --- a/litellm/integrations/braintrust_logging.py +++ b/litellm/integrations/braintrust_logging.py @@ -52,9 +52,9 @@ def __init__( "Authorization": "Bearer " + self.api_key, "Content-Type": "application/json", } - self._project_id_cache: Dict[ - str, str - ] = {} # Cache mapping project names to IDs + self._project_id_cache: Dict[str, str] = ( + {} + ) # Cache mapping project names to IDs self.global_braintrust_http_handler = get_async_httpx_client( llm_provider=httpxSpecialProvider.LoggingCallback ) diff --git a/litellm/integrations/cloudzero/cloudzero.py b/litellm/integrations/cloudzero/cloudzero.py index 9da8ea52b5c..8decd4ef23f 100644 --- a/litellm/integrations/cloudzero/cloudzero.py +++ b/litellm/integrations/cloudzero/cloudzero.py @@ -402,10 +402,10 @@ async def init_cloudzero_background_job(scheduler: AsyncIOScheduler): from litellm.constants import CLOUDZERO_EXPORT_INTERVAL_MINUTES from litellm.integrations.custom_logger import CustomLogger - prometheus_loggers: List[ - CustomLogger - ] = litellm.logging_callback_manager.get_custom_loggers_for_type( - callback_type=CloudZeroLogger + prometheus_loggers: List[CustomLogger] = ( + litellm.logging_callback_manager.get_custom_loggers_for_type( + callback_type=CloudZeroLogger + ) ) # we need to get the initialized prometheus logger instance(s) and call logger.initialize_remaining_budget_metrics() on them verbose_logger.debug("found %s cloudzero loggers", len(prometheus_loggers)) diff --git a/litellm/integrations/cloudzero/transform.py b/litellm/integrations/cloudzero/transform.py index c1b0d5cf411..2d84796150a 100644 --- a/litellm/integrations/cloudzero/transform.py +++ b/litellm/integrations/cloudzero/transform.py @@ -159,9 +159,9 @@ def _create_cbf_record(self, row: dict[str, Any]) -> CBFRecord: # CloudZero CBF format with proper column names cbf_record = { # Required CBF fields - "time/usage_start": usage_date.isoformat() - if usage_date - else None, # Required: ISO-formatted UTC datetime + "time/usage_start": ( + usage_date.isoformat() if usage_date else None + ), # Required: ISO-formatted UTC datetime "cost/cost": float(row.get("spend", 0.0)), # Required: billed cost "resource/id": resource_id, # CZRN (CloudZero Resource Name) # Usage metrics for token consumption @@ -182,9 +182,9 @@ def _create_cbf_record(self, row: dict[str, Any]) -> CBFRecord: # Add CZRN components that don't have direct CBF column mappings as resource tags cbf_record["resource/tag:provider"] = provider # CZRN provider component - cbf_record[ - "resource/tag:model" - ] = cloud_local_id # CZRN cloud-local-id component (model) + cbf_record["resource/tag:model"] = ( + cloud_local_id # CZRN cloud-local-id component (model) + ) # Add resource tags for all dimensions (using resource/tag: format) for key, value in dimensions.items(): diff --git a/litellm/integrations/custom_batch_logger.py b/litellm/integrations/custom_batch_logger.py index f9d4496c21f..6da7aec5559 100644 --- a/litellm/integrations/custom_batch_logger.py +++ b/litellm/integrations/custom_batch_logger.py @@ -1,5 +1,5 @@ """ -Custom Logger that handles batching logic +Custom Logger that handles batching logic Use this if you want your logs to be stored in memory and flushed periodically. """ diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index cccabf53e51..45c8e2f6262 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -874,9 +874,9 @@ def redact_standard_logging_payload_from_model_call_details( model_response_dict = model_response.model_dump() standard_logging_object_copy["response"] = model_response_dict - model_call_details_copy[ - "standard_logging_object" - ] = standard_logging_object_copy + model_call_details_copy["standard_logging_object"] = ( + standard_logging_object_copy + ) return model_call_details_copy async def get_proxy_server_request_from_cold_storage_with_object_key( diff --git a/litellm/integrations/datadog/datadog_llm_obs.py b/litellm/integrations/datadog/datadog_llm_obs.py index ec6c00961b6..201d3fb0a41 100644 --- a/litellm/integrations/datadog/datadog_llm_obs.py +++ b/litellm/integrations/datadog/datadog_llm_obs.py @@ -349,9 +349,9 @@ def _assemble_error_info( if standard_logging_payload.get("status") == "failure": # Try to get structured error information first - error_information: Optional[ - StandardLoggingPayloadErrorInformation - ] = standard_logging_payload.get("error_information") + error_information: Optional[StandardLoggingPayloadErrorInformation] = ( + standard_logging_payload.get("error_information") + ) if error_information: error_info = DDLLMObsError( @@ -621,9 +621,9 @@ def _get_latency_metrics( latency_metrics["litellm_overhead_time_ms"] = litellm_overhead_ms # Guardrail overhead latency - guardrail_info: Optional[ - list[StandardLoggingGuardrailInformation] - ] = standard_logging_payload.get("guardrail_information") + guardrail_info: Optional[list[StandardLoggingGuardrailInformation]] = ( + standard_logging_payload.get("guardrail_information") + ) if guardrail_info is not None: total_duration = 0.0 for info in guardrail_info: @@ -793,15 +793,15 @@ def _tool_calls_kv_pair(tool_calls: List[Dict[str, Any]]) -> Dict[str, Any]: if function_arguments: # Store arguments as JSON string for Datadog if isinstance(function_arguments, str): - kv_pairs[ - f"tool_calls.{idx}.function.arguments" - ] = function_arguments + kv_pairs[f"tool_calls.{idx}.function.arguments"] = ( + function_arguments + ) else: import json - kv_pairs[ - f"tool_calls.{idx}.function.arguments" - ] = json.dumps(function_arguments) + kv_pairs[f"tool_calls.{idx}.function.arguments"] = ( + json.dumps(function_arguments) + ) except (KeyError, TypeError, ValueError) as e: verbose_logger.debug( f"DataDogLLMObs: Error processing tool call {idx}: {str(e)}" diff --git a/litellm/integrations/focus/transformer.py b/litellm/integrations/focus/transformer.py index b7d28e3dbb9..6f4433b4a05 100644 --- a/litellm/integrations/focus/transformer.py +++ b/litellm/integrations/focus/transformer.py @@ -9,7 +9,6 @@ from .schema import FOCUS_NORMALIZED_SCHEMA - _TAG_KEYS = ( "team_id", "team_alias", diff --git a/litellm/integrations/gcs_bucket/gcs_bucket_base.py b/litellm/integrations/gcs_bucket/gcs_bucket_base.py index 923f613291f..2375f5104dc 100644 --- a/litellm/integrations/gcs_bucket/gcs_bucket_base.py +++ b/litellm/integrations/gcs_bucket/gcs_bucket_base.py @@ -146,9 +146,9 @@ async def get_gcs_logging_config( if kwargs is None: kwargs = {} - standard_callback_dynamic_params: Optional[ - StandardCallbackDynamicParams - ] = kwargs.get("standard_callback_dynamic_params", None) + standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( + kwargs.get("standard_callback_dynamic_params", None) + ) bucket_name: str path_service_account: Optional[str] diff --git a/litellm/integrations/humanloop.py b/litellm/integrations/humanloop.py index 11414869a65..369df5ee0bd 100644 --- a/litellm/integrations/humanloop.py +++ b/litellm/integrations/humanloop.py @@ -162,7 +162,11 @@ def get_chat_completion_prompt( prompt_version: Optional[int] = None, ignore_prompt_manager_model: Optional[bool] = False, ignore_prompt_manager_optional_params: Optional[bool] = False, - ) -> Tuple[str, List[AllMessageValues], dict,]: + ) -> Tuple[ + str, + List[AllMessageValues], + dict, + ]: humanloop_api_key = dynamic_callback_params.get( "humanloop_api_key" ) or get_secret_str("HUMANLOOP_API_KEY") diff --git a/litellm/integrations/langfuse/langfuse.py b/litellm/integrations/langfuse/langfuse.py index 6ac337d99a9..e691c490c85 100644 --- a/litellm/integrations/langfuse/langfuse.py +++ b/litellm/integrations/langfuse/langfuse.py @@ -572,9 +572,9 @@ def _log_langfuse_v2( # noqa: PLR0915 # we clean out all extra litellm metadata params before logging clean_metadata: Dict[str, Any] = {} if prompt_management_metadata is not None: - clean_metadata[ - "prompt_management_metadata" - ] = prompt_management_metadata + clean_metadata["prompt_management_metadata"] = ( + prompt_management_metadata + ) if isinstance(metadata, dict): for key, value in metadata.items(): # generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy diff --git a/litellm/integrations/langfuse/langfuse_handler.py b/litellm/integrations/langfuse/langfuse_handler.py index f9d27f6cf00..fbadf1a2fc7 100644 --- a/litellm/integrations/langfuse/langfuse_handler.py +++ b/litellm/integrations/langfuse/langfuse_handler.py @@ -86,9 +86,7 @@ def _return_global_langfuse_logger( if globalLangfuseLogger is not None: return globalLangfuseLogger - credentials_dict: Dict[ - str, Any - ] = ( + credentials_dict: Dict[str, Any] = ( {} ) # the global langfuse logger uses Environment Variables, there are no dynamic credentials globalLangfuseLogger = in_memory_dynamic_logger_cache.get_cache( diff --git a/litellm/integrations/langfuse/langfuse_prompt_management.py b/litellm/integrations/langfuse/langfuse_prompt_management.py index bea027aa63d..5f4ced3a5cb 100644 --- a/litellm/integrations/langfuse/langfuse_prompt_management.py +++ b/litellm/integrations/langfuse/langfuse_prompt_management.py @@ -190,7 +190,11 @@ async def async_get_chat_completion_prompt( prompt_version: Optional[int] = None, ignore_prompt_manager_model: Optional[bool] = False, ignore_prompt_manager_optional_params: Optional[bool] = False, - ) -> Tuple[str, List[AllMessageValues], dict,]: + ) -> Tuple[ + str, + List[AllMessageValues], + dict, + ]: return self.get_chat_completion_prompt( model, messages, diff --git a/litellm/integrations/langsmith.py b/litellm/integrations/langsmith.py index b931d7ecfe7..3d4fd39ebe1 100644 --- a/litellm/integrations/langsmith.py +++ b/litellm/integrations/langsmith.py @@ -83,9 +83,9 @@ def __init__( if _batch_size: self.batch_size = int(_batch_size) self.log_queue: List[LangsmithQueueObject] = [] - self._flush_task: Optional[ - asyncio.Task[Any] - ] = self._start_periodic_flush_task() + self._flush_task: Optional[asyncio.Task[Any]] = ( + self._start_periodic_flush_task() + ) def _start_periodic_flush_task(self) -> Optional[asyncio.Task[Any]]: """Start the periodic flush task only when an event loop is already running.""" @@ -501,9 +501,9 @@ def _group_batches_by_credentials(self) -> Dict[CredentialsKey, BatchGroup]: return log_queue_by_credentials def _get_sampling_rate_to_use_for_request(self, kwargs: Dict[str, Any]) -> float: - standard_callback_dynamic_params: Optional[ - StandardCallbackDynamicParams - ] = kwargs.get("standard_callback_dynamic_params", None) + standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( + kwargs.get("standard_callback_dynamic_params", None) + ) sampling_rate: float = self.sampling_rate if standard_callback_dynamic_params is not None: _sampling_rate = standard_callback_dynamic_params.get( @@ -523,9 +523,9 @@ def _get_credentials_to_use_for_request( Otherwise, use the default credentials. """ - standard_callback_dynamic_params: Optional[ - StandardCallbackDynamicParams - ] = kwargs.get("standard_callback_dynamic_params", None) + standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( + kwargs.get("standard_callback_dynamic_params", None) + ) if standard_callback_dynamic_params is not None: credentials = self.get_credentials_from_env( langsmith_api_key=standard_callback_dynamic_params.get( diff --git a/litellm/integrations/mock_client_factory.py b/litellm/integrations/mock_client_factory.py index 3f2f0ae5b6d..02a927fe64f 100644 --- a/litellm/integrations/mock_client_factory.py +++ b/litellm/integrations/mock_client_factory.py @@ -25,9 +25,9 @@ class MockClientConfig: default_latency_ms: int = 100 # Default mock latency in milliseconds default_status_code: int = 200 # Default HTTP status code default_json_data: Optional[Dict] = None # Default JSON response data - url_matchers: Optional[ - List[str] - ] = None # List of strings to match in URLs (e.g., ["storage.googleapis.com"]) + url_matchers: Optional[List[str]] = ( + None # List of strings to match in URLs (e.g., ["storage.googleapis.com"]) + ) patch_async_handler: bool = True # Whether to patch AsyncHTTPHandler.post patch_sync_client: bool = False # Whether to patch httpx.Client.post patch_http_handler: bool = ( diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index 559ed05d30a..ecfb42cea7b 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -655,9 +655,9 @@ def get_tracer_to_use_for_request(self, kwargs: dict) -> Tracer: def _get_dynamic_otel_headers_from_kwargs(self, kwargs) -> Optional[dict]: """Extract dynamic headers from kwargs if available.""" - standard_callback_dynamic_params: Optional[ - StandardCallbackDynamicParams - ] = kwargs.get("standard_callback_dynamic_params") + standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( + kwargs.get("standard_callback_dynamic_params") + ) if not standard_callback_dynamic_params: return None diff --git a/litellm/integrations/opik/utils.py b/litellm/integrations/opik/utils.py index b0ab5991c91..43577505c11 100644 --- a/litellm/integrations/opik/utils.py +++ b/litellm/integrations/opik/utils.py @@ -105,7 +105,7 @@ def _remove_nulls(x: Dict[str, Any]) -> Dict[str, Any]: def get_traces_and_spans_from_payload( - payload: List[Dict[str, Any]] + payload: List[Dict[str, Any]], ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: """ Separate traces and spans from payload. diff --git a/litellm/integrations/posthog.py b/litellm/integrations/posthog.py index 17bb56b8f17..072ae4945a0 100644 --- a/litellm/integrations/posthog.py +++ b/litellm/integrations/posthog.py @@ -349,9 +349,9 @@ def _get_credentials_for_request( Returns: tuple[str, str]: (api_key, api_url) """ - standard_callback_dynamic_params: Optional[ - StandardCallbackDynamicParams - ] = kwargs.get("standard_callback_dynamic_params", None) + standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( + kwargs.get("standard_callback_dynamic_params", None) + ) if standard_callback_dynamic_params is not None: api_key = ( diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index 357e0229fc6..410802c698c 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -974,9 +974,11 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti ), client_ip=standard_logging_payload["metadata"].get("requester_ip_address"), user_agent=standard_logging_payload["metadata"].get("user_agent"), - stream=str(standard_logging_payload.get("stream")) - if litellm.prometheus_emit_stream_label - else None, + stream=( + str(standard_logging_payload.get("stream")) + if litellm.prometheus_emit_stream_label + else None + ), ) if ( @@ -1629,9 +1631,11 @@ async def async_post_call_failure_hook( client_ip=_metadata.get("requester_ip_address"), user_agent=_metadata.get("user_agent"), model_id=model_id, - stream=str(request_data.get("stream")) - if litellm.prometheus_emit_stream_label - else None, + stream=( + str(request_data.get("stream")) + if litellm.prometheus_emit_stream_label + else None + ), ) _labels = prometheus_label_factory( supported_enum_labels=self.get_labels_for_metric( @@ -1955,9 +1959,9 @@ def set_llm_deployment_success_metrics( ): try: verbose_logger.debug("setting remaining tokens requests metric") - standard_logging_payload: Optional[ - StandardLoggingPayload - ] = request_kwargs.get("standard_logging_object") + standard_logging_payload: Optional[StandardLoggingPayload] = ( + request_kwargs.get("standard_logging_object") + ) if standard_logging_payload is None: return @@ -2469,9 +2473,7 @@ async def _initialize_api_key_budget_metrics(self): ) return - async def fetch_keys( - page_size: int, page: int - ) -> Tuple[ + async def fetch_keys(page_size: int, page: int) -> Tuple[ List[Union[str, UserAPIKeyAuth, LiteLLM_DeletedVerificationToken]], Optional[int], ]: @@ -2995,10 +2997,10 @@ def initialize_budget_metrics_cron_job(scheduler: AsyncIOScheduler): from litellm.constants import PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES from litellm.integrations.custom_logger import CustomLogger - prometheus_loggers: List[ - CustomLogger - ] = litellm.logging_callback_manager.get_custom_loggers_for_type( - callback_type=PrometheusLogger + prometheus_loggers: List[CustomLogger] = ( + litellm.logging_callback_manager.get_custom_loggers_for_type( + callback_type=PrometheusLogger + ) ) # we need to get the initialized prometheus logger instance(s) and call logger.initialize_remaining_budget_metrics() on them verbose_logger.debug("found %s prometheus loggers", len(prometheus_loggers)) diff --git a/litellm/integrations/s3_v2.py b/litellm/integrations/s3_v2.py index 405bf9698cc..b78be59780d 100644 --- a/litellm/integrations/s3_v2.py +++ b/litellm/integrations/s3_v2.py @@ -1,8 +1,8 @@ """ s3 Bucket Logging Integration -async_log_success_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3 -async_log_failure_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3 +async_log_success_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3 +async_log_failure_event: Processes the event, stores it in memory for DEFAULT_S3_FLUSH_INTERVAL_SECONDS seconds or until DEFAULT_S3_BATCH_SIZE and then flushes to s3 NOTE 1: S3 does not provide a BATCH PUT API endpoint, so we create tasks to upload each element individually """ @@ -578,9 +578,11 @@ def upload_data_to_s3(self, batch_logging_element: s3BatchLoggingElement): signed_headers = dict(aws_request.headers.items()) httpx_client = _get_httpx_client( - params={"ssl_verify": self.s3_verify} - if self.s3_verify is not None - else None + params=( + {"ssl_verify": self.s3_verify} + if self.s3_verify is not None + else None + ) ) # Make the request response = httpx_client.put(url, data=json_string, headers=signed_headers) diff --git a/litellm/integrations/vantage/vantage_logger.py b/litellm/integrations/vantage/vantage_logger.py index e0942472bec..1e6e46b36ae 100644 --- a/litellm/integrations/vantage/vantage_logger.py +++ b/litellm/integrations/vantage/vantage_logger.py @@ -83,9 +83,11 @@ def __init__( verbose_logger.debug( "VantageLogger initialized (integration_token=%s)", - resolved_token[:4] + "***" - if resolved_token and len(resolved_token) > 4 - else "***", + ( + resolved_token[:4] + "***" + if resolved_token and len(resolved_token) > 4 + else "***" + ), ) async def initialize_focus_export_job(self) -> None: @@ -124,10 +126,10 @@ async def init_vantage_background_job( scheduler: AsyncIOScheduler, ) -> None: """Register the Vantage export job with the provided scheduler.""" - vantage_loggers: List[ - CustomLogger - ] = litellm.logging_callback_manager.get_custom_loggers_for_type( - callback_type=VantageLogger + vantage_loggers: List[CustomLogger] = ( + litellm.logging_callback_manager.get_custom_loggers_for_type( + callback_type=VantageLogger + ) ) if not vantage_loggers: verbose_logger.debug("No Vantage logger registered; skipping scheduler") diff --git a/litellm/integrations/vector_store_integrations/vector_store_pre_call_hook.py b/litellm/integrations/vector_store_integrations/vector_store_pre_call_hook.py index 50420fb7137..482a19c5d72 100644 --- a/litellm/integrations/vector_store_integrations/vector_store_pre_call_hook.py +++ b/litellm/integrations/vector_store_integrations/vector_store_pre_call_hook.py @@ -88,12 +88,12 @@ async def async_get_chat_completion_prompt( pass # Use database fallback to ensure synchronization across instances - vector_stores_to_run: List[ - LiteLLM_ManagedVectorStore - ] = await litellm.vector_store_registry.pop_vector_stores_to_run_with_db_fallback( - non_default_params=non_default_params, - tools=tools, - prisma_client=prisma_client, + vector_stores_to_run: List[LiteLLM_ManagedVectorStore] = ( + await litellm.vector_store_registry.pop_vector_stores_to_run_with_db_fallback( + non_default_params=non_default_params, + tools=tools, + prisma_client=prisma_client, + ) ) if not vector_stores_to_run: @@ -147,9 +147,9 @@ async def async_get_chat_completion_prompt( # Store search results as-is (already in OpenAI-compatible format) if litellm_logging_obj and all_search_results: - litellm_logging_obj.model_call_details[ - "search_results" - ] = all_search_results + litellm_logging_obj.model_call_details["search_results"] = ( + all_search_results + ) return model, modified_messages, non_default_params @@ -208,9 +208,9 @@ def _append_search_results_to_messages( Returns: Modified list of messages with context appended """ - search_response_data: Optional[ - List[VectorStoreSearchResult] - ] = search_response.get("data") + search_response_data: Optional[List[VectorStoreSearchResult]] = ( + search_response.get("data") + ) if not search_response_data: return messages @@ -268,9 +268,9 @@ async def async_post_call_success_deployment_hook( ) # Get search results from model_call_details (already in OpenAI format) - search_results: Optional[ - List[VectorStoreSearchResponse] - ] = litellm_logging_obj.model_call_details.get("search_results") + search_results: Optional[List[VectorStoreSearchResponse]] = ( + litellm_logging_obj.model_call_details.get("search_results") + ) verbose_logger.debug(f"Search results found: {search_results is not None}") @@ -328,9 +328,9 @@ async def async_post_call_streaming_deployment_hook( ) # Get search results from model_call_details (already in OpenAI format) - search_results: Optional[ - List[VectorStoreSearchResponse] - ] = request_data.get("search_results") + search_results: Optional[List[VectorStoreSearchResponse]] = ( + request_data.get("search_results") + ) verbose_logger.debug( f"Search results found for streaming chunk: {search_results is not None}" diff --git a/litellm/integrations/websearch_interception/transformation.py b/litellm/integrations/websearch_interception/transformation.py index f777a7d7418..00d4829ad39 100644 --- a/litellm/integrations/websearch_interception/transformation.py +++ b/litellm/integrations/websearch_interception/transformation.py @@ -3,6 +3,7 @@ Transforms between Anthropic/OpenAI tool_use format and LiteLLM search format. """ + import json from typing import Any, Dict, List, Optional, Tuple, Union @@ -326,9 +327,11 @@ def _transform_response_openai( "type": "function", "function": { "name": tc["name"], - "arguments": json.dumps(tc["input"]) - if isinstance(tc["input"], dict) - else str(tc["input"]), + "arguments": ( + json.dumps(tc["input"]) + if isinstance(tc["input"], dict) + else str(tc["input"]) + ), }, } for tc in tool_calls diff --git a/litellm/integrations/weights_biases.py b/litellm/integrations/weights_biases.py index 028b6e69a81..e9539d27e97 100644 --- a/litellm/integrations/weights_biases.py +++ b/litellm/integrations/weights_biases.py @@ -21,8 +21,7 @@ class OpenAIResponse(Protocol[K, V]): # type: ignore # contains a (known) object attribute object: Literal["chat.completion", "edit", "text_completion"] - def __getitem__(self, key: K) -> V: - ... # noqa + def __getitem__(self, key: K) -> V: ... # noqa def get(self, key: K, default: Optional[V] = None) -> Optional[V]: # noqa ... # pragma: no cover diff --git a/litellm/interactions/__init__.py b/litellm/interactions/__init__.py index e1125b649a6..0077ac82076 100644 --- a/litellm/interactions/__init__.py +++ b/litellm/interactions/__init__.py @@ -5,28 +5,28 @@ Usage: import litellm - + # Create an interaction with a model response = litellm.interactions.create( model="gemini-2.5-flash", input="Hello, how are you?" ) - + # Create an interaction with an agent response = litellm.interactions.create( agent="deep-research-pro-preview-12-2025", input="Research the current state of cancer research" ) - + # Async version response = await litellm.interactions.acreate(...) - + # Get an interaction response = litellm.interactions.get(interaction_id="...") - + # Delete an interaction result = litellm.interactions.delete(interaction_id="...") - + # Cancel an interaction result = litellm.interactions.cancel(interaction_id="...") diff --git a/litellm/interactions/litellm_responses_transformation/transformation.py b/litellm/interactions/litellm_responses_transformation/transformation.py index b07e61c76dd..100300af7b5 100644 --- a/litellm/interactions/litellm_responses_transformation/transformation.py +++ b/litellm/interactions/litellm_responses_transformation/transformation.py @@ -45,10 +45,10 @@ def transform_interactions_request_to_responses_request( # Transform input if input is not None: - responses_request[ - "input" - ] = LiteLLMResponsesInteractionsConfig._transform_interactions_input_to_responses_input( - input + responses_request["input"] = ( + LiteLLMResponsesInteractionsConfig._transform_interactions_input_to_responses_input( + input + ) ) # Transform system_instruction -> instructions diff --git a/litellm/interactions/main.py b/litellm/interactions/main.py index ab429ef6db5..be9f2b99e0e 100644 --- a/litellm/interactions/main.py +++ b/litellm/interactions/main.py @@ -8,25 +8,25 @@ Usage: import litellm - + # Create an interaction with a model response = litellm.interactions.create( model="gemini-2.5-flash", input="Hello, how are you?" ) - + # Create an interaction with an agent response = litellm.interactions.create( agent="deep-research-pro-preview-12-2025", input="Research the current state of cancer research" ) - + # Async version response = await litellm.interactions.acreate(...) - + # Get an interaction response = litellm.interactions.get(interaction_id="...") - + # Delete an interaction result = litellm.interactions.delete(interaction_id="...") """ diff --git a/litellm/litellm_core_utils/default_encoding.py b/litellm/litellm_core_utils/default_encoding.py index f704ba568de..f58b90c8e72 100644 --- a/litellm/litellm_core_utils/default_encoding.py +++ b/litellm/litellm_core_utils/default_encoding.py @@ -26,9 +26,9 @@ else: cache_dir = filename -os.environ[ - "TIKTOKEN_CACHE_DIR" -] = cache_dir # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071 +os.environ["TIKTOKEN_CACHE_DIR"] = ( + cache_dir # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071 +) import tiktoken import time diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 5323f692b80..a5262505983 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -354,9 +354,9 @@ def __init__( ) self.function_id = function_id self.streaming_chunks: List[Any] = [] # for generating complete stream response - self.sync_streaming_chunks: List[ - Any - ] = [] # for generating complete stream response + self.sync_streaming_chunks: List[Any] = ( + [] + ) # for generating complete stream response self.log_raw_request_response = log_raw_request_response # Initialize dynamic callbacks @@ -801,9 +801,9 @@ def _auto_detect_prompt_management_logger( prompt_spec=prompt_spec, dynamic_callback_params=dynamic_callback_params, ): - self.model_call_details[ - "prompt_integration" - ] = logger.__class__.__name__ + self.model_call_details["prompt_integration"] = ( + logger.__class__.__name__ + ) return logger except Exception: # If check fails, continue to next logger @@ -871,9 +871,9 @@ def get_custom_logger_for_prompt_management( if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook( non_default_params ): - self.model_call_details[ - "prompt_integration" - ] = anthropic_cache_control_logger.__class__.__name__ + self.model_call_details["prompt_integration"] = ( + anthropic_cache_control_logger.__class__.__name__ + ) return anthropic_cache_control_logger ######################################################### @@ -885,9 +885,9 @@ def get_custom_logger_for_prompt_management( internal_usage_cache=None, llm_router=None, ) - self.model_call_details[ - "prompt_integration" - ] = vector_store_custom_logger.__class__.__name__ + self.model_call_details["prompt_integration"] = ( + vector_store_custom_logger.__class__.__name__ + ) # Add to global callbacks so post-call hooks are invoked if ( vector_store_custom_logger @@ -947,9 +947,9 @@ def _pre_call(self, input, api_key, model=None, additional_args={}): model ): # if model name was changes pre-call, overwrite the initial model call name with the new one self.model_call_details["model"] = model - self.model_call_details["litellm_params"][ - "api_base" - ] = self._get_masked_api_base(additional_args.get("api_base", "")) + self.model_call_details["litellm_params"]["api_base"] = ( + self._get_masked_api_base(additional_args.get("api_base", "")) + ) def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915 # Log the exact input to the LLM API @@ -978,9 +978,7 @@ def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR try: # [Non-blocking Extra Debug Information in metadata] if turn_off_message_logging is True: - _metadata[ - "raw_request" - ] = "redacted by litellm. \ + _metadata["raw_request"] = "redacted by litellm. \ 'litellm.turn_off_message_logging=True'" else: curl_command = self._get_request_curl_command( @@ -992,35 +990,31 @@ def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR _metadata["raw_request"] = str(curl_command) # split up, so it's easier to parse in the UI - self.model_call_details[ - "raw_request_typed_dict" - ] = RawRequestTypedDict( - raw_request_api_base=str( - additional_args.get("api_base") or "" - ), - raw_request_body=self._get_raw_request_body( - additional_args.get("complete_input_dict", {}) - ), - # NOTE: setting ignore_sensitive_headers to True will cause - # the Authorization header to be leaked when calls to the health - # endpoint are made and fail. - raw_request_headers=self._get_masked_headers( - additional_args.get("headers", {}) or {}, - ), - error=None, + self.model_call_details["raw_request_typed_dict"] = ( + RawRequestTypedDict( + raw_request_api_base=str( + additional_args.get("api_base") or "" + ), + raw_request_body=self._get_raw_request_body( + additional_args.get("complete_input_dict", {}) + ), + # NOTE: setting ignore_sensitive_headers to True will cause + # the Authorization header to be leaked when calls to the health + # endpoint are made and fail. + raw_request_headers=self._get_masked_headers( + additional_args.get("headers", {}) or {}, + ), + error=None, + ) ) except Exception as e: - self.model_call_details[ - "raw_request_typed_dict" - ] = RawRequestTypedDict( - error=str(e), - ) - _metadata[ - "raw_request" - ] = "Unable to Log \ - raw request: {}".format( - str(e) + self.model_call_details["raw_request_typed_dict"] = ( + RawRequestTypedDict( + error=str(e), + ) ) + _metadata["raw_request"] = "Unable to Log \ + raw request: {}".format(str(e)) if getattr(self, "logger_fn", None) and callable(self.logger_fn): try: self.logger_fn( @@ -1320,13 +1314,13 @@ async def async_post_mcp_tool_call_hook( for callback in callbacks: try: if isinstance(callback, CustomLogger): - response: Optional[ - MCPPostCallResponseObject - ] = await callback.async_post_mcp_tool_call_hook( - kwargs=kwargs, - response_obj=post_mcp_tool_call_response_obj, - start_time=start_time, - end_time=end_time, + response: Optional[MCPPostCallResponseObject] = ( + await callback.async_post_mcp_tool_call_hook( + kwargs=kwargs, + response_obj=post_mcp_tool_call_response_obj, + start_time=start_time, + end_time=end_time, + ) ) ###################################################################### # if any of the callbacks modify the response, use the modified response @@ -1527,9 +1521,9 @@ def _response_cost_calculator( verbose_logger.debug( f"response_cost_failure_debug_information: {debug_info}" ) - self.model_call_details[ - "response_cost_failure_debug_information" - ] = debug_info + self.model_call_details["response_cost_failure_debug_information"] = ( + debug_info + ) return None try: @@ -1555,9 +1549,9 @@ def _response_cost_calculator( verbose_logger.debug( f"response_cost_failure_debug_information: {debug_info}" ) - self.model_call_details[ - "response_cost_failure_debug_information" - ] = debug_info + self.model_call_details["response_cost_failure_debug_information"] = ( + debug_info + ) return None @@ -1706,9 +1700,9 @@ def _merge_hidden_params_from_response_into_metadata( self.model_call_details["litellm_params"].setdefault("metadata", {}) if self.model_call_details["litellm_params"]["metadata"] is None: self.model_call_details["litellm_params"]["metadata"] = {} - self.model_call_details["litellm_params"]["metadata"][ - "hidden_params" - ] = getattr(logging_result, "_hidden_params", {}) + self.model_call_details["litellm_params"]["metadata"]["hidden_params"] = ( + getattr(logging_result, "_hidden_params", {}) + ) def _process_hidden_params_and_response_cost( self, @@ -1737,9 +1731,9 @@ def _process_hidden_params_and_response_cost( result=logging_result ) - self.model_call_details[ - "standard_logging_object" - ] = self._build_standard_logging_payload(logging_result, start_time, end_time) + self.model_call_details["standard_logging_object"] = ( + self._build_standard_logging_payload(logging_result, start_time, end_time) + ) if ( standard_logging_payload := self.model_call_details.get( @@ -1817,9 +1811,9 @@ def _success_handler_helper_fn( end_time = datetime.datetime.now() if self.completion_start_time is None: self.completion_start_time = end_time - self.model_call_details[ - "completion_start_time" - ] = self.completion_start_time + self.model_call_details["completion_start_time"] = ( + self.completion_start_time + ) self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["end_time"] = end_time @@ -1856,10 +1850,10 @@ def _success_handler_helper_fn( end_time=end_time, ) elif isinstance(result, dict) or isinstance(result, list): - self.model_call_details[ - "standard_logging_object" - ] = self._build_standard_logging_payload( - result, start_time, end_time + self.model_call_details["standard_logging_object"] = ( + self._build_standard_logging_payload( + result, start_time, end_time + ) ) if ( standard_logging_payload := self.model_call_details.get( @@ -1868,9 +1862,9 @@ def _success_handler_helper_fn( ) is not None: emit_standard_logging_payload(standard_logging_payload) elif standard_logging_object is not None: - self.model_call_details[ - "standard_logging_object" - ] = standard_logging_object + self.model_call_details["standard_logging_object"] = ( + standard_logging_object + ) else: self.model_call_details["response_cost"] = None @@ -2028,20 +2022,20 @@ def success_handler( # noqa: PLR0915 verbose_logger.debug( "Logging Details LiteLLM-Success Call streaming complete" ) - self.model_call_details[ - "complete_streaming_response" - ] = complete_streaming_response - self.model_call_details[ - "response_cost" - ] = self._response_cost_calculator(result=complete_streaming_response) + self.model_call_details["complete_streaming_response"] = ( + complete_streaming_response + ) + self.model_call_details["response_cost"] = ( + self._response_cost_calculator(result=complete_streaming_response) + ) self._merge_hidden_params_from_response_into_metadata( complete_streaming_response ) ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details[ - "standard_logging_object" - ] = self._build_standard_logging_payload( - complete_streaming_response, start_time, end_time + self.model_call_details["standard_logging_object"] = ( + self._build_standard_logging_payload( + complete_streaming_response, start_time, end_time + ) ) if ( standard_logging_payload := self.model_call_details.get( @@ -2375,10 +2369,10 @@ def success_handler( # noqa: PLR0915 ) else: if self.stream and complete_streaming_response: - self.model_call_details[ - "complete_response" - ] = self.model_call_details.get( - "complete_streaming_response", {} + self.model_call_details["complete_response"] = ( + self.model_call_details.get( + "complete_streaming_response", {} + ) ) result = self.model_call_details["complete_response"] openMeterLogger.log_success_event( @@ -2402,10 +2396,10 @@ def success_handler( # noqa: PLR0915 ) else: if self.stream and complete_streaming_response: - self.model_call_details[ - "complete_response" - ] = self.model_call_details.get( - "complete_streaming_response", {} + self.model_call_details["complete_response"] = ( + self.model_call_details.get( + "complete_streaming_response", {} + ) ) result = self.model_call_details["complete_response"] @@ -2544,9 +2538,9 @@ async def async_success_handler( # noqa: PLR0915 if complete_streaming_response is not None: print_verbose("Async success callbacks: Got a complete streaming response") - self.model_call_details[ - "async_complete_streaming_response" - ] = complete_streaming_response + self.model_call_details["async_complete_streaming_response"] = ( + complete_streaming_response + ) try: if self.model_call_details.get("cache_hit", False) is True: @@ -2557,10 +2551,10 @@ async def async_success_handler( # noqa: PLR0915 model_call_details=self.model_call_details ) # base_model defaults to None if not set on model_info - self.model_call_details[ - "response_cost" - ] = self._response_cost_calculator( - result=complete_streaming_response + self.model_call_details["response_cost"] = ( + self._response_cost_calculator( + result=complete_streaming_response + ) ) verbose_logger.debug( @@ -2577,10 +2571,10 @@ async def async_success_handler( # noqa: PLR0915 ) ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details[ - "standard_logging_object" - ] = self._build_standard_logging_payload( - complete_streaming_response, start_time, end_time + self.model_call_details["standard_logging_object"] = ( + self._build_standard_logging_payload( + complete_streaming_response, start_time, end_time + ) ) # print standard logging payload @@ -2607,9 +2601,9 @@ async def async_success_handler( # noqa: PLR0915 # _success_handler_helper_fn if self.model_call_details.get("standard_logging_object") is None: ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details[ - "standard_logging_object" - ] = self._build_standard_logging_payload(result, start_time, end_time) + self.model_call_details["standard_logging_object"] = ( + self._build_standard_logging_payload(result, start_time, end_time) + ) # print standard logging payload if ( @@ -2852,18 +2846,18 @@ def _failure_handler_helper_fn( ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details[ - "standard_logging_object" - ] = get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj={}, - start_time=start_time, - end_time=end_time, - logging_obj=self, - status="failure", - error_str=str(exception), - original_exception=exception, - standard_built_in_tools_params=self.standard_built_in_tools_params, + self.model_call_details["standard_logging_object"] = ( + get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj={}, + start_time=start_time, + end_time=end_time, + logging_obj=self, + status="failure", + error_str=str(exception), + original_exception=exception, + standard_built_in_tools_params=self.standard_built_in_tools_params, + ) ) return start_time, end_time @@ -3831,9 +3825,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 service_name=arize_config.project_name, ) - os.environ[ - "OTEL_EXPORTER_OTLP_TRACES_HEADERS" - ] = f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}" + os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( + f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}" + ) for callback in _in_memory_loggers: if ( isinstance(callback, ArizeLogger) @@ -3859,13 +3853,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "") # Add openinference.project.name attribute if existing_attrs: - os.environ[ - "OTEL_RESOURCE_ATTRIBUTES" - ] = f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}" + os.environ["OTEL_RESOURCE_ATTRIBUTES"] = ( + f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}" + ) else: - os.environ[ - "OTEL_RESOURCE_ATTRIBUTES" - ] = f"openinference.project.name={arize_phoenix_config.project_name}" + os.environ["OTEL_RESOURCE_ATTRIBUTES"] = ( + f"openinference.project.name={arize_phoenix_config.project_name}" + ) # Set Phoenix project name from environment variable phoenix_project_name = os.environ.get("PHOENIX_PROJECT_NAME", None) @@ -3873,19 +3867,19 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "") # Add openinference.project.name attribute if existing_attrs: - os.environ[ - "OTEL_RESOURCE_ATTRIBUTES" - ] = f"{existing_attrs},openinference.project.name={phoenix_project_name}" + os.environ["OTEL_RESOURCE_ATTRIBUTES"] = ( + f"{existing_attrs},openinference.project.name={phoenix_project_name}" + ) else: - os.environ[ - "OTEL_RESOURCE_ATTRIBUTES" - ] = f"openinference.project.name={phoenix_project_name}" + os.environ["OTEL_RESOURCE_ATTRIBUTES"] = ( + f"openinference.project.name={phoenix_project_name}" + ) # auth can be disabled on local deployments of arize phoenix if arize_phoenix_config.otlp_auth_headers is not None: - os.environ[ - "OTEL_EXPORTER_OTLP_TRACES_HEADERS" - ] = arize_phoenix_config.otlp_auth_headers + os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( + arize_phoenix_config.otlp_auth_headers + ) for callback in _in_memory_loggers: if ( @@ -4072,9 +4066,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 exporter="otlp_http", endpoint="https://langtrace.ai/api/trace", ) - os.environ[ - "OTEL_EXPORTER_OTLP_TRACES_HEADERS" - ] = f"api_key={os.getenv('LANGTRACE_API_KEY')}" + os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( + f"api_key={os.getenv('LANGTRACE_API_KEY')}" + ) for callback in _in_memory_loggers: if ( isinstance(callback, OpenTelemetry) @@ -4998,10 +4992,10 @@ def get_hidden_params( for key in StandardLoggingHiddenParams.__annotations__.keys(): if key in hidden_params: if key == "additional_headers": - clean_hidden_params[ - "additional_headers" - ] = StandardLoggingPayloadSetup.get_additional_headers( - hidden_params[key] + clean_hidden_params["additional_headers"] = ( + StandardLoggingPayloadSetup.get_additional_headers( + hidden_params[key] + ) ) else: clean_hidden_params[key] = hidden_params[key] # type: ignore @@ -5640,9 +5634,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]): ): for k, v in metadata["user_api_key_metadata"].items(): if k == "logging": # prevent logging user logging keys - cleaned_user_api_key_metadata[ - k - ] = "scrubbed_by_litellm_for_sensitive_keys" + cleaned_user_api_key_metadata[k] = ( + "scrubbed_by_litellm_for_sensitive_keys" + ) else: cleaned_user_api_key_metadata[k] = v diff --git a/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py b/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py index 20cc5746667..78378faa262 100644 --- a/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py +++ b/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py @@ -596,9 +596,9 @@ def convert_to_model_response_object( # noqa: PLR0915 provider_specific_fields["thinking_blocks"] = thinking_blocks if reasoning_content: - provider_specific_fields[ - "reasoning_content" - ] = reasoning_content + provider_specific_fields["reasoning_content"] = ( + reasoning_content + ) message = Message( content=content, @@ -787,9 +787,9 @@ def convert_to_model_response_object( # noqa: PLR0915 # tracking without exposing it in the response body. Must be set # after hidden_params assignment to avoid being overwritten. if "_audio_transcription_duration" in response_object: - model_response_object._hidden_params[ - "audio_transcription_duration" - ] = response_object["_audio_transcription_duration"] + model_response_object._hidden_params["audio_transcription_duration"] = ( + response_object["_audio_transcription_duration"] + ) if _response_headers is not None: model_response_object._response_headers = _response_headers diff --git a/litellm/litellm_core_utils/model_param_helper.py b/litellm/litellm_core_utils/model_param_helper.py index 66b174feac4..4d45c47c224 100644 --- a/litellm/litellm_core_utils/model_param_helper.py +++ b/litellm/litellm_core_utils/model_param_helper.py @@ -93,9 +93,9 @@ def _get_litellm_supported_chat_completion_kwargs() -> Set[str]: streaming_params: Set[str] = set( getattr(CompletionCreateParamsStreaming, "__annotations__", {}).keys() ) - litellm_provider_specific_params: Set[ - str - ] = ModelParamHelper.get_litellm_provider_specific_params_for_chat_params() + litellm_provider_specific_params: Set[str] = ( + ModelParamHelper.get_litellm_provider_specific_params_for_chat_params() + ) all_chat_completion_kwargs: Set[str] = non_streaming_params.union( streaming_params ).union(litellm_provider_specific_params) diff --git a/litellm/litellm_core_utils/prompt_templates/factory.py b/litellm/litellm_core_utils/prompt_templates/factory.py index d29ca1649ff..069e8274e9f 100644 --- a/litellm/litellm_core_utils/prompt_templates/factory.py +++ b/litellm/litellm_core_utils/prompt_templates/factory.py @@ -1393,10 +1393,10 @@ def convert_to_gemini_tool_call_invoke( if tool_calls is not None: for idx, tool in enumerate(tool_calls): if "function" in tool: - gemini_function_call: Optional[ - VertexFunctionCall - ] = _gemini_tool_call_invoke_helper( - function_call_params=tool["function"] + gemini_function_call: Optional[VertexFunctionCall] = ( + _gemini_tool_call_invoke_helper( + function_call_params=tool["function"] + ) ) if gemini_function_call is not None: part_dict: VertexPartType = { @@ -1574,9 +1574,7 @@ def convert_to_gemini_tool_call_result( # noqa: PLR0915 file_data = ( file_content.get("file_data", "") if isinstance(file_content, dict) - else file_content - if isinstance(file_content, str) - else "" + else file_content if isinstance(file_content, str) else "" ) if file_data: @@ -2081,9 +2079,9 @@ def _sanitize_empty_text_content( if isinstance(content, str): if not content or not content.strip(): message = cast(AllMessageValues, dict(message)) # Make a copy - message[ - "content" - ] = "[System: Empty message content sanitised to satisfy protocol]" + message["content"] = ( + "[System: Empty message content sanitised to satisfy protocol]" + ) verbose_logger.debug( f"_sanitize_empty_text_content: Replaced empty text content in {message.get('role')} message" ) @@ -2423,9 +2421,9 @@ def anthropic_messages_pt( # noqa: PLR0915 # Convert ChatCompletionImageUrlObject to dict if needed image_url_value = m["image_url"] if isinstance(image_url_value, str): - image_url_input: Union[ - str, dict[str, Any] - ] = image_url_value + image_url_input: Union[str, dict[str, Any]] = ( + image_url_value + ) else: # ChatCompletionImageUrlObject or dict case - convert to dict image_url_input = { @@ -2452,9 +2450,9 @@ def anthropic_messages_pt( # noqa: PLR0915 ) if "cache_control" in _content_element: - _anthropic_content_element[ - "cache_control" - ] = _content_element["cache_control"] + _anthropic_content_element["cache_control"] = ( + _content_element["cache_control"] + ) user_content.append(_anthropic_content_element) elif m.get("type", "") == "text": m = cast(ChatCompletionTextObject, m) @@ -2514,9 +2512,9 @@ def anthropic_messages_pt( # noqa: PLR0915 ) if "cache_control" in _content_element: - _anthropic_content_text_element[ - "cache_control" - ] = _content_element["cache_control"] + _anthropic_content_text_element["cache_control"] = ( + _content_element["cache_control"] + ) user_content.append(_anthropic_content_text_element) @@ -2649,9 +2647,9 @@ def anthropic_messages_pt( # noqa: PLR0915 original_content_element=dict(assistant_content_block), ) if "cache_control" in _content_element: - _anthropic_text_content_element[ - "cache_control" - ] = _content_element["cache_control"] + _anthropic_text_content_element["cache_control"] = ( + _content_element["cache_control"] + ) text_element = _anthropic_text_content_element # Interleave: each thinking block precedes its server tool group. @@ -2811,9 +2809,9 @@ def anthropic_messages_pt( # noqa: PLR0915 ) if "cache_control" in _content_element: - _anthropic_text_content_element[ - "cache_control" - ] = _content_element["cache_control"] + _anthropic_text_content_element["cache_control"] = ( + _content_element["cache_control"] + ) assistant_content.append(_anthropic_text_content_element) @@ -5255,9 +5253,7 @@ def default_response_schema_prompt(response_schema: dict) -> str: prompt_str = """Use this JSON schema: ```json {} - ```""".format( - response_schema - ) + ```""".format(response_schema) return prompt_str diff --git a/litellm/litellm_core_utils/realtime_streaming.py b/litellm/litellm_core_utils/realtime_streaming.py index 37233680714..4493a58f78b 100644 --- a/litellm/litellm_core_utils/realtime_streaming.py +++ b/litellm/litellm_core_utils/realtime_streaming.py @@ -199,12 +199,12 @@ async def log_messages(self): if self.input_messages: self.logging_obj.model_call_details["messages"] = self.input_messages if self.session_tools or self.tool_calls: - self.logging_obj.model_call_details[ - "realtime_tools" - ] = self.session_tools - self.logging_obj.model_call_details[ - "realtime_tool_calls" - ] = self.tool_calls + self.logging_obj.model_call_details["realtime_tools"] = ( + self.session_tools + ) + self.logging_obj.model_call_details["realtime_tool_calls"] = ( + self.tool_calls + ) ## ASYNC LOGGING # Create an event loop for the new thread asyncio.create_task(self.logging_obj.async_success_handler(self.messages)) diff --git a/litellm/litellm_core_utils/redact_messages.py b/litellm/litellm_core_utils/redact_messages.py index dbeb4111077..f3f560b33b9 100644 --- a/litellm/litellm_core_utils/redact_messages.py +++ b/litellm/litellm_core_utils/redact_messages.py @@ -285,9 +285,9 @@ def _get_turn_off_message_logging_from_dynamic_params( handles boolean and string values of `turn_off_message_logging` """ - standard_callback_dynamic_params: Optional[ - StandardCallbackDynamicParams - ] = model_call_details.get("standard_callback_dynamic_params", None) + standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( + model_call_details.get("standard_callback_dynamic_params", None) + ) if standard_callback_dynamic_params: _turn_off_message_logging = standard_callback_dynamic_params.get( "turn_off_message_logging" diff --git a/litellm/litellm_core_utils/safe_json_loads.py b/litellm/litellm_core_utils/safe_json_loads.py index bb4b72cfd97..b0a8e57d552 100644 --- a/litellm/litellm_core_utils/safe_json_loads.py +++ b/litellm/litellm_core_utils/safe_json_loads.py @@ -1,6 +1,7 @@ """ Helper for safe JSON loading in LiteLLM. """ + from typing import Any import json diff --git a/litellm/litellm_core_utils/specialty_caches/dynamic_logging_cache.py b/litellm/litellm_core_utils/specialty_caches/dynamic_logging_cache.py index c2acc708bb5..0a6a4e82c72 100644 --- a/litellm/litellm_core_utils/specialty_caches/dynamic_logging_cache.py +++ b/litellm/litellm_core_utils/specialty_caches/dynamic_logging_cache.py @@ -1,12 +1,13 @@ """ This is a cache for LangfuseLoggers. -Langfuse Python SDK initializes a thread for each client. +Langfuse Python SDK initializes a thread for each client. -This ensures we do +This ensures we do 1. Proper cleanup of Langfuse initialized clients. 2. Re-use created langfuse clients. """ + import hashlib import json from typing import Any, Optional diff --git a/litellm/litellm_core_utils/streaming_chunk_builder_utils.py b/litellm/litellm_core_utils/streaming_chunk_builder_utils.py index 1935372e5df..0829010ddb0 100644 --- a/litellm/litellm_core_utils/streaming_chunk_builder_utils.py +++ b/litellm/litellm_core_utils/streaming_chunk_builder_utils.py @@ -160,9 +160,9 @@ def get_combined_tool_content( # noqa: PLR0915 self, tool_call_chunks: List[Dict[str, Any]] ) -> List[ChatCompletionMessageToolCall]: tool_calls_list: List[ChatCompletionMessageToolCall] = [] - tool_call_map: Dict[ - int, Dict[str, Any] - ] = {} # Map to store tool calls by index + tool_call_map: Dict[int, Dict[str, Any]] = ( + {} + ) # Map to store tool calls by index for chunk in tool_call_chunks: choices = chunk["choices"] @@ -643,12 +643,12 @@ def calculate_usage( web_search_requests: Optional[int] = calculated_usage_per_chunk[ "web_search_requests" ] - completion_tokens_details: Optional[ - CompletionTokensDetails - ] = calculated_usage_per_chunk["completion_tokens_details"] - prompt_tokens_details: Optional[ - PromptTokensDetailsWrapper - ] = calculated_usage_per_chunk["prompt_tokens_details"] + completion_tokens_details: Optional[CompletionTokensDetails] = ( + calculated_usage_per_chunk["completion_tokens_details"] + ) + prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = ( + calculated_usage_per_chunk["prompt_tokens_details"] + ) try: returned_usage.prompt_tokens = prompt_tokens or token_counter( diff --git a/litellm/llms/a2a/__init__.py b/litellm/llms/a2a/__init__.py index 043efa5e8bf..340f45dbab6 100644 --- a/litellm/llms/a2a/__init__.py +++ b/litellm/llms/a2a/__init__.py @@ -1,6 +1,7 @@ """ A2A (Agent-to-Agent) Protocol Provider for LiteLLM """ + from .chat.transformation import A2AConfig __all__ = ["A2AConfig"] diff --git a/litellm/llms/a2a/chat/__init__.py b/litellm/llms/a2a/chat/__init__.py index 76bf4dd71d9..c7cc8a7b0da 100644 --- a/litellm/llms/a2a/chat/__init__.py +++ b/litellm/llms/a2a/chat/__init__.py @@ -1,6 +1,7 @@ """ A2A Chat Completion Implementation """ + from .transformation import A2AConfig __all__ = ["A2AConfig"] diff --git a/litellm/llms/a2a/chat/streaming_iterator.py b/litellm/llms/a2a/chat/streaming_iterator.py index 72902f65f7c..29167d89ae7 100644 --- a/litellm/llms/a2a/chat/streaming_iterator.py +++ b/litellm/llms/a2a/chat/streaming_iterator.py @@ -1,6 +1,7 @@ """ A2A Streaming Response Iterator """ + from typing import Optional, Union from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator diff --git a/litellm/llms/a2a/chat/transformation.py b/litellm/llms/a2a/chat/transformation.py index d0887028632..b9c9f944b3e 100644 --- a/litellm/llms/a2a/chat/transformation.py +++ b/litellm/llms/a2a/chat/transformation.py @@ -1,6 +1,7 @@ """ A2A Protocol Transformation for LiteLLM """ + import uuid from typing import Any, Dict, Iterator, List, Optional, Union diff --git a/litellm/llms/a2a/common_utils.py b/litellm/llms/a2a/common_utils.py index aa817ce0fe6..15ea9f01abd 100644 --- a/litellm/llms/a2a/common_utils.py +++ b/litellm/llms/a2a/common_utils.py @@ -1,6 +1,7 @@ """ Common utilities for A2A (Agent-to-Agent) Protocol """ + from typing import Any, Dict, List from pydantic import BaseModel diff --git a/litellm/llms/amazon_nova/chat/transformation.py b/litellm/llms/amazon_nova/chat/transformation.py index 0fd08e62872..74c7fd234fe 100644 --- a/litellm/llms/amazon_nova/chat/transformation.py +++ b/litellm/llms/amazon_nova/chat/transformation.py @@ -1,6 +1,7 @@ """ Translate from OpenAI's `/v1/chat/completions` to Amazon Nova's `/v1/chat/completions` """ + from typing import Any, List, Optional, Tuple import httpx diff --git a/litellm/llms/anthropic/batches/transformation.py b/litellm/llms/anthropic/batches/transformation.py index 98c0588a091..3f03c744efe 100644 --- a/litellm/llms/anthropic/batches/transformation.py +++ b/litellm/llms/anthropic/batches/transformation.py @@ -229,12 +229,12 @@ def parse_timestamp(ts_str: Optional[str]) -> Optional[int]: completed_at=ended_at if processing_status == "ended" else None, failed_at=None, expired_at=archived_at if archived_at else None, - cancelling_at=cancel_initiated_at - if processing_status == "canceling" - else None, - cancelled_at=ended_at - if processing_status == "canceling" and ended_at - else None, + cancelling_at=( + cancel_initiated_at if processing_status == "canceling" else None + ), + cancelled_at=( + ended_at if processing_status == "canceling" and ended_at else None + ), request_counts=request_counts, metadata={}, ) diff --git a/litellm/llms/anthropic/chat/guardrail_translation/handler.py b/litellm/llms/anthropic/chat/guardrail_translation/handler.py index 5372757cbb6..79d578d1438 100644 --- a/litellm/llms/anthropic/chat/guardrail_translation/handler.py +++ b/litellm/llms/anthropic/chat/guardrail_translation/handler.py @@ -87,9 +87,9 @@ async def process_input_messages( texts_to_check: List[str] = [] images_to_check: List[str] = [] - tools_to_check: List[ - ChatCompletionToolParam - ] = chat_completion_compatible_request.get("tools", []) + tools_to_check: List[ChatCompletionToolParam] = ( + chat_completion_compatible_request.get("tools", []) + ) task_mappings: List[Tuple[int, Optional[int]]] = [] # Track (message_index, content_index) for each text # content_index is None for string content, int for list content diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 9f2ddcae2c7..a2389f44295 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -578,9 +578,7 @@ def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage speed=self.speed, ) - def _content_block_delta_helper( - self, chunk: dict - ) -> Tuple[ + def _content_block_delta_helper(self, chunk: dict) -> Tuple[ str, Optional[ChatCompletionToolCallChunk], List[Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock]], @@ -805,9 +803,9 @@ def chunk_parser(self, chunk: dict) -> ModelResponseStream: # noqa: PLR0915 tool_input = content_block_start["content_block"].get( "input", {} ) - self._server_tool_inputs[ - self._current_server_tool_id - ] = tool_input + self._server_tool_inputs[self._current_server_tool_id] = ( + tool_input + ) # Include caller information if present (for programmatic tool calling) if "caller" in content_block_start["content_block"]: caller_data = content_block_start["content_block"]["caller"] @@ -828,9 +826,9 @@ def chunk_parser(self, chunk: dict) -> ModelResponseStream: # noqa: PLR0915 # Handle compaction blocks # The full content comes in content_block_start self.compaction_blocks.append(content_block_start["content_block"]) - provider_specific_fields[ - "compaction_blocks" - ] = self.compaction_blocks + provider_specific_fields["compaction_blocks"] = ( + self.compaction_blocks + ) provider_specific_fields["compaction_start"] = { "type": "compaction", "content": content_block_start["content_block"].get( @@ -852,9 +850,9 @@ def chunk_parser(self, chunk: dict) -> ModelResponseStream: # noqa: PLR0915 self.web_search_results.append( content_block_start["content_block"] ) - provider_specific_fields[ - "web_search_results" - ] = self.web_search_results + provider_specific_fields["web_search_results"] = ( + self.web_search_results + ) elif content_type == "web_fetch_tool_result": # Capture web_fetch_tool_result for multi-turn reconstruction # The full content comes in content_block_start, not in deltas @@ -862,18 +860,18 @@ def chunk_parser(self, chunk: dict) -> ModelResponseStream: # noqa: PLR0915 self.web_search_results.append( content_block_start["content_block"] ) - provider_specific_fields[ - "web_search_results" - ] = self.web_search_results + provider_specific_fields["web_search_results"] = ( + self.web_search_results + ) elif content_type != "tool_search_tool_result": # Handle other tool results (code execution, etc.) # Skip tool_search_tool_result as it's internal metadata self.tool_results.append(content_block_start["content_block"]) provider_specific_fields["tool_results"] = self.tool_results # Convert to provider-neutral code_interpreter_results - provider_specific_fields[ - "code_interpreter_results" - ] = self._build_code_interpreter_results() + provider_specific_fields["code_interpreter_results"] = ( + self._build_code_interpreter_results() + ) elif type_chunk == "content_block_stop": ContentBlockStop(**chunk) # type: ignore @@ -930,9 +928,9 @@ def chunk_parser(self, chunk: dict) -> ModelResponseStream: # noqa: PLR0915 ) if container_id and self.tool_results: self._container_id = container_id - provider_specific_fields[ - "code_interpreter_results" - ] = self._build_code_interpreter_results() + provider_specific_fields["code_interpreter_results"] = ( + self._build_code_interpreter_results() + ) elif type_chunk == "message_start": """ Anthropic diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 73d1b02c76d..483d87c6a60 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -964,11 +964,11 @@ def map_openai_params( # noqa: PLR0915 if mcp_servers: optional_params["mcp_servers"] = mcp_servers elif param == "tool_choice" or param == "parallel_tool_calls": - _tool_choice: Optional[ - AnthropicMessagesToolChoice - ] = self._map_tool_choice( - tool_choice=non_default_params.get("tool_choice"), - parallel_tool_use=non_default_params.get("parallel_tool_calls"), + _tool_choice: Optional[AnthropicMessagesToolChoice] = ( + self._map_tool_choice( + tool_choice=non_default_params.get("tool_choice"), + parallel_tool_use=non_default_params.get("parallel_tool_calls"), + ) ) if _tool_choice is not None: @@ -1066,9 +1066,9 @@ def map_openai_params( # noqa: PLR0915 self.map_openai_context_management_to_anthropic(value) ) if anthropic_context_management is not None: - optional_params[ - "context_management" - ] = anthropic_context_management + optional_params["context_management"] = ( + anthropic_context_management + ) elif param == "speed" and isinstance(value, str): # Pass through Anthropic-specific speed parameter for fast mode optional_params["speed"] = value @@ -1142,9 +1142,9 @@ def translate_system_message( text=system_message_block["content"], ) if "cache_control" in system_message_block: - anthropic_system_message_content[ - "cache_control" - ] = system_message_block["cache_control"] + anthropic_system_message_content["cache_control"] = ( + system_message_block["cache_control"] + ) anthropic_system_message_list.append( anthropic_system_message_content ) @@ -1168,9 +1168,9 @@ def translate_system_message( ) ) if "cache_control" in _content: - anthropic_system_message_content[ - "cache_control" - ] = _content["cache_control"] + anthropic_system_message_content["cache_control"] = ( + _content["cache_control"] + ) anthropic_system_message_list.append( anthropic_system_message_content @@ -1467,9 +1467,7 @@ def _transform_response_for_json_mode( ) return _message - def extract_response_content( - self, completion_response: dict - ) -> Tuple[ + def extract_response_content(self, completion_response: dict) -> Tuple[ str, Optional[List[Any]], Optional[ @@ -1763,9 +1761,9 @@ def _build_provider_specific_fields( code_interpreter_results = self._build_code_interpreter_results( tool_results, code_by_id, container_id ) - provider_specific_fields[ - "code_interpreter_results" - ] = code_interpreter_results + provider_specific_fields["code_interpreter_results"] = ( + code_interpreter_results + ) container = completion_response.get("container") if container is not None: diff --git a/litellm/llms/anthropic/common_utils.py b/litellm/llms/anthropic/common_utils.py index 7d2d0a74961..6ece0079a3d 100644 --- a/litellm/llms/anthropic/common_utils.py +++ b/litellm/llms/anthropic/common_utils.py @@ -464,9 +464,9 @@ def get_anthropic_headers( if web_search_tool_used: from litellm.types.llms.anthropic import ANTHROPIC_BETA_HEADER_VALUES - headers[ - "anthropic-beta" - ] = ANTHROPIC_BETA_HEADER_VALUES.WEB_SEARCH_2025_03_05.value + headers["anthropic-beta"] = ( + ANTHROPIC_BETA_HEADER_VALUES.WEB_SEARCH_2025_03_05.value + ) elif len(betas) > 0: headers["anthropic-beta"] = ",".join(betas) diff --git a/litellm/llms/anthropic/completion/transformation.py b/litellm/llms/anthropic/completion/transformation.py index 576ddb57fb1..a8798cd5d0e 100644 --- a/litellm/llms/anthropic/completion/transformation.py +++ b/litellm/llms/anthropic/completion/transformation.py @@ -55,9 +55,9 @@ class AnthropicTextConfig(BaseConfig): to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} """ - max_tokens_to_sample: Optional[ - int - ] = litellm.max_tokens # anthropic requires a default + max_tokens_to_sample: Optional[int] = ( + litellm.max_tokens + ) # anthropic requires a default stop_sequences: Optional[list] = None temperature: Optional[int] = None top_p: Optional[int] = None diff --git a/litellm/llms/anthropic/experimental_pass_through/adapters/streaming_iterator.py b/litellm/llms/anthropic/experimental_pass_through/adapters/streaming_iterator.py index 6bddad09f21..3a6f0dbb88e 100644 --- a/litellm/llms/anthropic/experimental_pass_through/adapters/streaming_iterator.py +++ b/litellm/llms/anthropic/experimental_pass_through/adapters/streaming_iterator.py @@ -282,16 +282,16 @@ async def __anext__(self): # noqa: PLR0915 hasattr(chunk.usage, "_cache_creation_input_tokens") and chunk.usage._cache_creation_input_tokens > 0 ): - usage_dict[ - "cache_creation_input_tokens" - ] = chunk.usage._cache_creation_input_tokens + usage_dict["cache_creation_input_tokens"] = ( + chunk.usage._cache_creation_input_tokens + ) if ( hasattr(chunk.usage, "_cache_read_input_tokens") and chunk.usage._cache_read_input_tokens > 0 ): - usage_dict[ - "cache_read_input_tokens" - ] = chunk.usage._cache_read_input_tokens + usage_dict["cache_read_input_tokens"] = ( + chunk.usage._cache_read_input_tokens + ) merged_chunk["usage"] = usage_dict # Queue the merged chunk and reset diff --git a/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py b/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py index ed49943b7fe..0979ff1ead4 100644 --- a/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py +++ b/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py @@ -550,9 +550,9 @@ def translate_anthropic_messages_to_openai( # noqa: PLR0915 ## ASSISTANT MESSAGE ## assistant_message_str: Optional[str] = None - assistant_content_list: List[ - Dict[str, Any] - ] = [] # For content blocks with cache_control + assistant_content_list: List[Dict[str, Any]] = ( + [] + ) # For content blocks with cache_control has_cache_control_in_text = False tool_calls: List[ChatCompletionAssistantToolCall] = [] thinking_blocks: List[ @@ -595,12 +595,12 @@ def translate_anthropic_messages_to_openai( # noqa: PLR0915 function_chunk.get("provider_specific_fields") or {} ) - provider_specific_fields[ - "thought_signature" - ] = signature - function_chunk[ - "provider_specific_fields" - ] = provider_specific_fields + provider_specific_fields["thought_signature"] = ( + signature + ) + function_chunk["provider_specific_fields"] = ( + provider_specific_fields + ) tool_call = ChatCompletionAssistantToolCall( id=content.get("id", ""), @@ -663,7 +663,7 @@ def translate_anthropic_messages_to_openai( # noqa: PLR0915 @staticmethod def translate_anthropic_thinking_to_reasoning_effort( - thinking: Dict[str, Any] + thinking: Dict[str, Any], ) -> Optional[str]: """ Translate Anthropic's thinking parameter to OpenAI's reasoning_effort. @@ -1334,9 +1334,9 @@ def translate_openai_response_to_anthropic( hasattr(usage, "_cache_creation_input_tokens") and usage._cache_creation_input_tokens > 0 ): - anthropic_usage[ - "cache_creation_input_tokens" - ] = usage._cache_creation_input_tokens + anthropic_usage["cache_creation_input_tokens"] = ( + usage._cache_creation_input_tokens + ) if cached_tokens > 0: anthropic_usage["cache_read_input_tokens"] = cached_tokens @@ -1513,9 +1513,9 @@ def translate_streaming_openai_response_to_anthropic( hasattr(litellm_usage_chunk, "_cache_creation_input_tokens") and litellm_usage_chunk._cache_creation_input_tokens > 0 ): - usage_delta[ - "cache_creation_input_tokens" - ] = litellm_usage_chunk._cache_creation_input_tokens + usage_delta["cache_creation_input_tokens"] = ( + litellm_usage_chunk._cache_creation_input_tokens + ) if cached_tokens > 0: usage_delta["cache_read_input_tokens"] = cached_tokens else: diff --git a/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py b/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py index 9b60a58260b..c67cb492a66 100644 --- a/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py +++ b/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py @@ -208,9 +208,9 @@ def transform_anthropic_messages_request( ) ) if transformed_context_management is not None: - anthropic_messages_optional_request_params[ - "context_management" - ] = transformed_context_management + anthropic_messages_optional_request_params["context_management"] = ( + transformed_context_management + ) ####### get required params for all anthropic messages requests ###### verbose_logger.debug(f"TRANSFORMATION DEBUG - Messages: {messages}") diff --git a/litellm/llms/anthropic/experimental_pass_through/responses_adapters/streaming_iterator.py b/litellm/llms/anthropic/experimental_pass_through/responses_adapters/streaming_iterator.py index aa0738a0719..94c5200be64 100644 --- a/litellm/llms/anthropic/experimental_pass_through/responses_adapters/streaming_iterator.py +++ b/litellm/llms/anthropic/experimental_pass_through/responses_adapters/streaming_iterator.py @@ -35,9 +35,9 @@ def __init__( # Map item_id -> content_block_index so we can stop the right block later self._item_id_to_block_index: Dict[str, int] = {} # Track open function_call items by item_id so we can emit tool_use start - self._pending_tool_ids: Dict[ - str, str - ] = {} # item_id -> call_id / name accumulator + self._pending_tool_ids: Dict[str, str] = ( + {} + ) # item_id -> call_id / name accumulator self._sent_message_start = False self._sent_message_stop = False self._chunk_queue: deque = deque() diff --git a/litellm/llms/anthropic/experimental_pass_through/responses_adapters/transformation.py b/litellm/llms/anthropic/experimental_pass_through/responses_adapters/transformation.py index dae7044a5bc..f74dc8149a0 100644 --- a/litellm/llms/anthropic/experimental_pass_through/responses_adapters/transformation.py +++ b/litellm/llms/anthropic/experimental_pass_through/responses_adapters/transformation.py @@ -251,7 +251,7 @@ def translate_context_management_to_responses_api( @staticmethod def translate_thinking_to_reasoning( - thinking: Dict[str, Any] + thinking: Dict[str, Any], ) -> Optional[Dict[str, Any]]: """ Convert Anthropic thinking param to Responses API reasoning param. @@ -337,10 +337,10 @@ def translate_request( # tool_choice tool_choice = anthropic_request.get("tool_choice") if tool_choice: - responses_kwargs[ - "tool_choice" - ] = self.translate_tool_choice_to_responses_api( - cast(AnthropicMessagesToolChoice, tool_choice) + responses_kwargs["tool_choice"] = ( + self.translate_tool_choice_to_responses_api( + cast(AnthropicMessagesToolChoice, tool_choice) + ) ) # thinking -> reasoning diff --git a/litellm/llms/anthropic/files/transformation.py b/litellm/llms/anthropic/files/transformation.py index 0545cefb071..aeaab4e57bf 100644 --- a/litellm/llms/anthropic/files/transformation.py +++ b/litellm/llms/anthropic/files/transformation.py @@ -79,9 +79,9 @@ def get_error_class( return AnthropicError( status_code=status_code, message=error_message, - headers=cast(httpx.Headers, headers) - if isinstance(headers, dict) - else headers, + headers=( + cast(httpx.Headers, headers) if isinstance(headers, dict) else headers + ), ) def validate_environment( diff --git a/litellm/llms/azure/chat/o_series_transformation.py b/litellm/llms/azure/chat/o_series_transformation.py index cae7513245c..0a73597a4e4 100644 --- a/litellm/llms/azure/chat/o_series_transformation.py +++ b/litellm/llms/azure/chat/o_series_transformation.py @@ -4,10 +4,10 @@ https://platform.openai.com/docs/guides/reasoning Translations handled by LiteLLM: -- modalities: image => drop param (if user opts in to dropping param) -- role: system ==> translate to role 'user' -- streaming => faked by LiteLLM -- Tools, response_format => drop param (if user opts in to dropping param) +- modalities: image => drop param (if user opts in to dropping param) +- role: system ==> translate to role 'user' +- streaming => faked by LiteLLM +- Tools, response_format => drop param (if user opts in to dropping param) - Logprobs => drop param (if user opts in to dropping param) - Temperature => drop param (if user opts in to dropping param) """ diff --git a/litellm/llms/azure/fine_tuning/handler.py b/litellm/llms/azure/fine_tuning/handler.py index 429b8349896..3d7cc336fb5 100644 --- a/litellm/llms/azure/fine_tuning/handler.py +++ b/litellm/llms/azure/fine_tuning/handler.py @@ -25,7 +25,14 @@ def get_openai_client( _is_async: bool = False, api_version: Optional[str] = None, litellm_params: Optional[dict] = None, - ) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]: + ) -> Optional[ + Union[ + OpenAI, + AsyncOpenAI, + AzureOpenAI, + AsyncAzureOpenAI, + ] + ]: # Override to use Azure-specific client initialization if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI): client = None diff --git a/litellm/llms/azure_ai/anthropic/__init__.py b/litellm/llms/azure_ai/anthropic/__init__.py index 931c71de3b3..5ec22703aec 100644 --- a/litellm/llms/azure_ai/anthropic/__init__.py +++ b/litellm/llms/azure_ai/anthropic/__init__.py @@ -1,6 +1,7 @@ """ Azure Anthropic provider - supports Claude models via Azure Foundry """ + from .handler import AzureAnthropicChatCompletion from .transformation import AzureAnthropicConfig diff --git a/litellm/llms/azure_ai/anthropic/handler.py b/litellm/llms/azure_ai/anthropic/handler.py index a2263e72a14..f3a50b73c1a 100644 --- a/litellm/llms/azure_ai/anthropic/handler.py +++ b/litellm/llms/azure_ai/anthropic/handler.py @@ -1,6 +1,7 @@ """ Azure Anthropic handler - reuses AnthropicChatCompletion logic with Azure authentication """ + import copy import json from typing import TYPE_CHECKING, Callable, Union diff --git a/litellm/llms/azure_ai/anthropic/messages_transformation.py b/litellm/llms/azure_ai/anthropic/messages_transformation.py index 59d8fb02c6d..a81218ab76a 100644 --- a/litellm/llms/azure_ai/anthropic/messages_transformation.py +++ b/litellm/llms/azure_ai/anthropic/messages_transformation.py @@ -1,6 +1,7 @@ """ Azure Anthropic messages transformation config - extends AnthropicMessagesConfig with Azure authentication """ + from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from litellm.llms.anthropic.experimental_pass_through.messages.transformation import ( diff --git a/litellm/llms/azure_ai/anthropic/transformation.py b/litellm/llms/azure_ai/anthropic/transformation.py index 5d8f27b97df..e935aa1c057 100644 --- a/litellm/llms/azure_ai/anthropic/transformation.py +++ b/litellm/llms/azure_ai/anthropic/transformation.py @@ -1,6 +1,7 @@ """ Azure Anthropic transformation config - extends AnthropicConfig with Azure authentication """ + from typing import TYPE_CHECKING, Dict, List, Optional, Union from litellm.llms.anthropic.chat.transformation import AnthropicConfig from litellm.llms.azure.common_utils import BaseAzureLLM diff --git a/litellm/llms/azure_ai/azure_model_router/__init__.py b/litellm/llms/azure_ai/azure_model_router/__init__.py index 0165d60b643..bbee759459f 100644 --- a/litellm/llms/azure_ai/azure_model_router/__init__.py +++ b/litellm/llms/azure_ai/azure_model_router/__init__.py @@ -1,4 +1,5 @@ """Azure AI Foundry Model Router support.""" + from .transformation import AzureModelRouterConfig __all__ = ["AzureModelRouterConfig"] diff --git a/litellm/llms/azure_ai/azure_model_router/transformation.py b/litellm/llms/azure_ai/azure_model_router/transformation.py index 57acb147063..e4174f41ad7 100644 --- a/litellm/llms/azure_ai/azure_model_router/transformation.py +++ b/litellm/llms/azure_ai/azure_model_router/transformation.py @@ -4,6 +4,7 @@ The Model Router is a special Azure AI deployment that automatically routes requests to the best available model. It has specific cost tracking requirements. """ + from typing import Any, List, Optional from httpx import Response diff --git a/litellm/llms/azure_ai/embed/cohere_transformation.py b/litellm/llms/azure_ai/embed/cohere_transformation.py index 64433c21b61..bbbfb60fbde 100644 --- a/litellm/llms/azure_ai/embed/cohere_transformation.py +++ b/litellm/llms/azure_ai/embed/cohere_transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic from OpenAI /v1/embeddings format to Azure AI Cohere's /v1/embed. +Transformation logic from OpenAI /v1/embeddings format to Azure AI Cohere's /v1/embed. Why separate file? Make it easy to see how transformation works diff --git a/litellm/llms/azure_ai/ocr/__init__.py b/litellm/llms/azure_ai/ocr/__init__.py index e49217a5baf..ade1165b848 100644 --- a/litellm/llms/azure_ai/ocr/__init__.py +++ b/litellm/llms/azure_ai/ocr/__init__.py @@ -1,4 +1,5 @@ """Azure AI OCR module.""" + from .common_utils import get_azure_ai_ocr_config from .document_intelligence.transformation import ( AzureDocumentIntelligenceOCRConfig, diff --git a/litellm/llms/azure_ai/ocr/document_intelligence/__init__.py b/litellm/llms/azure_ai/ocr/document_intelligence/__init__.py index fb14fbbf0ac..32d700fd195 100644 --- a/litellm/llms/azure_ai/ocr/document_intelligence/__init__.py +++ b/litellm/llms/azure_ai/ocr/document_intelligence/__init__.py @@ -1,4 +1,5 @@ """Azure Document Intelligence OCR module.""" + from .transformation import AzureDocumentIntelligenceOCRConfig __all__ = ["AzureDocumentIntelligenceOCRConfig"] diff --git a/litellm/llms/azure_ai/ocr/document_intelligence/transformation.py b/litellm/llms/azure_ai/ocr/document_intelligence/transformation.py index 6ef309ca679..81d15bac481 100644 --- a/litellm/llms/azure_ai/ocr/document_intelligence/transformation.py +++ b/litellm/llms/azure_ai/ocr/document_intelligence/transformation.py @@ -7,6 +7,7 @@ Note: Azure Document Intelligence API is async - POST returns 202 Accepted with Operation-Location header. The operation location must be polled until the analysis completes. """ + import asyncio import re import time diff --git a/litellm/llms/azure_ai/ocr/transformation.py b/litellm/llms/azure_ai/ocr/transformation.py index 8f57bb3358b..f661ddb9ebc 100644 --- a/litellm/llms/azure_ai/ocr/transformation.py +++ b/litellm/llms/azure_ai/ocr/transformation.py @@ -1,6 +1,7 @@ """ Azure AI OCR transformation implementation. """ + from typing import Dict, Optional from litellm._logging import verbose_logger diff --git a/litellm/llms/azure_ai/rerank/transformation.py b/litellm/llms/azure_ai/rerank/transformation.py index b5993040ea0..f64133afa8b 100644 --- a/litellm/llms/azure_ai/rerank/transformation.py +++ b/litellm/llms/azure_ai/rerank/transformation.py @@ -1,5 +1,5 @@ """ -Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format. +Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format. """ from typing import Optional diff --git a/litellm/llms/base_llm/ocr/__init__.py b/litellm/llms/base_llm/ocr/__init__.py index 5965af5f2b7..2aea2d67807 100644 --- a/litellm/llms/base_llm/ocr/__init__.py +++ b/litellm/llms/base_llm/ocr/__init__.py @@ -1,4 +1,5 @@ """Base OCR transformation module.""" + from .transformation import ( BaseOCRConfig, DocumentType, diff --git a/litellm/llms/base_llm/ocr/transformation.py b/litellm/llms/base_llm/ocr/transformation.py index 7d16c696dba..b7f4d8e3b2d 100644 --- a/litellm/llms/base_llm/ocr/transformation.py +++ b/litellm/llms/base_llm/ocr/transformation.py @@ -1,6 +1,7 @@ """ Base OCR transformation configuration. """ + from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import httpx diff --git a/litellm/llms/base_llm/search/__init__.py b/litellm/llms/base_llm/search/__init__.py index f185b4e5955..c423db9ed95 100644 --- a/litellm/llms/base_llm/search/__init__.py +++ b/litellm/llms/base_llm/search/__init__.py @@ -1,6 +1,7 @@ """ Base Search API module. """ + from litellm.llms.base_llm.search.transformation import ( BaseSearchConfig, SearchResponse, diff --git a/litellm/llms/base_llm/search/transformation.py b/litellm/llms/base_llm/search/transformation.py index 1fbc5b670a9..4dfe86685fb 100644 --- a/litellm/llms/base_llm/search/transformation.py +++ b/litellm/llms/base_llm/search/transformation.py @@ -1,6 +1,7 @@ """ Base Search transformation configuration. """ + from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union import httpx diff --git a/litellm/llms/base_llm/vector_store_files/transformation.py b/litellm/llms/base_llm/vector_store_files/transformation.py index f13de563821..02915d013e5 100644 --- a/litellm/llms/base_llm/vector_store_files/transformation.py +++ b/litellm/llms/base_llm/vector_store_files/transformation.py @@ -54,14 +54,12 @@ def map_openai_params( @abstractmethod def get_auth_credentials( self, litellm_params: Dict[str, Any] - ) -> VectorStoreFileAuthCredentials: - ... + ) -> VectorStoreFileAuthCredentials: ... @abstractmethod def get_vector_store_file_endpoints_by_type( self, - ) -> Dict[str, Tuple[Tuple[str, str], ...]]: - ... + ) -> Dict[str, Tuple[Tuple[str, str], ...]]: ... @abstractmethod def validate_environment( @@ -91,16 +89,14 @@ def transform_create_vector_store_file_request( vector_store_id: str, create_request: VectorStoreFileCreateRequest, api_base: str, - ) -> Tuple[str, Dict[str, Any]]: - ... + ) -> Tuple[str, Dict[str, Any]]: ... @abstractmethod def transform_create_vector_store_file_response( self, *, response: httpx.Response, - ) -> VectorStoreFileObject: - ... + ) -> VectorStoreFileObject: ... @abstractmethod def transform_list_vector_store_files_request( @@ -109,16 +105,14 @@ def transform_list_vector_store_files_request( vector_store_id: str, query_params: VectorStoreFileListQueryParams, api_base: str, - ) -> Tuple[str, Dict[str, Any]]: - ... + ) -> Tuple[str, Dict[str, Any]]: ... @abstractmethod def transform_list_vector_store_files_response( self, *, response: httpx.Response, - ) -> VectorStoreFileListResponse: - ... + ) -> VectorStoreFileListResponse: ... @abstractmethod def transform_retrieve_vector_store_file_request( @@ -127,16 +121,14 @@ def transform_retrieve_vector_store_file_request( vector_store_id: str, file_id: str, api_base: str, - ) -> Tuple[str, Dict[str, Any]]: - ... + ) -> Tuple[str, Dict[str, Any]]: ... @abstractmethod def transform_retrieve_vector_store_file_response( self, *, response: httpx.Response, - ) -> VectorStoreFileObject: - ... + ) -> VectorStoreFileObject: ... @abstractmethod def transform_retrieve_vector_store_file_content_request( @@ -145,16 +137,14 @@ def transform_retrieve_vector_store_file_content_request( vector_store_id: str, file_id: str, api_base: str, - ) -> Tuple[str, Dict[str, Any]]: - ... + ) -> Tuple[str, Dict[str, Any]]: ... @abstractmethod def transform_retrieve_vector_store_file_content_response( self, *, response: httpx.Response, - ) -> VectorStoreFileContentResponse: - ... + ) -> VectorStoreFileContentResponse: ... @abstractmethod def transform_update_vector_store_file_request( @@ -164,16 +154,14 @@ def transform_update_vector_store_file_request( file_id: str, update_request: VectorStoreFileUpdateRequest, api_base: str, - ) -> Tuple[str, Dict[str, Any]]: - ... + ) -> Tuple[str, Dict[str, Any]]: ... @abstractmethod def transform_update_vector_store_file_response( self, *, response: httpx.Response, - ) -> VectorStoreFileObject: - ... + ) -> VectorStoreFileObject: ... @abstractmethod def transform_delete_vector_store_file_request( @@ -182,16 +170,14 @@ def transform_delete_vector_store_file_request( vector_store_id: str, file_id: str, api_base: str, - ) -> Tuple[str, Dict[str, Any]]: - ... + ) -> Tuple[str, Dict[str, Any]]: ... @abstractmethod def transform_delete_vector_store_file_response( self, *, response: httpx.Response, - ) -> VectorStoreFileDeleteResponse: - ... + ) -> VectorStoreFileDeleteResponse: ... def get_error_class( self, diff --git a/litellm/llms/bedrock/batches/handler.py b/litellm/llms/bedrock/batches/handler.py index e0c7c088362..f141bbd9ab4 100644 --- a/litellm/llms/bedrock/batches/handler.py +++ b/litellm/llms/bedrock/batches/handler.py @@ -64,9 +64,11 @@ async def _async_get_status(): created_at=status_response["submitTime"], in_progress_at=status_response["lastModifiedTime"], completed_at=status_response.get("endTime"), - failed_at=status_response.get("endTime") - if status_response["status"] == "failed" - else None, + failed_at=( + status_response.get("endTime") + if status_response["status"] == "failed" + else None + ), request_counts=BatchRequestCounts( total=1, completed=1 if status_response["status"] == "completed" else 0, diff --git a/litellm/llms/bedrock/chat/agentcore/transformation.py b/litellm/llms/bedrock/chat/agentcore/transformation.py index d6eb5a734c4..8a3e73a1495 100644 --- a/litellm/llms/bedrock/chat/agentcore/transformation.py +++ b/litellm/llms/bedrock/chat/agentcore/transformation.py @@ -877,9 +877,9 @@ async def get_async_custom_stream_wrapper( ) parsed = self._parse_json_response(response_json) - async def _json_as_async_stream() -> AsyncGenerator[ - ModelResponseStream, None - ]: + async def _json_as_async_stream() -> ( + AsyncGenerator[ModelResponseStream, None] + ): # Content chunk content_chunk = ModelResponseStream( id=f"chatcmpl-{uuid.uuid4()}", diff --git a/litellm/llms/bedrock/chat/converse_handler.py b/litellm/llms/bedrock/chat/converse_handler.py index ef46ae5c189..388947a4e9b 100644 --- a/litellm/llms/bedrock/chat/converse_handler.py +++ b/litellm/llms/bedrock/chat/converse_handler.py @@ -332,9 +332,9 @@ def completion( # noqa: PLR0915 aws_external_id = optional_params.pop("aws_external_id", None) optional_params.pop("aws_region_name", None) - litellm_params[ - "aws_region_name" - ] = aws_region_name # [DO NOT DELETE] important for async calls + litellm_params["aws_region_name"] = ( + aws_region_name # [DO NOT DELETE] important for async calls + ) credentials: Credentials = self.get_credentials( aws_access_key_id=aws_access_key_id, diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index dd8b1b0a69f..6709e6703a5 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -1758,9 +1758,7 @@ def apply_tool_call_transformation_if_needed( return message, returned_finish_reason - def _translate_message_content( - self, content_blocks: List[ContentBlock] - ) -> Tuple[ + def _translate_message_content(self, content_blocks: List[ContentBlock]) -> Tuple[ str, List[ChatCompletionToolCallChunk], Optional[List[BedrockConverseReasoningContentBlock]], @@ -1777,9 +1775,9 @@ def _translate_message_content( """ content_str = "" tools: List[ChatCompletionToolCallChunk] = [] - reasoningContentBlocks: Optional[ - List[BedrockConverseReasoningContentBlock] - ] = None + reasoningContentBlocks: Optional[List[BedrockConverseReasoningContentBlock]] = ( + None + ) citationsContentBlocks: Optional[List[CitationsContentBlock]] = None for idx, content in enumerate(content_blocks): """ @@ -1990,9 +1988,9 @@ def _transform_response( # noqa: PLR0915 chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"} content_str = "" tools: List[ChatCompletionToolCallChunk] = [] - reasoningContentBlocks: Optional[ - List[BedrockConverseReasoningContentBlock] - ] = None + reasoningContentBlocks: Optional[List[BedrockConverseReasoningContentBlock]] = ( + None + ) citationsContentBlocks: Optional[List[CitationsContentBlock]] = None if message is not None: @@ -2011,17 +2009,17 @@ def _transform_response( # noqa: PLR0915 provider_specific_fields["citationsContent"] = citationsContentBlocks if provider_specific_fields: - chat_completion_message[ - "provider_specific_fields" - ] = provider_specific_fields + chat_completion_message["provider_specific_fields"] = ( + provider_specific_fields + ) if reasoningContentBlocks is not None: - chat_completion_message[ - "reasoning_content" - ] = self._transform_reasoning_content(reasoningContentBlocks) - chat_completion_message[ - "thinking_blocks" - ] = self._transform_thinking_blocks(reasoningContentBlocks) + chat_completion_message["reasoning_content"] = ( + self._transform_reasoning_content(reasoningContentBlocks) + ) + chat_completion_message["thinking_blocks"] = ( + self._transform_thinking_blocks(reasoningContentBlocks) + ) chat_completion_message["content"] = content_str filtered_tools = self._filter_json_mode_tools( json_mode=json_mode, diff --git a/litellm/llms/bedrock/chat/invoke_handler.py b/litellm/llms/bedrock/chat/invoke_handler.py index 1077731779d..a6952254142 100644 --- a/litellm/llms/bedrock/chat/invoke_handler.py +++ b/litellm/llms/bedrock/chat/invoke_handler.py @@ -199,11 +199,13 @@ async def make_call( if client is None: client = get_async_httpx_client( llm_provider=litellm.LlmProviders.BEDROCK, - params={"ssl_verify": logging_obj.litellm_params.get("ssl_verify")} - if logging_obj - and logging_obj.litellm_params - and logging_obj.litellm_params.get("ssl_verify") - else None, + params=( + {"ssl_verify": logging_obj.litellm_params.get("ssl_verify")} + if logging_obj + and logging_obj.litellm_params + and logging_obj.litellm_params.get("ssl_verify") + else None + ), ) # Create a new client if none provided response = await client.post( @@ -293,11 +295,13 @@ def make_sync_call( try: if client is None: client = _get_httpx_client( - params={"ssl_verify": logging_obj.litellm_params.get("ssl_verify")} - if logging_obj - and logging_obj.litellm_params - and logging_obj.litellm_params.get("ssl_verify") - else None + params=( + {"ssl_verify": logging_obj.litellm_params.get("ssl_verify")} + if logging_obj + and logging_obj.litellm_params + and logging_obj.litellm_params.get("ssl_verify") + else None + ) ) response = client.post( @@ -547,9 +551,9 @@ def process_response( # noqa: PLR0915 content=None, ) model_response.choices[0].message = _message # type: ignore - model_response._hidden_params[ - "original_response" - ] = outputText # allow user to access raw anthropic tool calling response + model_response._hidden_params["original_response"] = ( + outputText # allow user to access raw anthropic tool calling response + ) if ( _is_function_call is True and stream is not None @@ -882,9 +886,9 @@ def completion( # noqa: PLR0915 ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v if stream is True: - inference_params[ - "stream" - ] = True # cohere requires stream = True in inference params + inference_params["stream"] = ( + True # cohere requires stream = True in inference params + ) data = json.dumps({"prompt": prompt, **inference_params}) elif provider == "anthropic": if self.is_claude_messages_api_model(model): diff --git a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py index cf8aee6954b..43850440072 100644 --- a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py +++ b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py @@ -182,9 +182,9 @@ def transform_request( config = litellm.AmazonCohereConfig.get_config() self._apply_config_to_params(config, inference_params) if stream is True: - inference_params[ - "stream" - ] = True # cohere requires stream = True in inference params + inference_params["stream"] = ( + True # cohere requires stream = True in inference params + ) request_data = {"prompt": prompt, **inference_params} elif provider == "anthropic": transformed_request = ( diff --git a/litellm/llms/bedrock/common_utils.py b/litellm/llms/bedrock/common_utils.py index 9666aa68c99..81ea94e07c7 100644 --- a/litellm/llms/bedrock/common_utils.py +++ b/litellm/llms/bedrock/common_utils.py @@ -1062,9 +1062,11 @@ def sign_aws_request( return ( dict(prepped.headers), - request_data.encode("utf-8") - if isinstance(request_data, str) - else request_data, + ( + request_data.encode("utf-8") + if isinstance(request_data, str) + else request_data + ), ) def generate_unique_job_name(self, model: str, prefix: str = "litellm") -> str: diff --git a/litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py b/litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py index 2747551af81..64a79b73273 100644 --- a/litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py +++ b/litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format. +Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format. Why separate file? Make it easy to see how transformation works diff --git a/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py b/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py index 07b04734c30..2713f54e623 100644 --- a/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py +++ b/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py @@ -38,9 +38,9 @@ def map_openai_params( ) -> dict: for k, v in non_default_params.items(): if k == "dimensions": - optional_params[ - "embeddingConfig" - ] = AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v) + optional_params["embeddingConfig"] = ( + AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v) + ) return optional_params def _transform_request( diff --git a/litellm/llms/bedrock/embed/cohere_transformation.py b/litellm/llms/bedrock/embed/cohere_transformation.py index d00cb74aae0..885c91f9754 100644 --- a/litellm/llms/bedrock/embed/cohere_transformation.py +++ b/litellm/llms/bedrock/embed/cohere_transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format. +Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format. Why separate file? Make it easy to see how transformation works """ diff --git a/litellm/llms/bedrock/image_generation/amazon_nova_canvas_transformation.py b/litellm/llms/bedrock/image_generation/amazon_nova_canvas_transformation.py index 86c005bbfad..87ef469beb5 100644 --- a/litellm/llms/bedrock/image_generation/amazon_nova_canvas_transformation.py +++ b/litellm/llms/bedrock/image_generation/amazon_nova_canvas_transformation.py @@ -103,9 +103,9 @@ def transform_request_body( imageGenerationConfig=image_generation_config_typed, ) if task_type == "COLOR_GUIDED_GENERATION": - color_guided_generation_params: Dict[ - str, Any - ] = image_generation_config.pop("colorGuidedGenerationParams", {}) + color_guided_generation_params: Dict[str, Any] = ( + image_generation_config.pop("colorGuidedGenerationParams", {}) + ) color_guided_generation_params = { "text": text, **color_guided_generation_params, diff --git a/litellm/llms/bedrock/messages/invoke_transformations/anthropic_claude3_transformation.py b/litellm/llms/bedrock/messages/invoke_transformations/anthropic_claude3_transformation.py index e31820d7631..7ab58687983 100644 --- a/litellm/llms/bedrock/messages/invoke_transformations/anthropic_claude3_transformation.py +++ b/litellm/llms/bedrock/messages/invoke_transformations/anthropic_claude3_transformation.py @@ -394,9 +394,9 @@ def transform_anthropic_messages_request( # 1. anthropic_version is required for all claude models if "anthropic_version" not in anthropic_messages_request: - anthropic_messages_request[ - "anthropic_version" - ] = self.DEFAULT_BEDROCK_ANTHROPIC_API_VERSION + anthropic_messages_request["anthropic_version"] = ( + self.DEFAULT_BEDROCK_ANTHROPIC_API_VERSION + ) # 2. `stream` is not allowed in request body for bedrock invoke if "stream" in anthropic_messages_request: diff --git a/litellm/llms/bedrock/realtime/handler.py b/litellm/llms/bedrock/realtime/handler.py index cde9f3e6fce..705ef62389c 100644 --- a/litellm/llms/bedrock/realtime/handler.py +++ b/litellm/llms/bedrock/realtime/handler.py @@ -235,7 +235,9 @@ async def _forward_bedrock_to_client( # Transform Bedrock format to OpenAI format from litellm.types.realtime import RealtimeResponseTransformInput - realtime_response_transform_input: RealtimeResponseTransformInput = { + realtime_response_transform_input: ( + RealtimeResponseTransformInput + ) = { "current_output_item_id": session_state.get( "current_output_item_id" ), diff --git a/litellm/llms/bedrock/realtime/transformation.py b/litellm/llms/bedrock/realtime/transformation.py index 13d5bf35466..9124a8c21b4 100644 --- a/litellm/llms/bedrock/realtime/transformation.py +++ b/litellm/llms/bedrock/realtime/transformation.py @@ -1016,9 +1016,9 @@ def transform_conversation_item_create_tool_result_event( "toolResult": { "promptName": self.prompt_name, "contentName": tool_content_name, - "content": output - if isinstance(output, str) - else json.dumps(output), + "content": ( + output if isinstance(output, str) else json.dumps(output) + ), } } } diff --git a/litellm/llms/bedrock_mantle/chat/transformation.py b/litellm/llms/bedrock_mantle/chat/transformation.py index e413bb22b2d..81a56030a5c 100644 --- a/litellm/llms/bedrock_mantle/chat/transformation.py +++ b/litellm/llms/bedrock_mantle/chat/transformation.py @@ -16,7 +16,6 @@ from ...openai_like.chat.transformation import OpenAILikeChatConfig - BEDROCK_MANTLE_DEFAULT_REGION = "us-east-1" diff --git a/litellm/llms/chatgpt/chat/streaming_utils.py b/litellm/llms/chatgpt/chat/streaming_utils.py index e9cf2d15c20..a08fecd9625 100644 --- a/litellm/llms/chatgpt/chat/streaming_utils.py +++ b/litellm/llms/chatgpt/chat/streaming_utils.py @@ -24,9 +24,9 @@ def __init__(self, stream: Any): self._stream = stream self._seen_ids: Dict[str, int] = {} # tool_call_id -> assigned_index self._next_index: int = 0 - self._last_id: Optional[ - str - ] = None # tracks which tool call the next delta belongs to + self._last_id: Optional[str] = ( + None # tracks which tool call the next delta belongs to + ) def __getattr__(self, name: str) -> Any: return getattr(self._stream, name) diff --git a/litellm/llms/chatgpt/common_utils.py b/litellm/llms/chatgpt/common_utils.py index 9cbcd6a4f46..830414d9cad 100644 --- a/litellm/llms/chatgpt/common_utils.py +++ b/litellm/llms/chatgpt/common_utils.py @@ -1,6 +1,7 @@ """ Constants and helpers for ChatGPT subscription OAuth. """ + import os import platform from typing import Any, Optional, Union diff --git a/litellm/llms/chatgpt/responses/transformation.py b/litellm/llms/chatgpt/responses/transformation.py index 3c59ca16581..66acd933416 100644 --- a/litellm/llms/chatgpt/responses/transformation.py +++ b/litellm/llms/chatgpt/responses/transformation.py @@ -77,9 +77,9 @@ def transform_responses_api_request( existing_instructions = request.get("instructions") if existing_instructions: if base_instructions not in existing_instructions: - request[ - "instructions" - ] = f"{base_instructions}\n\n{existing_instructions}" + request["instructions"] = ( + f"{base_instructions}\n\n{existing_instructions}" + ) else: request["instructions"] = base_instructions request["store"] = False diff --git a/litellm/llms/cohere/embed/handler.py b/litellm/llms/cohere/embed/handler.py index 3ab8baf7ba8..81b6a1c7aec 100644 --- a/litellm/llms/cohere/embed/handler.py +++ b/litellm/llms/cohere/embed/handler.py @@ -1,5 +1,5 @@ """ -Legacy /v1/embedding handler for Bedrock Cohere. +Legacy /v1/embedding handler for Bedrock Cohere. """ import json diff --git a/litellm/llms/custom_httpx/async_client_cleanup.py b/litellm/llms/custom_httpx/async_client_cleanup.py index 22629383ac2..9c1f6af7e9c 100644 --- a/litellm/llms/custom_httpx/async_client_cleanup.py +++ b/litellm/llms/custom_httpx/async_client_cleanup.py @@ -1,6 +1,7 @@ """ Utility functions for cleaning up async HTTP clients to prevent resource leaks. """ + import asyncio diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 001547557d4..a14eeb1f899 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -886,9 +886,9 @@ def _create_aiohttp_transport( if AIOHTTP_CONNECTOR_LIMIT > 0: transport_connector_kwargs["limit"] = AIOHTTP_CONNECTOR_LIMIT if AIOHTTP_CONNECTOR_LIMIT_PER_HOST > 0: - transport_connector_kwargs[ - "limit_per_host" - ] = AIOHTTP_CONNECTOR_LIMIT_PER_HOST + transport_connector_kwargs["limit_per_host"] = ( + AIOHTTP_CONNECTOR_LIMIT_PER_HOST + ) return LiteLLMAiohttpTransport( client=lambda: ClientSession( diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 4c9abaad908..30320d9adab 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -4495,9 +4495,9 @@ async def _call_agentic_completion_hooks( # Second: Execute agentic loop # Add custom_llm_provider to kwargs so the agentic loop can reconstruct the full model name kwargs_with_provider = kwargs.copy() if kwargs else {} - kwargs_with_provider[ - "custom_llm_provider" - ] = custom_llm_provider + kwargs_with_provider["custom_llm_provider"] = ( + custom_llm_provider + ) agentic_response = await callback.async_run_agentic_loop( tools=tool_calls, model=model, @@ -4613,9 +4613,9 @@ async def _call_agentic_chat_completion_hooks( # Second: Execute agentic loop # Add custom_llm_provider to kwargs so the agentic loop can reconstruct the full model name kwargs_with_provider = kwargs.copy() if kwargs else {} - kwargs_with_provider[ - "custom_llm_provider" - ] = custom_llm_provider + kwargs_with_provider["custom_llm_provider"] = ( + custom_llm_provider + ) agentic_response = ( await callback.async_run_chat_completion_agentic_loop( tools=tool_calls, @@ -5099,7 +5099,10 @@ def image_edit_handler( _is_async: bool = False, fake_stream: bool = False, litellm_metadata: Optional[Dict[str, Any]] = None, - ) -> Union[ImageResponse, Coroutine[Any, Any, ImageResponse],]: + ) -> Union[ + ImageResponse, + Coroutine[Any, Any, ImageResponse], + ]: """ Handles image edit requests. @@ -5311,7 +5314,10 @@ def image_generation_handler( fake_stream: bool = False, litellm_metadata: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None, - ) -> Union[ImageResponse, Coroutine[Any, Any, ImageResponse],]: + ) -> Union[ + ImageResponse, + Coroutine[Any, Any, ImageResponse], + ]: """ Handles image generation requests. When _is_async=True, returns a coroutine instead of making the call directly. @@ -5551,7 +5557,10 @@ def video_generation_handler( fake_stream: bool = False, litellm_metadata: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None, - ) -> Union[VideoObject, Coroutine[Any, Any, VideoObject],]: + ) -> Union[ + VideoObject, + Coroutine[Any, Any, VideoObject], + ]: """ Handles video generation requests. When _is_async=True, returns a coroutine instead of making the call directly. diff --git a/litellm/llms/custom_httpx/mock_transport.py b/litellm/llms/custom_httpx/mock_transport.py index c9844753e0e..ad93cc134ee 100644 --- a/litellm/llms/custom_httpx/mock_transport.py +++ b/litellm/llms/custom_httpx/mock_transport.py @@ -13,7 +13,6 @@ import httpx - # --------------------------------------------------------------------------- # Pre-built response templates # --------------------------------------------------------------------------- diff --git a/litellm/llms/dashscope/chat/transformation.py b/litellm/llms/dashscope/chat/transformation.py index cc5cf991826..4d90b2a1f9f 100644 --- a/litellm/llms/dashscope/chat/transformation.py +++ b/litellm/llms/dashscope/chat/transformation.py @@ -14,8 +14,7 @@ class DashScopeChatConfig(OpenAIGPTConfig): @overload def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: Literal[True] - ) -> Coroutine[Any, Any, List[AllMessageValues]]: - ... + ) -> Coroutine[Any, Any, List[AllMessageValues]]: ... @overload def _transform_messages( @@ -23,8 +22,7 @@ def _transform_messages( messages: List[AllMessageValues], model: str, is_async: Literal[False] = False, - ) -> List[AllMessageValues]: - ... + ) -> List[AllMessageValues]: ... def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: bool = False diff --git a/litellm/llms/dashscope/cost_calculator.py b/litellm/llms/dashscope/cost_calculator.py index 9b3e3851162..8bb7f605b82 100644 --- a/litellm/llms/dashscope/cost_calculator.py +++ b/litellm/llms/dashscope/cost_calculator.py @@ -1,5 +1,5 @@ """ -Cost calculator for Dashscope Chat models. +Cost calculator for Dashscope Chat models. Handles tiered pricing and prompt caching scenarios. """ diff --git a/litellm/llms/databricks/chat/transformation.py b/litellm/llms/databricks/chat/transformation.py index 8ae02bd65ed..638d2d2d9e2 100644 --- a/litellm/llms/databricks/chat/transformation.py +++ b/litellm/llms/databricks/chat/transformation.py @@ -353,8 +353,7 @@ def _should_fake_stream(self, optional_params: dict) -> bool: @overload def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: Literal[True] - ) -> Coroutine[Any, Any, List[AllMessageValues]]: - ... + ) -> Coroutine[Any, Any, List[AllMessageValues]]: ... @overload def _transform_messages( @@ -362,8 +361,7 @@ def _transform_messages( messages: List[AllMessageValues], model: str, is_async: Literal[False] = False, - ) -> List[AllMessageValues]: - ... + ) -> List[AllMessageValues]: ... def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: bool = False diff --git a/litellm/llms/databricks/common_utils.py b/litellm/llms/databricks/common_utils.py index 608f29a03a7..d39d52d2d59 100644 --- a/litellm/llms/databricks/common_utils.py +++ b/litellm/llms/databricks/common_utils.py @@ -289,9 +289,9 @@ def _get_databricks_credentials( api_base = api_base or f"{databricks_client.config.host}/serving-endpoints" if api_key is None: - databricks_auth_headers: dict[ - str, str - ] = databricks_client.config.authenticate() + databricks_auth_headers: dict[str, str] = ( + databricks_client.config.authenticate() + ) headers = {**databricks_auth_headers, **headers} return api_base, headers diff --git a/litellm/llms/databricks/embed/transformation.py b/litellm/llms/databricks/embed/transformation.py index a113a349cc6..53e3b30dd21 100644 --- a/litellm/llms/databricks/embed/transformation.py +++ b/litellm/llms/databricks/embed/transformation.py @@ -11,9 +11,9 @@ class DatabricksEmbeddingConfig: Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task """ - instruction: Optional[ - str - ] = None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries + instruction: Optional[str] = ( + None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries + ) def __init__(self, instruction: Optional[str] = None) -> None: locals_ = locals().copy() diff --git a/litellm/llms/dataforseo/search/transformation.py b/litellm/llms/dataforseo/search/transformation.py index 940f1ca6007..27c10d740b5 100644 --- a/litellm/llms/dataforseo/search/transformation.py +++ b/litellm/llms/dataforseo/search/transformation.py @@ -3,6 +3,7 @@ DataForSEO API Reference: https://docs.dataforseo.com/v3/serp/google/organic/live/advanced/?bash """ + from typing import Any, Dict, List, Literal, Optional, Union import httpx diff --git a/litellm/llms/datarobot/chat/transformation.py b/litellm/llms/datarobot/chat/transformation.py index 23ce63c25b2..f81e2420930 100644 --- a/litellm/llms/datarobot/chat/transformation.py +++ b/litellm/llms/datarobot/chat/transformation.py @@ -1,5 +1,5 @@ """ -Support for OpenAI's `/v1/chat/completions` endpoint. +Support for OpenAI's `/v1/chat/completions` endpoint. Calls done in OpenAI/openai.py as DataRobot is openai-compatible. """ diff --git a/litellm/llms/deepinfra/chat/transformation.py b/litellm/llms/deepinfra/chat/transformation.py index c36b490abca..a6bd8b4934f 100644 --- a/litellm/llms/deepinfra/chat/transformation.py +++ b/litellm/llms/deepinfra/chat/transformation.py @@ -161,8 +161,7 @@ def _transform_tool_message_content( @overload def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: Literal[True] - ) -> Coroutine[Any, Any, List[AllMessageValues]]: - ... + ) -> Coroutine[Any, Any, List[AllMessageValues]]: ... @overload def _transform_messages( @@ -170,8 +169,7 @@ def _transform_messages( messages: List[AllMessageValues], model: str, is_async: Literal[False] = False, - ) -> List[AllMessageValues]: - ... + ) -> List[AllMessageValues]: ... def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: bool = False diff --git a/litellm/llms/deepinfra/rerank/transformation.py b/litellm/llms/deepinfra/rerank/transformation.py index 71e300d258c..1850c11839c 100644 --- a/litellm/llms/deepinfra/rerank/transformation.py +++ b/litellm/llms/deepinfra/rerank/transformation.py @@ -1,5 +1,5 @@ """ -Translate between Cohere's `/rerank` format and Deepinfra's `/rerank` format. +Translate between Cohere's `/rerank` format and Deepinfra's `/rerank` format. """ from typing import Any, Dict, List, Optional, Union diff --git a/litellm/llms/deepseek/chat/transformation.py b/litellm/llms/deepseek/chat/transformation.py index d38ec4d67dd..5cd8d119542 100644 --- a/litellm/llms/deepseek/chat/transformation.py +++ b/litellm/llms/deepseek/chat/transformation.py @@ -65,8 +65,7 @@ def map_openai_params( @overload def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: Literal[True] - ) -> Coroutine[Any, Any, List[AllMessageValues]]: - ... + ) -> Coroutine[Any, Any, List[AllMessageValues]]: ... @overload def _transform_messages( @@ -74,8 +73,7 @@ def _transform_messages( messages: List[AllMessageValues], model: str, is_async: Literal[False] = False, - ) -> List[AllMessageValues]: - ... + ) -> List[AllMessageValues]: ... def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: bool = False diff --git a/litellm/llms/deepseek/cost_calculator.py b/litellm/llms/deepseek/cost_calculator.py index 0f4490cb3df..e652ebeac54 100644 --- a/litellm/llms/deepseek/cost_calculator.py +++ b/litellm/llms/deepseek/cost_calculator.py @@ -1,5 +1,5 @@ """ -Cost calculator for DeepSeek Chat models. +Cost calculator for DeepSeek Chat models. Handles prompt caching scenario. """ diff --git a/litellm/llms/deprecated_providers/aleph_alpha.py b/litellm/llms/deprecated_providers/aleph_alpha.py index 4cfede2a1b9..81ad1346414 100644 --- a/litellm/llms/deprecated_providers/aleph_alpha.py +++ b/litellm/llms/deprecated_providers/aleph_alpha.py @@ -77,9 +77,9 @@ class AlephAlphaConfig: - `control_log_additive` (boolean; default value: true): Method of applying control to attention scores. """ - maximum_tokens: Optional[ - int - ] = litellm.max_tokens # aleph alpha requires max tokens + maximum_tokens: Optional[int] = ( + litellm.max_tokens + ) # aleph alpha requires max tokens minimum_tokens: Optional[int] = None echo: Optional[bool] = None temperature: Optional[int] = None diff --git a/litellm/llms/docker_model_runner/chat/transformation.py b/litellm/llms/docker_model_runner/chat/transformation.py index 4b81502bf81..dc03c80f154 100644 --- a/litellm/llms/docker_model_runner/chat/transformation.py +++ b/litellm/llms/docker_model_runner/chat/transformation.py @@ -26,8 +26,7 @@ class DockerModelRunnerChatConfig(OpenAIGPTConfig): @overload def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: Literal[True] - ) -> Coroutine[Any, Any, List[AllMessageValues]]: - ... + ) -> Coroutine[Any, Any, List[AllMessageValues]]: ... @overload def _transform_messages( @@ -35,8 +34,7 @@ def _transform_messages( messages: List[AllMessageValues], model: str, is_async: Literal[False] = False, - ) -> List[AllMessageValues]: - ... + ) -> List[AllMessageValues]: ... def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: bool = False diff --git a/litellm/llms/duckduckgo/search/__init__.py b/litellm/llms/duckduckgo/search/__init__.py index c0019637838..7ae8f7b397a 100644 --- a/litellm/llms/duckduckgo/search/__init__.py +++ b/litellm/llms/duckduckgo/search/__init__.py @@ -1,6 +1,7 @@ """ DuckDuckGo Search API module. """ + from litellm.llms.duckduckgo.search.transformation import DuckDuckGoSearchConfig __all__ = ["DuckDuckGoSearchConfig"] diff --git a/litellm/llms/duckduckgo/search/transformation.py b/litellm/llms/duckduckgo/search/transformation.py index c754338153a..e8eda3a37ab 100644 --- a/litellm/llms/duckduckgo/search/transformation.py +++ b/litellm/llms/duckduckgo/search/transformation.py @@ -3,6 +3,7 @@ DuckDuckGo API Reference: https://duckduckgo.com/api """ + from typing import Dict, List, Literal, Optional, TypedDict, Union from urllib.parse import urlencode diff --git a/litellm/llms/elevenlabs/text_to_speech/transformation.py b/litellm/llms/elevenlabs/text_to_speech/transformation.py index 4dac2b8ba92..46a0772251a 100644 --- a/litellm/llms/elevenlabs/text_to_speech/transformation.py +++ b/litellm/llms/elevenlabs/text_to_speech/transformation.py @@ -21,7 +21,6 @@ from ..common_utils import ElevenLabsException - if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.types.llms.openai import HttpxBinaryResponseContent diff --git a/litellm/llms/exa_ai/search/__init__.py b/litellm/llms/exa_ai/search/__init__.py index db1f0804646..80bc10043bb 100644 --- a/litellm/llms/exa_ai/search/__init__.py +++ b/litellm/llms/exa_ai/search/__init__.py @@ -1,6 +1,7 @@ """ Exa AI Search API module. """ + from litellm.llms.exa_ai.search.transformation import ExaAISearchConfig __all__ = ["ExaAISearchConfig"] diff --git a/litellm/llms/exa_ai/search/transformation.py b/litellm/llms/exa_ai/search/transformation.py index fb352f3f93e..7a34ededa6b 100644 --- a/litellm/llms/exa_ai/search/transformation.py +++ b/litellm/llms/exa_ai/search/transformation.py @@ -3,6 +3,7 @@ Exa AI API Reference: https://docs.exa.ai/reference/search """ + from typing import Dict, List, Optional, TypedDict, Union import httpx diff --git a/litellm/llms/firecrawl/__init__.py b/litellm/llms/firecrawl/__init__.py index b43d2da3214..ef8414689a4 100644 --- a/litellm/llms/firecrawl/__init__.py +++ b/litellm/llms/firecrawl/__init__.py @@ -1,6 +1,7 @@ """ Firecrawl API integration module. """ + from litellm.llms.firecrawl.search.transformation import FirecrawlSearchConfig __all__ = ["FirecrawlSearchConfig"] diff --git a/litellm/llms/firecrawl/search/__init__.py b/litellm/llms/firecrawl/search/__init__.py index 46619d05b63..5b28e6a5068 100644 --- a/litellm/llms/firecrawl/search/__init__.py +++ b/litellm/llms/firecrawl/search/__init__.py @@ -1,6 +1,7 @@ """ Firecrawl Search API module. """ + from litellm.llms.firecrawl.search.transformation import FirecrawlSearchConfig __all__ = ["FirecrawlSearchConfig"] diff --git a/litellm/llms/firecrawl/search/transformation.py b/litellm/llms/firecrawl/search/transformation.py index 61b589218cc..ab98cb636cf 100644 --- a/litellm/llms/firecrawl/search/transformation.py +++ b/litellm/llms/firecrawl/search/transformation.py @@ -3,6 +3,7 @@ Firecrawl API Reference: https://docs.firecrawl.dev/api-reference/endpoint/search """ + from typing import Dict, List, Optional, TypedDict, Union import httpx diff --git a/litellm/llms/fireworks_ai/chat/transformation.py b/litellm/llms/fireworks_ai/chat/transformation.py index 6b654ebdfd3..ed6d167a118 100644 --- a/litellm/llms/fireworks_ai/chat/transformation.py +++ b/litellm/llms/fireworks_ai/chat/transformation.py @@ -392,11 +392,11 @@ def transform_response( ## FIREWORKS AI sends tool calls in the content field instead of tool_calls for choice in response.choices: - cast( - Choices, choice - ).message = self._handle_message_content_with_tool_calls( - message=cast(Choices, choice).message, - tool_calls=optional_params.get("tools", None), + cast(Choices, choice).message = ( + self._handle_message_content_with_tool_calls( + message=cast(Choices, choice).message, + tool_calls=optional_params.get("tools", None), + ) ) response._hidden_params = {"additional_headers": additional_headers} diff --git a/litellm/llms/gemini/files/transformation.py b/litellm/llms/gemini/files/transformation.py index bdfb0ee1e52..151f1430a0d 100644 --- a/litellm/llms/gemini/files/transformation.py +++ b/litellm/llms/gemini/files/transformation.py @@ -3,6 +3,7 @@ For vertex ai, check out the vertex_ai/files/handler.py file. """ + import time from typing import Any, List, Literal, Optional @@ -267,9 +268,11 @@ def transform_retrieve_file_response( object="file", purpose="user_data", status=status, - status_details=str(response_json.get("error", "")) - if gemini_state == "FAILED" - else None, + status_details=( + str(response_json.get("error", "")) + if gemini_state == "FAILED" + else None + ), ) except Exception as e: verbose_logger.exception(f"Error parsing file retrieve response: {str(e)}") diff --git a/litellm/llms/gemini/image_edit/transformation.py b/litellm/llms/gemini/image_edit/transformation.py index 5d9b1255d09..d46733e04b2 100644 --- a/litellm/llms/gemini/image_edit/transformation.py +++ b/litellm/llms/gemini/image_edit/transformation.py @@ -111,9 +111,9 @@ def transform_image_edit_request( # type: ignore[override] # Move aspectRatio into imageConfig inside generationConfig if "imageConfig" not in generation_config: generation_config["imageConfig"] = {} - generation_config["imageConfig"][ - "aspectRatio" - ] = image_edit_optional_request_params["aspectRatio"] + generation_config["imageConfig"]["aspectRatio"] = ( + image_edit_optional_request_params["aspectRatio"] + ) if generation_config: request_body["generationConfig"] = generation_config diff --git a/litellm/llms/gemini/image_generation/transformation.py b/litellm/llms/gemini/image_generation/transformation.py index b094fc133d7..9c4cd008b8c 100644 --- a/litellm/llms/gemini/image_generation/transformation.py +++ b/litellm/llms/gemini/image_generation/transformation.py @@ -245,11 +245,11 @@ def transform_image_generation_response( ImageObject( b64_json=inline_data["data"], url=None, - provider_specific_fields={ - "thought_signature": thought_sig - } - if thought_sig - else None, + provider_specific_fields=( + {"thought_signature": thought_sig} + if thought_sig + else None + ), ) ) diff --git a/litellm/llms/gemini/realtime/transformation.py b/litellm/llms/gemini/realtime/transformation.py index 2bb7bcd8b4f..c04f5725cf9 100644 --- a/litellm/llms/gemini/realtime/transformation.py +++ b/litellm/llms/gemini/realtime/transformation.py @@ -186,10 +186,10 @@ def map_openai_params( ) vertex_gemini_config = VertexGeminiConfig() - optional_params["generationConfig"][ - "tools" - ] = vertex_gemini_config._map_function( - value=value, optional_params=optional_params + optional_params["generationConfig"]["tools"] = ( + vertex_gemini_config._map_function( + value=value, optional_params=optional_params + ) ) elif key == "input_audio_transcription" and value is not None: optional_params["inputAudioTranscription"] = {} @@ -201,10 +201,10 @@ def map_openai_params( if ( len(transformed_audio_activity_config) > 0 ): # if the config is not empty, add it to the optional params - optional_params[ - "realtimeInputConfig" - ] = BidiGenerateContentRealtimeInputConfig( - automaticActivityDetection=transformed_audio_activity_config + optional_params["realtimeInputConfig"] = ( + BidiGenerateContentRealtimeInputConfig( + automaticActivityDetection=transformed_audio_activity_config + ) ) if len(optional_params["generationConfig"]) == 0: optional_params.pop("generationConfig") @@ -864,9 +864,9 @@ def transform_realtime_response( # noqa: PLR0915 "session_configuration_request" ] current_item_chunks = realtime_response_transform_input["current_item_chunks"] - current_delta_type: Optional[ - ALL_DELTA_TYPES - ] = realtime_response_transform_input["current_delta_type"] + current_delta_type: Optional[ALL_DELTA_TYPES] = ( + realtime_response_transform_input["current_delta_type"] + ) returned_message: List[OpenAIRealtimeEvents] = [] # Handle transcription events that arrive independently from model diff --git a/litellm/llms/github_copilot/common_utils.py b/litellm/llms/github_copilot/common_utils.py index d3169e3ca94..a9944df51a4 100644 --- a/litellm/llms/github_copilot/common_utils.py +++ b/litellm/llms/github_copilot/common_utils.py @@ -1,6 +1,7 @@ """ Constants for Copilot integration """ + from typing import Optional, Union from uuid import uuid4 diff --git a/litellm/llms/github_copilot/embedding/transformation.py b/litellm/llms/github_copilot/embedding/transformation.py index fa7bd4e3223..5fc6970342a 100644 --- a/litellm/llms/github_copilot/embedding/transformation.py +++ b/litellm/llms/github_copilot/embedding/transformation.py @@ -6,6 +6,7 @@ Implementation based on analysis of the copilot-api project by caozhiyuan: https://github.com/caozhiyuan/copilot-api """ + from typing import TYPE_CHECKING, Any, Optional import httpx diff --git a/litellm/llms/github_copilot/responses/transformation.py b/litellm/llms/github_copilot/responses/transformation.py index 46efc124b1d..d97b56db395 100644 --- a/litellm/llms/github_copilot/responses/transformation.py +++ b/litellm/llms/github_copilot/responses/transformation.py @@ -7,6 +7,7 @@ Implementation based on analysis of the copilot-api project by caozhiyuan: https://github.com/caozhiyuan/copilot-api """ + from typing import TYPE_CHECKING, Any, Dict, Optional, Union from litellm._logging import verbose_logger diff --git a/litellm/llms/google_pse/search/__init__.py b/litellm/llms/google_pse/search/__init__.py index 0fcfff82c38..0d2acafb798 100644 --- a/litellm/llms/google_pse/search/__init__.py +++ b/litellm/llms/google_pse/search/__init__.py @@ -1,6 +1,7 @@ """ Google Programmable Search Engine (PSE) API module. """ + from litellm.llms.google_pse.search.transformation import GooglePSESearchConfig __all__ = ["GooglePSESearchConfig"] diff --git a/litellm/llms/google_pse/search/transformation.py b/litellm/llms/google_pse/search/transformation.py index 2fabbc5d16e..a8aa109cbf0 100644 --- a/litellm/llms/google_pse/search/transformation.py +++ b/litellm/llms/google_pse/search/transformation.py @@ -3,6 +3,7 @@ Google PSE API Reference: https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list """ + from typing import Dict, List, Literal, Optional, TypedDict, Union import httpx @@ -42,10 +43,14 @@ class GooglePSESearchRequest(_GooglePSESearchRequestRequired, total=False): hq: str # Optional - append query terms to query imgSize: str # Optional - returns images of specified size imgType: str # Optional - returns images of specified type - linkSite: str # Optional - specifies all search results should contain a link to a URL + linkSite: ( + str # Optional - specifies all search results should contain a link to a URL + ) lr: str # Optional - language restrict (e.g., 'lang_en', 'lang_es') orTerms: str # Optional - provides additional search terms - relatedSite: str # Optional - specifies all search results should be pages related to URL + relatedSite: ( + str # Optional - specifies all search results should be pages related to URL + ) rights: str # Optional - filters based on licensing safe: str # Optional - search safety level ('active', 'off') searchType: str # Optional - specifies search type ('image') diff --git a/litellm/llms/groq/chat/transformation.py b/litellm/llms/groq/chat/transformation.py index 34ea7b03dd9..d07da006f2d 100644 --- a/litellm/llms/groq/chat/transformation.py +++ b/litellm/llms/groq/chat/transformation.py @@ -1,6 +1,7 @@ """ Translate from OpenAI's `/v1/chat/completions` to Groq's `/v1/chat/completions` """ + from typing import ( Any, Coroutine, @@ -115,8 +116,7 @@ def get_supported_openai_params(self, model: str) -> list: @overload def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: Literal[True] - ) -> Coroutine[Any, Any, List[AllMessageValues]]: - ... + ) -> Coroutine[Any, Any, List[AllMessageValues]]: ... @overload def _transform_messages( @@ -124,8 +124,7 @@ def _transform_messages( messages: List[AllMessageValues], model: str, is_async: Literal[False] = False, - ) -> List[AllMessageValues]: - ... + ) -> List[AllMessageValues]: ... def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: bool = False @@ -293,10 +292,10 @@ def transform_response( json_mode=json_mode, ) - mapped_service_tier: Literal[ - "auto", "default", "flex" - ] = self._map_groq_service_tier( - original_service_tier=getattr(model_response, "service_tier") + mapped_service_tier: Literal["auto", "default", "flex"] = ( + self._map_groq_service_tier( + original_service_tier=getattr(model_response, "service_tier") + ) ) setattr(model_response, "service_tier", mapped_service_tier) return model_response diff --git a/litellm/llms/heroku/chat/transformation.py b/litellm/llms/heroku/chat/transformation.py index d95e953636f..fb4cc361189 100644 --- a/litellm/llms/heroku/chat/transformation.py +++ b/litellm/llms/heroku/chat/transformation.py @@ -3,6 +3,7 @@ this is OpenAI compatible - no translation needed / occurs """ + import os from typing import Optional, List, Tuple, Union, Coroutine, Any, Literal, overload @@ -22,8 +23,7 @@ class HerokuChatConfig(OpenAIGPTConfig): @overload def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: Literal[True] - ) -> Coroutine[Any, Any, List[AllMessageValues]]: - ... + ) -> Coroutine[Any, Any, List[AllMessageValues]]: ... @overload def _transform_messages( @@ -31,8 +31,7 @@ def _transform_messages( messages: List[AllMessageValues], model: str, is_async: Literal[False] = False, - ) -> List[AllMessageValues]: - ... + ) -> List[AllMessageValues]: ... def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: bool = False diff --git a/litellm/llms/hosted_vllm/chat/transformation.py b/litellm/llms/hosted_vllm/chat/transformation.py index 05db1544a2b..b5a8b25beba 100644 --- a/litellm/llms/hosted_vllm/chat/transformation.py +++ b/litellm/llms/hosted_vllm/chat/transformation.py @@ -121,8 +121,7 @@ def _convert_file_to_video_url( @overload def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: Literal[True] - ) -> Coroutine[Any, Any, List[AllMessageValues]]: - ... + ) -> Coroutine[Any, Any, List[AllMessageValues]]: ... @overload def _transform_messages( @@ -130,8 +129,7 @@ def _transform_messages( messages: List[AllMessageValues], model: str, is_async: Literal[False] = False, - ) -> List[AllMessageValues]: - ... + ) -> List[AllMessageValues]: ... def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: bool = False @@ -146,9 +144,14 @@ def _transform_messages( thinking_blocks = message.pop("thinking_blocks", None) # type: ignore if thinking_blocks: new_content: list = [ - {"type": block["type"], "thinking": block.get("thinking", "")} - if block.get("type") == "thinking" - else {"type": block["type"], "data": block.get("data", "")} + ( + { + "type": block["type"], + "thinking": block.get("thinking", ""), + } + if block.get("type") == "thinking" + else {"type": block["type"], "data": block.get("data", "")} + ) for block in thinking_blocks ] existing_content = message.get("content") diff --git a/litellm/llms/huggingface/embedding/transformation.py b/litellm/llms/huggingface/embedding/transformation.py index 03088d6e151..88d42cfcdcc 100644 --- a/litellm/llms/huggingface/embedding/transformation.py +++ b/litellm/llms/huggingface/embedding/transformation.py @@ -40,17 +40,17 @@ class HuggingFaceEmbeddingConfig(BaseConfig): Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate """ - hf_task: Optional[ - hf_tasks - ] = None # litellm-specific param, used to know the api spec to use when calling huggingface api + hf_task: Optional[hf_tasks] = ( + None # litellm-specific param, used to know the api spec to use when calling huggingface api + ) best_of: Optional[int] = None decoder_input_details: Optional[bool] = None details: Optional[bool] = True # enables returning logprobs + best of max_new_tokens: Optional[int] = None repetition_penalty: Optional[float] = None - return_full_text: Optional[ - bool - ] = False # by default don't return the input as part of the output + return_full_text: Optional[bool] = ( + False # by default don't return the input as part of the output + ) seed: Optional[int] = None temperature: Optional[float] = None top_k: Optional[int] = None @@ -120,9 +120,9 @@ def map_openai_params( optional_params["top_p"] = value if param == "n": optional_params["best_of"] = value - optional_params[ - "do_sample" - ] = True # Need to sample if you want best of for hf inference endpoints + optional_params["do_sample"] = ( + True # Need to sample if you want best of for hf inference endpoints + ) if param == "stream": optional_params["stream"] = value if param == "stop": @@ -363,9 +363,9 @@ def validate_environment( "content-type": "application/json", } if api_key is not None: - default_headers[ - "Authorization" - ] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens + default_headers["Authorization"] = ( + f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens + ) headers = {**headers, **default_headers} return headers diff --git a/litellm/llms/infinity/rerank/transformation.py b/litellm/llms/infinity/rerank/transformation.py index 314bf2f8a36..b9804605454 100644 --- a/litellm/llms/infinity/rerank/transformation.py +++ b/litellm/llms/infinity/rerank/transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format. +Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format. Why separate file? Make it easy to see how transformation works """ diff --git a/litellm/llms/jina_ai/rerank/transformation.py b/litellm/llms/jina_ai/rerank/transformation.py index 48d876f8ea2..e05aaed6c85 100644 --- a/litellm/llms/jina_ai/rerank/transformation.py +++ b/litellm/llms/jina_ai/rerank/transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format. +Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format. Why separate file? Make it easy to see how transformation works diff --git a/litellm/llms/lemonade/chat/transformation.py b/litellm/llms/lemonade/chat/transformation.py index a9039388a49..168d51a16d8 100644 --- a/litellm/llms/lemonade/chat/transformation.py +++ b/litellm/llms/lemonade/chat/transformation.py @@ -1,6 +1,7 @@ """ Translate from OpenAI's `/v1/chat/completions` to Lemonade's `/v1/chat/completions` """ + from typing import Any, List, Optional, Tuple, Union import httpx diff --git a/litellm/llms/lemonade/cost_calculator.py b/litellm/llms/lemonade/cost_calculator.py index 2042f6d0d4d..74d62da8759 100644 --- a/litellm/llms/lemonade/cost_calculator.py +++ b/litellm/llms/lemonade/cost_calculator.py @@ -4,6 +4,7 @@ Since Lemonade is a local/self-hosted service, all costs default to 0. This prevents cost calculation errors when using models not in model_prices_and_context_window.json """ + from typing import Tuple from litellm.types.utils import Usage diff --git a/litellm/llms/linkup/__init__.py b/litellm/llms/linkup/__init__.py index c761584b07c..dd242fc5660 100644 --- a/litellm/llms/linkup/__init__.py +++ b/litellm/llms/linkup/__init__.py @@ -1,6 +1,7 @@ """ Linkup API integration module. """ + from litellm.llms.linkup.search.transformation import LinkupSearchConfig __all__ = ["LinkupSearchConfig"] diff --git a/litellm/llms/linkup/search/__init__.py b/litellm/llms/linkup/search/__init__.py index 667c4630238..a8f2c04350c 100644 --- a/litellm/llms/linkup/search/__init__.py +++ b/litellm/llms/linkup/search/__init__.py @@ -1,6 +1,7 @@ """ Linkup Search API module. """ + from litellm.llms.linkup.search.transformation import LinkupSearchConfig __all__ = ["LinkupSearchConfig"] diff --git a/litellm/llms/linkup/search/transformation.py b/litellm/llms/linkup/search/transformation.py index 0554b8ab341..2b17d5642ac 100644 --- a/litellm/llms/linkup/search/transformation.py +++ b/litellm/llms/linkup/search/transformation.py @@ -3,6 +3,7 @@ Linkup API Reference: https://docs.linkup.so/pages/documentation/api-reference/endpoint/post-search """ + from typing import Dict, List, Literal, Optional, TypedDict, Union import httpx diff --git a/litellm/llms/lm_studio/embed/transformation.py b/litellm/llms/lm_studio/embed/transformation.py index 1285550c30f..87f4f6e73d5 100644 --- a/litellm/llms/lm_studio/embed/transformation.py +++ b/litellm/llms/lm_studio/embed/transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic from OpenAI /v1/embeddings format to LM Studio's `/v1/embeddings` format. +Transformation logic from OpenAI /v1/embeddings format to LM Studio's `/v1/embeddings` format. Why separate file? Make it easy to see how transformation works diff --git a/litellm/llms/minimax/chat/transformation.py b/litellm/llms/minimax/chat/transformation.py index 4095e57a8ae..69f228160f6 100644 --- a/litellm/llms/minimax/chat/transformation.py +++ b/litellm/llms/minimax/chat/transformation.py @@ -1,6 +1,7 @@ """ MiniMax OpenAI transformation config - extends OpenAI chat config for MiniMax's OpenAI-compatible API """ + from typing import List, Optional, Tuple import litellm diff --git a/litellm/llms/minimax/messages/transformation.py b/litellm/llms/minimax/messages/transformation.py index 13ed6ad3863..3190a5f5412 100644 --- a/litellm/llms/minimax/messages/transformation.py +++ b/litellm/llms/minimax/messages/transformation.py @@ -1,6 +1,7 @@ """ MiniMax Anthropic transformation config - extends AnthropicConfig for MiniMax's Anthropic-compatible API """ + from typing import Optional import litellm diff --git a/litellm/llms/mistral/chat/transformation.py b/litellm/llms/mistral/chat/transformation.py index 23fbe467fc8..f1ad3708236 100644 --- a/litellm/llms/mistral/chat/transformation.py +++ b/litellm/llms/mistral/chat/transformation.py @@ -344,9 +344,9 @@ def _add_reasoning_system_prompt_if_needed( # Handle both string and list content, preserving original format if isinstance(existing_content, str): # String content - prepend reasoning prompt - new_content: Union[ - str, list - ] = f"{reasoning_prompt}\n\n{existing_content}" + new_content: Union[str, list] = ( + f"{reasoning_prompt}\n\n{existing_content}" + ) elif isinstance(existing_content, list): # List content - prepend reasoning prompt as text block new_content = [ diff --git a/litellm/llms/mistral/ocr/transformation.py b/litellm/llms/mistral/ocr/transformation.py index 3d5e8763027..752df4f349e 100644 --- a/litellm/llms/mistral/ocr/transformation.py +++ b/litellm/llms/mistral/ocr/transformation.py @@ -1,6 +1,7 @@ """ Mistral OCR transformation implementation. """ + from typing import Any, Dict, Optional import httpx diff --git a/litellm/llms/moonshot/chat/transformation.py b/litellm/llms/moonshot/chat/transformation.py index e4d7b5f033b..4eb00fd81d6 100644 --- a/litellm/llms/moonshot/chat/transformation.py +++ b/litellm/llms/moonshot/chat/transformation.py @@ -19,8 +19,7 @@ class MoonshotChatConfig(OpenAIGPTConfig): @overload def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: Literal[True] - ) -> Coroutine[Any, Any, List[AllMessageValues]]: - ... + ) -> Coroutine[Any, Any, List[AllMessageValues]]: ... @overload def _transform_messages( @@ -28,8 +27,7 @@ def _transform_messages( messages: List[AllMessageValues], model: str, is_async: Literal[False] = False, - ) -> List[AllMessageValues]: - ... + ) -> List[AllMessageValues]: ... def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: bool = False diff --git a/litellm/llms/novita/chat/transformation.py b/litellm/llms/novita/chat/transformation.py index c05d2d7b2c5..5a64a124ade 100644 --- a/litellm/llms/novita/chat/transformation.py +++ b/litellm/llms/novita/chat/transformation.py @@ -1,5 +1,5 @@ """ -Support for OpenAI's `/v1/chat/completions` endpoint. +Support for OpenAI's `/v1/chat/completions` endpoint. Calls done in OpenAI/openai.py as Novita AI is openai-compatible. diff --git a/litellm/llms/nvidia_nim/chat/transformation.py b/litellm/llms/nvidia_nim/chat/transformation.py index e687229949b..2ef92a90626 100644 --- a/litellm/llms/nvidia_nim/chat/transformation.py +++ b/litellm/llms/nvidia_nim/chat/transformation.py @@ -1,12 +1,13 @@ """ -Nvidia NIM endpoint: https://docs.api.nvidia.com/nim/reference/databricks-dbrx-instruct-infer +Nvidia NIM endpoint: https://docs.api.nvidia.com/nim/reference/databricks-dbrx-instruct-infer -This is OpenAI compatible +This is OpenAI compatible This file only contains param mapping logic API calling is done using the OpenAI SDK with an api_base """ + from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig diff --git a/litellm/llms/nvidia_nim/embed.py b/litellm/llms/nvidia_nim/embed.py index 24c6cc34e4d..61c8e8244e4 100644 --- a/litellm/llms/nvidia_nim/embed.py +++ b/litellm/llms/nvidia_nim/embed.py @@ -1,7 +1,7 @@ """ Nvidia NIM embeddings endpoint: https://docs.api.nvidia.com/nim/reference/nvidia-nv-embedqa-e5-v5-infer -This is OpenAI compatible +This is OpenAI compatible This file only contains param mapping logic diff --git a/litellm/llms/oci/chat/transformation.py b/litellm/llms/oci/chat/transformation.py index b1af7ed2ec3..a972f8666bc 100644 --- a/litellm/llms/oci/chat/transformation.py +++ b/litellm/llms/oci/chat/transformation.py @@ -456,9 +456,7 @@ def _sign_with_manual_credentials( private_key = ( load_private_key_from_str(oci_key_content) if oci_key_content - else load_private_key_from_file(oci_key_file) - if oci_key_file - else None + else load_private_key_from_file(oci_key_file) if oci_key_file else None ) if private_key is None: diff --git a/litellm/llms/ollama/completion/transformation.py b/litellm/llms/ollama/completion/transformation.py index 6a03325e6c7..32981776753 100644 --- a/litellm/llms/ollama/completion/transformation.py +++ b/litellm/llms/ollama/completion/transformation.py @@ -93,9 +93,9 @@ class OllamaConfig(BaseConfig): repeat_penalty: Optional[float] = None temperature: Optional[float] = None seed: Optional[int] = None - stop: Optional[ - list - ] = None # stop is a list based on this - https://github.com/ollama/ollama/pull/442 + stop: Optional[list] = ( + None # stop is a list based on this - https://github.com/ollama/ollama/pull/442 + ) tfs_z: Optional[float] = None num_predict: Optional[int] = None top_k: Optional[int] = None diff --git a/litellm/llms/openai/chat/gpt_transformation.py b/litellm/llms/openai/chat/gpt_transformation.py index c12c6e6ba09..6b7ec4dfb1c 100644 --- a/litellm/llms/openai/chat/gpt_transformation.py +++ b/litellm/llms/openai/chat/gpt_transformation.py @@ -370,10 +370,10 @@ async def _async_transform(): List[OpenAIMessageContentListBlock], message_content ) for i, content_item in enumerate(message_content_types): - message_content_types[ - i - ] = await self._async_transform_content_item( - cast(OpenAIMessageContentListBlock, content_item), + message_content_types[i] = ( + await self._async_transform_content_item( + cast(OpenAIMessageContentListBlock, content_item), + ) ) return messages diff --git a/litellm/llms/openai/chat/guardrail_translation/handler.py b/litellm/llms/openai/chat/guardrail_translation/handler.py index bab4c3b5eb7..fe4026136ec 100644 --- a/litellm/llms/openai/chat/guardrail_translation/handler.py +++ b/litellm/llms/openai/chat/guardrail_translation/handler.py @@ -86,9 +86,9 @@ async def process_input_messages( if tool_calls_to_check: inputs["tool_calls"] = tool_calls_to_check # type: ignore if messages: - inputs[ - "structured_messages" - ] = messages # pass the openai /chat/completions messages to the guardrail, as-is + inputs["structured_messages"] = ( + messages # pass the openai /chat/completions messages to the guardrail, as-is + ) # Pass tools (function definitions) to the guardrail tools = data.get("tools") if tools: diff --git a/litellm/llms/openai/chat/o_series_transformation.py b/litellm/llms/openai/chat/o_series_transformation.py index fe8aec9bc2b..8db7ecf7b3a 100644 --- a/litellm/llms/openai/chat/o_series_transformation.py +++ b/litellm/llms/openai/chat/o_series_transformation.py @@ -1,14 +1,14 @@ """ -Support for o1/o3 model family +Support for o1/o3 model family https://platform.openai.com/docs/guides/reasoning Translations handled by LiteLLM: -- modalities: image => drop param (if user opts in to dropping param) -- role: system ==> translate to role 'user' -- streaming => faked by LiteLLM -- Tools, response_format => drop param (if user opts in to dropping param) -- Logprobs => drop param (if user opts in to dropping param) +- modalities: image => drop param (if user opts in to dropping param) +- role: system ==> translate to role 'user' +- streaming => faked by LiteLLM +- Tools, response_format => drop param (if user opts in to dropping param) +- Logprobs => drop param (if user opts in to dropping param) """ from typing import Any, Coroutine, List, Literal, Optional, Union, cast, overload @@ -141,8 +141,7 @@ def is_model_o_series_model(self, model: str) -> bool: @overload def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: Literal[True] - ) -> Coroutine[Any, Any, List[AllMessageValues]]: - ... + ) -> Coroutine[Any, Any, List[AllMessageValues]]: ... @overload def _transform_messages( @@ -150,8 +149,7 @@ def _transform_messages( messages: List[AllMessageValues], model: str, is_async: Literal[False] = False, - ) -> List[AllMessageValues]: - ... + ) -> List[AllMessageValues]: ... def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: bool = False diff --git a/litellm/llms/openai/common_utils.py b/litellm/llms/openai/common_utils.py index 35723ccd637..381f215a13f 100644 --- a/litellm/llms/openai/common_utils.py +++ b/litellm/llms/openai/common_utils.py @@ -201,7 +201,7 @@ def get_openai_client_cache_key( @staticmethod def get_openai_client_initialization_param_fields( - client_type: Literal["openai", "azure"] + client_type: Literal["openai", "azure"], ) -> Tuple[str, ...]: """Returns a tuple of fields that are used to initialize the OpenAI client""" if client_type == "openai": @@ -227,9 +227,9 @@ def _get_async_http_client( return httpx.AsyncClient( verify=ssl_config, transport=AsyncHTTPHandler._create_async_transport( - ssl_context=ssl_config - if isinstance(ssl_config, ssl.SSLContext) - else None, + ssl_context=( + ssl_config if isinstance(ssl_config, ssl.SSLContext) else None + ), ssl_verify=ssl_config if isinstance(ssl_config, bool) else None, shared_session=shared_session, ), diff --git a/litellm/llms/openai/completion/transformation.py b/litellm/llms/openai/completion/transformation.py index 44a4949d455..77dc0b54fe0 100644 --- a/litellm/llms/openai/completion/transformation.py +++ b/litellm/llms/openai/completion/transformation.py @@ -111,9 +111,9 @@ def convert_to_chat_model_response_object( if "model" in response_object: model_response_object.model = response_object["model"] - model_response_object._hidden_params[ - "original_response" - ] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response + model_response_object._hidden_params["original_response"] = ( + response_object # track original response, if users make a litellm.text_completion() request, we can return the original response + ) return model_response_object except Exception as e: raise e diff --git a/litellm/llms/openai/fine_tuning/handler.py b/litellm/llms/openai/fine_tuning/handler.py index 9804ff3539e..ec83e20a3b6 100644 --- a/litellm/llms/openai/fine_tuning/handler.py +++ b/litellm/llms/openai/fine_tuning/handler.py @@ -28,7 +28,14 @@ def get_openai_client( _is_async: bool = False, api_version: Optional[str] = None, litellm_params: Optional[dict] = None, - ) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]: + ) -> Optional[ + Union[ + OpenAI, + AsyncOpenAI, + AzureOpenAI, + AsyncAzureOpenAI, + ] + ]: received_args = locals() openai_client: Optional[ Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index be542677480..2ac4974bb08 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -562,9 +562,9 @@ async def _call_agentic_completion_hooks_openai( kwargs_with_provider = ( litellm_params.copy() if litellm_params else {} ) - kwargs_with_provider[ - "custom_llm_provider" - ] = custom_llm_provider + kwargs_with_provider["custom_llm_provider"] = ( + custom_llm_provider + ) # For OpenAI Chat Completions, use the chat completion agentic loop method agentic_response = ( diff --git a/litellm/llms/openai/transcriptions/whisper_transformation.py b/litellm/llms/openai/transcriptions/whisper_transformation.py index 1a7f47ae56e..fa507e1bc26 100644 --- a/litellm/llms/openai/transcriptions/whisper_transformation.py +++ b/litellm/llms/openai/transcriptions/whisper_transformation.py @@ -110,9 +110,9 @@ def transform_audio_transcription_request( if "response_format" not in data or ( data["response_format"] == "text" or data["response_format"] == "json" ): - data[ - "response_format" - ] = "verbose_json" # ensures 'duration' is received - used for cost calculation + data["response_format"] = ( + "verbose_json" # ensures 'duration' is received - used for cost calculation + ) return AudioTranscriptionRequestData( data=data, diff --git a/litellm/llms/openai_like/dynamic_config.py b/litellm/llms/openai_like/dynamic_config.py index 3d66556e522..fac453447fa 100644 --- a/litellm/llms/openai_like/dynamic_config.py +++ b/litellm/llms/openai_like/dynamic_config.py @@ -28,8 +28,7 @@ class JSONProviderConfig(base_class): # type: ignore[valid-type,misc] @overload def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: Literal[True] - ) -> Coroutine[Any, Any, List[AllMessageValues]]: - ... + ) -> Coroutine[Any, Any, List[AllMessageValues]]: ... @overload def _transform_messages( @@ -37,8 +36,7 @@ def _transform_messages( messages: List[AllMessageValues], model: str, is_async: Literal[False] = False, - ) -> List[AllMessageValues]: - ... + ) -> List[AllMessageValues]: ... def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: bool = False diff --git a/litellm/llms/openrouter/chat/transformation.py b/litellm/llms/openrouter/chat/transformation.py index 86e63fd0c41..0d7850e8c74 100644 --- a/litellm/llms/openrouter/chat/transformation.py +++ b/litellm/llms/openrouter/chat/transformation.py @@ -70,9 +70,9 @@ def map_openai_params( extra_body["models"] = models if route is not None: extra_body["route"] = route - mapped_openai_params[ - "extra_body" - ] = extra_body # openai client supports `extra_body` param + mapped_openai_params["extra_body"] = ( + extra_body # openai client supports `extra_body` param + ) return mapped_openai_params def _supports_cache_control_in_content(self, model: str) -> bool: diff --git a/litellm/llms/openrouter/embedding/transformation.py b/litellm/llms/openrouter/embedding/transformation.py index d1d0e911d16..8b836e8e5d2 100644 --- a/litellm/llms/openrouter/embedding/transformation.py +++ b/litellm/llms/openrouter/embedding/transformation.py @@ -6,6 +6,7 @@ Docs: https://openrouter.ai/docs """ + from typing import TYPE_CHECKING, Any, Optional import httpx diff --git a/litellm/llms/openrouter/image_edit/transformation.py b/litellm/llms/openrouter/image_edit/transformation.py index 9e5e313aad0..fcf066dd5ac 100644 --- a/litellm/llms/openrouter/image_edit/transformation.py +++ b/litellm/llms/openrouter/image_edit/transformation.py @@ -97,9 +97,9 @@ def map_openai_params( if key == "size": if "image_config" not in mapped_params: mapped_params["image_config"] = {} - mapped_params["image_config"][ - "aspect_ratio" - ] = self._map_size_to_aspect_ratio(cast(str, value)) + mapped_params["image_config"]["aspect_ratio"] = ( + self._map_size_to_aspect_ratio(cast(str, value)) + ) elif key == "quality": image_size = self._map_quality_to_image_size(cast(str, value)) if image_size: diff --git a/litellm/llms/openrouter/image_generation/transformation.py b/litellm/llms/openrouter/image_generation/transformation.py index a55716a5e50..9c2293eb3f1 100644 --- a/litellm/llms/openrouter/image_generation/transformation.py +++ b/litellm/llms/openrouter/image_generation/transformation.py @@ -49,7 +49,6 @@ ) from litellm.llms.openrouter.common_utils import OpenRouterException - if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj else: diff --git a/litellm/llms/ovhcloud/chat/transformation.py b/litellm/llms/ovhcloud/chat/transformation.py index 1416b782f17..342ad700e00 100644 --- a/litellm/llms/ovhcloud/chat/transformation.py +++ b/litellm/llms/ovhcloud/chat/transformation.py @@ -4,6 +4,7 @@ Our unified API follows the OpenAI standard. More information on our website: https://endpoints.ai.cloud.ovh.net """ + from typing import Optional, Union, List import httpx diff --git a/litellm/llms/ovhcloud/embedding/transformation.py b/litellm/llms/ovhcloud/embedding/transformation.py index 38e88da125f..6b5c43e2d06 100644 --- a/litellm/llms/ovhcloud/embedding/transformation.py +++ b/litellm/llms/ovhcloud/embedding/transformation.py @@ -2,6 +2,7 @@ This is OpenAI compatible - no transformation is applied """ + from typing import List, Optional, Union import httpx diff --git a/litellm/llms/parallel_ai/search/__init__.py b/litellm/llms/parallel_ai/search/__init__.py index b96914f13dd..23c4b9751d7 100644 --- a/litellm/llms/parallel_ai/search/__init__.py +++ b/litellm/llms/parallel_ai/search/__init__.py @@ -1,6 +1,7 @@ """ Parallel AI Search API module. """ + from litellm.llms.parallel_ai.search.transformation import ParallelAISearchConfig __all__ = ["ParallelAISearchConfig"] diff --git a/litellm/llms/parallel_ai/search/transformation.py b/litellm/llms/parallel_ai/search/transformation.py index e19bc5400d1..12d570f1733 100644 --- a/litellm/llms/parallel_ai/search/transformation.py +++ b/litellm/llms/parallel_ai/search/transformation.py @@ -3,6 +3,7 @@ Parallel AI API Reference: https://docs.parallel.ai/api-reference/search-and-extract-api-beta/search """ + from typing import Dict, List, Optional, TypedDict, Union import httpx diff --git a/litellm/llms/perplexity/search/transformation.py b/litellm/llms/perplexity/search/transformation.py index f89d5565498..ea96f87957c 100644 --- a/litellm/llms/perplexity/search/transformation.py +++ b/litellm/llms/perplexity/search/transformation.py @@ -1,6 +1,7 @@ """ Calls Perplexity's /search endpoint to search the web. """ + from typing import Dict, List, Optional, TypedDict, Union import httpx diff --git a/litellm/llms/petals/completion/transformation.py b/litellm/llms/petals/completion/transformation.py index 24910cba8f3..d50afc4625a 100644 --- a/litellm/llms/petals/completion/transformation.py +++ b/litellm/llms/petals/completion/transformation.py @@ -37,9 +37,9 @@ class PetalsConfig(BaseConfig): """ max_length: Optional[int] = None - max_new_tokens: Optional[ - int - ] = litellm.max_tokens # petals requires max tokens to be set + max_new_tokens: Optional[int] = ( + litellm.max_tokens + ) # petals requires max tokens to be set do_sample: Optional[bool] = None temperature: Optional[float] = None top_k: Optional[int] = None diff --git a/litellm/llms/predibase/chat/transformation.py b/litellm/llms/predibase/chat/transformation.py index 9fbb9d6c9e2..0569318062f 100644 --- a/litellm/llms/predibase/chat/transformation.py +++ b/litellm/llms/predibase/chat/transformation.py @@ -31,9 +31,9 @@ class PredibaseConfig(BaseConfig): DEFAULT_MAX_TOKENS # openai default - requests hang if max_new_tokens not given ) repetition_penalty: Optional[float] = None - return_full_text: Optional[ - bool - ] = False # by default don't return the input as part of the output + return_full_text: Optional[bool] = ( + False # by default don't return the input as part of the output + ) seed: Optional[int] = None stop: Optional[List[str]] = None temperature: Optional[float] = None @@ -100,9 +100,9 @@ def map_openai_params( optional_params["top_p"] = value if param == "n": optional_params["best_of"] = value - optional_params[ - "do_sample" - ] = True # Need to sample if you want best of for hf inference endpoints + optional_params["do_sample"] = ( + True # Need to sample if you want best of for hf inference endpoints + ) if param == "stream": optional_params["stream"] = value if param == "stop": diff --git a/litellm/llms/runwayml/text_to_speech/__init__.py b/litellm/llms/runwayml/text_to_speech/__init__.py index 98337a8321a..cf6e9071bf0 100644 --- a/litellm/llms/runwayml/text_to_speech/__init__.py +++ b/litellm/llms/runwayml/text_to_speech/__init__.py @@ -1,4 +1,5 @@ """RunwayML Text-to-Speech implementation.""" + from .transformation import RunwayMLTextToSpeechConfig __all__ = ["RunwayMLTextToSpeechConfig"] diff --git a/litellm/llms/runwayml/text_to_speech/transformation.py b/litellm/llms/runwayml/text_to_speech/transformation.py index dfcb92bc68b..314a538f7c5 100644 --- a/litellm/llms/runwayml/text_to_speech/transformation.py +++ b/litellm/llms/runwayml/text_to_speech/transformation.py @@ -3,6 +3,7 @@ Maps OpenAI TTS spec to RunwayML Text-to-Speech API """ + import asyncio import time from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Union diff --git a/litellm/llms/sagemaker/completion/transformation.py b/litellm/llms/sagemaker/completion/transformation.py index dd7cb603905..8fd32bc4460 100644 --- a/litellm/llms/sagemaker/completion/transformation.py +++ b/litellm/llms/sagemaker/completion/transformation.py @@ -1,7 +1,7 @@ """ Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invoke` -In the Huggingface TGI format. +In the Huggingface TGI format. """ import json @@ -100,9 +100,9 @@ def map_openai_params( optional_params["top_p"] = value if param == "n": optional_params["best_of"] = value - optional_params[ - "do_sample" - ] = True # Need to sample if you want best of for hf inference endpoints + optional_params["do_sample"] = ( + True # Need to sample if you want best of for hf inference endpoints + ) if param == "stream": optional_params["stream"] = value if param == "stop": diff --git a/litellm/llms/sagemaker/embedding/transformation.py b/litellm/llms/sagemaker/embedding/transformation.py index 04430171187..09bdb9295e7 100644 --- a/litellm/llms/sagemaker/embedding/transformation.py +++ b/litellm/llms/sagemaker/embedding/transformation.py @@ -1,7 +1,7 @@ """ Translate from OpenAI's `/v1/embeddings` to Sagemaker's `/invoke` -In the Huggingface TGI format. +In the Huggingface TGI format. """ from typing import TYPE_CHECKING, Any, List, Optional, Union diff --git a/litellm/llms/sambanova/chat.py b/litellm/llms/sambanova/chat.py index 3c4003f72e9..2120256f918 100644 --- a/litellm/llms/sambanova/chat.py +++ b/litellm/llms/sambanova/chat.py @@ -100,8 +100,7 @@ def map_openai_params( @overload def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: Literal[True] - ) -> Coroutine[Any, Any, List[AllMessageValues]]: - ... + ) -> Coroutine[Any, Any, List[AllMessageValues]]: ... @overload def _transform_messages( @@ -109,8 +108,7 @@ def _transform_messages( messages: List[AllMessageValues], model: str, is_async: Literal[False] = False, - ) -> List[AllMessageValues]: - ... + ) -> List[AllMessageValues]: ... def _transform_messages( self, messages: List[AllMessageValues], model: str, is_async: bool = False diff --git a/litellm/llms/sambanova/embedding/transformation.py b/litellm/llms/sambanova/embedding/transformation.py index eca44c7c039..5c88188b84e 100644 --- a/litellm/llms/sambanova/embedding/transformation.py +++ b/litellm/llms/sambanova/embedding/transformation.py @@ -2,6 +2,7 @@ This is OpenAI compatible - no transformation is applied """ + from typing import List, Optional, Union import httpx diff --git a/litellm/llms/sap/chat/transformation.py b/litellm/llms/sap/chat/transformation.py index 7f6bab4a1d5..0d6d387122d 100755 --- a/litellm/llms/sap/chat/transformation.py +++ b/litellm/llms/sap/chat/transformation.py @@ -1,6 +1,7 @@ """ Translate from OpenAI's `/v1/chat/completions` to SAP Generative AI Hub's Orchestration Service`v2/completion` """ + from typing import ( List, Optional, diff --git a/litellm/llms/searchapi/search/__init__.py b/litellm/llms/searchapi/search/__init__.py index 783238c9f73..19a2e652416 100644 --- a/litellm/llms/searchapi/search/__init__.py +++ b/litellm/llms/searchapi/search/__init__.py @@ -1,4 +1,5 @@ """SearchAPI.io search integration for LiteLLM.""" + from litellm.llms.searchapi.search.transformation import SearchAPIConfig __all__ = ["SearchAPIConfig"] diff --git a/litellm/llms/searchapi/search/transformation.py b/litellm/llms/searchapi/search/transformation.py index 92b2814018d..c04e1377f9c 100644 --- a/litellm/llms/searchapi/search/transformation.py +++ b/litellm/llms/searchapi/search/transformation.py @@ -3,6 +3,7 @@ SearchAPI.io API Reference: https://www.searchapi.io/docs/google """ + from typing import Dict, List, Literal, Optional, TypedDict, Union, cast from urllib.parse import urlencode diff --git a/litellm/llms/searxng/__init__.py b/litellm/llms/searxng/__init__.py index f7ad1978c76..320cebb06f2 100644 --- a/litellm/llms/searxng/__init__.py +++ b/litellm/llms/searxng/__init__.py @@ -1,6 +1,7 @@ """ SearXNG API integration module. """ + from litellm.llms.searxng.search.transformation import SearXNGSearchConfig __all__ = ["SearXNGSearchConfig"] diff --git a/litellm/llms/searxng/search/__init__.py b/litellm/llms/searxng/search/__init__.py index 88ac5dc629b..b52b323a2eb 100644 --- a/litellm/llms/searxng/search/__init__.py +++ b/litellm/llms/searxng/search/__init__.py @@ -1,6 +1,7 @@ """ SearXNG Search API module. """ + from litellm.llms.searxng.search.transformation import SearXNGSearchConfig __all__ = ["SearXNGSearchConfig"] diff --git a/litellm/llms/searxng/search/transformation.py b/litellm/llms/searxng/search/transformation.py index bbd3b765010..ee6f3895721 100644 --- a/litellm/llms/searxng/search/transformation.py +++ b/litellm/llms/searxng/search/transformation.py @@ -3,6 +3,7 @@ SearXNG API Reference: https://docs.searxng.org/dev/search_api.html """ + from typing import Dict, List, Optional, TypedDict, Union import httpx diff --git a/litellm/llms/serper/search/__init__.py b/litellm/llms/serper/search/__init__.py index cdb4bd4b53f..3bf59ee8d6b 100644 --- a/litellm/llms/serper/search/__init__.py +++ b/litellm/llms/serper/search/__init__.py @@ -1,6 +1,7 @@ """ Serper Search API module. """ + from litellm.llms.serper.search.transformation import SerperSearchConfig __all__ = ["SerperSearchConfig"] diff --git a/litellm/llms/serper/search/transformation.py b/litellm/llms/serper/search/transformation.py index 34e726dc77d..0daccbe652b 100644 --- a/litellm/llms/serper/search/transformation.py +++ b/litellm/llms/serper/search/transformation.py @@ -3,6 +3,7 @@ Serper API Reference: https://serper.dev """ + from typing import Dict, List, Optional, TypedDict, Union import httpx diff --git a/litellm/llms/snowflake/chat/transformation.py b/litellm/llms/snowflake/chat/transformation.py index 3e590680a75..23bb6f44757 100644 --- a/litellm/llms/snowflake/chat/transformation.py +++ b/litellm/llms/snowflake/chat/transformation.py @@ -14,7 +14,6 @@ from ..utils import SnowflakeBaseConfig - if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj diff --git a/litellm/llms/stability/image_generation/transformation.py b/litellm/llms/stability/image_generation/transformation.py index ac63548bf56..c8c2a16fcd1 100644 --- a/litellm/llms/stability/image_generation/transformation.py +++ b/litellm/llms/stability/image_generation/transformation.py @@ -80,9 +80,9 @@ def map_openai_params( if k in supported_params: # Map size to aspect_ratio if k == "size" and v in OPENAI_SIZE_TO_STABILITY_ASPECT_RATIO: - optional_params[ - "aspect_ratio" - ] = OPENAI_SIZE_TO_STABILITY_ASPECT_RATIO[v] + optional_params["aspect_ratio"] = ( + OPENAI_SIZE_TO_STABILITY_ASPECT_RATIO[v] + ) elif k == "n": # Store n for later, but don't pass to Stability optional_params["_n"] = v diff --git a/litellm/llms/tavily/search/__init__.py b/litellm/llms/tavily/search/__init__.py index 6e3fe1163c7..38ca5b60ae0 100644 --- a/litellm/llms/tavily/search/__init__.py +++ b/litellm/llms/tavily/search/__init__.py @@ -1,6 +1,7 @@ """ Tavily Search API module. """ + from litellm.llms.tavily.search.transformation import TavilySearchConfig __all__ = ["TavilySearchConfig"] diff --git a/litellm/llms/tavily/search/transformation.py b/litellm/llms/tavily/search/transformation.py index 1228433b539..ec96db96f36 100644 --- a/litellm/llms/tavily/search/transformation.py +++ b/litellm/llms/tavily/search/transformation.py @@ -3,6 +3,7 @@ Tavily API Reference: https://docs.tavily.com/documentation/api-reference/endpoint/search """ + from typing import Dict, List, Optional, TypedDict, Union import httpx @@ -32,7 +33,9 @@ class TavilySearchRequest(_TavilySearchRequestRequired, total=False): include_domains: List[str] # Optional - list of domains to include (max 300) exclude_domains: List[str] # Optional - list of domains to exclude (max 150) topic: str # Optional - category of search ('general', 'news', 'finance'), default 'general' - search_depth: str # Optional - depth of search ('basic', 'advanced'), default 'basic' + search_depth: ( + str # Optional - depth of search ('basic', 'advanced'), default 'basic' + ) include_answer: Union[bool, str] # Optional - include LLM-generated answer include_raw_content: Union[bool, str] # Optional - include raw HTML content include_images: bool # Optional - perform image search diff --git a/litellm/llms/together_ai/chat.py b/litellm/llms/together_ai/chat.py index e8a784d2779..706995eaba7 100644 --- a/litellm/llms/together_ai/chat.py +++ b/litellm/llms/together_ai/chat.py @@ -1,5 +1,5 @@ """ -Support for OpenAI's `/v1/chat/completions` endpoint. +Support for OpenAI's `/v1/chat/completions` endpoint. Calls done in OpenAI/openai.py as TogetherAI is openai-compatible. diff --git a/litellm/llms/together_ai/embed.py b/litellm/llms/together_ai/embed.py index 577df0256cc..6a39b94acfc 100644 --- a/litellm/llms/together_ai/embed.py +++ b/litellm/llms/together_ai/embed.py @@ -1,5 +1,5 @@ """ -Support for OpenAI's `/v1/embeddings` endpoint. +Support for OpenAI's `/v1/embeddings` endpoint. Calls done in OpenAI/openai.py as TogetherAI is openai-compatible. diff --git a/litellm/llms/together_ai/rerank/transformation.py b/litellm/llms/together_ai/rerank/transformation.py index 63b593dfe42..f4d642bd25a 100644 --- a/litellm/llms/together_ai/rerank/transformation.py +++ b/litellm/llms/together_ai/rerank/transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format. +Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format. Why separate file? Make it easy to see how transformation works """ diff --git a/litellm/llms/vercel_ai_gateway/chat/transformation.py b/litellm/llms/vercel_ai_gateway/chat/transformation.py index 81a1688b909..fda1c4a77cb 100644 --- a/litellm/llms/vercel_ai_gateway/chat/transformation.py +++ b/litellm/llms/vercel_ai_gateway/chat/transformation.py @@ -63,9 +63,9 @@ def map_openai_params( if provider_options is not None: extra_body["providerOptions"] = provider_options - mapped_openai_params[ - "extra_body" - ] = extra_body # openai client supports `extra_body` param + mapped_openai_params["extra_body"] = ( + extra_body # openai client supports `extra_body` param + ) return mapped_openai_params def transform_request( diff --git a/litellm/llms/vertex_ai/batches/handler.py b/litellm/llms/vertex_ai/batches/handler.py index 2cb02942061..028e02eb0ca 100644 --- a/litellm/llms/vertex_ai/batches/handler.py +++ b/litellm/llms/vertex_ai/batches/handler.py @@ -73,8 +73,10 @@ def create_batch( "Authorization": f"Bearer {access_token}", } - vertex_batch_request: VertexAIBatchPredictionJob = VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request( - request=create_batch_data + vertex_batch_request: VertexAIBatchPredictionJob = ( + VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request( + request=create_batch_data + ) ) if _is_async is True: diff --git a/litellm/llms/vertex_ai/context_caching/transformation.py b/litellm/llms/vertex_ai/context_caching/transformation.py index 950edbeb478..ef71357d1d2 100644 --- a/litellm/llms/vertex_ai/context_caching/transformation.py +++ b/litellm/llms/vertex_ai/context_caching/transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic for context caching. +Transformation logic for context caching. Why separate file? Make it easy to see how transformation works """ @@ -19,7 +19,7 @@ def get_first_continuous_block_idx( - filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message) + filtered_messages: List[Tuple[int, AllMessageValues]], # (idx, message) ) -> int: """ Find the array index that ends the first continuous sequence of message blocks. diff --git a/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py b/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py index b677cf3b1ec..7316f81c9f7 100644 --- a/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py +++ b/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py @@ -93,9 +93,9 @@ def _get_token_and_url_context_caching( model=model, vertex_project=vertex_project, vertex_location=vertex_location, - vertex_api_version="v1beta1" - if custom_llm_provider == "vertex_ai_beta" - else "v1", + vertex_api_version=( + "v1beta1" if custom_llm_provider == "vertex_ai_beta" else "v1" + ), ) def check_cache( diff --git a/litellm/llms/vertex_ai/fine_tuning/handler.py b/litellm/llms/vertex_ai/fine_tuning/handler.py index 77891e245cd..a5971de0e94 100644 --- a/litellm/llms/vertex_ai/fine_tuning/handler.py +++ b/litellm/llms/vertex_ai/fine_tuning/handler.py @@ -65,9 +65,9 @@ def convert_openai_request_to_vertex( ) if create_fine_tuning_job_data.validation_file: - supervised_tuning_spec[ - "validation_dataset" - ] = create_fine_tuning_job_data.validation_file + supervised_tuning_spec["validation_dataset"] = ( + create_fine_tuning_job_data.validation_file + ) _vertex_hyperparameters = ( self._transform_openai_hyperparameters_to_vertex_hyperparameters( @@ -349,9 +349,9 @@ async def pass_through_vertex_ai_POST_request( elif "cachedContents" in request_route: _model = request_data.get("model") if _model is not None and "/publishers/google/models/" not in _model: - request_data[ - "model" - ] = f"projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{_model}" + request_data["model"] = ( + f"projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{_model}" + ) url = f"{base_url}/v1beta1/projects/{vertex_project}/locations/{vertex_location}{request_route}" else: diff --git a/litellm/llms/vertex_ai/gemini/transformation.py b/litellm/llms/vertex_ai/gemini/transformation.py index 7945c44d44c..a3e2b8b43e5 100644 --- a/litellm/llms/vertex_ai/gemini/transformation.py +++ b/litellm/llms/vertex_ai/gemini/transformation.py @@ -3,6 +3,7 @@ Why separate file? Make it easy to see how transformation works """ + import json import os from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union, cast @@ -612,16 +613,14 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915 contents.append(ContentType(role="user", parts=tool_call_responses)) if len(contents) == 0: - verbose_logger.warning( - """ + verbose_logger.warning(""" No contents in messages. Contents are required. See https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.publishers.models/generateContent#request-body. If the original request did not comply to OpenAI API requirements it should have failed by now, but LiteLLM does not check for missing messages. Setting an empty content to prevent an 400 error. Relevant Issue - https://github.com/BerriAI/litellm/issues/9733 - """ - ) + """) contents.append(ContentType(role="user", parts=[PartType(text=" ")])) return contents except Exception as e: diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index 36f51c5b2f5..13896d136c8 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -500,9 +500,9 @@ def _map_function( # noqa: PLR0915 value = _remove_strict_from_schema(value) for tool in value: - openai_function_object: Optional[ - ChatCompletionToolParamFunctionChunk - ] = None + openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = ( + None + ) if "function" in tool: # tools list _openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore **tool["function"] @@ -634,15 +634,15 @@ def _map_function( # noqa: PLR0915 _tools_list.append(search_tool) if googleSearchRetrieval is not None: retrieval_tool = Tools() - retrieval_tool[ - VertexToolName.GOOGLE_SEARCH_RETRIEVAL.value - ] = googleSearchRetrieval + retrieval_tool[VertexToolName.GOOGLE_SEARCH_RETRIEVAL.value] = ( + googleSearchRetrieval + ) _tools_list.append(retrieval_tool) if enterpriseWebSearch is not None: enterprise_tool = Tools() - enterprise_tool[ - VertexToolName.ENTERPRISE_WEB_SEARCH.value - ] = enterpriseWebSearch + enterprise_tool[VertexToolName.ENTERPRISE_WEB_SEARCH.value] = ( + enterpriseWebSearch + ) _tools_list.append(enterprise_tool) if code_execution is not None: code_tool = Tools() @@ -1089,16 +1089,16 @@ def map_openai_params( # noqa: PLR0915 param_description="thinking_budget", ) if VertexGeminiConfig._is_gemini_3_or_newer(model): - optional_params[ - "thinkingConfig" - ] = VertexGeminiConfig._map_reasoning_effort_to_thinking_level( - effort_value, model + optional_params["thinkingConfig"] = ( + VertexGeminiConfig._map_reasoning_effort_to_thinking_level( + effort_value, model + ) ) else: - optional_params[ - "thinkingConfig" - ] = VertexGeminiConfig._map_reasoning_effort_to_thinking_budget( - effort_value, model + optional_params["thinkingConfig"] = ( + VertexGeminiConfig._map_reasoning_effort_to_thinking_budget( + effort_value, model + ) ) elif param == "thinking": # Validate no conflict with thinking_level @@ -1107,11 +1107,11 @@ def map_openai_params( # noqa: PLR0915 param_name="thinking", param_description="thinking_budget", ) - optional_params[ - "thinkingConfig" - ] = VertexGeminiConfig._map_thinking_param( - cast(AnthropicThinkingParam, value), - model=model, + optional_params["thinkingConfig"] = ( + VertexGeminiConfig._map_thinking_param( + cast(AnthropicThinkingParam, value), + model=model, + ) ) elif param == "modalities" and isinstance(value, list): response_modalities = self.map_response_modalities(value) @@ -1533,10 +1533,10 @@ def _transform_parts( _tool_response_chunk["provider_specific_fields"] = { # type: ignore "thought_signature": thought_signature } - _tool_response_chunk[ - "id" - ] = _encode_tool_call_id_with_signature( - _tool_response_chunk["id"] or "", thought_signature + _tool_response_chunk["id"] = ( + _encode_tool_call_id_with_signature( + _tool_response_chunk["id"] or "", thought_signature + ) ) _tools.append(_tool_response_chunk) cumulative_tool_call_idx += 1 @@ -2383,28 +2383,28 @@ def _transform_google_generate_content_to_openai_model_response( ## ADD METADATA TO RESPONSE ## setattr(model_response, "vertex_ai_grounding_metadata", grounding_metadata) - model_response._hidden_params[ - "vertex_ai_grounding_metadata" - ] = grounding_metadata + model_response._hidden_params["vertex_ai_grounding_metadata"] = ( + grounding_metadata + ) setattr( model_response, "vertex_ai_url_context_metadata", url_context_metadata ) - model_response._hidden_params[ - "vertex_ai_url_context_metadata" - ] = url_context_metadata + model_response._hidden_params["vertex_ai_url_context_metadata"] = ( + url_context_metadata + ) setattr(model_response, "vertex_ai_safety_results", safety_ratings) - model_response._hidden_params[ - "vertex_ai_safety_results" - ] = safety_ratings # older approach - maintaining to prevent regressions + model_response._hidden_params["vertex_ai_safety_results"] = ( + safety_ratings # older approach - maintaining to prevent regressions + ) ## ADD CITATION METADATA ## setattr(model_response, "vertex_ai_citation_metadata", citation_metadata) - model_response._hidden_params[ - "vertex_ai_citation_metadata" - ] = citation_metadata # older approach - maintaining to prevent regressions + model_response._hidden_params["vertex_ai_citation_metadata"] = ( + citation_metadata # older approach - maintaining to prevent regressions + ) ## ADD TRAFFIC TYPE ## traffic_type = completion_response.get("usageMetadata", {}).get( diff --git a/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py b/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py index 08831a8215f..830bc504bab 100644 --- a/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py +++ b/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py @@ -1,5 +1,5 @@ """ -Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format. +Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format. Why separate file? Make it easy to see how transformation works """ diff --git a/litellm/llms/vertex_ai/image_generation/vertex_gemini_transformation.py b/litellm/llms/vertex_ai/image_generation/vertex_gemini_transformation.py index 98e02743bd2..f4bda8d1bed 100644 --- a/litellm/llms/vertex_ai/image_generation/vertex_gemini_transformation.py +++ b/litellm/llms/vertex_ai/image_generation/vertex_gemini_transformation.py @@ -313,11 +313,11 @@ def transform_image_generation_response( ImageObject( b64_json=inline_data["data"], url=None, - provider_specific_fields={ - "thought_signature": thought_sig - } - if thought_sig - else None, + provider_specific_fields=( + {"thought_signature": thought_sig} + if thought_sig + else None + ), ) ) diff --git a/litellm/llms/vertex_ai/ocr/__init__.py b/litellm/llms/vertex_ai/ocr/__init__.py index 15da24f3089..915fbd49030 100644 --- a/litellm/llms/vertex_ai/ocr/__init__.py +++ b/litellm/llms/vertex_ai/ocr/__init__.py @@ -1,4 +1,5 @@ """Vertex AI OCR module.""" + from .transformation import VertexAIOCRConfig __all__ = ["VertexAIOCRConfig"] diff --git a/litellm/llms/vertex_ai/ocr/deepseek_transformation.py b/litellm/llms/vertex_ai/ocr/deepseek_transformation.py index 953bb51fd1c..516ee03ba55 100644 --- a/litellm/llms/vertex_ai/ocr/deepseek_transformation.py +++ b/litellm/llms/vertex_ai/ocr/deepseek_transformation.py @@ -1,6 +1,7 @@ """ Vertex AI DeepSeek OCR transformation implementation. """ + import json from typing import TYPE_CHECKING, Any, Dict, Optional @@ -314,9 +315,11 @@ def transform_ocr_response( "pages": [ { "index": 0, - "markdown": content - if isinstance(content, str) - else json.dumps(content), + "markdown": ( + content + if isinstance(content, str) + else json.dumps(content) + ), } ], "model": ocr_data.get("model", model), diff --git a/litellm/llms/vertex_ai/ocr/transformation.py b/litellm/llms/vertex_ai/ocr/transformation.py index 6fe88459ea2..cbf15803132 100644 --- a/litellm/llms/vertex_ai/ocr/transformation.py +++ b/litellm/llms/vertex_ai/ocr/transformation.py @@ -1,6 +1,7 @@ """ Vertex AI Mistral OCR transformation implementation. """ + from typing import Dict, Optional from litellm._logging import verbose_logger diff --git a/litellm/llms/vertex_ai/text_to_speech/text_to_speech_handler.py b/litellm/llms/vertex_ai/text_to_speech/text_to_speech_handler.py index 9d9015c2b91..b835ad7d8fa 100644 --- a/litellm/llms/vertex_ai/text_to_speech/text_to_speech_handler.py +++ b/litellm/llms/vertex_ai/text_to_speech/text_to_speech_handler.py @@ -139,7 +139,7 @@ def audio_speech( ########## End of logging ############ ####### Send the request ################### if _is_async is True: - return self.async_audio_speech( # type:ignore + return self.async_audio_speech( # type: ignore logging_obj=logging_obj, url=url, headers=headers, request=request ) sync_handler = _get_httpx_client() diff --git a/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py b/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py index 5d94cd42129..f523f814d5f 100644 --- a/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py +++ b/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py @@ -5,6 +5,7 @@ Unlike Gemini models which use Google's token counting API, partner models use their respective publisher-specific count-tokens endpoints. """ + from typing import Any, Dict, Optional from litellm.llms.custom_httpx.http_handler import get_async_httpx_client diff --git a/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py b/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py index 5fffd983c24..18c5ec3d839 100644 --- a/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py @@ -90,8 +90,10 @@ def embedding( use_psc_endpoint_format=use_psc_endpoint_format, ) headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) - vertex_request: VertexEmbeddingRequest = litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( - input=input, optional_params=optional_params, model=model + vertex_request: VertexEmbeddingRequest = ( + litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( + input=input, optional_params=optional_params, model=model + ) ) _client_params = {} @@ -184,8 +186,10 @@ async def async_embedding( use_psc_endpoint_format=use_psc_endpoint_format, ) headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) - vertex_request: VertexEmbeddingRequest = litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( - input=input, optional_params=optional_params, model=model + vertex_request: VertexEmbeddingRequest = ( + litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( + input=input, optional_params=optional_params, model=model + ) ) _async_client_params = {} diff --git a/litellm/llms/vllm/completion/transformation.py b/litellm/llms/vllm/completion/transformation.py index ec4c07e95d8..e03b07f9897 100644 --- a/litellm/llms/vllm/completion/transformation.py +++ b/litellm/llms/vllm/completion/transformation.py @@ -1,5 +1,5 @@ """ -Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`. +Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`. NOT RECOMMENDED FOR PRODUCTION USE. Use `hosted_vllm/` instead. """ diff --git a/litellm/llms/voyage/embedding/transformation_contextual.py b/litellm/llms/voyage/embedding/transformation_contextual.py index 4df2fa4ba31..1f5ca99f47d 100644 --- a/litellm/llms/voyage/embedding/transformation_contextual.py +++ b/litellm/llms/voyage/embedding/transformation_contextual.py @@ -1,7 +1,8 @@ """ -This module is used to transform the request and response for the Voyage contextualized embeddings API. -This would be used for all the contextualized embeddings models in Voyage. +This module is used to transform the request and response for the Voyage contextualized embeddings API. +This would be used for all the contextualized embeddings models in Voyage. """ + from typing import List, Optional, Union import httpx diff --git a/litellm/main.py b/litellm/main.py index eace9c630ba..605d784afdd 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3792,9 +3792,9 @@ def completion( # type: ignore # noqa: PLR0915 "aws_region_name" not in optional_params or optional_params["aws_region_name"] is None ): - optional_params[ - "aws_region_name" - ] = aws_bedrock_client.meta.region_name + optional_params["aws_region_name"] = ( + aws_bedrock_client.meta.region_name + ) bedrock_route = BedrockModelInfo.get_bedrock_route(model) if bedrock_route == "converse": @@ -6198,9 +6198,9 @@ def adapter_completion( new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs) response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore - translated_response: Optional[ - Union[BaseModel, AdapterCompletionStreamWrapper] - ] = None + translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = ( + None + ) if isinstance(response, ModelResponse): translated_response = translation_obj.translate_completion_output_params( response=response @@ -6380,9 +6380,9 @@ async def atranscription(*args, **kwargs) -> TranscriptionResponse: if existing_duration is None: calculated_duration = calculate_request_duration(file) if calculated_duration is not None: - response._hidden_params[ - "audio_transcription_duration" - ] = calculated_duration + response._hidden_params["audio_transcription_duration"] = ( + calculated_duration + ) return response except Exception as e: @@ -6605,9 +6605,9 @@ def transcription( if existing_duration is None: calculated_duration = calculate_request_duration(file) if calculated_duration is not None: - response._hidden_params[ - "audio_transcription_duration" - ] = calculated_duration + response._hidden_params["audio_transcription_duration"] = ( + calculated_duration + ) if response is None: raise ValueError("Unmapped provider passed in. Unable to get the response.") @@ -6911,9 +6911,9 @@ def speech( # noqa: PLR0915 ElevenLabsTextToSpeechConfig.ELEVENLABS_QUERY_PARAMS_KEY ] = query_params - litellm_params_dict[ - ElevenLabsTextToSpeechConfig.ELEVENLABS_VOICE_ID_KEY - ] = voice_id + litellm_params_dict[ElevenLabsTextToSpeechConfig.ELEVENLABS_VOICE_ID_KEY] = ( + voice_id + ) if api_base is not None: litellm_params_dict["api_base"] = api_base @@ -7492,9 +7492,9 @@ def stream_chunk_builder( # noqa: PLR0915 ] if len(content_chunks) > 0: - response["choices"][0]["message"][ - "content" - ] = processor.get_combined_content(content_chunks) + response["choices"][0]["message"]["content"] = ( + processor.get_combined_content(content_chunks) + ) thinking_blocks = [ chunk @@ -7505,9 +7505,9 @@ def stream_chunk_builder( # noqa: PLR0915 ] if len(thinking_blocks) > 0: - response["choices"][0]["message"][ - "thinking_blocks" - ] = processor.get_combined_thinking_content(thinking_blocks) + response["choices"][0]["message"]["thinking_blocks"] = ( + processor.get_combined_thinking_content(thinking_blocks) + ) reasoning_chunks = [ chunk @@ -7518,9 +7518,9 @@ def stream_chunk_builder( # noqa: PLR0915 ] if len(reasoning_chunks) > 0: - response["choices"][0]["message"][ - "reasoning_content" - ] = processor.get_combined_reasoning_content(reasoning_chunks) + response["choices"][0]["message"]["reasoning_content"] = ( + processor.get_combined_reasoning_content(reasoning_chunks) + ) annotation_chunks = [ chunk diff --git a/litellm/ocr/__init__.py b/litellm/ocr/__init__.py index a20b0ef6cad..e97497b2db7 100644 --- a/litellm/ocr/__init__.py +++ b/litellm/ocr/__init__.py @@ -1,4 +1,5 @@ """OCR module for LiteLLM.""" + from .main import aocr, ocr __all__ = ["ocr", "aocr"] diff --git a/litellm/ocr/main.py b/litellm/ocr/main.py index d90a931b59a..5d73ddc8972 100644 --- a/litellm/ocr/main.py +++ b/litellm/ocr/main.py @@ -1,6 +1,7 @@ """ Main OCR function for LiteLLM. """ + import asyncio import base64 import contextvars @@ -262,11 +263,11 @@ def ocr( api_base = dynamic_api_base # Get provider config - ocr_provider_config: Optional[ - BaseOCRConfig - ] = ProviderConfigManager.get_provider_ocr_config( - model=model, - provider=litellm.LlmProviders(custom_llm_provider), + ocr_provider_config: Optional[BaseOCRConfig] = ( + ProviderConfigManager.get_provider_ocr_config( + model=model, + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if ocr_provider_config is None: diff --git a/litellm/proxy/_experimental/mcp_server/db.py b/litellm/proxy/_experimental/mcp_server/db.py index fbef33c32ed..de265f1304c 100644 --- a/litellm/proxy/_experimental/mcp_server/db.py +++ b/litellm/proxy/_experimental/mcp_server/db.py @@ -188,12 +188,12 @@ async def get_mcp_server( """ Returns the matching mcp server from the db iff exists """ - mcp_server: Optional[ - LiteLLM_MCPServerTable - ] = await prisma_client.db.litellm_mcpservertable.find_unique( - where={ - "server_id": server_id, - } + mcp_server: Optional[LiteLLM_MCPServerTable] = ( + await prisma_client.db.litellm_mcpservertable.find_unique( + where={ + "server_id": server_id, + } + ) ) return mcp_server @@ -204,12 +204,12 @@ async def get_mcp_servers( """ Returns the matching mcp servers from the db with the server_ids """ - _mcp_servers: List[ - LiteLLM_MCPServerTable - ] = await prisma_client.db.litellm_mcpservertable.find_many( - where={ - "server_id": {"in": server_ids}, - } + _mcp_servers: List[LiteLLM_MCPServerTable] = ( + await prisma_client.db.litellm_mcpservertable.find_many( + where={ + "server_id": {"in": server_ids}, + } + ) ) final_mcp_servers: List[LiteLLM_MCPServerTable] = [] for _mcp_server in _mcp_servers: diff --git a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py index 07309eb57f2..e5a8211f94d 100644 --- a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py @@ -541,9 +541,9 @@ def _build_oauth_protected_resource_response( ) ], "resource": resource_url, - "scopes_supported": mcp_server.scopes - if mcp_server and mcp_server.scopes - else [], + "scopes_supported": ( + mcp_server.scopes if mcp_server and mcp_server.scopes else [] + ), } @@ -653,16 +653,18 @@ def _build_oauth_authorization_server_response( "authorization_endpoint": authorization_endpoint, "token_endpoint": token_endpoint, "response_types_supported": ["code"], - "scopes_supported": mcp_server.scopes - if mcp_server and mcp_server.scopes - else [], + "scopes_supported": ( + mcp_server.scopes if mcp_server and mcp_server.scopes else [] + ), "grant_types_supported": ["authorization_code", "refresh_token"], "code_challenge_methods_supported": ["S256"], "token_endpoint_auth_methods_supported": ["client_secret_post"], # Claude expects a registration endpoint, even if we just fake it - "registration_endpoint": f"{request_base_url}/{mcp_server_name}/register" - if mcp_server_name - else f"{request_base_url}/register", + "registration_endpoint": ( + f"{request_base_url}/{mcp_server_name}/register" + if mcp_server_name + else f"{request_base_url}/register" + ), } diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 1e9d5c5a529..163765e11a7 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -501,12 +501,12 @@ async def _register_openapi_tools( ) # Update tool name to server name mapping (for both prefixed and base names) - self.tool_name_to_mcp_server_name_mapping[ - base_tool_name - ] = server_prefix - self.tool_name_to_mcp_server_name_mapping[ - prefixed_tool_name - ] = server_prefix + self.tool_name_to_mcp_server_name_mapping[base_tool_name] = ( + server_prefix + ) + self.tool_name_to_mcp_server_name_mapping[prefixed_tool_name] = ( + server_prefix + ) registered_count += 1 verbose_logger.debug( diff --git a/litellm/proxy/_experimental/mcp_server/semantic_tool_filter.py b/litellm/proxy/_experimental/mcp_server/semantic_tool_filter.py index 0bafd7da265..0e32bfd7026 100644 --- a/litellm/proxy/_experimental/mcp_server/semantic_tool_filter.py +++ b/litellm/proxy/_experimental/mcp_server/semantic_tool_filter.py @@ -3,6 +3,7 @@ Filters MCP tools semantically for /chat/completions and /responses endpoints. """ + from typing import TYPE_CHECKING, Any, Dict, List, Optional from litellm._logging import verbose_logger diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index cd06de2a2df..74aa4415d79 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -1,6 +1,7 @@ """ LiteLLM MCP Server Routes """ + # pyright: reportInvalidTypeForm=false, reportArgumentType=false, reportOptionalCall=false import asyncio @@ -1822,9 +1823,9 @@ async def execute_mcp_tool( # noqa: PLR0915 "litellm_logging_obj", None ) if litellm_logging_obj: - litellm_logging_obj.model_call_details[ - "mcp_tool_call_metadata" - ] = standard_logging_mcp_tool_call + litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = ( + standard_logging_mcp_tool_call + ) litellm_logging_obj.model = f"MCP: {name}" # Resolve the MCP server early so BYOK checks and credential injection # apply to ALL dispatch paths (local tool registry AND managed MCP server). @@ -1836,9 +1837,9 @@ async def execute_mcp_tool( # noqa: PLR0915 mcp_server.mcp_info or {} ).get("mcp_server_cost_info") if litellm_logging_obj: - litellm_logging_obj.model_call_details[ - "mcp_tool_call_metadata" - ] = standard_logging_mcp_tool_call + litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = ( + standard_logging_mcp_tool_call + ) # BYOK: retrieve the stored per-user credential. A single DB call # both checks existence and fetches the value, avoiding a double query. @@ -2580,17 +2581,15 @@ def set_auth_context( ) auth_context_var.set(auth_user) - def get_auth_context() -> ( - Tuple[ - Optional[UserAPIKeyAuth], - Optional[str], - Optional[List[str]], - Optional[Dict[str, Dict[str, str]]], - Optional[Dict[str, str]], - Optional[Dict[str, str]], - Optional[str], - ] - ): + def get_auth_context() -> Tuple[ + Optional[UserAPIKeyAuth], + Optional[str], + Optional[List[str]], + Optional[Dict[str, Dict[str, str]]], + Optional[Dict[str, str]], + Optional[Dict[str, str]], + Optional[str], + ]: """ Get the UserAPIKeyAuth from the auth context variable. diff --git a/litellm/proxy/_experimental/mcp_server/utils.py b/litellm/proxy/_experimental/mcp_server/utils.py index 8189f212bcb..66d9fde0bb6 100644 --- a/litellm/proxy/_experimental/mcp_server/utils.py +++ b/litellm/proxy/_experimental/mcp_server/utils.py @@ -1,6 +1,7 @@ """ MCP Server Utilities """ + from typing import Any, Dict, Mapping, Optional, Tuple import os diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 91a953c217e..9174a29cbd2 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -883,9 +883,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase): allowed_cache_controls: Optional[list] = [] config: Optional[dict] = {} permissions: Optional[dict] = {} - model_max_budget: Optional[ - dict - ] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} + model_max_budget: Optional[dict] = ( + {} + ) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} model_config = ConfigDict(protected_namespaces=()) model_rpm_limit: Optional[dict] = None @@ -1027,9 +1027,9 @@ class RegenerateKeyRequest(GenerateKeyRequest): spend: Optional[float] = None metadata: Optional[dict] = None new_master_key: Optional[str] = None - grace_period: Optional[ - str - ] = None # Duration to keep old key valid (e.g. "24h", "2d"); None = immediate revoke + grace_period: Optional[str] = ( + None # Duration to keep old key valid (e.g. "24h", "2d"); None = immediate revoke + ) class ResetSpendRequest(LiteLLMPydanticObjectBase): @@ -1539,12 +1539,12 @@ class NewCustomerRequest(BudgetNewRequest): blocked: bool = False # allow/disallow requests for this end-user budget_id: Optional[str] = None # give either a budget_id or max_budget spend: Optional[float] = None - allowed_model_region: Optional[ - AllowedModelRegion - ] = None # require all user requests to use models in this specific region - default_model: Optional[ - str - ] = None # if no equivalent model in allowed region - default all requests to this model + allowed_model_region: Optional[AllowedModelRegion] = ( + None # require all user requests to use models in this specific region + ) + default_model: Optional[str] = ( + None # if no equivalent model in allowed region - default all requests to this model + ) object_permission: Optional[LiteLLM_ObjectPermissionBase] = None @model_validator(mode="before") @@ -1567,12 +1567,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase): blocked: bool = False # allow/disallow requests for this end-user max_budget: Optional[float] = None budget_id: Optional[str] = None # give either a budget_id or max_budget - allowed_model_region: Optional[ - AllowedModelRegion - ] = None # require all user requests to use models in this specific region - default_model: Optional[ - str - ] = None # if no equivalent model in allowed region - default all requests to this model + allowed_model_region: Optional[AllowedModelRegion] = ( + None # require all user requests to use models in this specific region + ) + default_model: Optional[str] = ( + None # if no equivalent model in allowed region - default all requests to this model + ) object_permission: Optional[LiteLLM_ObjectPermissionBase] = None @@ -1662,15 +1662,15 @@ class NewTeamRequest(TeamBase): ] = None # raise an error if 'guaranteed_throughput' is set and we're overallocating tpm model_tpm_limit: Optional[Dict[str, int]] = None - team_member_budget: Optional[ - float - ] = None # allow user to set a budget for all team members - team_member_rpm_limit: Optional[ - int - ] = None # allow user to set RPM limit for all team members - team_member_tpm_limit: Optional[ - int - ] = None # allow user to set TPM limit for all team members + team_member_budget: Optional[float] = ( + None # allow user to set a budget for all team members + ) + team_member_rpm_limit: Optional[int] = ( + None # allow user to set RPM limit for all team members + ) + team_member_tpm_limit: Optional[int] = ( + None # allow user to set TPM limit for all team members + ) team_member_key_duration: Optional[str] = None # e.g. "1d", "1w", "1m" team_member_budget_duration: Optional[str] = None # e.g. "30d", "1mo" allowed_vector_store_indexes: Optional[List[AllowedVectorStoreIndexItem]] = None @@ -1767,9 +1767,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase): class AddTeamCallback(LiteLLMPydanticObjectBase): callback_name: str - callback_type: Optional[ - Literal["success", "failure", "success_and_failure"] - ] = "success_and_failure" + callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = ( + "success_and_failure" + ) callback_vars: Dict[str, str] @model_validator(mode="before") @@ -2109,9 +2109,9 @@ class ConfigList(LiteLLMPydanticObjectBase): stored_in_db: Optional[bool] field_default_value: Any premium_field: bool = False - nested_fields: Optional[ - List[FieldDetail] - ] = None # For nested dictionary or Pydantic fields + nested_fields: Optional[List[FieldDetail]] = ( + None # For nested dictionary or Pydantic fields + ) class UserHeaderMapping(LiteLLMPydanticObjectBase): @@ -2468,9 +2468,9 @@ class UserAPIKeyAuth( user_max_budget: Optional[float] = None request_route: Optional[str] = None user: Optional[Any] = None # Expanded user object when expand=user is used - created_by_user: Optional[ - Any - ] = None # Expanded created_by user when expand=user is used + created_by_user: Optional[Any] = ( + None # Expanded created_by user when expand=user is used + ) end_user_object_permission: Optional[LiteLLM_ObjectPermissionTable] = None # Decoded upstream IdP claims (groups, roles, etc.) propagated by JWT auth machinery # and forwarded into outbound tokens by guardrails such as MCPJWTSigner. @@ -2609,9 +2609,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase): budget_id: Optional[str] = None created_at: datetime updated_at: datetime - user: Optional[ - Any - ] = None # You might want to replace 'Any' with a more specific type if available + user: Optional[Any] = ( + None # You might want to replace 'Any' with a more specific type if available + ) litellm_budget_table: Optional[LiteLLM_BudgetTable] = None user_email: Optional[str] = None @@ -3761,9 +3761,9 @@ class TeamModelDeleteRequest(BaseModel): # Organization Member Requests class OrganizationMemberAddRequest(OrgMemberAddRequest): organization_id: str - max_budget_in_organization: Optional[ - float - ] = None # Users max budget within the organization + max_budget_in_organization: Optional[float] = ( + None # Users max budget within the organization + ) class OrganizationMemberDeleteRequest(MemberDeleteRequest): @@ -4014,9 +4014,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase): Maps provider names to their budget configs. """ - providers: Dict[ - str, ProviderBudgetResponseObject - ] = {} # Dictionary mapping provider names to their budget configurations + providers: Dict[str, ProviderBudgetResponseObject] = ( + {} + ) # Dictionary mapping provider names to their budget configurations class ProxyStateVariables(TypedDict): @@ -4160,9 +4160,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): enforce_rbac: bool = False roles_jwt_field: Optional[str] = None # v2 on role mappings role_mappings: Optional[List[RoleMapping]] = None - object_id_jwt_field: Optional[ - str - ] = None # can be either user / team, inferred from the role mapping + object_id_jwt_field: Optional[str] = ( + None # can be either user / team, inferred from the role mapping + ) scope_mappings: Optional[List[ScopeMapping]] = None enforce_scope_based_access: bool = False enforce_team_based_model_access: bool = False diff --git a/litellm/proxy/agent_endpoints/agent_registry.py b/litellm/proxy/agent_endpoints/agent_registry.py index 436de8b0aff..e69d16e0eac 100644 --- a/litellm/proxy/agent_endpoints/agent_registry.py +++ b/litellm/proxy/agent_endpoints/agent_registry.py @@ -171,13 +171,13 @@ async def add_agent_to_db( created_agent_dict = created_agent.model_dump() if created_agent.object_permission is not None: try: - created_agent_dict[ - "object_permission" - ] = created_agent.object_permission.model_dump() + created_agent_dict["object_permission"] = ( + created_agent.object_permission.model_dump() + ) except Exception: - created_agent_dict[ - "object_permission" - ] = created_agent.object_permission.dict() + created_agent_dict["object_permission"] = ( + created_agent.object_permission.dict() + ) return AgentResponse(**created_agent_dict) # type: ignore except Exception as e: raise Exception(f"Error adding agent to DB: {str(e)}") @@ -283,13 +283,13 @@ async def patch_agent_in_db( patched_agent_dict = patched_agent.model_dump() if patched_agent.object_permission is not None: try: - patched_agent_dict[ - "object_permission" - ] = patched_agent.object_permission.model_dump() + patched_agent_dict["object_permission"] = ( + patched_agent.object_permission.model_dump() + ) except Exception: - patched_agent_dict[ - "object_permission" - ] = patched_agent.object_permission.dict() + patched_agent_dict["object_permission"] = ( + patched_agent.object_permission.dict() + ) return AgentResponse(**patched_agent_dict) # type: ignore except Exception as e: raise Exception(f"Error patching agent in DB: {str(e)}") @@ -384,13 +384,13 @@ async def update_agent_in_db( updated_agent_dict = updated_agent.model_dump() if updated_agent.object_permission is not None: try: - updated_agent_dict[ - "object_permission" - ] = updated_agent.object_permission.model_dump() + updated_agent_dict["object_permission"] = ( + updated_agent.object_permission.model_dump() + ) except Exception: - updated_agent_dict[ - "object_permission" - ] = updated_agent.object_permission.dict() + updated_agent_dict["object_permission"] = ( + updated_agent.object_permission.dict() + ) return AgentResponse(**updated_agent_dict) # type: ignore except Exception as e: raise Exception(f"Error updating agent in DB: {str(e)}") @@ -414,9 +414,9 @@ async def get_all_agents_from_db( # object_permission is eagerly loaded via include above if agent.object_permission is not None: try: - agent_dict[ - "object_permission" - ] = agent.object_permission.model_dump() + agent_dict["object_permission"] = ( + agent.object_permission.model_dump() + ) except Exception: agent_dict["object_permission"] = agent.object_permission.dict() agents.append(agent_dict) diff --git a/litellm/proxy/agent_endpoints/endpoints.py b/litellm/proxy/agent_endpoints/endpoints.py index 6e5d4562b55..af15b8a11c1 100644 --- a/litellm/proxy/agent_endpoints/endpoints.py +++ b/litellm/proxy/agent_endpoints/endpoints.py @@ -177,10 +177,9 @@ async def get_agents( for agent in returned_agents: if agent.litellm_params is None: agent.litellm_params = {} - agent.litellm_params[ - "is_public" - ] = litellm.public_agent_groups is not None and ( - agent.agent_id in litellm.public_agent_groups + agent.litellm_params["is_public"] = ( + litellm.public_agent_groups is not None + and (agent.agent_id in litellm.public_agent_groups) ) if health_check: @@ -378,13 +377,13 @@ async def get_agent_by_id( agent_dict = agent_row.model_dump() if agent_row.object_permission is not None: try: - agent_dict[ - "object_permission" - ] = agent_row.object_permission.model_dump() + agent_dict["object_permission"] = ( + agent_row.object_permission.model_dump() + ) except Exception: - agent_dict[ - "object_permission" - ] = agent_row.object_permission.dict() + agent_dict["object_permission"] = ( + agent_row.object_permission.dict() + ) agent = AgentResponse(**agent_dict) # type: ignore else: # Agent found in memory — refresh spend from DB diff --git a/litellm/proxy/agent_endpoints/model_list_helpers.py b/litellm/proxy/agent_endpoints/model_list_helpers.py index 37308b92f78..b88c602ac34 100644 --- a/litellm/proxy/agent_endpoints/model_list_helpers.py +++ b/litellm/proxy/agent_endpoints/model_list_helpers.py @@ -3,6 +3,7 @@ Used by proxy model endpoints to make agents appear in UI alongside models. """ + from typing import List from litellm._logging import verbose_proxy_logger diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 815393467de..a19ef1ede21 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -8,6 +8,7 @@ 2. If user is in budget 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget """ + import asyncio import re import time @@ -414,9 +415,9 @@ async def common_checks( # noqa: PLR0915 model=_model, team_object=team_object, llm_router=llm_router, - team_model_aliases=valid_token.team_model_aliases - if valid_token - else None, + team_model_aliases=( + valid_token.team_model_aliases if valid_token else None + ), ): raise ProxyException( message=f"Team not allowed to access model. Team={team_object.team_id}, Model={_model}. Allowed team models = {team_object.models}", diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 800cca21dbe..e2f06abc52f 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -920,9 +920,9 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 route=route, ) if _end_user_object is not None: - end_user_params[ - "allowed_model_region" - ] = _end_user_object.allowed_model_region + end_user_params["allowed_model_region"] = ( + _end_user_object.allowed_model_region + ) if _end_user_object.litellm_budget_table is not None: _apply_budget_limits_to_end_user_params( end_user_params=end_user_params, @@ -1499,9 +1499,9 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 if _end_user_object is not None: valid_token_dict.update(end_user_params) - valid_token_dict[ - "end_user_object_permission" - ] = _end_user_object.object_permission + valid_token_dict["end_user_object_permission"] = ( + _end_user_object.object_permission + ) # check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions # sso/login, ui/login, /key functions and /user functions diff --git a/litellm/proxy/client/cli/commands/auth.py b/litellm/proxy/client/cli/commands/auth.py index aeb59e78a53..ec81fdbe12d 100644 --- a/litellm/proxy/client/cli/commands/auth.py +++ b/litellm/proxy/client/cli/commands/auth.py @@ -241,7 +241,7 @@ def prompt_team_selection(teams: List[Dict[str, Any]]) -> Optional[Dict[str, Any def prompt_team_selection_fallback( - teams: List[Dict[str, Any]] + teams: List[Dict[str, Any]], ) -> Optional[Dict[str, Any]]: """Fallback team selection for non-interactive environments""" if not teams: diff --git a/litellm/proxy/client/cli/commands/models.py b/litellm/proxy/client/cli/commands/models.py index 8acafbd88ab..387979a69a0 100644 --- a/litellm/proxy/client/cli/commands/models.py +++ b/litellm/proxy/client/cli/commands/models.py @@ -129,9 +129,11 @@ def list_models(ctx: click.Context, output_format: Literal["table", "json"]) -> table.add_row( str(model.get("id", "")), str(model.get("object", "model")), - format_timestamp(created) - if isinstance(created, int) - else format_iso_datetime_str(created), + ( + format_timestamp(created) + if isinstance(created, int) + else format_iso_datetime_str(created) + ), str(model.get("owned_by", "")), ) diff --git a/litellm/proxy/client/cli/main.py b/litellm/proxy/client/cli/main.py index 744acf38382..22de5a78614 100644 --- a/litellm/proxy/client/cli/main.py +++ b/litellm/proxy/client/cli/main.py @@ -45,14 +45,16 @@ def print_version(base_url: str, api_key: Optional[str]): expose_value=False, help="Show the LiteLLM Proxy CLI and server version and exit.", callback=lambda ctx, param, value: ( - print_version( - ctx.params.get("base_url") or "http://localhost:4000", - ctx.params.get("api_key"), + ( + print_version( + ctx.params.get("base_url") or "http://localhost:4000", + ctx.params.get("api_key"), + ) + or ctx.exit() ) - or ctx.exit() - ) - if value and not ctx.resilient_parsing - else None, + if value and not ctx.resilient_parsing + else None + ), ) @click.option( "--base-url", diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index d1aebe4dceb..d8061e98a31 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -793,9 +793,11 @@ def _debug_log_request_payload(self) -> None: "Request received by LiteLLM: payload too large to log (%d bytes, limit %d). Keys: %s", len(_payload_str), MAX_PAYLOAD_SIZE_FOR_DEBUG_LOG, - list(self.data.keys()) - if isinstance(self.data, dict) - else type(self.data).__name__, + ( + list(self.data.keys()) + if isinstance(self.data, dict) + else type(self.data).__name__ + ), ) else: verbose_proxy_logger.debug( @@ -1055,9 +1057,9 @@ async def base_process_llm_request( # noqa: PLR0915 # aliasing/routing, but the OpenAI-compatible response `model` field should reflect # what the client sent. if requested_model_from_client: - self.data[ - "_litellm_client_requested_model" - ] = requested_model_from_client + self.data["_litellm_client_requested_model"] = ( + requested_model_from_client + ) # Streaming: attach a closure that CSW.__anext__ will call # at stream end instead of firing logging directly. The @@ -1649,7 +1651,9 @@ async def async_streaming_data_generator( verbose_proxy_logger.debug("inside generator") try: str_so_far = "" - async for chunk in proxy_logging_obj.async_post_call_streaming_iterator_hook( + async for ( + chunk + ) in proxy_logging_obj.async_post_call_streaming_iterator_hook( user_api_key_dict=user_api_key_dict, response=response, request_data=request_data, @@ -1877,9 +1881,9 @@ def _inject_cost_into_usage_dict(obj: dict, model_name: str) -> Optional[dict]: # Add cache-related fields to **params (handled by Usage.__init__) if cache_creation_input_tokens is not None: - usage_kwargs[ - "cache_creation_input_tokens" - ] = cache_creation_input_tokens + usage_kwargs["cache_creation_input_tokens"] = ( + cache_creation_input_tokens + ) if cache_read_input_tokens is not None: usage_kwargs["cache_read_input_tokens"] = cache_read_input_tokens diff --git a/litellm/proxy/common_utils/cache_coordinator.py b/litellm/proxy/common_utils/cache_coordinator.py index ccc73c5e6d8..24da9450ab8 100644 --- a/litellm/proxy/common_utils/cache_coordinator.py +++ b/litellm/proxy/common_utils/cache_coordinator.py @@ -22,11 +22,9 @@ class AsyncCacheProtocol(Protocol): """Protocol for cache backends used by EventDrivenCacheCoordinator.""" - async def async_get_cache(self, key: str, **kwargs: Any) -> Any: - ... + async def async_get_cache(self, key: str, **kwargs: Any) -> Any: ... - async def async_set_cache(self, key: str, value: Any, **kwargs: Any) -> Any: - ... + async def async_set_cache(self, key: str, value: Any, **kwargs: Any) -> Any: ... class EventDrivenCacheCoordinator: diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py index 9ecae363ed7..a206be87a11 100644 --- a/litellm/proxy/common_utils/callback_utils.py +++ b/litellm/proxy/common_utils/callback_utils.py @@ -362,17 +362,17 @@ def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str, remaining_requests_variable_name = f"litellm-key-remaining-requests-{model_group}" remaining_requests = _metadata.get(remaining_requests_variable_name, None) if remaining_requests: - headers[ - f"x-litellm-key-remaining-requests-{h11_model_group_name}" - ] = remaining_requests + headers[f"x-litellm-key-remaining-requests-{h11_model_group_name}"] = ( + remaining_requests + ) # Remaining Tokens remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}" remaining_tokens = _metadata.get(remaining_tokens_variable_name, None) if remaining_tokens: - headers[ - f"x-litellm-key-remaining-tokens-{h11_model_group_name}" - ] = remaining_tokens + headers[f"x-litellm-key-remaining-tokens-{h11_model_group_name}"] = ( + remaining_tokens + ) return headers @@ -472,9 +472,9 @@ def add_guardrail_response_to_standard_logging_object( ): if litellm_logging_obj is None: return - standard_logging_object: Optional[ - StandardLoggingPayload - ] = litellm_logging_obj.model_call_details.get("standard_logging_object") + standard_logging_object: Optional[StandardLoggingPayload] = ( + litellm_logging_obj.model_call_details.get("standard_logging_object") + ) if standard_logging_object is None: return guardrail_information = standard_logging_object.get("guardrail_information", []) diff --git a/litellm/proxy/common_utils/custom_openapi_spec.py b/litellm/proxy/common_utils/custom_openapi_spec.py index a93749c3952..fa3cb02195b 100644 --- a/litellm/proxy/common_utils/custom_openapi_spec.py +++ b/litellm/proxy/common_utils/custom_openapi_spec.py @@ -324,7 +324,7 @@ def add_request_schema( @staticmethod def add_chat_completion_request_schema( - openapi_schema: Dict[str, Any] + openapi_schema: Dict[str, Any], ) -> Dict[str, Any]: """ Add ProxyChatCompletionRequest schema to chat completion endpoints for documentation. @@ -380,7 +380,7 @@ def add_embedding_request_schema(openapi_schema: Dict[str, Any]) -> Dict[str, An @staticmethod def add_responses_api_request_schema( - openapi_schema: Dict[str, Any] + openapi_schema: Dict[str, Any], ) -> Dict[str, Any]: """ Add ResponsesAPIRequestParams schema to responses API endpoints for documentation. @@ -410,7 +410,7 @@ def add_responses_api_request_schema( @staticmethod def add_llm_api_request_schema_body( - openapi_schema: Dict[str, Any] + openapi_schema: Dict[str, Any], ) -> Dict[str, Any]: """ Add LLM API request schema bodies to OpenAPI specification for documentation. diff --git a/litellm/proxy/common_utils/debug_utils.py b/litellm/proxy/common_utils/debug_utils.py index 6f7038377bd..99eeeda1c86 100644 --- a/litellm/proxy/common_utils/debug_utils.py +++ b/litellm/proxy/common_utils/debug_utils.py @@ -245,9 +245,9 @@ async def get_memory_summary( health_status = "healthy" except ImportError: - process_memory[ - "error" - ] = "Install psutil for memory monitoring: pip install psutil" + process_memory["error"] = ( + "Install psutil for memory monitoring: pip install psutil" + ) except Exception as e: process_memory["error"] = str(e) @@ -301,9 +301,9 @@ async def get_memory_summary( # Add warning if garbage collection issues detected if uncollectable > 0: - gc_info[ - "warning" - ] = f"{uncollectable} uncollectable objects (possible memory leak)" + gc_info["warning"] = ( + f"{uncollectable} uncollectable objects (possible memory leak)" + ) return { "worker_pid": os.getpid(), @@ -369,9 +369,11 @@ def _get_uncollectable_objects_info() -> Dict[str, Any]: return { "count": len(uncollectable), "sample_types": [type(obj).__name__ for obj in uncollectable[:10]], - "warning": "If count > 0, you may have reference cycles preventing garbage collection" - if len(uncollectable) > 0 - else None, + "warning": ( + "If count > 0, you may have reference cycles preventing garbage collection" + if len(uncollectable) > 0 + else None + ), } @@ -441,12 +443,16 @@ def _get_cache_memory_stats( if hasattr(redis_usage_cache.redis_client, "connection_pool"): pool_info = redis_usage_cache.redis_client.connection_pool # type: ignore cache_stats["redis_usage_cache"]["connection_pool"] = { - "max_connections": pool_info.max_connections - if hasattr(pool_info, "max_connections") - else None, - "connection_class": pool_info.connection_class.__name__ - if hasattr(pool_info, "connection_class") - else None, + "max_connections": ( + pool_info.max_connections + if hasattr(pool_info, "max_connections") + else None + ), + "connection_class": ( + pool_info.connection_class.__name__ + if hasattr(pool_info, "connection_class") + else None + ), } except Exception as e: verbose_proxy_logger.debug(f"Error getting Redis pool info: {e}") @@ -561,9 +567,11 @@ def _get_process_memory_info( "description": "Percentage of total system RAM being used", }, "open_file_handles": { - "count": process.num_fds() - if hasattr(process, "num_fds") - else "N/A (Windows)", + "count": ( + process.num_fds() + if hasattr(process, "num_fds") + else "N/A (Windows)" + ), "description": "Number of open file descriptors/handles", }, "threads": { diff --git a/litellm/proxy/common_utils/http_parsing_utils.py b/litellm/proxy/common_utils/http_parsing_utils.py index 1dd25262127..ae121117110 100644 --- a/litellm/proxy/common_utils/http_parsing_utils.py +++ b/litellm/proxy/common_utils/http_parsing_utils.py @@ -197,10 +197,10 @@ def check_file_size_under_limit( if llm_router is not None and request_data["model"] in router_model_names: try: - deployment: Optional[ - Deployment - ] = llm_router.get_deployment_by_model_group_name( - model_group_name=request_data["model"] + deployment: Optional[Deployment] = ( + llm_router.get_deployment_by_model_group_name( + model_group_name=request_data["model"] + ) ) if ( deployment @@ -257,7 +257,7 @@ async def get_form_data(request: Request) -> Dict[str, Any]: async def convert_upload_files_to_file_data( - form_data: Dict[str, Any] + form_data: Dict[str, Any], ) -> Dict[str, Any]: """ Convert FastAPI UploadFile objects to file data tuples for litellm. diff --git a/litellm/proxy/common_utils/openai_endpoint_utils.py b/litellm/proxy/common_utils/openai_endpoint_utils.py index 7e5c83500a2..a6fa4b36675 100644 --- a/litellm/proxy/common_utils/openai_endpoint_utils.py +++ b/litellm/proxy/common_utils/openai_endpoint_utils.py @@ -1,5 +1,5 @@ """ -Contains utils used by OpenAI compatible endpoints +Contains utils used by OpenAI compatible endpoints """ from typing import Optional, Set diff --git a/litellm/proxy/config_management_endpoints/pass_through_endpoints.py b/litellm/proxy/config_management_endpoints/pass_through_endpoints.py index 5ff02b8bce0..4ebd989dc53 100644 --- a/litellm/proxy/config_management_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/config_management_endpoints/pass_through_endpoints.py @@ -1,5 +1,5 @@ """ -What is this? +What is this? CRUD endpoints for managing pass-through endpoints """ diff --git a/litellm/proxy/db/create_views.py b/litellm/proxy/db/create_views.py index e9303077b18..82f95164190 100644 --- a/litellm/proxy/db/create_views.py +++ b/litellm/proxy/db/create_views.py @@ -24,8 +24,7 @@ async def create_missing_views(db: _db): # noqa: PLR0915 print("LiteLLM_VerificationTokenView Exists!") # noqa except Exception: # If an error occurs, the view does not exist, so create it - await db.execute_raw( - """ + await db.execute_raw(""" CREATE VIEW "LiteLLM_VerificationTokenView" AS SELECT v.*, @@ -35,8 +34,7 @@ async def create_missing_views(db: _db): # noqa: PLR0915 t.rpm_limit AS team_rpm_limit FROM "LiteLLM_VerificationToken" v LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id; - """ - ) + """) print("LiteLLM_VerificationTokenView Created!") # noqa diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index a305d5be1e6..fc4e35e41d4 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -1615,14 +1615,14 @@ async def _update_daily_spend( # Add cache-related fields if they exist if "cache_read_input_tokens" in transaction: - common_data[ - "cache_read_input_tokens" - ] = transaction.get("cache_read_input_tokens", 0) + common_data["cache_read_input_tokens"] = ( + transaction.get("cache_read_input_tokens", 0) + ) if "cache_creation_input_tokens" in transaction: - common_data[ - "cache_creation_input_tokens" - ] = transaction.get( - "cache_creation_input_tokens", 0 + common_data["cache_creation_input_tokens"] = ( + transaction.get( + "cache_creation_input_tokens", 0 + ) ) if entity_type == "tag" and "request_id" in transaction: diff --git a/litellm/proxy/db/db_transaction_queue/base_update_queue.py b/litellm/proxy/db/db_transaction_queue/base_update_queue.py index e37200c02e9..7f1a7474690 100644 --- a/litellm/proxy/db/db_transaction_queue/base_update_queue.py +++ b/litellm/proxy/db/db_transaction_queue/base_update_queue.py @@ -1,6 +1,7 @@ """ Base class for in memory buffer for database transactions """ + import asyncio from typing import Optional diff --git a/litellm/proxy/db/db_transaction_queue/daily_spend_update_queue.py b/litellm/proxy/db/db_transaction_queue/daily_spend_update_queue.py index 75e9b9580b6..f47b694d44e 100644 --- a/litellm/proxy/db/db_transaction_queue/daily_spend_update_queue.py +++ b/litellm/proxy/db/db_transaction_queue/daily_spend_update_queue.py @@ -54,9 +54,9 @@ class DailySpendUpdateQueue(BaseUpdateQueue): def __init__(self): super().__init__() - self.update_queue: asyncio.Queue[ - Dict[str, BaseDailySpendTransaction] - ] = asyncio.Queue(maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE) + self.update_queue: asyncio.Queue[Dict[str, BaseDailySpendTransaction]] = ( + asyncio.Queue(maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE) + ) async def add_update(self, update: Dict[str, BaseDailySpendTransaction]): """Enqueue an update.""" @@ -73,9 +73,9 @@ async def aggregate_queue_updates(self): Combine all updates in the queue into a single update. This is used to reduce the size of the in-memory queue. """ - updates: List[ - Dict[str, BaseDailySpendTransaction] - ] = await self.flush_all_updates_from_in_memory_queue() + updates: List[Dict[str, BaseDailySpendTransaction]] = ( + await self.flush_all_updates_from_in_memory_queue() + ) aggregated_updates = self.get_aggregated_daily_spend_update_transactions( updates ) diff --git a/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py b/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py index c51c06df2f3..ad59680e1fc 100644 --- a/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py +++ b/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py @@ -71,9 +71,9 @@ def _should_commit_spend_updates_to_redis() -> bool: """ from litellm.proxy.proxy_server import general_settings - _use_redis_transaction_buffer: Optional[ - Union[bool, str] - ] = general_settings.get("use_redis_transaction_buffer", False) + _use_redis_transaction_buffer: Optional[Union[bool, str]] = ( + general_settings.get("use_redis_transaction_buffer", False) + ) if isinstance(_use_redis_transaction_buffer, str): _use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer) if _use_redis_transaction_buffer is None: diff --git a/litellm/proxy/db/db_transaction_queue/spend_update_queue.py b/litellm/proxy/db/db_transaction_queue/spend_update_queue.py index 727e8dc1d5a..8100a1e8a12 100644 --- a/litellm/proxy/db/db_transaction_queue/spend_update_queue.py +++ b/litellm/proxy/db/db_transaction_queue/spend_update_queue.py @@ -53,9 +53,9 @@ async def add_update(self, update: SpendUpdateQueueItem): async def aggregate_queue_updates(self): """Concatenate all updates in the queue to reduce the size of in-memory queue""" - updates: List[ - SpendUpdateQueueItem - ] = await self.flush_all_updates_from_in_memory_queue() + updates: List[SpendUpdateQueueItem] = ( + await self.flush_all_updates_from_in_memory_queue() + ) aggregated_updates = self._get_aggregated_spend_update_queue_item(updates) for update in aggregated_updates: await self.update_queue.put(update) diff --git a/litellm/proxy/fine_tuning_endpoints/endpoints.py b/litellm/proxy/fine_tuning_endpoints/endpoints.py index ff6300a4fa0..17a7c09321a 100644 --- a/litellm/proxy/fine_tuning_endpoints/endpoints.py +++ b/litellm/proxy/fine_tuning_endpoints/endpoints.py @@ -306,9 +306,9 @@ async def retrieve_fine_tuning_job( **data, ), ) - response._hidden_params[ - "unified_finetuning_job_id" - ] = unified_finetuning_job_id + response._hidden_params["unified_finetuning_job_id"] = ( + unified_finetuning_job_id + ) elif custom_llm_provider: # get configs for custom_llm_provider llm_provider_config = get_fine_tuning_provider_config( @@ -595,9 +595,9 @@ async def cancel_fine_tuning_job( **data, ), ) - response._hidden_params[ - "unified_finetuning_job_id" - ] = unified_finetuning_job_id + response._hidden_params["unified_finetuning_job_id"] = ( + unified_finetuning_job_id + ) else: # get configs for custom_llm_provider llm_provider_config = get_fine_tuning_provider_config( diff --git a/litellm/proxy/guardrails/guardrail_endpoints.py b/litellm/proxy/guardrails/guardrail_endpoints.py index 2b20876ba22..179da579e6f 100644 --- a/litellm/proxy/guardrails/guardrail_endpoints.py +++ b/litellm/proxy/guardrails/guardrail_endpoints.py @@ -566,7 +566,9 @@ class GuardrailSubmissionItem(BaseModel): guardrail_name: str status: str # pending_review | active | rejected team_id: Optional[str] = None - team_guardrail: bool = False # True when submitted via team (team_id set); use to distinguish team vs regular guardrails + team_guardrail: bool = ( + False # True when submitted via team (team_id set); use to distinguish team vs regular guardrails + ) litellm_params: Optional[Dict[str, Any]] = None guardrail_info: Optional[Dict[str, Any]] = None submitted_by_user_id: Optional[str] = None @@ -661,9 +663,9 @@ async def register_guardrail( guardrail_info = dict(request.guardrail_info or {}) guardrail_info["submitted_by_user_id"] = user_api_key_dict.user_id guardrail_info["submitted_by_email"] = user_api_key_dict.user_email - guardrail_info[ - "team_guardrail" - ] = True # Mark as team submission for filtering/display + guardrail_info["team_guardrail"] = ( + True # Mark as team submission for filtering/display + ) guardrail_info_str = safe_dumps(guardrail_info) try: @@ -1806,9 +1808,9 @@ async def get_provider_specific_params(): lakera_v2_fields = _get_fields_from_model(LakeraV2GuardrailConfigModel) tool_permission_fields = _get_fields_from_model(ToolPermissionGuardrailConfigModel) - tool_permission_fields[ - "ui_friendly_name" - ] = ToolPermissionGuardrailConfigModel.ui_friendly_name() + tool_permission_fields["ui_friendly_name"] = ( + ToolPermissionGuardrailConfigModel.ui_friendly_name() + ) # Return the provider-specific parameters provider_params = { @@ -2081,10 +2083,10 @@ async def apply_guardrail( from litellm.proxy.utils import handle_exception_on_proxy try: - active_guardrail: Optional[ - CustomGuardrail - ] = GUARDRAIL_REGISTRY.get_initialized_guardrail_callback( - guardrail_name=request.guardrail_name + active_guardrail: Optional[CustomGuardrail] = ( + GUARDRAIL_REGISTRY.get_initialized_guardrail_callback( + guardrail_name=request.guardrail_name + ) ) if active_guardrail is None: raise HTTPException( diff --git a/litellm/proxy/guardrails/guardrail_hooks/akto/__init__.py b/litellm/proxy/guardrails/guardrail_hooks/akto/__init__.py index c4aaea709ba..1e3dd906b9f 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/akto/__init__.py +++ b/litellm/proxy/guardrails/guardrail_hooks/akto/__init__.py @@ -4,7 +4,6 @@ from .akto import AktoGuardrail - if TYPE_CHECKING: from litellm.types.guardrails import Guardrail, LitellmParams diff --git a/litellm/proxy/guardrails/guardrail_hooks/akto/akto.py b/litellm/proxy/guardrails/guardrail_hooks/akto/akto.py index 5058ee348db..ece311666c5 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/akto/akto.py +++ b/litellm/proxy/guardrails/guardrail_hooks/akto/akto.py @@ -91,9 +91,9 @@ def __init__( "akto_api_key is required. Set AKTO_API_KEY or pass it in litellm_params." ) - self.unreachable_fallback: Literal[ - "fail_closed", "fail_open" - ] = unreachable_fallback + self.unreachable_fallback: Literal["fail_closed", "fail_open"] = ( + unreachable_fallback + ) self.guardrail_timeout = guardrail_timeout or DEFAULT_GUARDRAIL_TIMEOUT self.akto_account_id = akto_account_id or os.environ.get( "AKTO_ACCOUNT_ID", "1000000" diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py index 8ef188bb23c..2772446a2b3 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py +++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -796,9 +796,9 @@ async def async_pre_call_hook( ######################################################### ########## 1. Make the Bedrock API request ########## ######################################################### - bedrock_guardrail_response: Optional[ - Union[BedrockGuardrailResponse, str] - ] = None + bedrock_guardrail_response: Optional[Union[BedrockGuardrailResponse, str]] = ( + None + ) try: bedrock_guardrail_response = await self.make_bedrock_api_request( source="INPUT", messages=filtered_messages, request_data=data @@ -868,9 +868,9 @@ async def async_moderation_hook( ######################################################### ########## 1. Make the Bedrock API request ########## ######################################################### - bedrock_guardrail_response: Optional[ - Union[BedrockGuardrailResponse, str] - ] = None + bedrock_guardrail_response: Optional[Union[BedrockGuardrailResponse, str]] = ( + None + ) try: bedrock_guardrail_response = await self.make_bedrock_api_request( source="INPUT", messages=filtered_messages, request_data=data diff --git a/litellm/proxy/guardrails/guardrail_hooks/block_code_execution/block_code_execution.py b/litellm/proxy/guardrails/guardrail_hooks/block_code_execution/block_code_execution.py index efd781681a8..49c3dc00cd3 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/block_code_execution/block_code_execution.py +++ b/litellm/proxy/guardrails/guardrail_hooks/block_code_execution/block_code_execution.py @@ -347,9 +347,9 @@ def __init__( **kwargs: Any, ) -> None: # Normalize to type expected by CustomGuardrail - _event_hook: Optional[ - Union[GuardrailEventHooks, List[GuardrailEventHooks]] - ] = None + _event_hook: Optional[Union[GuardrailEventHooks, List[GuardrailEventHooks]]] = ( + None + ) if event_hook is not None: if isinstance(event_hook, list): _event_hook = [ diff --git a/litellm/proxy/guardrails/guardrail_hooks/generic_guardrail_api/generic_guardrail_api.py b/litellm/proxy/guardrails/guardrail_hooks/generic_guardrail_api/generic_guardrail_api.py index 18720845085..790ee31f2e0 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/generic_guardrail_api/generic_guardrail_api.py +++ b/litellm/proxy/guardrails/guardrail_hooks/generic_guardrail_api/generic_guardrail_api.py @@ -219,9 +219,9 @@ def __init__( additional_provider_specific_params or {} ) - self.unreachable_fallback: Literal[ - "fail_closed", "fail_open" - ] = unreachable_fallback + self.unreachable_fallback: Literal["fail_closed", "fail_open"] = ( + unreachable_fallback + ) # Set supported event hooks if "supported_event_hooks" not in kwargs: diff --git a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py index 28f0d830f12..ff802223f21 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py +++ b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py @@ -150,9 +150,9 @@ async def _check( # noqa: PLR0915 text = "" _json_data: str = "" if "messages" in data and isinstance(data["messages"], list): - prompt_injection_obj: Optional[ - GuardrailItem - ] = litellm.guardrail_name_config_map.get("prompt_injection") + prompt_injection_obj: Optional[GuardrailItem] = ( + litellm.guardrail_name_config_map.get("prompt_injection") + ) if prompt_injection_obj is not None: enabled_roles = prompt_injection_obj.enabled_roles else: diff --git a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai_v2.py b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai_v2.py index 6b917bc794c..7aa26435ba6 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai_v2.py +++ b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai_v2.py @@ -393,9 +393,9 @@ async def async_post_call_success_hook( for idx, msg in enumerate(assistant_messages): if idx < len(choice_indices): choice_idx = choice_indices[idx] - response_dict["choices"][choice_idx]["message"][ - "content" - ] = msg.get("content", "") + response_dict["choices"][choice_idx]["message"]["content"] = ( + msg.get("content", "") + ) add_guardrail_to_applied_guardrails_header( request_data=data, guardrail_name=self.guardrail_name ) diff --git a/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/competitor_intent/airline.py b/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/competitor_intent/airline.py index 90b45262c4c..9ab5b9c1d5b 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/competitor_intent/airline.py +++ b/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/competitor_intent/airline.py @@ -159,9 +159,9 @@ def __init__(self, config: Dict[str, Any]) -> None: if not merged.get("explicit_competitor_marker"): merged["explicit_competitor_marker"] = AIRLINE_EXPLICIT_COMPETITOR_MARKER if not merged.get("explicit_other_meaning_marker"): - merged[ - "explicit_other_meaning_marker" - ] = AIRLINE_EXPLICIT_OTHER_MEANING_MARKER + merged["explicit_other_meaning_marker"] = ( + AIRLINE_EXPLICIT_OTHER_MEANING_MARKER + ) if not merged.get("domain_words"): merged["domain_words"] = ["airline", "airlines", "carrier"] if not merged.get("competitors"): diff --git a/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/content_filter.py b/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/content_filter.py index e4da1c1ae77..10b17a09ef9 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/content_filter.py +++ b/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/content_filter.py @@ -212,17 +212,17 @@ def __init__( self.image_model = image_model # Store loaded categories self.loaded_categories: Dict[str, CategoryConfig] = {} - self.category_keywords: Dict[ - str, Tuple[str, str, ContentFilterAction] - ] = {} # keyword -> (category, severity, action) + self.category_keywords: Dict[str, Tuple[str, str, ContentFilterAction]] = ( + {} + ) # keyword -> (category, severity, action) # Always-block keywords are checked after exceptions (exceptions take precedence) self.always_block_category_keywords: Dict[ str, Tuple[str, str, ContentFilterAction] ] = {} # Store conditional categories (identifier_words + block_words) - self.conditional_categories: Dict[ - str, Dict[str, Any] - ] = {} # category_name -> {identifier_words, block_words, action, severity} + self.conditional_categories: Dict[str, Dict[str, Any]] = ( + {} + ) # category_name -> {identifier_words, block_words, action, severity} # Competitor intent checker (optional; airline uses major_airlines.json, generic requires competitors) self._competitor_intent_checker: Optional[BaseCompetitorIntentChecker] = None diff --git a/litellm/proxy/guardrails/guardrail_hooks/model_armor/model_armor.py b/litellm/proxy/guardrails/guardrail_hooks/model_armor/model_armor.py index 2d3f048f81b..dcb7f2404fc 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/model_armor/model_armor.py +++ b/litellm/proxy/guardrails/guardrail_hooks/model_armor/model_armor.py @@ -295,9 +295,7 @@ def _get_sanitized_content(self, armor_response: dict) -> Optional[str]: filters = ( list(filter_results.values()) if isinstance(filter_results, dict) - else filter_results - if isinstance(filter_results, list) - else [] + else filter_results if isinstance(filter_results, list) else [] ) # Prefer sanitized text from deidentifyResult if present diff --git a/litellm/proxy/guardrails/guardrail_hooks/panw_prisma_airs/panw_prisma_airs.py b/litellm/proxy/guardrails/guardrail_hooks/panw_prisma_airs/panw_prisma_airs.py index 2545693b937..bbffc70ddbf 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/panw_prisma_airs/panw_prisma_airs.py +++ b/litellm/proxy/guardrails/guardrail_hooks/panw_prisma_airs/panw_prisma_airs.py @@ -295,10 +295,10 @@ async def _call_panw_api( # noqa: PLR0915 panw_metadata = { "app_user": ( - metadata.get("app_user") or metadata.get("user") or "litellm_user" - ) - if metadata - else "litellm_user", + (metadata.get("app_user") or metadata.get("user") or "litellm_user") + if metadata + else "litellm_user" + ), "ai_model": metadata.get("model", "unknown") if metadata else "unknown", "app_name": app_name_value, "source": "litellm_builtin_guardrail", @@ -1088,9 +1088,11 @@ async def async_pre_call_hook( guardrail_provider=self._PROVIDER_NAME, guardrail_json_response=scan_result, request_data=data, - guardrail_status="success" - if scan_result.get("action") == "allow" - else "guardrail_intervened", + guardrail_status=( + "success" + if scan_result.get("action") == "allow" + else "guardrail_intervened" + ), start_time=start_time.timestamp(), end_time=end_time.timestamp(), duration=(end_time - start_time).total_seconds(), @@ -1226,9 +1228,11 @@ async def async_post_call_success_hook( guardrail_provider=self._PROVIDER_NAME, guardrail_json_response=scan_result, request_data=data, - guardrail_status="success" - if scan_result.get("action") == "allow" - else "guardrail_intervened", + guardrail_status=( + "success" + if scan_result.get("action") == "allow" + else "guardrail_intervened" + ), start_time=start_time.timestamp(), end_time=end_time.timestamp(), duration=(end_time - start_time).total_seconds(), @@ -1449,9 +1453,11 @@ async def async_post_call_streaming_iterator_hook( guardrail_provider=self._PROVIDER_NAME, guardrail_json_response=scan_result, request_data=request_data, - guardrail_status="success" - if scan_result.get("action") == "allow" - else "guardrail_intervened", + guardrail_status=( + "success" + if scan_result.get("action") == "allow" + else "guardrail_intervened" + ), start_time=start_time.timestamp(), end_time=end_time.timestamp(), duration=(end_time - start_time).total_seconds(), diff --git a/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/litellm/proxy/guardrails/guardrail_hooks/presidio.py index 0f4ebbd4880..41348597e6f 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/presidio.py +++ b/litellm/proxy/guardrails/guardrail_hooks/presidio.py @@ -729,9 +729,9 @@ async def async_pre_call_hook( if messages is None: return data tasks = [] - task_mappings: List[ - Tuple[int, Optional[int]] - ] = [] # Track (message_index, content_index) for each task + task_mappings: List[Tuple[int, Optional[int]]] = ( + [] + ) # Track (message_index, content_index) for each task for msg_idx, m in enumerate(messages): content = m.get("content", None) @@ -832,9 +832,9 @@ async def async_logging_hook( ): # /chat/completions requests messages: Optional[List] = kwargs.get("messages", None) tasks = [] - task_mappings: List[ - Tuple[int, Optional[int]] - ] = [] # Track (message_index, content_index) for each task + task_mappings: List[Tuple[int, Optional[int]]] = ( + [] + ) # Track (message_index, content_index) for each task if messages is None: return kwargs, result diff --git a/litellm/proxy/guardrails/tool_name_extraction.py b/litellm/proxy/guardrails/tool_name_extraction.py index c554c4fc9ac..fb1c0d72ee7 100644 --- a/litellm/proxy/guardrails/tool_name_extraction.py +++ b/litellm/proxy/guardrails/tool_name_extraction.py @@ -40,12 +40,12 @@ def _extract_mcp_tool_names(data: dict) -> List[str]: def _register_standalone_extractors() -> None: if STANDALONE_EXTRACTORS: return - STANDALONE_EXTRACTORS[ - CallTypes.generate_content.value - ] = _extract_generate_content_tool_names - STANDALONE_EXTRACTORS[ - CallTypes.agenerate_content.value - ] = _extract_generate_content_tool_names + STANDALONE_EXTRACTORS[CallTypes.generate_content.value] = ( + _extract_generate_content_tool_names + ) + STANDALONE_EXTRACTORS[CallTypes.agenerate_content.value] = ( + _extract_generate_content_tool_names + ) STANDALONE_EXTRACTORS[CallTypes.call_mcp_tool.value] = _extract_mcp_tool_names diff --git a/litellm/proxy/hooks/batch_rate_limiter.py b/litellm/proxy/hooks/batch_rate_limiter.py index ba8a4672cae..06b7d857896 100644 --- a/litellm/proxy/hooks/batch_rate_limiter.py +++ b/litellm/proxy/hooks/batch_rate_limiter.py @@ -194,9 +194,7 @@ async def _check_and_increment_batch_counters( required_capacity = ( batch_usage.request_count if rate_limit_type == "requests" - else batch_usage.total_tokens - if rate_limit_type == "tokens" - else 0 + else batch_usage.total_tokens if rate_limit_type == "tokens" else 0 ) if required_capacity > limit_remaining: diff --git a/litellm/proxy/hooks/dynamic_rate_limiter.py b/litellm/proxy/hooks/dynamic_rate_limiter.py index 14fde51210d..f1c1d487cc1 100644 --- a/litellm/proxy/hooks/dynamic_rate_limiter.py +++ b/litellm/proxy/hooks/dynamic_rate_limiter.py @@ -103,9 +103,9 @@ async def check_available_usage( """ try: # Get model info first for conversion - model_group_info: Optional[ - ModelGroupInfo - ] = self.llm_router.get_model_group_info(model_group=model) + model_group_info: Optional[ModelGroupInfo] = ( + self.llm_router.get_model_group_info(model_group=model) + ) weight: float = 1 if ( @@ -277,16 +277,16 @@ async def async_post_call_success_hook( ) = await self.check_available_usage( model=model_info["model_name"], priority=key_priority ) - response._hidden_params[ - "additional_headers" - ] = { # Add additional response headers - easier debugging - "x-litellm-model_group": model_info["model_name"], - "x-ratelimit-remaining-litellm-project-tokens": available_tpm, - "x-ratelimit-remaining-litellm-project-requests": available_rpm, - "x-ratelimit-remaining-model-tokens": model_tpm, - "x-ratelimit-remaining-model-requests": model_rpm, - "x-ratelimit-current-active-projects": active_projects, - } + response._hidden_params["additional_headers"] = ( + { # Add additional response headers - easier debugging + "x-litellm-model_group": model_info["model_name"], + "x-ratelimit-remaining-litellm-project-tokens": available_tpm, + "x-ratelimit-remaining-litellm-project-requests": available_rpm, + "x-ratelimit-remaining-model-tokens": model_tpm, + "x-ratelimit-remaining-model-requests": model_rpm, + "x-ratelimit-current-active-projects": active_projects, + } + ) return response return await super().async_post_call_success_hook( diff --git a/litellm/proxy/hooks/dynamic_rate_limiter_v3.py b/litellm/proxy/hooks/dynamic_rate_limiter_v3.py index 5a1d6bec5d5..72483d29cdc 100644 --- a/litellm/proxy/hooks/dynamic_rate_limiter_v3.py +++ b/litellm/proxy/hooks/dynamic_rate_limiter_v3.py @@ -322,9 +322,9 @@ def _create_priority_based_descriptors( return descriptors # Get model group info - model_group_info: Optional[ - ModelGroupInfo - ] = self.llm_router.get_model_group_info(model_group=model) + model_group_info: Optional[ModelGroupInfo] = ( + self.llm_router.get_model_group_info(model_group=model) + ) if model_group_info is None: return descriptors @@ -597,9 +597,9 @@ async def async_pre_call_hook( ) # Get model configuration - model_group_info: Optional[ - ModelGroupInfo - ] = self.llm_router.get_model_group_info(model_group=model) + model_group_info: Optional[ModelGroupInfo] = ( + self.llm_router.get_model_group_info(model_group=model) + ) if model_group_info is None: verbose_proxy_logger.debug( f"No model group info for {model}, allowing request" diff --git a/litellm/proxy/hooks/key_management_event_hooks.py b/litellm/proxy/hooks/key_management_event_hooks.py index 2d61203ad51..5cdd9ddb4bd 100644 --- a/litellm/proxy/hooks/key_management_event_hooks.py +++ b/litellm/proxy/hooks/key_management_event_hooks.py @@ -364,10 +364,10 @@ async def _delete_virtual_keys_from_secret_manager( if key.key_alias is not None: team_id = getattr(key, "team_id", None) if team_id not in team_settings_cache: - team_settings_cache[ - team_id - ] = await KeyManagementEventHooks._get_secret_manager_optional_params( - team_id + team_settings_cache[team_id] = ( + await KeyManagementEventHooks._get_secret_manager_optional_params( + team_id + ) ) optional_params = team_settings_cache[team_id] await litellm.secret_manager_client.async_delete_secret( diff --git a/litellm/proxy/hooks/litellm_skills/__init__.py b/litellm/proxy/hooks/litellm_skills/__init__.py index 057cf3d8b38..1507b652ab4 100644 --- a/litellm/proxy/hooks/litellm_skills/__init__.py +++ b/litellm/proxy/hooks/litellm_skills/__init__.py @@ -6,7 +6,7 @@ Usage: from litellm.proxy.hooks.litellm_skills import SkillsInjectionHook - + # Register hook in proxy litellm.callbacks.append(SkillsInjectionHook()) """ diff --git a/litellm/proxy/hooks/litellm_skills/main.py b/litellm/proxy/hooks/litellm_skills/main.py index 83e419bc23c..7c6bfbd6b2b 100644 --- a/litellm/proxy/hooks/litellm_skills/main.py +++ b/litellm/proxy/hooks/litellm_skills/main.py @@ -439,9 +439,11 @@ def _extract_tool_calls(self, response: Any) -> List[Dict[str, Any]]: { "id": tc.id, "name": tc.function.name, - "input": json.loads(tc.function.arguments) - if tc.function.arguments - else {}, + "input": ( + json.loads(tc.function.arguments) + if tc.function.arguments + else {} + ), } ) diff --git a/litellm/proxy/hooks/mcp_semantic_filter/__init__.py b/litellm/proxy/hooks/mcp_semantic_filter/__init__.py index 36d357d560f..c9ef11c8b0a 100644 --- a/litellm/proxy/hooks/mcp_semantic_filter/__init__.py +++ b/litellm/proxy/hooks/mcp_semantic_filter/__init__.py @@ -4,6 +4,7 @@ Semantic filtering for MCP tools to reduce context window size and improve tool selection accuracy. """ + from litellm.proxy.hooks.mcp_semantic_filter.hook import SemanticToolFilterHook __all__ = ["SemanticToolFilterHook"] diff --git a/litellm/proxy/hooks/mcp_semantic_filter/hook.py b/litellm/proxy/hooks/mcp_semantic_filter/hook.py index 4075641d63b..6343faaa965 100644 --- a/litellm/proxy/hooks/mcp_semantic_filter/hook.py +++ b/litellm/proxy/hooks/mcp_semantic_filter/hook.py @@ -4,6 +4,7 @@ Pre-call hook that filters MCP tools semantically before LLM inference. Reduces context window size and improves tool selection accuracy. """ + from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from litellm._logging import verbose_proxy_logger diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index fefc6c8af9c..59978405ffa 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -202,9 +202,7 @@ async def async_pre_call_hook( # noqa: PLR0915 if rpm_limit is None: rpm_limit = sys.maxsize - values_to_update_in_cache: List[ - Tuple[Any, Any] - ] = ( + values_to_update_in_cache: List[Tuple[Any, Any]] = ( [] ) # values that need to get updated in cache, will run a batch_set_cache after this function @@ -703,9 +701,9 @@ async def async_log_success_event( # noqa: PLR0915 async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): try: self.print_verbose("Inside Max Parallel Request Failure Hook") - litellm_parent_otel_span: Union[ - Span, None - ] = _get_parent_otel_span_from_kwargs(kwargs=kwargs) + litellm_parent_otel_span: Union[Span, None] = ( + _get_parent_otel_span_from_kwargs(kwargs=kwargs) + ) _metadata = kwargs["litellm_params"].get("metadata", {}) or {} global_max_parallel_requests = _metadata.get( "global_max_parallel_requests", None @@ -832,11 +830,11 @@ async def async_post_call_success_hook( current_minute = datetime.now().strftime("%M") precise_minute = f"{current_date}-{current_hour}-{current_minute}" request_count_api_key = f"{api_key}::{precise_minute}::request_count" - current: Optional[ - CurrentItemRateLimit - ] = await self.internal_usage_cache.async_get_cache( - key=request_count_api_key, - litellm_parent_otel_span=user_api_key_dict.parent_otel_span, + current: Optional[CurrentItemRateLimit] = ( + await self.internal_usage_cache.async_get_cache( + key=request_count_api_key, + litellm_parent_otel_span=user_api_key_dict.parent_otel_span, + ) ) key_remaining_rpm_limit: Optional[int] = None @@ -868,15 +866,15 @@ async def async_post_call_success_hook( _additional_headers = _hidden_params.get("additional_headers", {}) or {} if key_remaining_rpm_limit is not None: - _additional_headers[ - "x-ratelimit-remaining-requests" - ] = key_remaining_rpm_limit + _additional_headers["x-ratelimit-remaining-requests"] = ( + key_remaining_rpm_limit + ) if key_rpm_limit is not None: _additional_headers["x-ratelimit-limit-requests"] = key_rpm_limit if key_remaining_tpm_limit is not None: - _additional_headers[ - "x-ratelimit-remaining-tokens" - ] = key_remaining_tpm_limit + _additional_headers["x-ratelimit-remaining-tokens"] = ( + key_remaining_tpm_limit + ) if key_tpm_limit is not None: _additional_headers["x-ratelimit-limit-tokens"] = key_tpm_limit diff --git a/litellm/proxy/hooks/parallel_request_limiter_v3.py b/litellm/proxy/hooks/parallel_request_limiter_v3.py index 5aaac088dc2..bbc268a0c4f 100644 --- a/litellm/proxy/hooks/parallel_request_limiter_v3.py +++ b/litellm/proxy/hooks/parallel_request_limiter_v3.py @@ -1682,9 +1682,9 @@ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_ti from litellm.types.caching import RedisPipelineIncrementOperation try: - litellm_parent_otel_span: Union[ - Span, None - ] = _get_parent_otel_span_from_kwargs(kwargs) + litellm_parent_otel_span: Union[Span, None] = ( + _get_parent_otel_span_from_kwargs(kwargs) + ) # Get metadata from standard_logging_object - this correctly handles both # 'metadata' and 'litellm_metadata' fields from litellm_params standard_logging_object = kwargs.get("standard_logging_object") or {} diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py index 43cfd930193..1d7631eb27c 100644 --- a/litellm/proxy/hooks/proxy_track_cost_callback.py +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -74,11 +74,11 @@ async def async_post_call_failure_hook( ) _metadata["user_api_key"] = user_api_key_dict.api_key _metadata["status"] = "failure" - _metadata[ - "error_information" - ] = StandardLoggingPayloadSetup.get_error_information( - original_exception=original_exception, - traceback_str=traceback_str, + _metadata["error_information"] = ( + StandardLoggingPayloadSetup.get_error_information( + original_exception=original_exception, + traceback_str=traceback_str, + ) ) _metadata = await _ProxyDBLogger._enrich_failure_metadata_with_key_info( diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index daf2867699e..5e08a484cfe 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -196,12 +196,12 @@ def _get_dynamic_logging_metadata( user_api_key_dict: UserAPIKeyAuth, proxy_config: ProxyConfig ) -> Optional[TeamCallbackMetadata]: callback_settings_obj: Optional[TeamCallbackMetadata] = None - key_dynamic_logging_settings: Optional[ - dict - ] = KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict) - team_dynamic_logging_settings: Optional[ - dict - ] = KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict) + key_dynamic_logging_settings: Optional[dict] = ( + KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict) + ) + team_dynamic_logging_settings: Optional[dict] = ( + KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict) + ) ######################################################################################### # Key-based callbacks ######################################################################################### @@ -753,11 +753,11 @@ def add_key_level_controls( ## KEY-LEVEL SPEND LOGS / TAGS if "tags" in key_metadata and key_metadata["tags"] is not None: - data[_metadata_variable_name][ - "tags" - ] = LiteLLMProxyRequestSetup._merge_tags( - request_tags=data[_metadata_variable_name].get("tags"), - tags_to_add=key_metadata["tags"], + data[_metadata_variable_name]["tags"] = ( + LiteLLMProxyRequestSetup._merge_tags( + request_tags=data[_metadata_variable_name].get("tags"), + tags_to_add=key_metadata["tags"], + ) ) if "disable_global_guardrails" in key_metadata and isinstance( key_metadata["disable_global_guardrails"], bool @@ -1053,9 +1053,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915 data[_metadata_variable_name]["litellm_api_version"] = version if general_settings is not None: - data[_metadata_variable_name][ - "global_max_parallel_requests" - ] = general_settings.get("global_max_parallel_requests", None) + data[_metadata_variable_name]["global_max_parallel_requests"] = ( + general_settings.get("global_max_parallel_requests", None) + ) ### KEY-LEVEL Controls key_metadata = user_api_key_dict.metadata diff --git a/litellm/proxy/management_endpoints/budget_management_endpoints.py b/litellm/proxy/management_endpoints/budget_management_endpoints.py index 20c7f9ec412..fbbf417318e 100644 --- a/litellm/proxy/management_endpoints/budget_management_endpoints.py +++ b/litellm/proxy/management_endpoints/budget_management_endpoints.py @@ -1,9 +1,9 @@ """ BUDGET MANAGEMENT -All /budget management endpoints +All /budget management endpoints -/budget/new +/budget/new /budget/info /budget/update /budget/delete diff --git a/litellm/proxy/management_endpoints/callback_management_endpoints.py b/litellm/proxy/management_endpoints/callback_management_endpoints.py index 9132d3fe1d7..f9781f3634c 100644 --- a/litellm/proxy/management_endpoints/callback_management_endpoints.py +++ b/litellm/proxy/management_endpoints/callback_management_endpoints.py @@ -1,6 +1,7 @@ """ Endpoints for managing callbacks """ + import json import os diff --git a/litellm/proxy/management_endpoints/common_daily_activity.py b/litellm/proxy/management_endpoints/common_daily_activity.py index 011d2f7485d..ac66adc26f3 100644 --- a/litellm/proxy/management_endpoints/common_daily_activity.py +++ b/litellm/proxy/management_endpoints/common_daily_activity.py @@ -105,24 +105,26 @@ def update_breakdown_metrics( # Update API key breakdown for this model if record.api_key not in breakdown.models[record.model].api_key_breakdown: - breakdown.models[record.model].api_key_breakdown[ - record.api_key - ] = KeyMetricWithMetadata( - metrics=SpendMetrics(), - metadata=KeyMetadata( - key_alias=api_key_metadata.get(record.api_key, {}).get( - "key_alias", None + breakdown.models[record.model].api_key_breakdown[record.api_key] = ( + KeyMetricWithMetadata( + metrics=SpendMetrics(), + metadata=KeyMetadata( + key_alias=api_key_metadata.get(record.api_key, {}).get( + "key_alias", None + ), + team_id=api_key_metadata.get(record.api_key, {}).get( + "team_id", None + ), ), - team_id=api_key_metadata.get(record.api_key, {}).get( - "team_id", None - ), - ), + ) + ) + breakdown.models[record.model].api_key_breakdown[record.api_key].metrics = ( + update_metrics( + breakdown.models[record.model] + .api_key_breakdown[record.api_key] + .metrics, + record, ) - breakdown.models[record.model].api_key_breakdown[ - record.api_key - ].metrics = update_metrics( - breakdown.models[record.model].api_key_breakdown[record.api_key].metrics, - record, ) # Update model group breakdown @@ -218,22 +220,24 @@ def update_breakdown_metrics( # Update API key breakdown for this provider if record.api_key not in breakdown.providers[provider].api_key_breakdown: - breakdown.providers[provider].api_key_breakdown[ - record.api_key - ] = KeyMetricWithMetadata( - metrics=SpendMetrics(), - metadata=KeyMetadata( - key_alias=api_key_metadata.get(record.api_key, {}).get( - "key_alias", None + breakdown.providers[provider].api_key_breakdown[record.api_key] = ( + KeyMetricWithMetadata( + metrics=SpendMetrics(), + metadata=KeyMetadata( + key_alias=api_key_metadata.get(record.api_key, {}).get( + "key_alias", None + ), + team_id=api_key_metadata.get(record.api_key, {}).get( + "team_id", None + ), ), - team_id=api_key_metadata.get(record.api_key, {}).get("team_id", None), - ), + ) + ) + breakdown.providers[provider].api_key_breakdown[record.api_key].metrics = ( + update_metrics( + breakdown.providers[provider].api_key_breakdown[record.api_key].metrics, + record, ) - breakdown.providers[provider].api_key_breakdown[ - record.api_key - ].metrics = update_metrics( - breakdown.providers[provider].api_key_breakdown[record.api_key].metrics, - record, ) # Update endpoint breakdown @@ -249,18 +253,18 @@ def update_breakdown_metrics( # Update API key breakdown for this endpoint if record.api_key not in breakdown.endpoints[record.endpoint].api_key_breakdown: - breakdown.endpoints[record.endpoint].api_key_breakdown[ - record.api_key - ] = KeyMetricWithMetadata( - metrics=SpendMetrics(), - metadata=KeyMetadata( - key_alias=api_key_metadata.get(record.api_key, {}).get( - "key_alias", None - ), - team_id=api_key_metadata.get(record.api_key, {}).get( - "team_id", None + breakdown.endpoints[record.endpoint].api_key_breakdown[record.api_key] = ( + KeyMetricWithMetadata( + metrics=SpendMetrics(), + metadata=KeyMetadata( + key_alias=api_key_metadata.get(record.api_key, {}).get( + "key_alias", None + ), + team_id=api_key_metadata.get(record.api_key, {}).get( + "team_id", None + ), ), - ), + ) ) breakdown.endpoints[record.endpoint].api_key_breakdown[ record.api_key @@ -307,24 +311,26 @@ def update_breakdown_metrics( # Update API key breakdown for this entity if record.api_key not in breakdown.entities[entity_value].api_key_breakdown: - breakdown.entities[entity_value].api_key_breakdown[ - record.api_key - ] = KeyMetricWithMetadata( - metrics=SpendMetrics(), - metadata=KeyMetadata( - key_alias=api_key_metadata.get(record.api_key, {}).get( - "key_alias", None - ), - team_id=api_key_metadata.get(record.api_key, {}).get( - "team_id", None + breakdown.entities[entity_value].api_key_breakdown[record.api_key] = ( + KeyMetricWithMetadata( + metrics=SpendMetrics(), + metadata=KeyMetadata( + key_alias=api_key_metadata.get(record.api_key, {}).get( + "key_alias", None + ), + team_id=api_key_metadata.get(record.api_key, {}).get( + "team_id", None + ), ), - ), + ) + ) + breakdown.entities[entity_value].api_key_breakdown[record.api_key].metrics = ( + update_metrics( + breakdown.entities[entity_value] + .api_key_breakdown[record.api_key] + .metrics, + record, ) - breakdown.entities[entity_value].api_key_breakdown[ - record.api_key - ].metrics = update_metrics( - breakdown.entities[entity_value].api_key_breakdown[record.api_key].metrics, - record, ) return breakdown diff --git a/litellm/proxy/management_endpoints/cost_tracking_settings.py b/litellm/proxy/management_endpoints/cost_tracking_settings.py index bf24d8924de..b5ae8f93be6 100644 --- a/litellm/proxy/management_endpoints/cost_tracking_settings.py +++ b/litellm/proxy/management_endpoints/cost_tracking_settings.py @@ -69,9 +69,11 @@ def _resolve_model_for_cost_lookup(model: str) -> Tuple[str, Optional[str]]: custom_llm_provider = litellm_params.get("custom_llm_provider") return ( str(base_model), - str(custom_llm_provider) - if custom_llm_provider is not None - else None, + ( + str(custom_llm_provider) + if custom_llm_provider is not None + else None + ), ) resolved_model = litellm_params.get("model") @@ -83,9 +85,11 @@ def _resolve_model_for_cost_lookup(model: str) -> Tuple[str, Optional[str]]: custom_llm_provider = litellm_params.get("custom_llm_provider") return ( str(resolved_model), - str(custom_llm_provider) - if custom_llm_provider is not None - else None, + ( + str(custom_llm_provider) + if custom_llm_provider is not None + else None + ), ) except Exception as e: verbose_proxy_logger.debug( diff --git a/litellm/proxy/management_endpoints/customer_endpoints.py b/litellm/proxy/management_endpoints/customer_endpoints.py index 084c2f47d0f..1fd8320db20 100644 --- a/litellm/proxy/management_endpoints/customer_endpoints.py +++ b/litellm/proxy/management_endpoints/customer_endpoints.py @@ -1,9 +1,9 @@ """ CUSTOMER MANAGEMENT -All /customer management endpoints +All /customer management endpoints -/customer/new +/customer/new /customer/info /customer/update /customer/delete @@ -626,9 +626,9 @@ async def update_end_user( ) ) - update_end_user_table_data[ - "budget_id" - ] = budget_table_data_record.budget_id + update_end_user_table_data["budget_id"] = ( + budget_table_data_record.budget_id + ) else: ## Update existing budget ## budget_table_data_record = ( diff --git a/litellm/proxy/management_endpoints/fallback_management_endpoints.py b/litellm/proxy/management_endpoints/fallback_management_endpoints.py index f91b95acd6c..ffb12111d82 100644 --- a/litellm/proxy/management_endpoints/fallback_management_endpoints.py +++ b/litellm/proxy/management_endpoints/fallback_management_endpoints.py @@ -7,6 +7,7 @@ GET /fallback/{model} - Get fallbacks for a specific model DELETE /fallback/{model} - Delete fallbacks for a specific model """ + # pyright: reportMissingImports=false import json diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index ca8c345f46c..e694bb0cfed 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -64,9 +64,9 @@ def _update_internal_new_user_params(data_json: dict, data: NewUserRequest) -> d auto_create_key = data_json.pop("auto_create_key", True) if auto_create_key is False: - data_json[ - "table_name" - ] = "user" # only create a user, don't create key if 'auto_create_key' set to False + data_json["table_name"] = ( + "user" # only create a user, don't create key if 'auto_create_key' set to False + ) if litellm.default_internal_user_params and ( data.user_role != LitellmUserRoles.PROXY_ADMIN.value @@ -1035,9 +1035,9 @@ def _update_internal_user_params( "budget_duration" not in non_default_values ): # applies internal user limits, if user role updated if is_internal_user and litellm.internal_user_budget_duration is not None: - non_default_values[ - "budget_duration" - ] = litellm.internal_user_budget_duration + non_default_values["budget_duration"] = ( + litellm.internal_user_budget_duration + ) from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time non_default_values["budget_reset_at"] = get_budget_reset_time( @@ -2297,13 +2297,13 @@ async def ui_view_users( } # Query users with pagination and filters - users: Optional[ - List[BaseModel] - ] = await prisma_client.db.litellm_usertable.find_many( - where=where_conditions, - skip=skip, - take=page_size, - order={"created_at": "desc"}, + users: Optional[List[BaseModel]] = ( + await prisma_client.db.litellm_usertable.find_many( + where=where_conditions, + skip=skip, + take=page_size, + order={"created_at": "desc"}, + ) ) if not users: diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 831922ec3f9..5a242ace017 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -732,9 +732,9 @@ async def _common_key_generation_helper( # noqa: PLR0915 request_type="key", **data_json, table_name="key" ) - response[ - "soft_budget" - ] = data.soft_budget # include the user-input soft budget in the response + response["soft_budget"] = ( + data.soft_budget + ) # include the user-input soft budget in the response response = GenerateKeyResponse(**response) @@ -3175,10 +3175,10 @@ async def delete_verification_tokens( try: if prisma_client: tokens = [_hash_token_if_needed(token=key) for key in tokens] - _keys_being_deleted: List[ - LiteLLM_VerificationToken - ] = await prisma_client.db.litellm_verificationtoken.find_many( - where={"token": {"in": tokens}} + _keys_being_deleted: List[LiteLLM_VerificationToken] = ( + await prisma_client.db.litellm_verificationtoken.find_many( + where={"token": {"in": tokens}} + ) ) if len(_keys_being_deleted) == 0: @@ -3378,9 +3378,9 @@ async def _rotate_master_key( # noqa: PLR0915 from litellm.proxy.proxy_server import proxy_config try: - models: Optional[ - List - ] = await prisma_client.db.litellm_proxymodeltable.find_many() + models: Optional[List] = ( + await prisma_client.db.litellm_proxymodeltable.find_many() + ) except Exception: models = None # 2. process model table @@ -4020,11 +4020,11 @@ async def validate_key_list_check( param="user_id", code=status.HTTP_403_FORBIDDEN, ) - complete_user_info_db_obj: Optional[ - BaseModel - ] = await prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_api_key_dict.user_id}, - include={"organization_memberships": True}, + complete_user_info_db_obj: Optional[BaseModel] = ( + await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_api_key_dict.user_id}, + include={"organization_memberships": True}, + ) ) if complete_user_info_db_obj is None: @@ -4107,10 +4107,10 @@ async def _fetch_user_team_objects( if complete_user_info is None or not complete_user_info.teams: return [] - teams: Optional[ - List[BaseModel] - ] = await prisma_client.db.litellm_teamtable.find_many( - where={"team_id": {"in": complete_user_info.teams}} + teams: Optional[List[BaseModel]] = ( + await prisma_client.db.litellm_teamtable.find_many( + where={"team_id": {"in": complete_user_info.teams}} + ) ) if teams is None: return [] diff --git a/litellm/proxy/management_endpoints/mcp_management_endpoints.py b/litellm/proxy/management_endpoints/mcp_management_endpoints.py index e4bb288cda9..594f5b1bad9 100644 --- a/litellm/proxy/management_endpoints/mcp_management_endpoints.py +++ b/litellm/proxy/management_endpoints/mcp_management_endpoints.py @@ -424,17 +424,17 @@ def _inherit_credentials_from_existing_server( inherited_credentials["scopes"] = existing_server.scopes # AWS SigV4 fields if existing_server.aws_access_key_id: - inherited_credentials[ - "aws_access_key_id" - ] = existing_server.aws_access_key_id + inherited_credentials["aws_access_key_id"] = ( + existing_server.aws_access_key_id + ) if existing_server.aws_secret_access_key: - inherited_credentials[ - "aws_secret_access_key" - ] = existing_server.aws_secret_access_key + inherited_credentials["aws_secret_access_key"] = ( + existing_server.aws_secret_access_key + ) if existing_server.aws_session_token: - inherited_credentials[ - "aws_session_token" - ] = existing_server.aws_session_token + inherited_credentials["aws_session_token"] = ( + existing_server.aws_session_token + ) if existing_server.aws_region_name: inherited_credentials["aws_region_name"] = existing_server.aws_region_name if existing_server.aws_service_name: diff --git a/litellm/proxy/management_endpoints/organization_endpoints.py b/litellm/proxy/management_endpoints/organization_endpoints.py index edea0c79c96..ffce23e1844 100644 --- a/litellm/proxy/management_endpoints/organization_endpoints.py +++ b/litellm/proxy/management_endpoints/organization_endpoints.py @@ -726,20 +726,20 @@ async def info_organization(organization_id: str): if prisma_client is None: raise HTTPException(status_code=500, detail={"error": "No db connected"}) - response: Optional[ - LiteLLM_OrganizationTableWithMembers - ] = await prisma_client.db.litellm_organizationtable.find_unique( - where={"organization_id": organization_id}, - include={ - "litellm_budget_table": True, - "members": { - "include": { - "user": True, - } + response: Optional[LiteLLM_OrganizationTableWithMembers] = ( + await prisma_client.db.litellm_organizationtable.find_unique( + where={"organization_id": organization_id}, + include={ + "litellm_budget_table": True, + "members": { + "include": { + "user": True, + } + }, + "teams": True, + "object_permission": True, }, - "teams": True, - "object_permission": True, - }, + ) ) if response is None: @@ -1035,16 +1035,16 @@ async def organization_member_update( }, data={"budget_id": budget_id}, ) - final_organization_membership: Optional[ - BaseModel - ] = await prisma_client.db.litellm_organizationmembership.find_unique( - where={ - "user_id_organization_id": { - "user_id": data.user_id, - "organization_id": data.organization_id, - } - }, - include={"litellm_budget_table": True}, + final_organization_membership: Optional[BaseModel] = ( + await prisma_client.db.litellm_organizationmembership.find_unique( + where={ + "user_id_organization_id": { + "user_id": data.user_id, + "organization_id": data.organization_id, + } + }, + include={"litellm_budget_table": True}, + ) ) if final_organization_membership is None: diff --git a/litellm/proxy/management_endpoints/project_endpoints.py b/litellm/proxy/management_endpoints/project_endpoints.py index 8f48f9def78..f6ed7767c46 100644 --- a/litellm/proxy/management_endpoints/project_endpoints.py +++ b/litellm/proxy/management_endpoints/project_endpoints.py @@ -601,9 +601,11 @@ async def update_project( # noqa: PLR0915 user_api_key_dict=user_api_key_dict, team_id=existing_project.team_id, prisma_client=prisma_client, - team_object=LiteLLM_TeamTable(**team_obj_for_checks.model_dump()) - if team_obj_for_checks - else None, + team_object=( + LiteLLM_TeamTable(**team_obj_for_checks.model_dump()) + if team_obj_for_checks + else None + ), ) if not has_permission: @@ -662,9 +664,9 @@ async def update_project( # noqa: PLR0915 data=object_permission_data, ) ) - update_data[ - "object_permission_id" - ] = created_permission.object_permission_id + update_data["object_permission_id"] = ( + created_permission.object_permission_id + ) # Handle metadata fields for field in LiteLLM_ManagementEndpoint_MetadataFields: diff --git a/litellm/proxy/management_endpoints/scim/scim_v2.py b/litellm/proxy/management_endpoints/scim/scim_v2.py index 2d657d96c1b..4c472ed7f21 100644 --- a/litellm/proxy/management_endpoints/scim/scim_v2.py +++ b/litellm/proxy/management_endpoints/scim/scim_v2.py @@ -765,13 +765,13 @@ async def get_users( where_conditions["user_email"] = email # Get users from database - users: List[ - LiteLLM_UserTable - ] = await prisma_client.db.litellm_usertable.find_many( - where=where_conditions, - skip=(startIndex - 1), - take=count, - order={"created_at": "desc"}, + users: List[LiteLLM_UserTable] = ( + await prisma_client.db.litellm_usertable.find_many( + where=where_conditions, + skip=(startIndex - 1), + take=count, + order={"created_at": "desc"}, + ) ) # Get total count for pagination diff --git a/litellm/proxy/management_endpoints/sso/custom_microsoft_sso.py b/litellm/proxy/management_endpoints/sso/custom_microsoft_sso.py index 191212d6f0b..04e44c623d1 100644 --- a/litellm/proxy/management_endpoints/sso/custom_microsoft_sso.py +++ b/litellm/proxy/management_endpoints/sso/custom_microsoft_sso.py @@ -7,7 +7,7 @@ Environment Variables: - MICROSOFT_AUTHORIZATION_ENDPOINT: Custom authorization endpoint URL -- MICROSOFT_TOKEN_ENDPOINT: Custom token endpoint URL +- MICROSOFT_TOKEN_ENDPOINT: Custom token endpoint URL - MICROSOFT_USERINFO_ENDPOINT: Custom userinfo endpoint URL If these are not set, the default Microsoft endpoints are used. diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 3643373be65..93a31c6cf58 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -1545,12 +1545,12 @@ async def update_team( # noqa: PLR0915 updated_kv["router_settings"] = safe_dumps(updated_kv["router_settings"]) updated_kv = prisma_client.jsonify_team_object(db_data=updated_kv) - team_row: Optional[ - LiteLLM_TeamTable - ] = await prisma_client.db.litellm_teamtable.update( - where={"team_id": data.team_id}, - data=updated_kv, - include={"litellm_model_table": True}, # type: ignore + team_row: Optional[LiteLLM_TeamTable] = ( + await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, + data=updated_kv, + include={"litellm_model_table": True}, # type: ignore + ) ) if team_row is None or team_row.team_id is None: @@ -2297,13 +2297,13 @@ async def team_member_delete( ) # Fetch keys before deletion to persist them - keys_to_delete: List[ - LiteLLM_VerificationToken - ] = await prisma_client.db.litellm_verificationtoken.find_many( - where={ - "user_id": {"in": list(user_ids_to_delete)}, - "team_id": data.team_id, - } + keys_to_delete: List[LiteLLM_VerificationToken] = ( + await prisma_client.db.litellm_verificationtoken.find_many( + where={ + "user_id": {"in": list(user_ids_to_delete)}, + "team_id": data.team_id, + } + ) ) if keys_to_delete: @@ -2687,10 +2687,10 @@ async def delete_team( team_rows: List[LiteLLM_TeamTable] = [] for team_id in data.team_ids: try: - team_row_base: Optional[ - BaseModel - ] = await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": team_id} + team_row_base: Optional[BaseModel] = ( + await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id} + ) ) if team_row_base is None: raise Exception @@ -2749,10 +2749,10 @@ async def delete_team( _persist_deleted_verification_tokens, ) - keys_to_delete: List[ - LiteLLM_VerificationToken - ] = await prisma_client.db.litellm_verificationtoken.find_many( - where={"team_id": {"in": data.team_ids}} + keys_to_delete: List[LiteLLM_VerificationToken] = ( + await prisma_client.db.litellm_verificationtoken.find_many( + where={"team_id": {"in": data.team_ids}} + ) ) if keys_to_delete: @@ -2972,11 +2972,11 @@ async def team_info( ) try: - team_info: Optional[ - BaseModel - ] = await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": team_id}, - include={"object_permission": True}, + team_info: Optional[BaseModel] = ( + await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id}, + include={"object_permission": True}, + ) ) if team_info is None: raise Exception @@ -3732,9 +3732,7 @@ async def list_team( except Exception as e: team_exception = """Invalid team object for team_id: {}. team_object={}. Error: {} - """.format( - team.team_id, team.model_dump(), str(e) - ) + """.format(team.team_id, team.model_dump(), str(e)) verbose_proxy_logger.exception(team_exception) continue # Sort the responses by team_alias diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index d06ce56f816..29ec57c8fb7 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -701,9 +701,9 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]: import ast try: - generic_user_role_mappings_data: Dict[ - LitellmUserRoles, List[str] - ] = ast.literal_eval(generic_role_mappings) + generic_user_role_mappings_data: Dict[LitellmUserRoles, List[str]] = ( + ast.literal_eval(generic_role_mappings) + ) if isinstance(generic_user_role_mappings_data, dict): from litellm.types.proxy.management_endpoints.ui_sso import RoleMappings @@ -858,9 +858,9 @@ def response_convertor(response, client): verbose_proxy_logger.debug("calling generic_sso.verify_and_process") additional_generic_sso_headers_dict = _parse_generic_sso_headers() - code_verifier: Optional[ - str - ] = None # assigned inside try; initialized for type tracking + code_verifier: Optional[str] = ( + None # assigned inside try; initialized for type tracking + ) try: token_exchange_params = ( @@ -1187,9 +1187,9 @@ def apply_user_info_values_to_sso_user_defined_values( else: # SSO didn't provide a valid role, fall back to DB role or default if user_info is None or user_info.user_role is None: - user_defined_values[ - "user_role" - ] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value + user_defined_values["user_role"] = ( + LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value + ) verbose_proxy_logger.debug( "No SSO or DB role found, using default: INTERNAL_USER_VIEW_ONLY" ) @@ -1622,9 +1622,9 @@ async def insert_sso_user( if user_defined_values.get("max_budget") is None: user_defined_values["max_budget"] = litellm.max_internal_user_budget if user_defined_values.get("budget_duration") is None: - user_defined_values[ - "budget_duration" - ] = litellm.internal_user_budget_duration + user_defined_values["budget_duration"] = ( + litellm.internal_user_budget_duration + ) if user_defined_values["user_role"] is None: user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY @@ -3252,9 +3252,9 @@ async def get_microsoft_callback_response( # if user is trying to get the raw sso response for debugging, return the raw sso response if return_raw_sso_response: - original_msft_result[ - MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY - ] = user_team_ids + original_msft_result[MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY] = ( + user_team_ids + ) original_msft_result["app_roles"] = app_roles return original_msft_result or {} @@ -3373,9 +3373,9 @@ async def get_user_groups_from_graph_api( # Fetch user membership from Microsoft Graph API all_group_ids = [] - next_link: Optional[ - str - ] = MicrosoftSSOHandler.graph_api_user_groups_endpoint + next_link: Optional[str] = ( + MicrosoftSSOHandler.graph_api_user_groups_endpoint + ) auth_headers = {"Authorization": f"Bearer {access_token}"} page_count = 0 diff --git a/litellm/proxy/management_endpoints/user_agent_analytics_endpoints.py b/litellm/proxy/management_endpoints/user_agent_analytics_endpoints.py index 9d3ecdba92f..ebd276fbee5 100644 --- a/litellm/proxy/management_endpoints/user_agent_analytics_endpoints.py +++ b/litellm/proxy/management_endpoints/user_agent_analytics_endpoints.py @@ -3,7 +3,7 @@ This module provides optimized endpoints for tracking user agent activity metrics including: - Daily Active Users (DAU) by tags for configurable number of days -- Weekly Active Users (WAU) by tags for configurable number of weeks +- Weekly Active Users (WAU) by tags for configurable number of weeks - Monthly Active Users (MAU) by tags for configurable number of months - Summary analytics by tags @@ -35,9 +35,9 @@ class TagActiveUsersResponse(BaseModel): tag: str active_users: int date: str # The specific date or period identifier - period_start: Optional[ - str - ] = None # For WAU/MAU, this will be the start of the period + period_start: Optional[str] = ( + None # For WAU/MAU, this will be the start of the period + ) period_end: Optional[str] = None # For WAU/MAU, this will be the end of the period diff --git a/litellm/proxy/management_helpers/object_permission_utils.py b/litellm/proxy/management_helpers/object_permission_utils.py index 8aba8307b9d..164d65c7e64 100644 --- a/litellm/proxy/management_helpers/object_permission_utils.py +++ b/litellm/proxy/management_helpers/object_permission_utils.py @@ -208,10 +208,10 @@ async def _resolve_team_allowed_mcp_servers( ) direct_servers: List[str] = team_object_permission.mcp_servers or [] - access_group_servers: List[ - str - ] = await MCPRequestHandler._get_mcp_servers_from_access_groups( - team_object_permission.mcp_access_groups or [] + access_group_servers: List[str] = ( + await MCPRequestHandler._get_mcp_servers_from_access_groups( + team_object_permission.mcp_access_groups or [] + ) ) raw_tool_perms = team_object_permission.mcp_tool_permissions or {} if isinstance(raw_tool_perms, str): diff --git a/litellm/proxy/middleware/prometheus_auth_middleware.py b/litellm/proxy/middleware/prometheus_auth_middleware.py index 5915e4aa07d..6bdff59da52 100644 --- a/litellm/proxy/middleware/prometheus_auth_middleware.py +++ b/litellm/proxy/middleware/prometheus_auth_middleware.py @@ -1,6 +1,7 @@ """ Prometheus Auth Middleware - Pure ASGI implementation """ + import json from fastapi import Request diff --git a/litellm/proxy/openai_files_endpoints/files_endpoints.py b/litellm/proxy/openai_files_endpoints/files_endpoints.py index 973836b13d8..8e9fcb75197 100644 --- a/litellm/proxy/openai_files_endpoints/files_endpoints.py +++ b/litellm/proxy/openai_files_endpoints/files_endpoints.py @@ -363,10 +363,10 @@ async def create_file( # noqa: PLR0915 expires_after: Optional[FileExpiresAfter] = None form_data_raw = await request.form() form_data_dict: Dict[str, Any] = dict(form_data_raw) - extracted_litellm_metadata: Optional[ - Dict[str, Any] - ] = extract_nested_form_metadata( - form_data=form_data_dict, prefix="litellm_metadata[" + extracted_litellm_metadata: Optional[Dict[str, Any]] = ( + extract_nested_form_metadata( + form_data=form_data_dict, prefix="litellm_metadata[" + ) ) expires_after_anchor = form_data_raw.get("expires_after[anchor]") expires_after_seconds_str = form_data_raw.get("expires_after[seconds]") diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py index fcb1e0b2e49..216eb61a9d1 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py @@ -168,9 +168,9 @@ def _create_anthropic_response_logging_payload( litellm_model_response.model = model logging_obj.model_call_details["model"] = model if not logging_obj.model_call_details.get("custom_llm_provider"): - logging_obj.model_call_details[ - "custom_llm_provider" - ] = litellm.LlmProviders.ANTHROPIC.value + logging_obj.model_call_details["custom_llm_provider"] = ( + litellm.LlmProviders.ANTHROPIC.value + ) return kwargs except Exception as e: verbose_proxy_logger.exception( diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/cursor_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/cursor_passthrough_logging_handler.py index a104f962630..e7696e5a18a 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/cursor_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/cursor_passthrough_logging_handler.py @@ -18,7 +18,6 @@ from litellm.proxy._types import PassThroughEndpointLoggingTypedDict from litellm.types.utils import StandardPassThroughResponseObject - CURSOR_AGENT_ENDPOINTS: Dict[str, str] = { "POST /v0/agents": "cursor:agent:create", "GET /v0/agents": "cursor:agent:list", diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/openai_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/openai_passthrough_logging_handler.py index 38b2734bc26..29bbb37501f 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/openai_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/openai_passthrough_logging_handler.py @@ -367,9 +367,9 @@ def openai_passthrough_handler( # noqa: PLR0915 kwargs["custom_llm_provider"] = custom_llm_provider # Extract user information for tracking - passthrough_logging_payload: Optional[ - PassthroughStandardLoggingPayload - ] = kwargs.get("passthrough_logging_payload") + passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( + kwargs.get("passthrough_logging_payload") + ) if passthrough_logging_payload: user = handler_instance._get_user_from_metadata( passthrough_logging_payload=passthrough_logging_payload, @@ -398,9 +398,7 @@ def openai_passthrough_handler( # noqa: PLR0915 endpoint_type = ( "chat_completions" if is_chat_completions - else "image_generation" - if is_image_generation - else "image_editing" + else "image_generation" if is_image_generation else "image_editing" ) verbose_proxy_logger.debug( f"OpenAI passthrough cost tracking - Endpoint: {endpoint_type}, Model: {model}, Cost: ${response_cost:.6f}" @@ -558,10 +556,10 @@ def _handle_logging_openai_collected_chunks( } # Extract user information for tracking - passthrough_logging_payload: Optional[ - PassthroughStandardLoggingPayload - ] = litellm_logging_obj.model_call_details.get( - "passthrough_logging_payload" + passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( + litellm_logging_obj.model_call_details.get( + "passthrough_logging_payload" + ) ) if passthrough_logging_payload: user = handler_instance._get_user_from_metadata( @@ -584,9 +582,9 @@ def _handle_logging_openai_collected_chunks( # Update logging object with cost information litellm_logging_obj.model_call_details["model"] = model - litellm_logging_obj.model_call_details[ - "custom_llm_provider" - ] = custom_llm_provider + litellm_logging_obj.model_call_details["custom_llm_provider"] = ( + custom_llm_provider + ) litellm_logging_obj.model_call_details["response_cost"] = response_cost verbose_proxy_logger.debug( diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index e2f7646c0aa..bf7648167cc 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -457,10 +457,10 @@ async def make_multipart_http_request( for field_name, field_value in form_data.items(): if isinstance(field_value, (StarletteUploadFile, UploadFile)): - files[ - field_name - ] = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file( - upload_file=field_value + files[field_name] = ( + await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file( + upload_file=field_value + ) ) else: form_data_dict[field_name] = field_value @@ -554,9 +554,9 @@ def _init_kwargs_for_pass_through_endpoint( "passthrough_logging_payload": passthrough_logging_payload, } - logging_obj.model_call_details[ - "passthrough_logging_payload" - ] = passthrough_logging_payload + logging_obj.model_call_details["passthrough_logging_payload"] = ( + passthrough_logging_payload + ) return kwargs @@ -1485,9 +1485,9 @@ async def forward_client_to_upstream() -> None: ) if extracted_model: kwargs["model"] = extracted_model - kwargs[ - "custom_llm_provider" - ] = "vertex_ai-language-models" + kwargs["custom_llm_provider"] = ( + "vertex_ai-language-models" + ) # Update logging object with correct model logging_obj.model = extracted_model logging_obj.model_call_details[ @@ -1553,9 +1553,9 @@ async def forward_upstream_to_client() -> None: # Update logging object with correct model logging_obj.model = extracted_model logging_obj.model_call_details["model"] = extracted_model - logging_obj.model_call_details[ - "custom_llm_provider" - ] = "vertex_ai_language_models" + logging_obj.model_call_details["custom_llm_provider"] = ( + "vertex_ai_language_models" + ) verbose_proxy_logger.debug( f"WebSocket passthrough ({endpoint}): Successfully extracted model '{extracted_model}' and set provider to 'vertex_ai' from server setup response" ) diff --git a/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py b/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py index ae2f8edc74f..a32659e45bd 100644 --- a/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py +++ b/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py @@ -134,9 +134,9 @@ def add_vertex_credentials( vertex_location=location, vertex_credentials=vertex_credentials, ) - self.deployment_key_to_vertex_credentials[ - deployment_key - ] = vertex_pass_through_credentials + self.deployment_key_to_vertex_credentials[deployment_key] = ( + vertex_pass_through_credentials + ) def _get_deployment_key( self, project_id: Optional[str], location: Optional[str] @@ -156,10 +156,10 @@ def get_vector_store_credentials( """ if litellm.vector_store_registry is None: return None - vector_store_to_run: Optional[ - LiteLLM_ManagedVectorStore - ] = litellm.vector_store_registry.get_litellm_managed_vector_store_from_registry( - vector_store_id=vector_store_id + vector_store_to_run: Optional[LiteLLM_ManagedVectorStore] = ( + litellm.vector_store_registry.get_litellm_managed_vector_store_from_registry( + vector_store_id=vector_store_id + ) ) return vector_store_to_run diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index 33819b888d0..c14378ce616 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -283,9 +283,9 @@ def normalize_llm_passthrough_logging_payload( standard_logging_response_object = vertex_ai_live_handler_result["result"] kwargs = vertex_ai_live_handler_result["kwargs"] - return_dict[ - "standard_logging_response_object" - ] = standard_logging_response_object + return_dict["standard_logging_response_object"] = ( + standard_logging_response_object + ) return_dict["kwargs"] = kwargs return return_dict @@ -308,9 +308,9 @@ async def pass_through_async_success_handler( standard_logging_response_object: Optional[ PassThroughEndpointLoggingResultValues ] = None - logging_obj.model_call_details[ - "passthrough_logging_payload" - ] = passthrough_logging_payload + logging_obj.model_call_details["passthrough_logging_payload"] = ( + passthrough_logging_payload + ) if self.is_assemblyai_route(url_route): if ( AssemblyAIPassthroughLoggingHandler._should_log_request( @@ -487,8 +487,8 @@ def _set_cost_per_request( kwargs["response_cost"] = passthrough_logging_payload.get( "cost_per_request" ) - logging_obj.model_call_details[ - "response_cost" - ] = passthrough_logging_payload.get("cost_per_request") + logging_obj.model_call_details["response_cost"] = ( + passthrough_logging_payload.get("cost_per_request") + ) return kwargs diff --git a/litellm/proxy/policy_engine/attachment_registry.py b/litellm/proxy/policy_engine/attachment_registry.py index 530e1fca1f5..6d1096d5ee9 100644 --- a/litellm/proxy/policy_engine/attachment_registry.py +++ b/litellm/proxy/policy_engine/attachment_registry.py @@ -464,13 +464,15 @@ async def sync_attachments_from_db( attachment = PolicyAttachment( policy=attachment_response.policy_name, scope=attachment_response.scope, - teams=attachment_response.teams - if attachment_response.teams - else None, + teams=( + attachment_response.teams if attachment_response.teams else None + ), keys=attachment_response.keys if attachment_response.keys else None, - models=attachment_response.models - if attachment_response.models - else None, + models=( + attachment_response.models + if attachment_response.models + else None + ), tags=attachment_response.tags if attachment_response.tags else None, ) self._attachments.append(attachment) diff --git a/litellm/proxy/policy_engine/init_policies.py b/litellm/proxy/policy_engine/init_policies.py index 3167a0fe8b3..d5529f5bfe1 100644 --- a/litellm/proxy/policy_engine/init_policies.py +++ b/litellm/proxy/policy_engine/init_policies.py @@ -264,9 +264,9 @@ def get_policies_summary() -> Dict[str, Any]: "description": policy.description if policy else None, "guardrails_add": policy.guardrails.get_add() if policy else [], "guardrails_remove": policy.guardrails.get_remove() if policy else [], - "condition": policy.condition.model_dump() - if policy and policy.condition - else None, + "condition": ( + policy.condition.model_dump() if policy and policy.condition else None + ), "resolved_guardrails": resolved_policy.guardrails, "inheritance_chain": resolved_policy.inheritance_chain, } diff --git a/litellm/proxy/prompts/prompt_registry.py b/litellm/proxy/prompts/prompt_registry.py index 58df60a42cb..25368c2a834 100644 --- a/litellm/proxy/prompts/prompt_registry.py +++ b/litellm/proxy/prompts/prompt_registry.py @@ -97,9 +97,9 @@ def __init__(self): Prompt id to Prompt object mapping """ - self.prompt_id_to_custom_prompt: Dict[ - str, Optional[CustomPromptManagement] - ] = {} + self.prompt_id_to_custom_prompt: Dict[str, Optional[CustomPromptManagement]] = ( + {} + ) """ Guardrail id to CustomGuardrail object mapping """ diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index c638e294268..c3a723dd742 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -215,9 +215,7 @@ def __init__(self, app, options=None): _endpoint_str = ( f"curl --location 'http://0.0.0.0:{port}/chat/completions' \\" ) - curl_command = ( - _endpoint_str - + """ + curl_command = _endpoint_str + """ --header 'Content-Type: application/json' \\ --data ' { "model": "gpt-3.5-turbo", @@ -230,7 +228,6 @@ def __init__(self, app, options=None): }' \n """ - ) print() # noqa print( # noqa '\033[1;34mLiteLLM: Test your local proxy with: "litellm --test" This runs an openai.ChatCompletion request to your proxy [In a new terminal tab]\033[0m\n' @@ -306,11 +303,9 @@ def _run_ollama_serve(): with open(os.devnull, "w") as devnull: subprocess.Popen(command, stdout=devnull, stderr=devnull) except Exception as e: - print( # noqa - f""" + print(f""" LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` - """ - ) # noqa + """) # noqa # noqa @staticmethod def _is_port_in_use(port): diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 7d3d2ceb533..94df7d1a39e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -639,9 +639,9 @@ def generate_feedback_box(): server_root_path = get_server_root_path() _license_check = LicenseCheck() premium_user: bool = _license_check.is_premium() -premium_user_data: Optional[ - "EnterpriseLicenseData" -] = _license_check.airgapped_license_data +premium_user_data: Optional["EnterpriseLicenseData"] = ( + _license_check.airgapped_license_data +) global_max_parallel_request_retries_env: Optional[str] = os.getenv( "LITELLM_GLOBAL_MAX_PARALLEL_REQUEST_RETRIES" ) @@ -1524,9 +1524,9 @@ async def root_redirect(): config_agents: Optional[List[AgentConfig]] = None otel_logging = False prisma_client: Optional[PrismaClient] = None -shared_aiohttp_session: Optional[ - "ClientSession" -] = None # Global shared session for connection reuse +shared_aiohttp_session: Optional["ClientSession"] = ( + None # Global shared session for connection reuse +) user_api_key_cache = DualCache( default_in_memory_ttl=UserAPIKeyCacheTTLEnum.in_memory_cache_ttl.value ) @@ -1534,13 +1534,13 @@ async def root_redirect(): dual_cache=user_api_key_cache ) litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter) -redis_usage_cache: Optional[ - RedisCache -] = None # redis cache used for tracking spend, tpm/rpm limits +redis_usage_cache: Optional[RedisCache] = ( + None # redis cache used for tracking spend, tpm/rpm limits +) polling_via_cache_enabled: Union[Literal["all"], List[str], bool] = False -native_background_mode: List[ - str -] = [] # Models that should use native provider background mode instead of polling +native_background_mode: List[str] = ( + [] +) # Models that should use native provider background mode instead of polling polling_cache_ttl: int = 3600 # Default 1 hour TTL for polling cache user_custom_auth = None user_custom_key_generate = None @@ -1900,9 +1900,9 @@ async def _update_team_cache(): _id = "team_id:{}".format(team_id) try: # Fetch the existing cost for the given user - existing_spend_obj: Optional[ - LiteLLM_TeamTable - ] = await user_api_key_cache.async_get_cache(key=_id) + existing_spend_obj: Optional[LiteLLM_TeamTable] = ( + await user_api_key_cache.async_get_cache(key=_id) + ) if existing_spend_obj is None: # do nothing if team not in api key cache return @@ -2023,11 +2023,9 @@ def run_ollama_serve(): with open(os.devnull, "w") as devnull: subprocess.Popen(command, stdout=devnull, stderr=devnull) except Exception as e: - verbose_proxy_logger.debug( - f""" + verbose_proxy_logger.debug(f""" LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` - """ - ) + """) def _get_process_rss_mb() -> Optional[float]: @@ -5001,10 +4999,10 @@ async def _init_guardrails_in_db(self, prisma_client: PrismaClient): ) try: - guardrails_in_db: List[ - Guardrail - ] = await GuardrailRegistry.get_all_guardrails_from_db( - prisma_client=prisma_client + guardrails_in_db: List[Guardrail] = ( + await GuardrailRegistry.get_all_guardrails_from_db( + prisma_client=prisma_client + ) ) verbose_proxy_logger.debug( "guardrails from the DB %s", str(guardrails_in_db) @@ -5386,9 +5384,9 @@ async def initialize( # noqa: PLR0915 user_api_base = api_base dynamic_config[user_model]["api_base"] = api_base if api_version: - os.environ[ - "AZURE_API_VERSION" - ] = api_version # set this for azure - litellm can read this from the env + os.environ["AZURE_API_VERSION"] = ( + api_version # set this for azure - litellm can read this from the env + ) if max_tokens: # model-specific param dynamic_config[user_model]["max_tokens"] = max_tokens if temperature: # model-specific param @@ -5725,9 +5723,9 @@ def _validate_redis_transaction_buffer_config( """ from litellm.secret_managers.main import str_to_bool - _use_redis_transaction_buffer: Optional[ - Union[bool, str] - ] = general_settings.get("use_redis_transaction_buffer", False) + _use_redis_transaction_buffer: Optional[Union[bool, str]] = ( + general_settings.get("use_redis_transaction_buffer", False) + ) if isinstance(_use_redis_transaction_buffer, str): _use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer) @@ -12322,9 +12320,9 @@ async def get_config_list( hasattr(sub_field_info, "description") and sub_field_info.description is not None ): - nested_fields[ - idx - ].field_description = sub_field_info.description + nested_fields[idx].field_description = ( + sub_field_info.description + ) idx += 1 _stored_in_db = None diff --git a/litellm/proxy/rag_endpoints/endpoints.py b/litellm/proxy/rag_endpoints/endpoints.py index 76136c12be5..95ca51612fc 100644 --- a/litellm/proxy/rag_endpoints/endpoints.py +++ b/litellm/proxy/rag_endpoints/endpoints.py @@ -180,9 +180,9 @@ async def _save_vector_store_to_db_from_rag_ingest( vector_store_name=vector_store_name, vector_store_description=vector_store_description, vector_store_metadata=initial_metadata, - litellm_params=provider_specific_params - if provider_specific_params - else None, + litellm_params=( + provider_specific_params if provider_specific_params else None + ), team_id=user_api_key_dict.team_id, user_id=user_api_key_dict.user_id, ) diff --git a/litellm/proxy/response_polling/__init__.py b/litellm/proxy/response_polling/__init__.py index b500354c373..7ece7099e27 100644 --- a/litellm/proxy/response_polling/__init__.py +++ b/litellm/proxy/response_polling/__init__.py @@ -1,6 +1,7 @@ """ Response Polling Module for Background Responses with Cache """ + from litellm.proxy.response_polling.background_streaming import ( background_streaming_task, ) diff --git a/litellm/proxy/response_polling/background_streaming.py b/litellm/proxy/response_polling/background_streaming.py index bcc98175773..03039d4f441 100644 --- a/litellm/proxy/response_polling/background_streaming.py +++ b/litellm/proxy/response_polling/background_streaming.py @@ -7,6 +7,7 @@ Follows OpenAI Response Streaming format: https://platform.openai.com/docs/api-reference/responses-streaming """ + import asyncio import json from typing import Any, Optional, cast @@ -118,9 +119,9 @@ async def background_streaming_task( # noqa: PLR0915 UPDATE_INTERVAL = 0.150 # 150ms batching interval # Track the terminal event from the stream (may not be "completed") - terminal_status: Optional[ - ResponsesAPIStatus - ] = None # Will be set by response.completed/failed/incomplete/cancelled + terminal_status: Optional[ResponsesAPIStatus] = ( + None # Will be set by response.completed/failed/incomplete/cancelled + ) terminal_error = None _event_to_status = { "response.completed": "completed", @@ -211,9 +212,9 @@ async def flush_state_if_needed(force: bool = False) -> None: if isinstance( content_list[content_index], dict ): - content_list[content_index][ - "text" - ] = accumulated_text[key] + content_list[content_index]["text"] = ( + accumulated_text[key] + ) state_dirty = True elif event_type == "response.content_part.done": diff --git a/litellm/proxy/response_polling/polling_handler.py b/litellm/proxy/response_polling/polling_handler.py index 71e97a46c62..739df3ce673 100644 --- a/litellm/proxy/response_polling/polling_handler.py +++ b/litellm/proxy/response_polling/polling_handler.py @@ -1,6 +1,7 @@ """ Response Polling Handler for Background Responses with Cache """ + import json from datetime import datetime, timezone from typing import Any, Dict, List, Optional diff --git a/litellm/proxy/search_endpoints/search_tool_management.py b/litellm/proxy/search_endpoints/search_tool_management.py index c46bbfddcac..725e83bf96d 100644 --- a/litellm/proxy/search_endpoints/search_tool_management.py +++ b/litellm/proxy/search_endpoints/search_tool_management.py @@ -1,6 +1,7 @@ """ CRUD ENDPOINTS FOR SEARCH TOOLS """ + from datetime import datetime from typing import Any, Dict, List, Union @@ -536,9 +537,9 @@ async def test_search_tool_connection(request: TestSearchToolConnectionRequest): "status": "success", "message": f"Successfully connected to {search_provider} search provider", "test_query": test_query, - "results_count": len(response.results) - if response and response.results - else 0, + "results_count": ( + len(response.results) if response and response.results else 0 + ), } except Exception as e: diff --git a/litellm/proxy/search_endpoints/search_tool_registry.py b/litellm/proxy/search_endpoints/search_tool_registry.py index e9eba1e1799..d4adc2573ea 100644 --- a/litellm/proxy/search_endpoints/search_tool_registry.py +++ b/litellm/proxy/search_endpoints/search_tool_registry.py @@ -1,6 +1,7 @@ """ Search Tool Registry for managing search tool configurations. """ + from datetime import datetime, timezone from typing import List, Optional diff --git a/litellm/proxy/spend_tracking/cold_storage_handler.py b/litellm/proxy/spend_tracking/cold_storage_handler.py index adbbc141234..57c41bafccd 100644 --- a/litellm/proxy/spend_tracking/cold_storage_handler.py +++ b/litellm/proxy/spend_tracking/cold_storage_handler.py @@ -3,6 +3,7 @@ It allows fetching a dict of the proxy server request from s3 or GCS bucket. """ + from typing import Optional import litellm @@ -32,19 +33,19 @@ async def get_proxy_server_request_from_cold_storage_with_object_key( """ # select the custom logger to use for cold storage - custom_logger_name: Optional[ - _custom_logger_compatible_callbacks_literal - ] = self._select_custom_logger_for_cold_storage() + custom_logger_name: Optional[_custom_logger_compatible_callbacks_literal] = ( + self._select_custom_logger_for_cold_storage() + ) # if no custom logger name is configured, return None if custom_logger_name is None: return None # get the active/initialized custom logger - custom_logger: Optional[ - CustomLogger - ] = litellm.logging_callback_manager.get_active_custom_logger_for_callback_name( - custom_logger_name + custom_logger: Optional[CustomLogger] = ( + litellm.logging_callback_manager.get_active_custom_logger_for_callback_name( + custom_logger_name + ) ) # if no custom logger is found, return None diff --git a/litellm/proxy/spend_tracking/spend_management_endpoints.py b/litellm/proxy/spend_tracking/spend_management_endpoints.py index b3b4b55af19..aba36098084 100644 --- a/litellm/proxy/spend_tracking/spend_management_endpoints.py +++ b/litellm/proxy/spend_tracking/spend_management_endpoints.py @@ -3076,16 +3076,14 @@ async def provider_budgets() -> ProviderBudgetResponse: async def get_spend_by_tags( prisma_client: PrismaClient, start_date=None, end_date=None ): - response = await prisma_client.db.query_raw( - """ + response = await prisma_client.db.query_raw(""" SELECT jsonb_array_elements_text(request_tags) AS individual_request_tag, COUNT(*) AS log_count, SUM(spend) AS total_spend FROM "LiteLLM_SpendLogs" GROUP BY individual_request_tag; - """ - ) + """) return response diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index 3eacc19a6df..0e349764fa9 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -126,9 +126,9 @@ def _get_spend_logs_metadata( clean_metadata["applied_guardrails"] = applied_guardrails clean_metadata["batch_models"] = batch_models clean_metadata["mcp_tool_call_metadata"] = mcp_tool_call_metadata - clean_metadata[ - "vector_store_request_metadata" - ] = _get_vector_store_request_for_spend_logs_payload(vector_store_request_metadata) + clean_metadata["vector_store_request_metadata"] = ( + _get_vector_store_request_for_spend_logs_payload(vector_store_request_metadata) + ) clean_metadata["guardrail_information"] = guardrail_information clean_metadata["usage_object"] = usage_object clean_metadata["model_map_information"] = model_map_information diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 7954f0b6460..b9c1e2f2bc3 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1898,9 +1898,9 @@ async def _handle_logging_proxy_only_error( normalized_call_type = CallTypes.aembedding.value if normalized_call_type is not None: litellm_logging_obj.call_type = normalized_call_type - litellm_logging_obj.model_call_details[ - "call_type" - ] = normalized_call_type + litellm_logging_obj.model_call_details["call_type"] = ( + normalized_call_type + ) # Pass-through endpoints are logged via the callback loop's # async_post_call_failure_hook — skip pre_call and failure handlers. if litellm_logging_obj.call_type == CallTypes.pass_through.value: @@ -2498,8 +2498,7 @@ async def check_view_exists(self): required_view = "LiteLLM_VerificationTokenView" expected_views_str = ", ".join(f"'{view}'" for view in expected_views) pg_schema = os.getenv("DATABASE_SCHEMA", "public") - ret = await self.db.query_raw( - f""" + ret = await self.db.query_raw(f""" WITH existing_views AS ( SELECT viewname FROM pg_views @@ -2511,8 +2510,7 @@ async def check_view_exists(self): (SELECT COUNT(*) FROM existing_views) AS view_count, ARRAY_AGG(viewname) AS view_names FROM existing_views - """ - ) + """) expected_total_views = len(expected_views) if ret[0]["view_count"] == expected_total_views: verbose_proxy_logger.info("All necessary views exist!") @@ -2521,8 +2519,7 @@ async def check_view_exists(self): ## check if required view exists ## if ret[0]["view_names"] and required_view not in ret[0]["view_names"]: await self.health_check() # make sure we can connect to db - await self.db.execute_raw( - """ + await self.db.execute_raw(""" CREATE VIEW "LiteLLM_VerificationTokenView" AS SELECT v.*, @@ -2532,8 +2529,7 @@ async def check_view_exists(self): t.rpm_limit AS team_rpm_limit FROM "LiteLLM_VerificationToken" v LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id; - """ - ) + """) verbose_proxy_logger.info( "LiteLLM_VerificationTokenView Created in DB!" @@ -2759,7 +2755,7 @@ async def get_data( # noqa: PLR0915 and reset_at is not None ): response = await self.db.litellm_verificationtoken.find_many( - where={ # type:ignore + where={ # type: ignore "OR": [ {"expires": None}, {"expires": {"gt": expires}}, @@ -2819,7 +2815,7 @@ async def get_data( # noqa: PLR0915 ) # type: ignore elif query_type == "find_all" and reset_at is not None: response = await self.db.litellm_usertable.find_many( - where={ # type:ignore + where={ # type: ignore "budget_reset_at": {"lt": reset_at}, } ) @@ -2831,10 +2827,10 @@ async def get_data( # noqa: PLR0915 if expires is not None: response = await self.db.litellm_usertable.find_many( # type: ignore order={"spend": "desc"}, - where={ # type:ignore + where={ # type: ignore "OR": [ - {"expires": None}, # type:ignore - {"expires": {"gt": expires}}, # type:ignore + {"expires": None}, # type: ignore + {"expires": {"gt": expires}}, # type: ignore ], }, ) @@ -2881,7 +2877,7 @@ async def get_data( # noqa: PLR0915 elif table_name == "budget" and reset_at is not None: if query_type == "find_all": response = await self.db.litellm_budgettable.find_many( - where={ # type:ignore + where={ # type: ignore "OR": [ { "AND": [ @@ -2909,7 +2905,7 @@ async def get_data( # noqa: PLR0915 ) elif query_type == "find_all" and reset_at is not None: response = await self.db.litellm_teamtable.find_many( - where={ # type:ignore + where={ # type: ignore "budget_reset_at": {"lt": reset_at}, } ) diff --git a/litellm/proxy/vector_store_endpoints/endpoints.py b/litellm/proxy/vector_store_endpoints/endpoints.py index d4594fb2fd0..b7e3d8de3d3 100644 --- a/litellm/proxy/vector_store_endpoints/endpoints.py +++ b/litellm/proxy/vector_store_endpoints/endpoints.py @@ -68,10 +68,10 @@ def _update_request_data_with_litellm_managed_vector_store_registry( HTTPException: If user doesn't have access to the vector store """ if litellm.vector_store_registry is not None: - vector_store_to_run: Optional[ - LiteLLM_ManagedVectorStore - ] = litellm.vector_store_registry.get_litellm_managed_vector_store_from_registry( - vector_store_id=vector_store_id + vector_store_to_run: Optional[LiteLLM_ManagedVectorStore] = ( + litellm.vector_store_registry.get_litellm_managed_vector_store_from_registry( + vector_store_id=vector_store_id + ) ) if vector_store_to_run is not None: # Check access control if user_api_key_dict is provided diff --git a/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py b/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py index 627618387d5..679b9c25ef8 100644 --- a/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/langfuse_endpoints.py @@ -1,5 +1,5 @@ """ -What is this? +What is this? Logging Pass-Through Endpoints """ @@ -70,10 +70,10 @@ async def langfuse_proxy_route( request=request, api_key="Bearer {}".format(api_key) ) - callback_settings_obj: Optional[ - TeamCallbackMetadata - ] = _get_dynamic_logging_metadata( - user_api_key_dict=user_api_key_dict, proxy_config=proxy_config + callback_settings_obj: Optional[TeamCallbackMetadata] = ( + _get_dynamic_logging_metadata( + user_api_key_dict=user_api_key_dict, proxy_config=proxy_config + ) ) dynamic_langfuse_public_key: Optional[str] = None diff --git a/litellm/responses/litellm_completion_transformation/handler.py b/litellm/responses/litellm_completion_transformation/handler.py index 5faa8b587c9..f730a089624 100644 --- a/litellm/responses/litellm_completion_transformation/handler.py +++ b/litellm/responses/litellm_completion_transformation/handler.py @@ -38,14 +38,16 @@ def response_api_handler( Any, Any, Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator] ], ]: - litellm_completion_request: dict = LiteLLMCompletionResponsesConfig.transform_responses_api_request_to_chat_completion_request( - model=model, - input=input, - responses_api_request=responses_api_request, - custom_llm_provider=custom_llm_provider, - stream=stream, - extra_headers=extra_headers, - **kwargs, + litellm_completion_request: dict = ( + LiteLLMCompletionResponsesConfig.transform_responses_api_request_to_chat_completion_request( + model=model, + input=input, + responses_api_request=responses_api_request, + custom_llm_provider=custom_llm_provider, + stream=stream, + extra_headers=extra_headers, + **kwargs, + ) ) if _is_async: @@ -68,10 +70,12 @@ def response_api_handler( ) if isinstance(litellm_completion_response, ModelResponse): - responses_api_response: ResponsesAPIResponse = LiteLLMCompletionResponsesConfig.transform_chat_completion_response_to_responses_api_response( - chat_completion_response=litellm_completion_response, - request_input=input, - responses_api_request=responses_api_request, + responses_api_response: ResponsesAPIResponse = ( + LiteLLMCompletionResponsesConfig.transform_chat_completion_response_to_responses_api_response( + chat_completion_response=litellm_completion_response, + request_input=input, + responses_api_request=responses_api_request, + ) ) return responses_api_response @@ -116,10 +120,12 @@ async def async_response_api_handler( ) if isinstance(litellm_completion_response, ModelResponse): - responses_api_response: ResponsesAPIResponse = LiteLLMCompletionResponsesConfig.transform_chat_completion_response_to_responses_api_response( - chat_completion_response=litellm_completion_response, - request_input=request_input, - responses_api_request=responses_api_request, + responses_api_response: ResponsesAPIResponse = ( + LiteLLMCompletionResponsesConfig.transform_chat_completion_response_to_responses_api_response( + chat_completion_response=litellm_completion_response, + request_input=request_input, + responses_api_request=responses_api_request, + ) ) return responses_api_response diff --git a/litellm/responses/litellm_completion_transformation/session_handler.py b/litellm/responses/litellm_completion_transformation/session_handler.py index 45ab16b0d4a..71ff2eb7acf 100644 --- a/litellm/responses/litellm_completion_transformation/session_handler.py +++ b/litellm/responses/litellm_completion_transformation/session_handler.py @@ -43,10 +43,10 @@ async def get_chat_completion_message_history_for_previous_response_id( verbose_proxy_logger.debug( "inside get_chat_completion_message_history_for_previous_response_id" ) - all_spend_logs: List[ - SpendLogsPayload - ] = await ResponsesSessionHandler.get_all_spend_logs_for_previous_response_id( - previous_response_id + all_spend_logs: List[SpendLogsPayload] = ( + await ResponsesSessionHandler.get_all_spend_logs_for_previous_response_id( + previous_response_id + ) ) verbose_proxy_logger.debug( "found %s spend logs for this response id", len(all_spend_logs) diff --git a/litellm/responses/litellm_completion_transformation/transformation.py b/litellm/responses/litellm_completion_transformation/transformation.py index b6479a36998..cf18511bfa3 100644 --- a/litellm/responses/litellm_completion_transformation/transformation.py +++ b/litellm/responses/litellm_completion_transformation/transformation.py @@ -2113,9 +2113,9 @@ def _transform_chat_completion_usage_to_responses_usage( hasattr(completion_details, "reasoning_tokens") and completion_details.reasoning_tokens is not None ): - output_details_dict[ - "reasoning_tokens" - ] = completion_details.reasoning_tokens + output_details_dict["reasoning_tokens"] = ( + completion_details.reasoning_tokens + ) else: output_details_dict["reasoning_tokens"] = 0 diff --git a/litellm/responses/main.py b/litellm/responses/main.py index c82574278ba..a91b4056d6e 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -1115,11 +1115,11 @@ def delete_responses( raise ValueError("custom_llm_provider is required but passed as None") # get provider config - responses_api_provider_config: Optional[ - BaseResponsesAPIConfig - ] = ProviderConfigManager.get_provider_responses_api_config( - model=None, - provider=custom_llm_provider, + responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( + ProviderConfigManager.get_provider_responses_api_config( + model=None, + provider=custom_llm_provider, + ) ) if responses_api_provider_config is None: @@ -1296,11 +1296,11 @@ def get_responses( raise ValueError("custom_llm_provider is required but passed as None") # get provider config - responses_api_provider_config: Optional[ - BaseResponsesAPIConfig - ] = ProviderConfigManager.get_provider_responses_api_config( - model=None, - provider=custom_llm_provider, + responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( + ProviderConfigManager.get_provider_responses_api_config( + model=None, + provider=custom_llm_provider, + ) ) if responses_api_provider_config is None: @@ -1454,11 +1454,11 @@ def list_input_items( if custom_llm_provider is None: raise ValueError("custom_llm_provider is required but passed as None") - responses_api_provider_config: Optional[ - BaseResponsesAPIConfig - ] = ProviderConfigManager.get_provider_responses_api_config( - model=None, - provider=custom_llm_provider, + responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( + ProviderConfigManager.get_provider_responses_api_config( + model=None, + provider=custom_llm_provider, + ) ) if responses_api_provider_config is None: @@ -1613,11 +1613,11 @@ def cancel_responses( raise ValueError("custom_llm_provider is required but passed as None") # get provider config - responses_api_provider_config: Optional[ - BaseResponsesAPIConfig - ] = ProviderConfigManager.get_provider_responses_api_config( - model=None, - provider=custom_llm_provider, + responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( + ProviderConfigManager.get_provider_responses_api_config( + model=None, + provider=custom_llm_provider, + ) ) if responses_api_provider_config is None: @@ -1801,11 +1801,11 @@ def compact_responses( raise ValueError("custom_llm_provider is required but passed as None") # get provider config - responses_api_provider_config: Optional[ - BaseResponsesAPIConfig - ] = ProviderConfigManager.get_provider_responses_api_config( - model=model, - provider=custom_llm_provider, + responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( + ProviderConfigManager.get_provider_responses_api_config( + model=model, + provider=custom_llm_provider, + ) ) if responses_api_provider_config is None: diff --git a/litellm/responses/mcp/litellm_proxy_mcp_handler.py b/litellm/responses/mcp/litellm_proxy_mcp_handler.py index 7a3934ffdaa..fff9dbe9b6d 100644 --- a/litellm/responses/mcp/litellm_proxy_mcp_handler.py +++ b/litellm/responses/mcp/litellm_proxy_mcp_handler.py @@ -682,14 +682,14 @@ async def _execute_tool_calls( # noqa: PLR0915 standard_logging_mcp_tool_call["mcp_server_logo_url"] = logo_url cost_info = mcp_info.get("mcp_server_cost_info") if cost_info: - standard_logging_mcp_tool_call[ - "mcp_server_cost_info" - ] = cost_info + standard_logging_mcp_tool_call["mcp_server_cost_info"] = ( + cost_info + ) if litellm_logging_obj: - litellm_logging_obj.model_call_details[ - "mcp_tool_call_metadata" - ] = standard_logging_mcp_tool_call + litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = ( + standard_logging_mcp_tool_call + ) litellm_logging_obj.model = f"MCP: {tool_name}" litellm_logging_obj.call_type = CallTypes.call_mcp_tool.value diff --git a/litellm/responses/mcp/mcp_streaming_iterator.py b/litellm/responses/mcp/mcp_streaming_iterator.py index 7aed48c2f9a..42c46dff47c 100644 --- a/litellm/responses/mcp/mcp_streaming_iterator.py +++ b/litellm/responses/mcp/mcp_streaming_iterator.py @@ -273,9 +273,9 @@ def __init__( self.finished = False # Event queues and generation flags - self.mcp_discovery_events: List[ - ResponsesAPIStreamingResponse - ] = mcp_events # Pre-generated MCP discovery events + self.mcp_discovery_events: List[ResponsesAPIStreamingResponse] = ( + mcp_events # Pre-generated MCP discovery events + ) self.tool_execution_events: List[ResponsesAPIStreamingResponse] = [] self.mcp_discovery_generated = True # Events are already generated self.mcp_events = ( @@ -284,9 +284,9 @@ def __init__( self.tool_server_map = tool_server_map # Iterator references - self.base_iterator: Optional[ - Union[Any, ResponsesAPIResponse] - ] = base_iterator # Will be created when needed + self.base_iterator: Optional[Union[Any, ResponsesAPIResponse]] = ( + base_iterator # Will be created when needed + ) self.follow_up_iterator: Optional[Any] = None # Response collection for tool execution @@ -582,9 +582,9 @@ async def _create_initial_response_iterator(self) -> None: # Use the pre-fetched all_tools from original_request_params (no re-processing needed) params_for_llm = {} for key, value in params.items(): - params_for_llm[ - key - ] = value # Copy all params as-is since tools are already processed + params_for_llm[key] = ( + value # Copy all params as-is since tools are already processed + ) tools_count = ( len(params_for_llm.get("tools", [])) diff --git a/litellm/responses/streaming_iterator.py b/litellm/responses/streaming_iterator.py index 10a74a5b3c6..57eba4c53ca 100644 --- a/litellm/responses/streaming_iterator.py +++ b/litellm/responses/streaming_iterator.py @@ -188,10 +188,10 @@ def _process_chunk(self, chunk) -> Optional[ResponsesAPIStreamingResponse]: ) if usage_obj is not None: try: - cost: Optional[ - float - ] = self.logging_obj._response_cost_calculator( - result=response_obj + cost: Optional[float] = ( + self.logging_obj._response_cost_calculator( + result=response_obj + ) ) if cost is not None: setattr(usage_obj, "cost", cost) @@ -1034,7 +1034,7 @@ def _extract_response_id(completed_event: Dict[str, Any]) -> Optional[str]: @staticmethod def _extract_output_messages( - completed_event: Dict[str, Any] + completed_event: Dict[str, Any], ) -> List[Dict[str, Any]]: """ Convert the output items in a ``response.completed`` event into diff --git a/litellm/responses/utils.py b/litellm/responses/utils.py index 11097864225..54f1816286b 100644 --- a/litellm/responses/utils.py +++ b/litellm/responses/utils.py @@ -346,10 +346,10 @@ def _update_encrypted_content_item_ids_in_response( if encrypted_content and isinstance(encrypted_content, str): # Always wrap encrypted_content with model_id for redundancy - item[ - "encrypted_content" - ] = ResponsesAPIRequestUtils._wrap_encrypted_content_with_model_id( - encrypted_content, model_id + item["encrypted_content"] = ( + ResponsesAPIRequestUtils._wrap_encrypted_content_with_model_id( + encrypted_content, model_id + ) ) # Also encode the ID if present if item_id and isinstance(item_id, str): diff --git a/litellm/router.py b/litellm/router.py index 25e5c9cb5d9..d28e13a0e0d 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -401,9 +401,9 @@ def __init__( # noqa: PLR0915 ) # names of models under litellm_params. ex. azure/chatgpt-v-2 self.deployment_latency_map = {} ### CACHING ### - cache_type: Literal[ - "local", "redis", "redis-semantic", "s3", "disk" - ] = "local" # default to an in-memory cache + cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = ( + "local" # default to an in-memory cache + ) redis_cache = None cache_config: Dict[str, Any] = {} @@ -451,9 +451,9 @@ def __init__( # noqa: PLR0915 self.default_max_parallel_requests = default_max_parallel_requests self.provider_default_deployment_ids: List[str] = [] self.pattern_router = PatternMatchRouter() - self.team_pattern_routers: Dict[ - str, PatternMatchRouter - ] = {} # {"TEAM_ID": PatternMatchRouter} + self.team_pattern_routers: Dict[str, PatternMatchRouter] = ( + {} + ) # {"TEAM_ID": PatternMatchRouter} self.auto_routers: Dict[str, "AutoRouter"] = {} self.complexity_routers: Dict[str, "ComplexityRouter"] = {} @@ -639,12 +639,12 @@ def __init__( # noqa: PLR0915 ) ) - self.model_group_retry_policy: Optional[ - Dict[str, RetryPolicy] - ] = model_group_retry_policy - self.model_group_affinity_config: Optional[ - Dict[str, List[str]] - ] = model_group_affinity_config + self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = ( + model_group_retry_policy + ) + self.model_group_affinity_config: Optional[Dict[str, List[str]]] = ( + model_group_affinity_config + ) self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None if allowed_fails_policy is not None: @@ -2050,7 +2050,10 @@ async def _silent_experiment_acompletion( async def _acompletion( # noqa: PLR0915 self, model: str, messages: List[Dict[str, str]], **kwargs - ) -> Union[ModelResponse, CustomStreamWrapper,]: + ) -> Union[ + ModelResponse, + CustomStreamWrapper, + ]: """ - Get an available deployment - call it with a semaphore over the call @@ -4284,9 +4287,9 @@ async def create_file_for_deployment(deployment: dict) -> OpenAIFileObject: healthy_deployments=healthy_deployments, responses=responses ) returned_response = cast(OpenAIFileObject, responses[0]) - returned_response._hidden_params[ - "model_file_id_mapping" - ] = model_file_id_mapping + returned_response._hidden_params["model_file_id_mapping"] = ( + model_file_id_mapping + ) return returned_response except Exception as e: verbose_router_logger.exception( @@ -5313,11 +5316,11 @@ async def async_function_with_fallbacks_common_utils( # noqa: PLR0915 if isinstance(e, litellm.ContextWindowExceededError): if context_window_fallbacks is not None: - context_window_fallback_model_group: Optional[ - List[str] - ] = self._get_fallback_model_group_from_fallbacks( - fallbacks=context_window_fallbacks, - model_group=model_group, + context_window_fallback_model_group: Optional[List[str]] = ( + self._get_fallback_model_group_from_fallbacks( + fallbacks=context_window_fallbacks, + model_group=model_group, + ) ) if context_window_fallback_model_group is None: raise original_exception @@ -5349,11 +5352,11 @@ async def async_function_with_fallbacks_common_utils( # noqa: PLR0915 e.message += "\n{}".format(error_message) elif isinstance(e, litellm.ContentPolicyViolationError): if content_policy_fallbacks is not None: - content_policy_fallback_model_group: Optional[ - List[str] - ] = self._get_fallback_model_group_from_fallbacks( - fallbacks=content_policy_fallbacks, - model_group=model_group, + content_policy_fallback_model_group: Optional[List[str]] = ( + self._get_fallback_model_group_from_fallbacks( + fallbacks=content_policy_fallbacks, + model_group=model_group, + ) ) if content_policy_fallback_model_group is None: raise original_exception @@ -5575,9 +5578,9 @@ async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915 ) ## ADD RETRY TRACKING TO METADATA - used for spend logs retry tracking _metadata["attempted_retries"] = 0 - _metadata[ - "max_retries" - ] = num_retries # Updated after overrides in exception handler + _metadata["max_retries"] = ( + num_retries # Updated after overrides in exception handler + ) try: self._handle_mock_testing_rate_limit_error( model_group=model_group, kwargs=kwargs @@ -6696,26 +6699,26 @@ def init_auto_router_deployment(self, deployment: Deployment): """ from litellm.router_strategy.auto_router.auto_router import AutoRouter - auto_router_config_path: Optional[ - str - ] = deployment.litellm_params.auto_router_config_path + auto_router_config_path: Optional[str] = ( + deployment.litellm_params.auto_router_config_path + ) auto_router_config: Optional[str] = deployment.litellm_params.auto_router_config if auto_router_config_path is None and auto_router_config is None: raise ValueError( "auto_router_config_path or auto_router_config is required for auto-router deployments. Please set it in the litellm_params" ) - default_model: Optional[ - str - ] = deployment.litellm_params.auto_router_default_model + default_model: Optional[str] = ( + deployment.litellm_params.auto_router_default_model + ) if default_model is None: raise ValueError( "auto_router_default_model is required for auto-router deployments. Please set it in the litellm_params" ) - embedding_model: Optional[ - str - ] = deployment.litellm_params.auto_router_embedding_model + embedding_model: Optional[str] = ( + deployment.litellm_params.auto_router_embedding_model + ) if embedding_model is None: raise ValueError( "auto_router_embedding_model is required for auto-router deployments. Please set it in the litellm_params" @@ -6758,13 +6761,13 @@ def init_complexity_router_deployment(self, deployment: Deployment): ComplexityRouter, ) - complexity_router_config: Optional[ - dict - ] = deployment.litellm_params.complexity_router_config + complexity_router_config: Optional[dict] = ( + deployment.litellm_params.complexity_router_config + ) - default_model: Optional[ - str - ] = deployment.litellm_params.complexity_router_default_model + default_model: Optional[str] = ( + deployment.litellm_params.complexity_router_default_model + ) # If no default model specified, try to get from config tiers if default_model is None and complexity_router_config: @@ -7375,9 +7378,9 @@ def get_deployment_credentials_with_provider( # Add custom_llm_provider if deployment.litellm_params.custom_llm_provider: - credentials[ - "custom_llm_provider" - ] = deployment.litellm_params.custom_llm_provider + credentials["custom_llm_provider"] = ( + deployment.litellm_params.custom_llm_provider + ) elif "/" in deployment.litellm_params.model: # Extract provider from "provider/model" format credentials["custom_llm_provider"] = deployment.litellm_params.model.split( diff --git a/litellm/router_strategy/auto_router/auto_router.py b/litellm/router_strategy/auto_router/auto_router.py index 6a786115193..4ead7225abc 100644 --- a/litellm/router_strategy/auto_router/auto_router.py +++ b/litellm/router_strategy/auto_router/auto_router.py @@ -1,6 +1,7 @@ """ Auto-Routing Strategy that works with a Semantic Router Config """ + from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from litellm._logging import verbose_router_logger diff --git a/litellm/router_strategy/base_routing_strategy.py b/litellm/router_strategy/base_routing_strategy.py index 6e410ef14a3..885798d706c 100644 --- a/litellm/router_strategy/base_routing_strategy.py +++ b/litellm/router_strategy/base_routing_strategy.py @@ -25,9 +25,9 @@ def __init__( if should_batch_redis_writes: self.setup_sync_task(default_sync_interval) - self.in_memory_keys_to_update: set[ - str - ] = set() # Set with max size of 1000 keys + self.in_memory_keys_to_update: set[str] = ( + set() + ) # Set with max size of 1000 keys def setup_sync_task(self, default_sync_interval: Optional[Union[int, float]]): """Setup the sync task in a way that's compatible with FastAPI""" diff --git a/litellm/router_strategy/budget_limiter.py b/litellm/router_strategy/budget_limiter.py index 64dc5fe4741..d3edd1be8ec 100644 --- a/litellm/router_strategy/budget_limiter.py +++ b/litellm/router_strategy/budget_limiter.py @@ -10,11 +10,11 @@ Example: ``` openai: - budget_limit: 0.000000000001 - time_period: 1d + budget_limit: 0.000000000001 + time_period: 1d anthropic: - budget_limit: 100 - time_period: 7d + budget_limit: 100 + time_period: 7d ``` """ @@ -100,9 +100,9 @@ def __init__( self.dual_cache = dual_cache self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = [] asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis()) - self.provider_budget_config: Optional[ - GenericBudgetConfigType - ] = provider_budget_config + self.provider_budget_config: Optional[GenericBudgetConfigType] = ( + provider_budget_config + ) self.deployment_budget_config: Optional[GenericBudgetConfigType] = None self.tag_budget_config: Optional[GenericBudgetConfigType] = None self._init_provider_budgets() diff --git a/litellm/router_strategy/complexity_router/complexity_router.py b/litellm/router_strategy/complexity_router/complexity_router.py index 29bed360fab..e51249b1cb1 100644 --- a/litellm/router_strategy/complexity_router/complexity_router.py +++ b/litellm/router_strategy/complexity_router/complexity_router.py @@ -8,6 +8,7 @@ Inspired by ClawRouter: https://github.com/BlockRunAI/ClawRouter """ + import re from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union diff --git a/litellm/router_strategy/complexity_router/evals/eval_complexity_router.py b/litellm/router_strategy/complexity_router/evals/eval_complexity_router.py index a361d95a0ac..939ecdd2d22 100644 --- a/litellm/router_strategy/complexity_router/evals/eval_complexity_router.py +++ b/litellm/router_strategy/complexity_router/evals/eval_complexity_router.py @@ -4,6 +4,7 @@ Tests the router's ability to correctly classify prompts into complexity tiers. Run with: python -m litellm.router_strategy.complexity_router.evals.eval_complexity_router """ + import os # Add parent to path for imports @@ -273,16 +274,20 @@ def run_eval() -> Tuple[int, int, List[dict]]: { "case": i, "description": case.description, - "prompt": case.prompt[:80] + "..." - if len(case.prompt) > 80 - else case.prompt, + "prompt": ( + case.prompt[:80] + "..." + if len(case.prompt) > 80 + else case.prompt + ), "expected": case.expected_tier.value, "actual": tier.value, "score": round(score, 3), "signals": signals, - "acceptable": [t.value for t in case.acceptable_tiers] - if case.acceptable_tiers - else None, + "acceptable": ( + [t.value for t in case.acceptable_tiers] + if case.acceptable_tiers + else None + ), } ) diff --git a/litellm/router_utils/common_utils.py b/litellm/router_utils/common_utils.py index 7530247ce75..bef42e23848 100644 --- a/litellm/router_utils/common_utils.py +++ b/litellm/router_utils/common_utils.py @@ -32,9 +32,9 @@ def add_model_file_id_mappings( model_file_id_mapping = {} if isinstance(healthy_deployments, list): for deployment, response in zip(healthy_deployments, responses): - model_file_id_mapping[ - deployment.get("model_info", {}).get("id") - ] = response.id + model_file_id_mapping[deployment.get("model_info", {}).get("id")] = ( + response.id + ) elif isinstance(healthy_deployments, dict): for model_id, file_id in healthy_deployments.items(): model_file_id_mapping[model_id] = file_id diff --git a/litellm/router_utils/cooldown_callbacks.py b/litellm/router_utils/cooldown_callbacks.py index 32777a1dd4d..343328dacf3 100644 --- a/litellm/router_utils/cooldown_callbacks.py +++ b/litellm/router_utils/cooldown_callbacks.py @@ -59,9 +59,9 @@ async def router_cooldown_event_callback( pass # get the prometheus logger from in memory loggers - prometheusLogger: Optional[ - PrometheusLogger - ] = _get_prometheus_logger_from_callbacks() + prometheusLogger: Optional[PrometheusLogger] = ( + _get_prometheus_logger_from_callbacks() + ) if prometheusLogger is not None: prometheusLogger.set_deployment_complete_outage( diff --git a/litellm/router_utils/get_retry_from_policy.py b/litellm/router_utils/get_retry_from_policy.py index ec326ebb50d..162d6428f85 100644 --- a/litellm/router_utils/get_retry_from_policy.py +++ b/litellm/router_utils/get_retry_from_policy.py @@ -1,5 +1,5 @@ """ -Get num retries for an exception. +Get num retries for an exception. - Account for retry policy by exception type. """ diff --git a/litellm/router_utils/pattern_match_deployments.py b/litellm/router_utils/pattern_match_deployments.py index 69d6ab9b6e2..48f85a83411 100644 --- a/litellm/router_utils/pattern_match_deployments.py +++ b/litellm/router_utils/pattern_match_deployments.py @@ -34,7 +34,7 @@ def calculate_pattern_specificity(pattern: str) -> Tuple[int, int]: @staticmethod def sorted_patterns( - patterns: Dict[str, List[Dict]] + patterns: Dict[str, List[Dict]], ) -> List[Tuple[str, List[Dict]]]: """ Cached property for patterns sorted by specificity. @@ -105,11 +105,13 @@ def _return_pattern_matched_deployments( new_deployments = [] for deployment in deployments: new_deployment = copy.deepcopy(deployment) - new_deployment["litellm_params"][ - "model" - ] = PatternMatchRouter.set_deployment_model_name( - matched_pattern=matched_pattern, - litellm_deployment_litellm_model=deployment["litellm_params"]["model"], + new_deployment["litellm_params"]["model"] = ( + PatternMatchRouter.set_deployment_model_name( + matched_pattern=matched_pattern, + litellm_deployment_litellm_model=deployment["litellm_params"][ + "model" + ], + ) ) new_deployments.append(new_deployment) diff --git a/litellm/router_utils/router_callbacks/track_deployment_metrics.py b/litellm/router_utils/router_callbacks/track_deployment_metrics.py index 1f226879d03..9039b0df8e6 100644 --- a/litellm/router_utils/router_callbacks/track_deployment_metrics.py +++ b/litellm/router_utils/router_callbacks/track_deployment_metrics.py @@ -1,5 +1,5 @@ """ -Helper functions to get/set num success and num failures per deployment +Helper functions to get/set num success and num failures per deployment set_deployment_failures_for_current_minute diff --git a/litellm/router_utils/search_api_router.py b/litellm/router_utils/search_api_router.py index 491a25e58ef..a26aa7e71ee 100644 --- a/litellm/router_utils/search_api_router.py +++ b/litellm/router_utils/search_api_router.py @@ -121,9 +121,9 @@ async def async_search_with_fallbacks( ) # Set up kwargs for the fallback system - kwargs[ - "model" - ] = search_tool_name # Use model field for compatibility with fallback system + kwargs["model"] = ( + search_tool_name # Use model field for compatibility with fallback system + ) kwargs["original_generic_function"] = original_function # Bind router_instance to the helper method using partial kwargs["original_function"] = partial( diff --git a/litellm/search/__init__.py b/litellm/search/__init__.py index a3ebb3d870b..51f311618e3 100644 --- a/litellm/search/__init__.py +++ b/litellm/search/__init__.py @@ -1,6 +1,7 @@ """ LiteLLM Search API module. """ + from litellm.search.cost_calculator import search_provider_cost_per_query from litellm.search.main import asearch, search diff --git a/litellm/search/cost_calculator.py b/litellm/search/cost_calculator.py index 9821c12ae49..841c003dff3 100644 --- a/litellm/search/cost_calculator.py +++ b/litellm/search/cost_calculator.py @@ -1,6 +1,7 @@ """ Cost calculation for search providers. """ + from typing import Optional, Tuple from litellm.utils import get_model_info diff --git a/litellm/search/main.py b/litellm/search/main.py index 6b2c837fd55..7711dee6e54 100644 --- a/litellm/search/main.py +++ b/litellm/search/main.py @@ -1,6 +1,7 @@ """ Main Search function for LiteLLM. """ + import asyncio import contextvars from functools import partial @@ -242,10 +243,10 @@ def search( raise ValueError("All items in query list must be strings") # Get provider config - search_provider_config: Optional[ - BaseSearchConfig - ] = ProviderConfigManager.get_provider_search_config( - provider=SearchProviders(search_provider), + search_provider_config: Optional[BaseSearchConfig] = ( + ProviderConfigManager.get_provider_search_config( + provider=SearchProviders(search_provider), + ) ) if search_provider_config is None: diff --git a/litellm/secret_managers/aws_secret_manager.py b/litellm/secret_managers/aws_secret_manager.py index fbe951e6492..60d0a713eff 100644 --- a/litellm/secret_managers/aws_secret_manager.py +++ b/litellm/secret_managers/aws_secret_manager.py @@ -4,7 +4,7 @@ Relevant issue: https://github.com/BerriAI/litellm/issues/1883 Requires: -* `os.environ["AWS_REGION_NAME"], +* `os.environ["AWS_REGION_NAME"], * `pip install boto3>=1.28.57` """ diff --git a/litellm/secret_managers/aws_secret_manager_v2.py b/litellm/secret_managers/aws_secret_manager_v2.py index c1b4d019dcf..4461e34396e 100644 --- a/litellm/secret_managers/aws_secret_manager_v2.py +++ b/litellm/secret_managers/aws_secret_manager_v2.py @@ -10,7 +10,7 @@ Relevant issue: https://github.com/BerriAI/litellm/issues/1883 Requires: -* `os.environ["AWS_REGION_NAME"], +* `os.environ["AWS_REGION_NAME"], * `pip install boto3>=1.28.57` """ diff --git a/litellm/secret_managers/secret_manager_handler.py b/litellm/secret_managers/secret_manager_handler.py index eb90dda0e99..2df0cd13db0 100644 --- a/litellm/secret_managers/secret_manager_handler.py +++ b/litellm/secret_managers/secret_manager_handler.py @@ -3,6 +3,7 @@ Handles retrieving secrets from different secret management systems. """ + import base64 import os from typing import Any, Optional @@ -160,9 +161,11 @@ def get_secret_from_manager( # noqa: PLR0915 if isinstance(client, CustomSecretManager): secret = client.sync_read_secret( secret_name=secret_name, - optional_params=key_management_settings.model_dump() - if key_management_settings - else None, + optional_params=( + key_management_settings.model_dump() + if key_management_settings + else None + ), ) if secret is None: raise ValueError( diff --git a/litellm/setup_wizard.py b/litellm/setup_wizard.py index ee5918e1273..cfefc508d86 100644 --- a/litellm/setup_wizard.py +++ b/litellm/setup_wizard.py @@ -433,9 +433,9 @@ def _collect_keys(providers: List[Dict]) -> Dict[str, str]: f" {blue('❯')} Azure deployment name {grey('(e.g. my-gpt4o)')}: " ) if deployment: - env_vars[ - f"_LITELLM_AZURE_DEPLOYMENT_{p['id'].upper()}" - ] = deployment + env_vars[f"_LITELLM_AZURE_DEPLOYMENT_{p['id'].upper()}"] = ( + deployment + ) # Store the key returned by validation — may be a re-entered replacement env_vars[p["env_key"]] = SetupWizard._validate_and_report(p, key) diff --git a/litellm/skills/main.py b/litellm/skills/main.py index c6ef6f28fb6..3ff0f52c641 100644 --- a/litellm/skills/main.py +++ b/litellm/skills/main.py @@ -173,10 +173,10 @@ def create_skill( ) # Get provider config for external providers (Anthropic, etc.) - skills_api_provider_config: Optional[ - BaseSkillsAPIConfig - ] = ProviderConfigManager.get_provider_skills_api_config( - provider=litellm.LlmProviders(custom_llm_provider), + skills_api_provider_config: Optional[BaseSkillsAPIConfig] = ( + ProviderConfigManager.get_provider_skills_api_config( + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if skills_api_provider_config is None: @@ -354,10 +354,10 @@ def list_skills( ) # Get provider config for external providers (Anthropic, etc.) - skills_api_provider_config: Optional[ - BaseSkillsAPIConfig - ] = ProviderConfigManager.get_provider_skills_api_config( - provider=litellm.LlmProviders(custom_llm_provider), + skills_api_provider_config: Optional[BaseSkillsAPIConfig] = ( + ProviderConfigManager.get_provider_skills_api_config( + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if skills_api_provider_config is None: @@ -529,10 +529,10 @@ def get_skill( ) # Get provider config for external providers (Anthropic, etc.) - skills_api_provider_config: Optional[ - BaseSkillsAPIConfig - ] = ProviderConfigManager.get_provider_skills_api_config( - provider=litellm.LlmProviders(custom_llm_provider), + skills_api_provider_config: Optional[BaseSkillsAPIConfig] = ( + ProviderConfigManager.get_provider_skills_api_config( + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if skills_api_provider_config is None: @@ -696,10 +696,10 @@ def delete_skill( ) # Get provider config for external providers (Anthropic, etc.) - skills_api_provider_config: Optional[ - BaseSkillsAPIConfig - ] = ProviderConfigManager.get_provider_skills_api_config( - provider=litellm.LlmProviders(custom_llm_provider), + skills_api_provider_config: Optional[BaseSkillsAPIConfig] = ( + ProviderConfigManager.get_provider_skills_api_config( + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if skills_api_provider_config is None: diff --git a/litellm/types/integrations/datadog_llm_obs.py b/litellm/types/integrations/datadog_llm_obs.py index 1f281e93e8c..4ea5ed66b87 100644 --- a/litellm/types/integrations/datadog_llm_obs.py +++ b/litellm/types/integrations/datadog_llm_obs.py @@ -3,6 +3,7 @@ API Reference: https://docs.datadoghq.com/llm_observability/setup/api/?tab=example#api-standards """ + from typing import Any, Dict, List, Literal, Optional from typing_extensions import TypedDict diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index 54237dfb37a..6830d95d36f 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -203,7 +203,9 @@ class ConverseResponseBlock(TypedDict, total=False): str ] # end_turn | tool_use | max_tokens | stop_sequence | content_filtered usage: Required[ConverseTokenUsageBlock] - serviceTier: ServiceTierBlock # Optional - only present when serviceTier was sent in request + serviceTier: ( + ServiceTierBlock # Optional - only present when serviceTier was sent in request + ) class ToolJsonSchemaBlock(TypedDict, total=False): diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 5a80b40d61f..a265198e6b8 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -970,12 +970,12 @@ def __init__(self, **kwargs): class Hyperparameters(BaseModel): batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch." - learning_rate_multiplier: Optional[ - Union[str, float] - ] = None # Scaling factor for the learning rate - n_epochs: Optional[ - Union[str, int] - ] = None # "The number of epochs to train the model for" + learning_rate_multiplier: Optional[Union[str, float]] = ( + None # Scaling factor for the learning rate + ) + n_epochs: Optional[Union[str, int]] = ( + None # "The number of epochs to train the model for" + ) model_config = {"extra": "allow"} @@ -1004,18 +1004,18 @@ class FineTuningJobCreate(BaseModel): model: str # "The name of the model to fine-tune." training_file: str # "The ID of an uploaded file that contains training data." - hyperparameters: Optional[ - Hyperparameters - ] = None # "The hyperparameters used for the fine-tuning job." - suffix: Optional[ - str - ] = None # "A string of up to 18 characters that will be added to your fine-tuned model name." - validation_file: Optional[ - str - ] = None # "The ID of an uploaded file that contains validation data." - integrations: Optional[ - List[str] - ] = None # "A list of integrations to enable for your fine-tuning job." + hyperparameters: Optional[Hyperparameters] = ( + None # "The hyperparameters used for the fine-tuning job." + ) + suffix: Optional[str] = ( + None # "A string of up to 18 characters that will be added to your fine-tuned model name." + ) + validation_file: Optional[str] = ( + None # "The ID of an uploaded file that contains validation data." + ) + integrations: Optional[List[str]] = ( + None # "A list of integrations to enable for your fine-tuning job." + ) seed: Optional[int] = None # "The seed controls the reproducibility of the job." diff --git a/litellm/types/management_endpoints/cache_settings_endpoints.py b/litellm/types/management_endpoints/cache_settings_endpoints.py index fd68f43e7b0..6d8cb63a15c 100644 --- a/litellm/types/management_endpoints/cache_settings_endpoints.py +++ b/litellm/types/management_endpoints/cache_settings_endpoints.py @@ -13,14 +13,14 @@ class CacheSettingsField(BaseModel): field_value: Any field_description: str field_default: Any = None - options: Optional[ - List[str] - ] = None # For fields with predefined options/enum values + options: Optional[List[str]] = ( + None # For fields with predefined options/enum values + ) ui_field_name: str # User-friendly display name link: Optional[str] = None # Documentation link for the field - redis_type: Optional[ - str - ] = None # Which Redis type this field applies to (node, cluster, sentinel) + redis_type: Optional[str] = ( + None # Which Redis type this field applies to (node, cluster, sentinel) + ) # Redis type descriptions diff --git a/litellm/types/management_endpoints/router_settings_endpoints.py b/litellm/types/management_endpoints/router_settings_endpoints.py index 4f09b7da853..4c2abfdbca8 100644 --- a/litellm/types/management_endpoints/router_settings_endpoints.py +++ b/litellm/types/management_endpoints/router_settings_endpoints.py @@ -75,9 +75,9 @@ class RouterSettingsField(BaseModel): field_value: Any field_description: str field_default: Any = None - options: Optional[ - List[str] - ] = None # For fields with predefined options/enum values + options: Optional[List[str]] = ( + None # For fields with predefined options/enum values + ) ui_field_name: str # User-friendly display name link: Optional[str] = None # Documentation link for the field diff --git a/litellm/types/mcp_server/mcp_server_manager.py b/litellm/types/mcp_server/mcp_server_manager.py index ed391f0af68..c26383b292f 100644 --- a/litellm/types/mcp_server/mcp_server_manager.py +++ b/litellm/types/mcp_server/mcp_server_manager.py @@ -28,19 +28,19 @@ class MCPServer(BaseModel): auth_type: Optional[MCPAuthType] = None authentication_token: Optional[str] = None mcp_info: Optional[MCPInfo] = None - extra_headers: Optional[ - List[str] - ] = None # allow admin to specify which headers to forward from client to the MCP server + extra_headers: Optional[List[str]] = ( + None # allow admin to specify which headers to forward from client to the MCP server + ) allowed_tools: Optional[List[str]] = None disallowed_tools: Optional[List[str]] = None tool_name_to_display_name: Optional[Dict[str, str]] = None tool_name_to_description: Optional[Dict[str, str]] = None - allowed_params: Optional[ - Dict[str, List[str]] - ] = None # map of tool names to allowed parameter lists - static_headers: Optional[ - Dict[str, str] - ] = None # static headers to forward to the MCP server + allowed_params: Optional[Dict[str, List[str]]] = ( + None # map of tool names to allowed parameter lists + ) + static_headers: Optional[Dict[str, str]] = ( + None # static headers to forward to the MCP server + ) # OAuth-specific fields client_id: Optional[str] = None client_secret: Optional[str] = None diff --git a/litellm/types/proxy/guardrails/guardrail_hooks/generic_guardrail_api.py b/litellm/types/proxy/guardrails/guardrail_hooks/generic_guardrail_api.py index c87086bdce2..94f219a5fc6 100644 --- a/litellm/types/proxy/guardrails/guardrail_hooks/generic_guardrail_api.py +++ b/litellm/types/proxy/guardrails/guardrail_hooks/generic_guardrail_api.py @@ -60,9 +60,9 @@ class GenericGuardrailAPIRequest(BaseModel): input_type: Literal["request", "response"] litellm_call_id: Optional[str] = None # the call id of the individual LLM call - litellm_trace_id: Optional[ - str - ] = None # the trace id of the LLM call - useful if there are multiple LLM calls for the same conversation + litellm_trace_id: Optional[str] = ( + None # the trace id of the LLM call - useful if there are multiple LLM calls for the same conversation + ) structured_messages: Optional[List[AllMessageValues]] = None images: Optional[List[str]] = None tools: Optional[List[ChatCompletionToolParam]] = None diff --git a/litellm/types/proxy/guardrails/guardrail_hooks/openai/openai_moderation.py b/litellm/types/proxy/guardrails/guardrail_hooks/openai/openai_moderation.py index ee67626967e..7d81cf9fe03 100644 --- a/litellm/types/proxy/guardrails/guardrail_hooks/openai/openai_moderation.py +++ b/litellm/types/proxy/guardrails/guardrail_hooks/openai/openai_moderation.py @@ -8,11 +8,11 @@ class BaseOpenAIModerationGuardrailConfigModel(GuardrailConfigModel): """Base configuration model for the OpenAI Moderation guardrail""" - model: Optional[ - Literal["omni-moderation-latest", "text-moderation-latest"] - ] = Field( - default="omni-moderation-latest", - description="The OpenAI moderation model to use. 'omni-moderation-latest' supports more categorization options and multi-modal inputs. Defaults to 'omni-moderation-latest'.", + model: Optional[Literal["omni-moderation-latest", "text-moderation-latest"]] = ( + Field( + default="omni-moderation-latest", + description="The OpenAI moderation model to use. 'omni-moderation-latest' supports more categorization options and multi-modal inputs. Defaults to 'omni-moderation-latest'.", + ) ) diff --git a/litellm/types/proxy/guardrails/guardrail_hooks/pillar.py b/litellm/types/proxy/guardrails/guardrail_hooks/pillar.py index 92e76d6693a..17391c070d0 100644 --- a/litellm/types/proxy/guardrails/guardrail_hooks/pillar.py +++ b/litellm/types/proxy/guardrails/guardrail_hooks/pillar.py @@ -1,6 +1,7 @@ """ Pillar Security Guardrail Config Model """ + from typing import Optional from pydantic import BaseModel, Field diff --git a/litellm/types/proxy/management_endpoints/internal_user_endpoints.py b/litellm/types/proxy/management_endpoints/internal_user_endpoints.py index 4770877daba..6023094a920 100644 --- a/litellm/types/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/types/proxy/management_endpoints/internal_user_endpoints.py @@ -25,13 +25,13 @@ class UserListResponse(BaseModel): class BulkUpdateUserRequest(BaseModel): """Request for bulk user updates""" - users: Optional[ - List[UpdateUserRequest] - ] = None # List of specific user update requests + users: Optional[List[UpdateUserRequest]] = ( + None # List of specific user update requests + ) all_users: Optional[bool] = False # Flag to update all users - user_updates: Optional[ - UpdateUserRequestNoUserIDorEmail - ] = None # Updates to apply to all users when all_users=True + user_updates: Optional[UpdateUserRequestNoUserIDorEmail] = ( + None # Updates to apply to all users when all_users=True + ) @field_validator("users", "all_users", "user_updates") @classmethod diff --git a/litellm/types/proxy/management_endpoints/model_management_endpoints.py b/litellm/types/proxy/management_endpoints/model_management_endpoints.py index be4d730e93e..bbbfc0de9f8 100644 --- a/litellm/types/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/types/proxy/management_endpoints/model_management_endpoints.py @@ -21,12 +21,12 @@ class UpdateUsefulLinksRequest(BaseModel): class NewModelGroupRequest(BaseModel): access_group: str # The access group name (e.g., "production-models") - model_names: Optional[ - List[str] - ] = None # Existing model groups to include - tags ALL deployments for each name - model_ids: Optional[ - List[str] - ] = None # Specific deployment IDs to tag (more precise than model_names) + model_names: Optional[List[str]] = ( + None # Existing model groups to include - tags ALL deployments for each name + ) + model_ids: Optional[List[str]] = ( + None # Specific deployment IDs to tag (more precise than model_names) + ) class NewModelGroupResponse(BaseModel): @@ -37,12 +37,12 @@ class NewModelGroupResponse(BaseModel): class UpdateModelGroupRequest(BaseModel): - model_names: Optional[ - List[str] - ] = None # Updated list of model groups to include - tags ALL deployments for each name - model_ids: Optional[ - List[str] - ] = None # Specific deployment IDs to tag (more precise than model_names) + model_names: Optional[List[str]] = ( + None # Updated list of model groups to include - tags ALL deployments for each name + ) + model_ids: Optional[List[str]] = ( + None # Specific deployment IDs to tag (more precise than model_names) + ) class DeleteModelGroupResponse(BaseModel): diff --git a/litellm/types/rerank.py b/litellm/types/rerank.py index fb6dae0d1df..d2c252a1e92 100644 --- a/litellm/types/rerank.py +++ b/litellm/types/rerank.py @@ -59,9 +59,9 @@ class RerankResponseResult(TypedDict, total=False): class RerankResponse(BaseModel): id: Optional[str] = None - results: Optional[ - List[RerankResponseResult] - ] = None # Contains index and relevance_score + results: Optional[List[RerankResponseResult]] = ( + None # Contains index and relevance_score + ) meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units # Define private attributes using PrivateAttr diff --git a/litellm/types/router.py b/litellm/types/router.py index 4257628e7cb..2e867e93a7c 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -95,16 +95,18 @@ class ModelInfo(BaseModel): id: Optional[ str ] # Allow id to be optional on input, but it will always be present as a str in the model instance - db_model: bool = False # used for proxy - to separate models which are stored in the db vs. config. + db_model: bool = ( + False # used for proxy - to separate models which are stored in the db vs. config. + ) updated_at: Optional[datetime.datetime] = None updated_by: Optional[str] = None created_at: Optional[datetime.datetime] = None created_by: Optional[str] = None - base_model: Optional[ - str - ] = None # specify if the base model is azure/gpt-3.5-turbo etc for accurate cost tracking + base_model: Optional[str] = ( + None # specify if the base model is azure/gpt-3.5-turbo etc for accurate cost tracking + ) tier: Optional[Literal["free", "paid"]] = None """ @@ -173,12 +175,12 @@ class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams): custom_llm_provider: Optional[str] = None tpm: Optional[int] = None rpm: Optional[int] = None - timeout: Optional[ - Union[float, str, httpx.Timeout] - ] = None # if str, pass in as os.environ/ - stream_timeout: Optional[ - Union[float, str] - ] = None # timeout when making stream=True calls, if str, pass in as os.environ/ + timeout: Optional[Union[float, str, httpx.Timeout]] = ( + None # if str, pass in as os.environ/ + ) + stream_timeout: Optional[Union[float, str]] = ( + None # timeout when making stream=True calls, if str, pass in as os.environ/ + ) max_retries: Optional[int] = None organization: Optional[str] = None # for openai orgs configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None diff --git a/litellm/types/search.py b/litellm/types/search.py index bbac1237a19..e94477fe1f6 100644 --- a/litellm/types/search.py +++ b/litellm/types/search.py @@ -3,6 +3,7 @@ This module defines types for the unified search API across different providers. """ + from typing import List, Optional from typing_extensions import Required, TypedDict diff --git a/litellm/types/videos/utils.py b/litellm/types/videos/utils.py index bf51fdda370..afa368ecdc3 100644 --- a/litellm/types/videos/utils.py +++ b/litellm/types/videos/utils.py @@ -4,6 +4,7 @@ Follows the pattern used in responses/utils.py for consistency. Format: vid_{base64_encoded_string} """ + import base64 from typing import Optional, Tuple diff --git a/litellm/utils.py b/litellm/utils.py index 088ee07d630..80efe7137d8 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -783,9 +783,9 @@ def function_setup( # noqa: PLR0915 coroutine_checker = get_coroutine_checker_fn() ## DYNAMIC CALLBACKS ## - dynamic_callbacks: Optional[ - List[Union[str, Callable, "CustomLogger"]] - ] = kwargs.pop("callbacks", None) + dynamic_callbacks: Optional[List[Union[str, Callable, "CustomLogger"]]] = ( + kwargs.pop("callbacks", None) + ) all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks) if len(all_callbacks) > 0: @@ -1695,9 +1695,9 @@ def wrapper(*args, **kwargs): # noqa: PLR0915 exception=e, retry_policy=kwargs.get("retry_policy"), ) - kwargs[ - "retry_policy" - ] = reset_retry_policy() # prevent infinite loops + kwargs["retry_policy"] = ( + reset_retry_policy() + ) # prevent infinite loops litellm.num_retries = ( None # set retries to None to prevent infinite loops ) @@ -1744,9 +1744,9 @@ def wrapper(*args, **kwargs): # noqa: PLR0915 exception=e, retry_policy=kwargs.get("retry_policy"), ) - kwargs[ - "retry_policy" - ] = reset_retry_policy() # prevent infinite loops + kwargs["retry_policy"] = ( + reset_retry_policy() + ) # prevent infinite loops litellm.num_retries = ( None # set retries to None to prevent infinite loops ) @@ -3762,10 +3762,10 @@ def pre_process_non_default_params( if "response_format" in non_default_params: if provider_config is not None: - non_default_params[ - "response_format" - ] = provider_config.get_json_schema_from_pydantic_object( - response_format=non_default_params["response_format"] + non_default_params["response_format"] = ( + provider_config.get_json_schema_from_pydantic_object( + response_format=non_default_params["response_format"] + ) ) else: non_default_params["response_format"] = type_to_response_format_param( @@ -3894,16 +3894,16 @@ def pre_process_optional_params( True # so that main.py adds the function call to the prompt ) if "tools" in non_default_params: - optional_params[ - "functions_unsupported_model" - ] = non_default_params.pop("tools") + optional_params["functions_unsupported_model"] = ( + non_default_params.pop("tools") + ) non_default_params.pop( "tool_choice", None ) # causes ollama requests to hang elif "functions" in non_default_params: - optional_params[ - "functions_unsupported_model" - ] = non_default_params.pop("functions") + optional_params["functions_unsupported_model"] = ( + non_default_params.pop("functions") + ) elif ( litellm.add_function_to_prompt ): # if user opts to add it to prompt instead @@ -7511,9 +7511,9 @@ def __init__(self, model_response: ModelResponse, convert_to_delta: bool = False if convert_to_delta is True: _stream_response = ModelResponseStream() _stream_response.choices[0].delta.content = model_response.choices[0].message.content # type: ignore - self.model_response: Union[ - ModelResponse, ModelResponseStream - ] = _stream_response + self.model_response: Union[ModelResponse, ModelResponseStream] = ( + _stream_response + ) else: self.model_response = model_response self.is_done = False diff --git a/litellm/vector_store_files/utils.py b/litellm/vector_store_files/utils.py index ffe73516bda..1ee5b47e306 100644 --- a/litellm/vector_store_files/utils.py +++ b/litellm/vector_store_files/utils.py @@ -21,7 +21,7 @@ def _filter_params(params: Dict[str, Any], model: Any) -> Dict[str, Any]: @staticmethod def get_create_request_params( - params: Dict[str, Any] + params: Dict[str, Any], ) -> VectorStoreFileCreateRequest: filtered = VectorStoreFileRequestUtils._filter_params( params=params, model=VectorStoreFileCreateRequest @@ -37,7 +37,7 @@ def get_list_query_params(params: Dict[str, Any]) -> VectorStoreFileListQueryPar @staticmethod def get_update_request_params( - params: Dict[str, Any] + params: Dict[str, Any], ) -> VectorStoreFileUpdateRequest: filtered = VectorStoreFileRequestUtils._filter_params( params=params, model=VectorStoreFileUpdateRequest diff --git a/litellm/vector_stores/vector_store_registry.py b/litellm/vector_stores/vector_store_registry.py index 2596f968a06..1fd95b16309 100644 --- a/litellm/vector_stores/vector_store_registry.py +++ b/litellm/vector_stores/vector_store_registry.py @@ -24,9 +24,9 @@ class VectorStoreIndexRegistry: def __init__( self, vector_store_indexes: List[LiteLLM_ManagedVectorStoreIndex] = [] ): - self.vector_store_indexes: List[ - LiteLLM_ManagedVectorStoreIndex - ] = vector_store_indexes + self.vector_store_indexes: List[LiteLLM_ManagedVectorStoreIndex] = ( + vector_store_indexes + ) def get_vector_store_indexes(self) -> List[LiteLLM_ManagedVectorStoreIndex]: """ diff --git a/litellm/videos/main.py b/litellm/videos/main.py index cd61293cd1c..a61fe99d584 100644 --- a/litellm/videos/main.py +++ b/litellm/videos/main.py @@ -174,7 +174,10 @@ def video_generation( # noqa: PLR0915 extra_query: Optional[Dict[str, Any]] = None, extra_body: Optional[Dict[str, Any]] = None, **kwargs, -) -> Union[VideoObject, Coroutine[Any, Any, VideoObject],]: +) -> Union[ + VideoObject, + Coroutine[Any, Any, VideoObject], +]: """ Maps the https://api.openai.com/v1/videos endpoint. @@ -203,11 +206,11 @@ def video_generation( # noqa: PLR0915 ) # get provider config - video_generation_provider_config: Optional[ - BaseVideoConfig - ] = ProviderConfigManager.get_provider_video_config( - model=model, - provider=litellm.LlmProviders(custom_llm_provider), + video_generation_provider_config: Optional[BaseVideoConfig] = ( + ProviderConfigManager.get_provider_video_config( + model=model, + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if video_generation_provider_config is None: @@ -286,7 +289,10 @@ def video_content( extra_query: Optional[Dict[str, Any]] = None, extra_body: Optional[Dict[str, Any]] = None, **kwargs, -) -> Union[bytes, Coroutine[Any, Any, bytes],]: +) -> Union[ + bytes, + Coroutine[Any, Any, bytes], +]: """ Download video content from OpenAI's video API. @@ -331,11 +337,11 @@ def video_content( litellm_params = GenericLiteLLMParams(**kwargs) # get provider config - video_provider_config: Optional[ - BaseVideoConfig - ] = ProviderConfigManager.get_provider_video_config( - model=None, - provider=litellm.LlmProviders(custom_llm_provider), + video_provider_config: Optional[BaseVideoConfig] = ( + ProviderConfigManager.get_provider_video_config( + model=None, + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if video_provider_config is None: @@ -574,7 +580,10 @@ def video_remix( # noqa: PLR0915 extra_query: Optional[Dict[str, Any]] = None, extra_body: Optional[Dict[str, Any]] = None, **kwargs, -) -> Union[VideoObject, Coroutine[Any, Any, VideoObject],]: +) -> Union[ + VideoObject, + Coroutine[Any, Any, VideoObject], +]: """ Maps the https://api.openai.com/v1/videos/{video_id}/remix endpoint. @@ -604,11 +613,11 @@ def video_remix( # noqa: PLR0915 litellm_params = GenericLiteLLMParams(**kwargs) # get provider config - video_remix_provider_config: Optional[ - BaseVideoConfig - ] = ProviderConfigManager.get_provider_video_config( - model=None, - provider=litellm.LlmProviders(custom_llm_provider), + video_remix_provider_config: Optional[BaseVideoConfig] = ( + ProviderConfigManager.get_provider_video_config( + model=None, + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if video_remix_provider_config is None: @@ -793,7 +802,10 @@ def video_list( # noqa: PLR0915 extra_query: Optional[Dict[str, Any]] = None, extra_body: Optional[Dict[str, Any]] = None, **kwargs, -) -> Union[List[VideoObject], Coroutine[Any, Any, List[VideoObject]],]: +) -> Union[ + List[VideoObject], + Coroutine[Any, Any, List[VideoObject]], +]: """ Maps the https://api.openai.com/v1/videos endpoint. @@ -820,11 +832,11 @@ def video_list( # noqa: PLR0915 litellm_params = GenericLiteLLMParams(**kwargs) # get provider config - video_list_provider_config: Optional[ - BaseVideoConfig - ] = ProviderConfigManager.get_provider_video_config( - model=None, - provider=litellm.LlmProviders(custom_llm_provider), + video_list_provider_config: Optional[BaseVideoConfig] = ( + ProviderConfigManager.get_provider_video_config( + model=None, + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if video_list_provider_config is None: @@ -991,7 +1003,10 @@ def video_status( # noqa: PLR0915 extra_query: Optional[Dict[str, Any]] = None, extra_body: Optional[Dict[str, Any]] = None, **kwargs, -) -> Union[VideoObject, Coroutine[Any, Any, VideoObject],]: +) -> Union[ + VideoObject, + Coroutine[Any, Any, VideoObject], +]: """ Retrieve video status from OpenAI's video API. @@ -1043,11 +1058,11 @@ def video_status( # noqa: PLR0915 litellm_params = GenericLiteLLMParams(**kwargs) # get provider config - video_status_provider_config: Optional[ - BaseVideoConfig - ] = ProviderConfigManager.get_provider_video_config( - model=None, - provider=litellm.LlmProviders(custom_llm_provider), + video_status_provider_config: Optional[BaseVideoConfig] = ( + ProviderConfigManager.get_provider_video_config( + model=None, + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if video_status_provider_config is None: @@ -1186,11 +1201,11 @@ def video_create_character( litellm_params = GenericLiteLLMParams(**kwargs) - provider_config: Optional[ - BaseVideoConfig - ] = ProviderConfigManager.get_provider_video_config( - model=None, - provider=litellm.LlmProviders(custom_llm_provider), + provider_config: Optional[BaseVideoConfig] = ( + ProviderConfigManager.get_provider_video_config( + model=None, + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if provider_config is None: @@ -1315,11 +1330,11 @@ def video_get_character( litellm_params = GenericLiteLLMParams(**kwargs) - provider_config: Optional[ - BaseVideoConfig - ] = ProviderConfigManager.get_provider_video_config( - model=None, - provider=litellm.LlmProviders(custom_llm_provider), + provider_config: Optional[BaseVideoConfig] = ( + ProviderConfigManager.get_provider_video_config( + model=None, + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if provider_config is None: @@ -1447,11 +1462,11 @@ def video_edit( litellm_params = GenericLiteLLMParams(**kwargs) - provider_config: Optional[ - BaseVideoConfig - ] = ProviderConfigManager.get_provider_video_config( - model=None, - provider=litellm.LlmProviders(custom_llm_provider), + provider_config: Optional[BaseVideoConfig] = ( + ProviderConfigManager.get_provider_video_config( + model=None, + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if provider_config is None: @@ -1582,11 +1597,11 @@ def video_extension( litellm_params = GenericLiteLLMParams(**kwargs) - provider_config: Optional[ - BaseVideoConfig - ] = ProviderConfigManager.get_provider_video_config( - model=None, - provider=litellm.LlmProviders(custom_llm_provider), + provider_config: Optional[BaseVideoConfig] = ( + ProviderConfigManager.get_provider_video_config( + model=None, + provider=litellm.LlmProviders(custom_llm_provider), + ) ) if provider_config is None: