Skip to content

Commit 4bbd15f

Browse files
committed
Fix test_async_post_call_success_hook_includes_client_ip_user_agent
1 parent 7e36d47 commit 4bbd15f

File tree

1 file changed

+49
-17
lines changed

1 file changed

+49
-17
lines changed

tests/test_litellm/integrations/test_prometheus_client_ip_user_agent.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from unittest.mock import AsyncMock, MagicMock, patch
2+
13
import pytest
2-
from unittest.mock import MagicMock, patch
4+
35
from litellm.integrations.prometheus import PrometheusLogger
6+
from litellm.proxy._types import UserAPIKeyAuth
47
from litellm.types.integrations.prometheus import (
58
UserAPIKeyLabelValues,
69
)
7-
from litellm.proxy._types import UserAPIKeyAuth
810

911

1012
@pytest.mark.asyncio
@@ -72,10 +74,12 @@ async def test_async_post_call_failure_hook_includes_client_ip_user_agent():
7274
@pytest.mark.asyncio
7375
async def test_async_post_call_success_hook_includes_client_ip_user_agent():
7476
"""
75-
Test that async_post_call_success_hook includes client_ip and user_agent in UserAPIKeyLabelValues
77+
Test that async_log_success_event includes client_ip and user_agent in UserAPIKeyLabelValues.
78+
79+
Note: After PR #21159, the metric increment was moved from async_post_call_success_hook
80+
to async_log_success_event to prevent double-counting.
7681
"""
7782
# Mocking
78-
# Mocking
7983
with patch(
8084
"litellm.integrations.prometheus.PrometheusLogger.__init__", return_value=None
8185
):
@@ -84,27 +88,55 @@ async def test_async_post_call_success_hook_includes_client_ip_user_agent():
8488
logger.get_labels_for_metric = MagicMock(
8589
return_value=["client_ip", "user_agent"]
8690
)
87-
88-
data = {
91+
logger._should_skip_metrics_for_invalid_key = MagicMock(return_value=False)
92+
logger._increment_top_level_request_and_spend_metrics = MagicMock()
93+
logger._increment_token_metrics = MagicMock()
94+
logger._increment_remaining_budget_metrics = AsyncMock()
95+
logger._set_virtual_key_rate_limit_metrics = MagicMock()
96+
logger._set_latency_metrics = MagicMock()
97+
logger.set_llm_deployment_success_metrics = MagicMock()
98+
logger._increment_cache_metrics = MagicMock()
99+
100+
kwargs = {
89101
"model": "gpt-4",
90-
"metadata": {
91-
"requester_ip_address": "192.168.1.1",
92-
"user_agent": "success-agent",
102+
"litellm_params": {
103+
"metadata": {}
104+
},
105+
"start_time": None,
106+
"standard_logging_object": {
107+
"model_group": "gpt-4",
108+
"model_id": "model_1",
109+
"api_base": "http://api.base",
110+
"custom_llm_provider": "openai",
111+
"completion_tokens": 10,
112+
"total_tokens": 20,
113+
"response_cost": 0.01,
114+
"request_tags": [],
115+
"metadata": {
116+
"user_api_key_user_id": "user_1",
117+
"user_api_key_hash": "hash_1",
118+
"user_api_key_alias": "alias_1",
119+
"user_api_key_team_id": "team_1",
120+
"user_api_key_team_alias": "team_alias_1",
121+
"user_api_key_user_email": "test@example.com",
122+
"user_api_key_request_route": "/chat/completions",
123+
"requester_ip_address": "192.168.1.1",
124+
"user_agent": "success-agent",
125+
},
93126
},
94127
}
95-
user_api_key_dict = UserAPIKeyAuth(token="test_token")
96-
response = MagicMock()
97128

98129
# Mock prometheus_label_factory to inspect arguments
99130
with patch(
100131
"litellm.integrations.prometheus.prometheus_label_factory"
101132
) as mock_label_factory:
102133
mock_label_factory.return_value = {}
103134

104-
await logger.async_post_call_success_hook(
105-
data=data,
106-
user_api_key_dict=user_api_key_dict,
107-
response=response,
135+
await logger.async_log_success_event(
136+
kwargs=kwargs,
137+
response_obj=None,
138+
start_time=None,
139+
end_time=None,
108140
)
109141

110142
# Verification
@@ -114,8 +146,8 @@ async def test_async_post_call_success_hook_includes_client_ip_user_agent():
114146
calls = mock_label_factory.call_args_list
115147
found = False
116148
for call in calls:
117-
kwargs = call.kwargs
118-
enum_values = kwargs.get("enum_values")
149+
kwargs_args = call.kwargs
150+
enum_values = kwargs_args.get("enum_values")
119151
if isinstance(enum_values, UserAPIKeyLabelValues):
120152
if (
121153
enum_values.client_ip == "192.168.1.1"

0 commit comments

Comments
 (0)