Skip to content

Filter keys hash and key alias on logs table #10715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 92 additions & 42 deletions litellm/proxy/spend_tracking/spend_management_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,12 +1655,16 @@ async def ui_view_spend_logs( # noqa: PLR0915
),
status_filter: Optional[str] = fastapi.Query(
default=None,
description="Filter logs by status (e.g., success, failure)"
description="Filter logs by status (e.g., success || "" || null, failure)"
),
model: Optional[str] = fastapi.Query( # Add this new parameter
default=None,
description="Filter logs by model name"
),
key_alias: Optional[str] = fastapi.Query(
default=None,
description="Filter logs by key alias"
),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Expand Down Expand Up @@ -1705,56 +1709,102 @@ async def ui_view_spend_logs( # noqa: PLR0915
start_date_iso = start_date_obj.isoformat() # Already in UTC, no need to add Z
end_date_iso = end_date_obj.isoformat() # Already in UTC, no need to add Z

# Build where conditions
where_conditions: Dict[str, Any] = {
"startTime": {"gte": start_date_iso, "lte": end_date_iso} # Ensure date range is always applied
}
if api_key:
where_conditions["api_key"] = api_key
if user_id:
where_conditions["user"] = user_id
# Initialize total_records at the start
total_records = 0


# Calculate offset for pagination
offset = (page - 1) * page_size

# Build dynamic SQL conditions
conditions = []
params: list[Any] = []
param_index = 1 # To keep track of $1, $2, ...

conditions.append(f'"startTime" >= ${param_index}::timestamp')
params.append(start_date_iso)
param_index += 1

conditions.append(f'"startTime" <= ${param_index}::timestamp')
params.append(end_date_iso)
param_index += 1

# Optional filters
if user_id:
conditions.append(f'"user" = ${param_index}')
params.append(user_id)
param_index += 1

if request_id:
where_conditions["request_id"] = request_id
conditions.append(f'"request_id" = ${param_index}')
params.append(request_id)
param_index += 1

if team_id:
where_conditions["team_id"] = team_id
conditions.append(f'"team_id" = ${param_index}')
params.append(team_id)
param_index += 1

if model:
where_conditions["model"] = model
conditions.append(f'"model" = ${param_index}')
params.append(model)
param_index += 1

if min_spend is not None:
where_conditions.setdefault("spend", {}).update({"gte": min_spend})
conditions.append(f'"spend" >= ${param_index}')
params.append(min_spend)
param_index += 1

if max_spend is not None:
where_conditions.setdefault("spend", {}).update({"lte": max_spend})
conditions.append(f'"spend" <= ${param_index}')
params.append(max_spend)
param_index += 1

# Calculate skip value for pagination
skip = (page - 1) * page_size
if api_key:
conditions.append(f"metadata->>'user_api_key' = ${param_index}")
params.append(api_key)
param_index += 1

# Get paginated data
data = await prisma_client.db.litellm_spendlogs.find_many(
where=where_conditions,
order={
"startTime": "desc",
},
skip=skip,
take=page_size,
)
if key_alias:
conditions.append(f"metadata->>'user_api_key_alias' = ${param_index}")
params.append(key_alias)
param_index += 1

if status_filter:
status_lower = status_filter.strip().lower()

def status_filter_fn(row):
metadata = getattr(row, "metadata", {}) or {}
status_val = metadata.get("status") if isinstance(metadata, dict) else None
if status_lower == "failure":
return status_val == "failure"
elif status_lower == "success":
return status_val != "failure"
return True

data = list(filter(status_filter_fn, data))

# Get total count of records
total_records = await prisma_client.db.litellm_spendlogs.count(
where=where_conditions,
)
if status_filter.lower() == "failure":
conditions.append("metadata->>'status' = 'failure'")
else:
conditions.append("(metadata->>'status' IS NULL OR metadata->>'status' != 'failure')")

# Final WHERE clause from dynamic conditions
where_clause = " AND ".join(conditions)

# Add LIMIT and OFFSET as the final two parameters
params += [page_size, offset]
param_index += 2 # Advance the parameter index accordingly

# Dynamically reference the new LIMIT and OFFSET parameters
limit_param = f"${param_index - 2}"
offset_param = f"${param_index - 1}"

# Final dynamic paginated query
raw_query = f"""
SELECT * FROM "LiteLLM_SpendLogs"
WHERE {where_clause}
ORDER BY "startTime" DESC
LIMIT {limit_param} OFFSET {offset_param}
"""
data = await prisma_client.db.query_raw(raw_query, *params)

# Count query (without limit/offset params)
count_query = f"""
SELECT COUNT(*) FROM "LiteLLM_SpendLogs"
WHERE {where_clause}
"""
total_result = await prisma_client.db.query_raw(count_query, *params[:-2]) # Exclude limit/offset
total_records = int(total_result[0]["count"]) if total_result else 0


# Calculate total pages
total_pages = (total_records + page_size - 1) // page_size

Expand Down
107 changes: 64 additions & 43 deletions tests/litellm/proxy/spend_tracking/test_spend_management_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ async def count(self, *args, **kwargs):
return 1
return len(mock_spend_logs)

async def query_raw(self, raw_query, *params):
if "count(*)" in raw_query.lower():
if "test_user_1" in params:
return [{"count": 1}]
return [{"count": len(mock_spend_logs)}]
return [mock_spend_logs[0]] if "test_user_1" in params else mock_spend_logs

class MockPrismaClient:
def __init__(self):
self.db = MockDB()
Expand Down Expand Up @@ -136,7 +143,7 @@ def __init__(self):

@pytest.mark.asyncio
async def test_ui_view_spend_logs_with_team_id(client, monkeypatch):
# Mock data for the test
# Mock data
mock_spend_logs = [
{
"id": "log1",
Expand All @@ -160,44 +167,36 @@ async def test_ui_view_spend_logs_with_team_id(client, monkeypatch):
},
]

# Create a mock prisma client
# Mock database with query_raw
class MockDB:
async def find_many(self, *args, **kwargs):
# Filter based on team_id in the where conditions
if (
"where" in kwargs
and "team_id" in kwargs["where"]
and kwargs["where"]["team_id"] == "team1"
):
return [mock_spend_logs[0]]
return mock_spend_logs

async def count(self, *args, **kwargs):
# Return count based on team_id filter
if (
"where" in kwargs
and "team_id" in kwargs["where"]
and kwargs["where"]["team_id"] == "team1"
):
return 1
return len(mock_spend_logs)
async def query_raw(self, query, *params):
if "count(*)" in query.lower():
# Simulate the count response
for param in params:
if param == "team1":
return [{"count": 1}]
return [{"count": 2}]
else:
for param in params:
if param == "team1":
return [mock_spend_logs[0]]
return mock_spend_logs

class MockPrismaClient:
def __init__(self):
self.db = MockDB()
self.db.litellm_spendlogs = self.db

# Apply the monkeypatch
mock_prisma_client = MockPrismaClient()
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
# Apply monkeypatch to replace the real Prisma client
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", MockPrismaClient())

# Set up test dates
# Build test time range
start_date = (
datetime.datetime.now(timezone.utc) - datetime.timedelta(days=7)
).strftime("%Y-%m-%d %H:%M:%S")
end_date = datetime.datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")

# Make the request with team_id filter
# Make the GET request
response = client.get(
"/spend/logs/ui",
params={
Expand All @@ -208,18 +207,20 @@ def __init__(self):
headers={"Authorization": "Bearer sk-test"},
)

# Assert response
# Assertions
assert response.status_code == 200
data = response.json()

# Verify the filtered data
assert data["total"] == 1
assert len(data["data"]) == 1
assert data["data"][0]["team_id"] == "team1"


@pytest.mark.asyncio
async def test_ui_view_spend_logs_pagination(client, monkeypatch):
import datetime
from datetime import timezone

# Create a larger set of mock data for pagination testing
mock_spend_logs = [
{
Expand All @@ -229,29 +230,36 @@ async def test_ui_view_spend_logs_pagination(client, monkeypatch):
"user": f"test_user_{i % 3}",
"team_id": f"team{i % 2 + 1}",
"spend": 0.05 * i,
"startTime": datetime.datetime.now(timezone.utc).isoformat(),
"startTime": (
datetime.datetime.now(timezone.utc) - datetime.timedelta(minutes=i)
).isoformat(),
"model": "gpt-3.5-turbo" if i % 2 == 0 else "gpt-4",
}
for i in range(1, 26) # 25 records
]

# Create a mock prisma client with pagination support
# Mock DB with query_raw and count support
class MockDB:
async def find_many(self, *args, **kwargs):
# Handle pagination
skip = kwargs.get("skip", 0)
take = kwargs.get("take", 10)
return mock_spend_logs[skip : skip + take]

async def count(self, *args, **kwargs):
return len(mock_spend_logs)
async def query_raw(self, raw_query, *params):
if "count(*)" in raw_query.lower():
# Simulate total count response
return [{"count": len(mock_spend_logs)}]
else:
# Simulate paginated data response
try:
limit = int(params[-2])
offset = int(params[-1])
except (IndexError, ValueError):
limit = 10
offset = 0
return mock_spend_logs[offset: offset + limit]

class MockPrismaClient:
def __init__(self):
self.db = MockDB()
self.db.litellm_spendlogs = self.db

# Apply the monkeypatch
# Monkeypatch the prisma client
mock_prisma_client = MockPrismaClient()
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)

Expand Down Expand Up @@ -300,6 +308,7 @@ def __init__(self):
assert data["page"] == 2



@pytest.mark.asyncio
async def test_ui_view_spend_logs_date_range_filter(client, monkeypatch):
# Create mock data with different dates
Expand Down Expand Up @@ -382,6 +391,21 @@ async def count(self, *args, **kwargs):
logs = await self.find_many(*args, **kwargs)
return len(logs)

async def query_raw(self, raw_query, *params):
# Simulate response to SELECT COUNT(*)
if "COUNT" in raw_query.upper():
filtered_logs = await self.find_many(where={"startTime": {
"gte": params[0],
"lte": params[1],
}})
return [{"count": str(len(filtered_logs))}]
else:
# Simulate raw query returning spend logs
return await self.find_many(where={"startTime": {
"gte": params[0],
"lte": params[1],
}})

class MockPrismaClient:
def __init__(self):
self.db = MockDB()
Expand All @@ -397,10 +421,7 @@ def __init__(self):

response = client.get(
"/spend/logs/ui",
params={
"start_date": start_date,
"end_date": end_date,
},
params={"start_date": start_date, "end_date": end_date},
headers={"Authorization": "Bearer sk-test"},
)

Expand Down
4 changes: 3 additions & 1 deletion ui/litellm-dashboard/src/components/networking.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2149,14 +2149,15 @@ export const uiSpendLogsCall = async (
user_id?: string,
status_filter?: string,
model?: string,
key_alias?: string, // Add key_alias here
) => {
try {
// Construct base URL
let url = proxyBaseUrl ? `${proxyBaseUrl}/spend/logs/ui` : `/spend/logs/ui`;

// Add query parameters if they exist
const queryParams = new URLSearchParams();
if (api_key) queryParams.append('api_key', api_key);
if (api_key) queryParams.append('api_key', api_key); //This is for Key Hash
if (team_id) queryParams.append('team_id', team_id);
if (request_id) queryParams.append('request_id', request_id);
if (start_date) queryParams.append('start_date', start_date);
Expand All @@ -2166,6 +2167,7 @@ export const uiSpendLogsCall = async (
if (user_id) queryParams.append('user_id', user_id);
if (status_filter) queryParams.append('status_filter', status_filter);
if (model) queryParams.append('model', model);
if (key_alias) queryParams.append('key_alias', key_alias);

// Append query parameters to URL if any exist
const queryString = queryParams.toString();
Expand Down
Loading