diff --git a/litellm/proxy/common_utils/encrypt_decrypt_utils.py b/litellm/proxy/common_utils/encrypt_decrypt_utils.py index a5da5798f47..7d5f6fec54f 100644 --- a/litellm/proxy/common_utils/encrypt_decrypt_utils.py +++ b/litellm/proxy/common_utils/encrypt_decrypt_utils.py @@ -1,19 +1,12 @@ import base64 -import os from typing import Literal, Optional from litellm._logging import verbose_proxy_logger +from litellm.proxy.common_utils.signing_key_utils import get_proxy_signing_key def _get_salt_key(): - from litellm.proxy.proxy_server import master_key - - salt_key = os.getenv("LITELLM_SALT_KEY", None) - - if salt_key is None: - salt_key = master_key - - return salt_key + return get_proxy_signing_key() def encrypt_value_helper(value: str, new_encryption_key: Optional[str] = None): diff --git a/litellm/proxy/common_utils/signing_key_utils.py b/litellm/proxy/common_utils/signing_key_utils.py new file mode 100644 index 00000000000..f74498cd6f6 --- /dev/null +++ b/litellm/proxy/common_utils/signing_key_utils.py @@ -0,0 +1,17 @@ +import os +import sys +from typing import Optional + + +def get_proxy_signing_key() -> Optional[str]: + salt_key = os.getenv("LITELLM_SALT_KEY") + if salt_key is not None: + return salt_key + + proxy_server_module = sys.modules.get("litellm.proxy.proxy_server") + if proxy_server_module is not None: + proxy_master_key = getattr(proxy_server_module, "master_key", None) + if isinstance(proxy_master_key, str): + return proxy_master_key + + return os.getenv("LITELLM_MASTER_KEY") diff --git a/litellm/proxy/health_endpoints/_health_endpoints.py b/litellm/proxy/health_endpoints/_health_endpoints.py index 8a09edfd4c4..1e8961b2c60 100644 --- a/litellm/proxy/health_endpoints/_health_endpoints.py +++ b/litellm/proxy/health_endpoints/_health_endpoints.py @@ -1,5 +1,6 @@ import asyncio import copy +import json import logging import os import time @@ -26,6 +27,7 @@ WebhookEvent, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler from litellm.proxy.health_check import ( _clean_endpoint_data, @@ -41,6 +43,44 @@ #### Health ENDPOINTS #### +def _str_to_bool(value: Optional[str]) -> Optional[bool]: + if value is None: + return None + + normalized_value = value.strip().lower() + if normalized_value == "true": + return True + if normalized_value == "false": + return False + return None + + +def _get_env_secret( + secret_name: str, default_value: Optional[Union[str, bool]] = None +) -> Optional[Union[str, bool]]: + if secret_name.startswith("os.environ/"): + secret_name = secret_name.replace("os.environ/", "") + + secret_value = os.getenv(secret_name) + if secret_value is None: + return default_value + + return secret_value + + +def get_secret_bool( + secret_name: str, default_value: Optional[bool] = None +) -> Optional[bool]: + secret_value = _get_env_secret(secret_name=secret_name) + if secret_value is None: + return default_value + + if isinstance(secret_value, bool): + return secret_value + + return _str_to_bool(secret_value) + + def _resolve_os_environ_variables(params: dict) -> dict: """ Resolve ``os.environ/`` environment variables in ``litellm_params``. @@ -145,6 +185,90 @@ def get_callback_identifier(callback): return callback_name(callback) +def _parse_config_row_param_value(param_value: Any) -> dict: + if param_value is None: + return {} + + if isinstance(param_value, str): + try: + parsed_value = json.loads(param_value) + except json.JSONDecodeError: + return {} + return parsed_value if isinstance(parsed_value, dict) else {} + + if isinstance(param_value, dict): + return dict(param_value) + + try: + parsed_value = dict(param_value) + except (TypeError, ValueError): + return {} + + return parsed_value if isinstance(parsed_value, dict) else {} + + +def _is_truthy_config_flag(value: Any) -> bool: + if isinstance(value, bool): + return value + + if isinstance(value, str): + return value.strip().lower() == "true" + + if value is None: + return False + + return bool(value) + + +async def _resolve_test_email_address(prisma_client: Any) -> Optional[str]: + test_email_address = os.getenv("TEST_EMAIL_ADDRESS") + + try: + store_model_in_db = ( + get_secret_bool("STORE_MODEL_IN_DB", default_value=False) is True + ) + + if not store_model_in_db and prisma_client is not None: + general_settings_row = await prisma_client.db.litellm_config.find_unique( + where={"param_name": "general_settings"} + ) + general_settings = _parse_config_row_param_value( + getattr(general_settings_row, "param_value", None) + ) + store_model_in_db = _is_truthy_config_flag( + general_settings.get("store_model_in_db") + ) + + if not store_model_in_db or prisma_client is None: + return test_email_address + + environment_variables_row = await prisma_client.db.litellm_config.find_unique( + where={"param_name": "environment_variables"} + ) + environment_variables = _parse_config_row_param_value( + getattr(environment_variables_row, "param_value", None) + ) + db_test_email_address = environment_variables.get("TEST_EMAIL_ADDRESS") + + if db_test_email_address is None: + return test_email_address + + decrypted_test_email_address = decrypt_value_helper( + value=db_test_email_address, + key="TEST_EMAIL_ADDRESS", + exception_type="debug", + return_original_value=True, + ) + + return decrypted_test_email_address or test_email_address + except Exception as e: + verbose_proxy_logger.debug( + "Falling back to TEST_EMAIL_ADDRESS from env after DB lookup failed: %s", + str(e), + ) + return test_email_address + + router = APIRouter() services = Union[ Literal[ @@ -447,7 +571,7 @@ async def health_services_endpoint( # noqa: PLR0915 spend=0, max_budget=0, user_id=user_api_key_dict.user_id, - user_email=os.getenv("TEST_EMAIL_ADDRESS"), + user_email=await _resolve_test_email_address(prisma_client), team_id=user_api_key_dict.team_id, ) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 85a12f70f58..eb67763386b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -13047,11 +13047,11 @@ def normalize_callback(callback): _value = os.getenv("SLACK_WEBHOOK_URL", None) _slack_env_vars[_var] = _value else: - # decode + decrypt the value - _decrypted_value = decrypt_value_helper( - value=env_variable, key=_var + _slack_env_vars[_var] = decrypt_value_helper( + value=env_variable, + key=_var, + return_original_value=True, ) - _slack_env_vars[_var] = _decrypted_value _alerting_types = proxy_logging_obj.slack_alerting_instance.alert_types _all_alert_types = ( @@ -13085,8 +13085,15 @@ def normalize_callback(callback): if env_variable is None: _email_env_vars[_var] = None else: - # decode + decrypt the value - _decrypted_value = decrypt_value_helper(value=env_variable, key=_var) + # Use return_original_value=True so this works for both: + # - DB mode: values already decrypted by _update_config_from_db → decryption + # fails gracefully and returns the original plaintext value + # - YAML mode: values still encrypted in config file → decrypted here + _decrypted_value = decrypt_value_helper( + value=env_variable, + key=_var, + return_original_value=True, + ) _email_env_vars[_var] = _decrypted_value alerting_data.append( diff --git a/tests/proxy_unit_tests/test_proxy_server.py b/tests/proxy_unit_tests/test_proxy_server.py index 7da4d41fbf1..dce2eb83250 100644 --- a/tests/proxy_unit_tests/test_proxy_server.py +++ b/tests/proxy_unit_tests/test_proxy_server.py @@ -2743,6 +2743,73 @@ async def test_get_config_callbacks_environment_variables(client_no_auth): assert otel_vars["OTEL_HEADERS"] == "key=value" +@pytest.mark.asyncio +async def test_get_config_callbacks_email_and_slack_values_are_not_decrypted_again( + client_no_auth, +): + """ + Test that /get/config/callbacks returns already-decrypted email/slack values as-is. + + decrypt_value_helper is called with return_original_value=True, so for already-plaintext + values (DB mode: decrypted by _update_config_from_db) it returns the original value + unchanged. For encrypted values (YAML mode) it properly decrypts them. + """ + mock_config_data = { + "litellm_settings": {}, + "environment_variables": { + "SLACK_WEBHOOK_URL": "https://hooks.slack.com/services/test/webhook", + "SMTP_HOST": "10.16.68.20", + "SMTP_PORT": "587", + "SMTP_USERNAME": "smtp-user", + "SMTP_PASSWORD": "smtp-password", + "SMTP_SENDER_EMAIL": "alerts@example.com", + "TEST_EMAIL_ADDRESS": "ops@example.com", + "EMAIL_LOGO_URL": "https://example.com/logo.png", + "EMAIL_SUPPORT_CONTACT": "support@example.com", + }, + "general_settings": {"alerting": ["slack"]}, + } + + proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config") + + # Simulate return_original_value=True behaviour: return the value as-is (already plaintext) + def fake_decrypt(value, key, return_original_value=False, **kwargs): + return value + + with patch.object( + proxy_config, "get_config", new=AsyncMock(return_value=mock_config_data) + ), patch( + "litellm.proxy.proxy_server.decrypt_value_helper", + side_effect=fake_decrypt, + ) as decrypt_mock: + response = client_no_auth.get("/get/config/callbacks") + + assert response.status_code == 200 + result = response.json() + alerts = result["alerts"] + + slack_alert = next((alert for alert in alerts if alert["name"] == "slack"), None) + assert slack_alert is not None + assert slack_alert["variables"] == { + "SLACK_WEBHOOK_URL": "https://hooks.slack.com/services/test/webhook" + } + + email_alert = next((alert for alert in alerts if alert["name"] == "email"), None) + assert email_alert is not None + assert email_alert["variables"] == { + "SMTP_HOST": "10.16.68.20", + "SMTP_PORT": "587", + "SMTP_USERNAME": "smtp-user", + "SMTP_PASSWORD": "smtp-password", + "SMTP_SENDER_EMAIL": "alerts@example.com", + "TEST_EMAIL_ADDRESS": "ops@example.com", + "EMAIL_LOGO_URL": "https://example.com/logo.png", + "EMAIL_SUPPORT_CONTACT": "support@example.com", + } + # decrypt_value_helper is called once per SMTP var + once for SLACK_WEBHOOK_URL + assert decrypt_mock.call_count == len(mock_config_data["environment_variables"]) + + @pytest.mark.asyncio async def test_update_config_success_callback_normalization(): """ diff --git a/tests/test_litellm/proxy/common_utils/test_signing_key_utils.py b/tests/test_litellm/proxy/common_utils/test_signing_key_utils.py new file mode 100644 index 00000000000..33c22b66595 --- /dev/null +++ b/tests/test_litellm/proxy/common_utils/test_signing_key_utils.py @@ -0,0 +1,36 @@ +import sys +from types import SimpleNamespace + +from litellm.proxy.common_utils.signing_key_utils import get_proxy_signing_key + + +def test_get_proxy_signing_key_prefers_salt_key(monkeypatch): + monkeypatch.setenv("LITELLM_SALT_KEY", "salt-key") + monkeypatch.setenv("LITELLM_MASTER_KEY", "env-master-key") + monkeypatch.setitem( + sys.modules, + "litellm.proxy.proxy_server", + SimpleNamespace(master_key="proxy-master-key"), + ) + + assert get_proxy_signing_key() == "salt-key" + + +def test_get_proxy_signing_key_uses_loaded_proxy_server_master_key(monkeypatch): + monkeypatch.delenv("LITELLM_SALT_KEY", raising=False) + monkeypatch.delenv("LITELLM_MASTER_KEY", raising=False) + monkeypatch.setitem( + sys.modules, + "litellm.proxy.proxy_server", + SimpleNamespace(master_key="proxy-master-key"), + ) + + assert get_proxy_signing_key() == "proxy-master-key" + + +def test_get_proxy_signing_key_falls_back_to_env_master_key(monkeypatch): + monkeypatch.delenv("LITELLM_SALT_KEY", raising=False) + monkeypatch.setenv("LITELLM_MASTER_KEY", "env-master-key") + monkeypatch.delitem(sys.modules, "litellm.proxy.proxy_server", raising=False) + + assert get_proxy_signing_key() == "env-master-key" diff --git a/tests/test_litellm/proxy/health_endpoints/test_health_endpoints.py b/tests/test_litellm/proxy/health_endpoints/test_health_endpoints.py index bc3aec58991..7b5fb346249 100644 --- a/tests/test_litellm/proxy/health_endpoints/test_health_endpoints.py +++ b/tests/test_litellm/proxy/health_endpoints/test_health_endpoints.py @@ -1,3 +1,4 @@ +import json import os import sys import time @@ -16,6 +17,7 @@ from litellm.proxy.health_endpoints._health_endpoints import ( _db_health_readiness_check, + _resolve_os_environ_variables, get_callback_identifier, health_license_endpoint, health_services_endpoint, @@ -337,6 +339,252 @@ async def test_health_services_endpoint_sqs(status, error_message): mock_instance.async_health_check.assert_awaited_once() +def test_resolve_os_environ_variables_should_use_secret_manager_get_secret(): + params = { + "api_key": "os.environ/TEST_API_KEY", + "api_base": "https://example.com", + } + + with patch( + "litellm.proxy.health_endpoints._health_endpoints.get_secret", + return_value="resolved-secret-value", + ) as mock_get_secret: + result = _resolve_os_environ_variables(params) + + assert result == { + "api_key": "resolved-secret-value", + "api_base": "https://example.com", + } + mock_get_secret.assert_called_once_with("os.environ/TEST_API_KEY") + + +def test_resolve_os_environ_variables_should_resolve_nested_dicts_and_lists(): + params = { + "api_key": "os.environ/ROOT_SECRET", + "headers": { + "Authorization": "os.environ/AUTH_SECRET", + "static": "value", + }, + "fallbacks": [ + "os.environ/FALLBACK_SECRET", + { + "nested_list_key": "os.environ/NESTED_LIST_SECRET", + }, + ["os.environ/DEEP_LIST_SECRET", "plain-value"], + ], + } + + resolved_values = { + "os.environ/ROOT_SECRET": "root-secret", + "os.environ/AUTH_SECRET": "auth-secret", + "os.environ/FALLBACK_SECRET": "fallback-secret", + "os.environ/NESTED_LIST_SECRET": "nested-list-secret", + "os.environ/DEEP_LIST_SECRET": "deep-list-secret", + } + + with patch( + "litellm.proxy.health_endpoints._health_endpoints.get_secret", + side_effect=lambda secret_name: resolved_values[secret_name], + ) as mock_get_secret: + result = _resolve_os_environ_variables(params) + + assert result == { + "api_key": "root-secret", + "headers": { + "Authorization": "auth-secret", + "static": "value", + }, + "fallbacks": [ + "fallback-secret", + {"nested_list_key": "nested-list-secret"}, + ["deep-list-secret", "plain-value"], + ], + } + assert mock_get_secret.call_count == 5 + + +@pytest.mark.asyncio +async def test_health_services_endpoint_email_should_use_test_email_address_from_db_when_store_model_in_db_enabled(): + mock_prisma = MagicMock() + mock_prisma.db.litellm_config.find_unique = AsyncMock( + side_effect=[ + SimpleNamespace(param_value={"store_model_in_db": True}), + SimpleNamespace(param_value={"TEST_EMAIL_ADDRESS": "encrypted-db-value"}), + ] + ) + mock_slack_alerting = SimpleNamespace( + send_key_created_or_user_invited_email=AsyncMock() + ) + mock_proxy_logging_obj = SimpleNamespace( + slack_alerting_instance=mock_slack_alerting + ) + mock_user_api_key_dict = SimpleNamespace( + token="test-token", + user_id="test-user", + team_id="test-team", + ) + + with patch.dict(os.environ, {"TEST_EMAIL_ADDRESS": "env@example.com"}), patch( + "litellm.proxy.proxy_server.general_settings", + {}, + ), patch( + "litellm.proxy.proxy_server.prisma_client", + mock_prisma, + ), patch( + "litellm.proxy.proxy_server.proxy_logging_obj", + mock_proxy_logging_obj, + ), patch( + "litellm.proxy.health_endpoints._health_endpoints.get_secret_bool", + return_value=False, + ), patch( + "litellm.proxy.health_endpoints._health_endpoints.decrypt_value_helper", + return_value="db@example.com", + ): + result = await health_services_endpoint( + service="email", + user_api_key_dict=mock_user_api_key_dict, + ) + + assert result["status"] == "success" + assert ( + mock_prisma.db.litellm_config.find_unique.await_args_list[0].kwargs["where"] + == {"param_name": "general_settings"} + ) + assert ( + mock_prisma.db.litellm_config.find_unique.await_args_list[1].kwargs["where"] + == {"param_name": "environment_variables"} + ) + webhook_event = ( + mock_slack_alerting.send_key_created_or_user_invited_email.await_args.kwargs[ + "webhook_event" + ] + ) + assert webhook_event.user_email == "db@example.com" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "store_model_in_db_secret,general_settings_row,environment_variables_row,expected_db_calls", + [ + (False, SimpleNamespace(param_value={"store_model_in_db": False}), None, 1), + (True, None, None, 1), + ], + ids=["db-disabled", "config-row-missing"], +) +async def test_health_services_endpoint_email_should_fall_back_to_env_test_email_address_when_db_disabled_or_missing( + store_model_in_db_secret, + general_settings_row, + environment_variables_row, + expected_db_calls, +): + mock_prisma = MagicMock() + db_rows = [] + if general_settings_row is not None: + db_rows.append(general_settings_row) + if store_model_in_db_secret: + db_rows.append(environment_variables_row) + mock_prisma.db.litellm_config.find_unique = AsyncMock(side_effect=db_rows) + mock_slack_alerting = SimpleNamespace( + send_key_created_or_user_invited_email=AsyncMock() + ) + mock_proxy_logging_obj = SimpleNamespace( + slack_alerting_instance=mock_slack_alerting + ) + mock_user_api_key_dict = SimpleNamespace( + token="test-token", + user_id="test-user", + team_id="test-team", + ) + + with patch.dict(os.environ, {"TEST_EMAIL_ADDRESS": "env@example.com"}), patch( + "litellm.proxy.proxy_server.general_settings", + {}, + ), patch( + "litellm.proxy.proxy_server.prisma_client", + mock_prisma, + ), patch( + "litellm.proxy.proxy_server.proxy_logging_obj", + mock_proxy_logging_obj, + ), patch( + "litellm.proxy.health_endpoints._health_endpoints.get_secret_bool", + return_value=store_model_in_db_secret, + ), patch( + "litellm.proxy.health_endpoints._health_endpoints.decrypt_value_helper" + ) as decrypt_mock: + result = await health_services_endpoint( + service="email", + user_api_key_dict=mock_user_api_key_dict, + ) + + assert result["status"] == "success" + webhook_event = ( + mock_slack_alerting.send_key_created_or_user_invited_email.await_args.kwargs[ + "webhook_event" + ] + ) + assert webhook_event.user_email == "env@example.com" + assert mock_prisma.db.litellm_config.find_unique.await_count == expected_db_calls + decrypt_mock.assert_not_called() + + +@pytest.mark.asyncio +async def test_health_services_endpoint_email_should_accept_json_string_environment_variables(): + mock_prisma = MagicMock() + mock_prisma.db.litellm_config.find_unique = AsyncMock( + return_value=SimpleNamespace( + param_value=json.dumps( + {"TEST_EMAIL_ADDRESS": "json-string-db-value"} + ) + ) + ) + mock_slack_alerting = SimpleNamespace( + send_key_created_or_user_invited_email=AsyncMock() + ) + mock_proxy_logging_obj = SimpleNamespace( + slack_alerting_instance=mock_slack_alerting + ) + mock_user_api_key_dict = SimpleNamespace( + token="test-token", + user_id="test-user", + team_id="test-team", + ) + + with patch.dict(os.environ, {"TEST_EMAIL_ADDRESS": "env@example.com"}), patch( + "litellm.proxy.proxy_server.general_settings", + {}, + ), patch( + "litellm.proxy.proxy_server.prisma_client", + mock_prisma, + ), patch( + "litellm.proxy.proxy_server.proxy_logging_obj", + mock_proxy_logging_obj, + ), patch( + "litellm.proxy.health_endpoints._health_endpoints.get_secret_bool", + return_value=True, + ), patch( + "litellm.proxy.health_endpoints._health_endpoints.decrypt_value_helper", + return_value="json@example.com", + ) as decrypt_mock: + result = await health_services_endpoint( + service="email", + user_api_key_dict=mock_user_api_key_dict, + ) + + assert result["status"] == "success" + decrypt_mock.assert_called_once_with( + value="json-string-db-value", + key="TEST_EMAIL_ADDRESS", + exception_type="debug", + return_original_value=True, + ) + webhook_event = ( + mock_slack_alerting.send_key_created_or_user_invited_email.await_args.kwargs[ + "webhook_event" + ] + ) + assert webhook_event.user_email == "json@example.com" + + @pytest.mark.asyncio async def test_health_license_endpoint_with_active_license(): license_data = { @@ -814,3 +1062,357 @@ def my_callback_function(): result = get_callback_identifier(my_callback_function) # Should fall back to callback_name() which returns __name__ assert result == "my_callback_function" + + +# ────────────────────────────────────────────────────────── +# _str_to_bool / _get_env_secret / get_secret_bool +# ────────────────────────────────────────────────────────── + + +def test_str_to_bool(): + from litellm.proxy.health_endpoints._health_endpoints import _str_to_bool + + assert _str_to_bool(None) is None + assert _str_to_bool("true") is True + assert _str_to_bool("True") is True + assert _str_to_bool(" TRUE ") is True + assert _str_to_bool("false") is False + assert _str_to_bool("False") is False + assert _str_to_bool(" FALSE ") is False + assert _str_to_bool("yes") is None + assert _str_to_bool("1") is None + assert _str_to_bool("") is None + + +def test_get_env_secret(monkeypatch): + from litellm.proxy.health_endpoints._health_endpoints import _get_env_secret + + # Not set – returns default + monkeypatch.delenv("MY_SECRET", raising=False) + assert _get_env_secret("MY_SECRET") is None + assert _get_env_secret("MY_SECRET", default_value="default") == "default" + + # Set via plain name + monkeypatch.setenv("MY_SECRET", "hello") + assert _get_env_secret("MY_SECRET") == "hello" + + # Set via os.environ/ prefix + assert _get_env_secret("os.environ/MY_SECRET") == "hello" + + # Not set, os.environ/ prefix – returns default + monkeypatch.delenv("MISSING_KEY", raising=False) + assert _get_env_secret("os.environ/MISSING_KEY", default_value=False) is False + + +def test_get_secret_bool(monkeypatch): + from litellm.proxy.health_endpoints._health_endpoints import get_secret_bool + + # Not set – returns default + monkeypatch.delenv("BOOL_FLAG", raising=False) + assert get_secret_bool("BOOL_FLAG") is None + assert get_secret_bool("BOOL_FLAG", default_value=True) is True + + # Set to "true" + monkeypatch.setenv("BOOL_FLAG", "true") + assert get_secret_bool("BOOL_FLAG") is True + + # Set to "false" + monkeypatch.setenv("BOOL_FLAG", "false") + assert get_secret_bool("BOOL_FLAG") is False + + # Set to unrecognised value – returns None (not default, because env var IS set) + monkeypatch.setenv("BOOL_FLAG", "maybe") + assert get_secret_bool("BOOL_FLAG", default_value=True) is None + + +# ────────────────────────────────────────────────────────── +# _parse_config_row_param_value +# ────────────────────────────────────────────────────────── + + +def test_parse_config_row_param_value(): + from litellm.proxy.health_endpoints._health_endpoints import ( + _parse_config_row_param_value, + ) + + # None → empty dict + assert _parse_config_row_param_value(None) == {} + + # Dict → copy of dict + assert _parse_config_row_param_value({"a": 1}) == {"a": 1} + + # Valid JSON string + assert _parse_config_row_param_value('{"x": 2}') == {"x": 2} + + # Invalid JSON string → empty dict + assert _parse_config_row_param_value("not-json") == {} + + # JSON string that parses to non-dict (list) → empty dict + assert _parse_config_row_param_value("[1, 2, 3]") == {} + + # Arbitrary non-convertible type → empty dict + assert _parse_config_row_param_value(12345) == {} + + +# ────────────────────────────────────────────────────────── +# _build_model_param_to_info_mapping +# ────────────────────────────────────────────────────────── + + +def test_build_model_param_to_info_mapping_basic(): + from litellm.proxy.health_endpoints._health_endpoints import ( + _build_model_param_to_info_mapping, + ) + + model_list = [ + { + "model_name": "gpt-4", + "model_info": {"id": "model-id-1"}, + "litellm_params": {"model": "openai/gpt-4"}, + }, + { + "model_name": "gpt-3.5", + "model_info": {"id": "model-id-2"}, + "litellm_params": {"model": "openai/gpt-3.5-turbo"}, + }, + ] + + result = _build_model_param_to_info_mapping(model_list) + + assert "openai/gpt-4" in result + assert result["openai/gpt-4"] == [{"model_name": "gpt-4", "model_id": "model-id-1"}] + assert result["openai/gpt-3.5-turbo"] == [ + {"model_name": "gpt-3.5", "model_id": "model-id-2"} + ] + + +def test_build_model_param_to_info_mapping_multiple_models_same_param(): + from litellm.proxy.health_endpoints._health_endpoints import ( + _build_model_param_to_info_mapping, + ) + + model_list = [ + { + "model_name": "prod-gpt4", + "model_info": {"id": "id-a"}, + "litellm_params": {"model": "openai/gpt-4"}, + }, + { + "model_name": "staging-gpt4", + "model_info": {"id": "id-b"}, + "litellm_params": {"model": "openai/gpt-4"}, + }, + ] + + result = _build_model_param_to_info_mapping(model_list) + + assert len(result["openai/gpt-4"]) == 2 + names = {entry["model_name"] for entry in result["openai/gpt-4"]} + assert names == {"prod-gpt4", "staging-gpt4"} + + +def test_build_model_param_to_info_mapping_skips_missing_fields(): + from litellm.proxy.health_endpoints._health_endpoints import ( + _build_model_param_to_info_mapping, + ) + + # Missing model_name or litellm_params.model → should be skipped + model_list = [ + { + "model_info": {"id": "id-1"}, + "litellm_params": {"model": "openai/gpt-4"}, + # no model_name + }, + { + "model_name": "gpt-4", + "model_info": {"id": "id-2"}, + "litellm_params": {}, # no model key + }, + ] + + result = _build_model_param_to_info_mapping(model_list) + assert result == {} + + +# ────────────────────────────────────────────────────────── +# _aggregate_health_check_results +# ────────────────────────────────────────────────────────── + + +def test_aggregate_health_check_results_healthy(): + from litellm.proxy.health_endpoints._health_endpoints import ( + _aggregate_health_check_results, + ) + + model_param_to_info = { + "openai/gpt-4": [{"model_name": "gpt-4", "model_id": "id-1"}] + } + healthy_endpoints = [{"model": "openai/gpt-4", "latency": 0.1}] + + result = _aggregate_health_check_results(model_param_to_info, healthy_endpoints, []) + + key = ("id-1", "gpt-4") + assert key in result + assert result[key]["healthy_count"] == 1 + assert result[key]["unhealthy_count"] == 0 + assert result[key]["error_message"] is None + + +def test_aggregate_health_check_results_unhealthy(): + from litellm.proxy.health_endpoints._health_endpoints import ( + _aggregate_health_check_results, + ) + + model_param_to_info = { + "openai/gpt-4": [{"model_name": "gpt-4", "model_id": "id-1"}] + } + unhealthy_endpoints = [{"model": "openai/gpt-4", "error": "connection refused"}] + + result = _aggregate_health_check_results(model_param_to_info, [], unhealthy_endpoints) + + key = ("id-1", "gpt-4") + assert key in result + assert result[key]["unhealthy_count"] == 1 + assert result[key]["error_message"] == "connection refused" + + +def test_aggregate_health_check_results_unknown_model_skipped(): + from litellm.proxy.health_endpoints._health_endpoints import ( + _aggregate_health_check_results, + ) + + model_param_to_info = {} + healthy_endpoints = [{"model": "openai/unknown"}] + + result = _aggregate_health_check_results(model_param_to_info, healthy_endpoints, []) + assert result == {} + + +# ────────────────────────────────────────────────────────── +# _save_background_health_checks_to_db +# ────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_save_background_health_checks_to_db_no_prisma(): + """When prisma_client is None the function returns early without error.""" + from litellm.proxy.health_endpoints._health_endpoints import ( + _save_background_health_checks_to_db, + ) + + # Should not raise + await _save_background_health_checks_to_db( + prisma_client=None, + model_list=[], + healthy_endpoints=[], + unhealthy_endpoints=[], + start_time=time.time(), + ) + + +@pytest.mark.asyncio +async def test_save_background_health_checks_to_db_saves_on_status_change(): + """When status changes, a DB write task should be created.""" + from litellm.proxy.health_endpoints._health_endpoints import ( + _save_background_health_checks_to_db, + ) + + model_list = [ + { + "model_name": "gpt-4", + "model_info": {"id": "model-id-1"}, + "litellm_params": {"model": "openai/gpt-4"}, + } + ] + healthy_endpoints = [{"model": "openai/gpt-4"}] + + # Simulate last check was "unhealthy" → status changes to "healthy" + last_check = SimpleNamespace( + model_id="model-id-1", + model_name="gpt-4", + status="unhealthy", + checked_at=None, + ) + + mock_prisma = MagicMock() + mock_prisma.get_all_latest_health_checks = AsyncMock(return_value=[last_check]) + mock_prisma.save_health_check_result = AsyncMock(return_value=None) + + with patch("asyncio.create_task") as mock_create_task: + await _save_background_health_checks_to_db( + prisma_client=mock_prisma, + model_list=model_list, + healthy_endpoints=healthy_endpoints, + unhealthy_endpoints=[], + start_time=time.time(), + ) + mock_create_task.assert_called_once() + + +@pytest.mark.asyncio +async def test_save_background_health_checks_to_db_skips_when_status_unchanged_recently(): + """When status is unchanged and last check was recent (<1 hour), no write.""" + from datetime import timezone + + from litellm.proxy.health_endpoints._health_endpoints import ( + _save_background_health_checks_to_db, + ) + + model_list = [ + { + "model_name": "gpt-4", + "model_info": {"id": "model-id-1"}, + "litellm_params": {"model": "openai/gpt-4"}, + } + ] + healthy_endpoints = [{"model": "openai/gpt-4"}] + + recent_time = datetime.now(timezone.utc) + last_check = SimpleNamespace( + model_id="model-id-1", + model_name="gpt-4", + status="healthy", # same as current + checked_at=recent_time, + ) + + mock_prisma = MagicMock() + mock_prisma.get_all_latest_health_checks = AsyncMock(return_value=[last_check]) + mock_prisma.save_health_check_result = AsyncMock(return_value=None) + + with patch("asyncio.create_task") as mock_create_task: + await _save_background_health_checks_to_db( + prisma_client=mock_prisma, + model_list=model_list, + healthy_endpoints=healthy_endpoints, + unhealthy_endpoints=[], + start_time=time.time(), + ) + mock_create_task.assert_not_called() + + +@pytest.mark.asyncio +async def test_save_background_health_checks_to_db_handles_db_error(): + """DB errors are caught and swallowed (health check should not fail).""" + from litellm.proxy.health_endpoints._health_endpoints import ( + _save_background_health_checks_to_db, + ) + + mock_prisma = MagicMock() + mock_prisma.get_all_latest_health_checks = AsyncMock( + side_effect=Exception("DB connection lost") + ) + + # Should not raise + await _save_background_health_checks_to_db( + prisma_client=mock_prisma, + model_list=[ + { + "model_name": "gpt-4", + "model_info": {"id": "id-1"}, + "litellm_params": {"model": "openai/gpt-4"}, + } + ], + healthy_endpoints=[{"model": "openai/gpt-4"}], + unhealthy_endpoints=[], + start_time=time.time(), + ) diff --git a/ui/litellm-dashboard/src/components/email_settings.tsx b/ui/litellm-dashboard/src/components/email_settings.tsx index 53d08ad64f5..8fd3a37398c 100644 --- a/ui/litellm-dashboard/src/components/email_settings.tsx +++ b/ui/litellm-dashboard/src/components/email_settings.tsx @@ -14,7 +14,7 @@ interface EmailSettingsProps { } const EmailSettings: React.FC = ({ accessToken, premiumUser, alerts }) => { - const handleSaveEmailSettings = async () => { + const handleSaveEmailSettings = async ({ silent = false }: { silent?: boolean } = {}) => { if (!accessToken) { return; } @@ -43,9 +43,15 @@ const EmailSettings: React.FC = ({ accessToken, premiumUser, }; try { await setCallbacksCall(accessToken, payload); - NotificationManager.success("Email settings updated successfully"); + if (!silent) { + NotificationManager.success("Email settings updated successfully"); + } } catch (error) { - NotificationManager.fromBackend(error); + if (!silent) { + NotificationManager.fromBackend(error); + } + // In silent mode (called from test flow) swallow the error so that + // the test can still proceed using env-var / YAML config values. } }; @@ -163,6 +169,12 @@ const EmailSettings: React.FC = ({ accessToken, premiumUser, onClick={async () => { if (!accessToken) return; try { + // Silently attempt to persist the current form values so the + // backend can read TEST_EMAIL_ADDRESS from the DB (DB mode). + // If saving is not supported (e.g. STORE_MODEL_IN_DB=False / + // YAML mode), this is a no-op and the backend will fall back to + // the TEST_EMAIL_ADDRESS environment variable instead. + await handleSaveEmailSettings({ silent: true }); await serviceHealthCheck(accessToken, "email"); NotificationManager.success("Email test triggered. Check your configured email inbox/logs."); } catch (error) {