-
-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Expand file tree
/
Copy pathauth_checks.py
More file actions
3640 lines (3094 loc) · 122 KB
/
auth_checks.py
File metadata and controls
3640 lines (3094 loc) · 122 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# What is this?
## Common auth checks between jwt + key based auth
"""
Got Valid Token from Cache, DB
Run checks for:
1. If user can call model
2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
"""
import asyncio
import os
import re
import time
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
from fastapi import HTTPException, Request, status
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.caching.dual_cache import LimitedSizeOrderedDict
from litellm.constants import (
CLI_JWT_EXPIRATION_HOURS,
CLI_JWT_TOKEN_NAME,
DEFAULT_ACCESS_GROUP_CACHE_TTL,
DEFAULT_IN_MEMORY_TTL,
DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL,
DEFAULT_MAX_RECURSE_DEPTH,
EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE,
)
from litellm.litellm_core_utils.dd_tracing import tracer
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.proxy._types import (
RBAC_ROLES,
CallInfo,
LiteLLM_AccessGroupTable,
LiteLLM_BudgetTable,
LiteLLM_EndUserTable,
Litellm_EntityType,
LiteLLM_JWTAuth,
LiteLLM_ObjectPermissionTable,
LiteLLM_OrganizationMembershipTable,
LiteLLM_OrganizationTable,
LiteLLM_ProjectTableCachedObj,
LiteLLM_TagTable,
LiteLLM_TeamMembership,
LiteLLM_TeamTable,
LiteLLM_TeamTableCachedObj,
LiteLLM_UserTable,
LiteLLMRoutes,
LitellmUserRoles,
NewTeamRequest,
ProxyErrorTypes,
ProxyException,
RoleBasedPermissions,
SpecialModelNames,
UserAPIKeyAuth,
)
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
from litellm.proxy.guardrails.tool_name_extraction import (
TOOL_CAPABLE_CALL_TYPES,
extract_request_tool_names,
)
from litellm.proxy.route_llm_request import route_request
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
from litellm.router import Router
from litellm.utils import get_utc_datetime
from .auth_checks_organization import organization_role_based_access_check
from .auth_utils import get_model_from_request
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
last_db_access_time = LimitedSizeOrderedDict(max_size=100)
db_cache_expiry = DEFAULT_IN_MEMORY_TTL # refresh every 5s
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
def _log_budget_lookup_failure(entity: str, error: Exception) -> None:
"""
Log a warning when budget lookup fails; cache will not be populated.
Skips logging for expected "user not found" cases (bare Exception from
get_user_object when user_id_upsert=False). Adds a schema migration hint
when the error appears schema-related.
"""
# Skip logging for expected "user not found" - not caching is correct
if str(error) == "" and type(error).__name__ == "Exception":
return
err_str = str(error).lower()
hint = ""
if any(
x in err_str
for x in ("column", "schema", "does not exist", "prisma", "migrate")
):
hint = (
" Run `prisma db push` or `prisma migrate deploy` to fix schema mismatches."
)
verbose_proxy_logger.error(
f"Budget lookup failed for {entity}; cache will not be populated. "
f"Each request will hit the database. Error: {error}.{hint}"
)
def _is_model_cost_zero(
model: Optional[Union[str, List[str]]], llm_router: Optional[Router]
) -> bool:
"""
Check if a model has zero cost (no configured pricing).
Uses the router's get_model_group_info method to get pricing information.
Args:
model: The model name or list of model names
llm_router: The LiteLLM router instance
Returns:
bool: True if all costs for the model are zero, False otherwise
"""
if model is None or llm_router is None:
return False
# Handle list of models
model_list = [model] if isinstance(model, str) else model
for model_name in model_list:
try:
# Use router's get_model_group_info method directly for better reliability
model_group_info = llm_router.get_model_group_info(model_group=model_name)
if model_group_info is None:
# Model not found or no pricing info available
# Conservative approach: assume it has cost
verbose_proxy_logger.debug(
f"No model group info found for {model_name}, assuming it has cost"
)
return False
# Check costs for this model
# Only allow bypass if BOTH costs are explicitly set to 0 (not None)
input_cost = model_group_info.input_cost_per_token
output_cost = model_group_info.output_cost_per_token
# If costs are not explicitly configured (None), assume it has cost
if input_cost is None or output_cost is None:
verbose_proxy_logger.debug(
f"Model {model_name} has undefined cost (input: {input_cost}, output: {output_cost}), assuming it has cost"
)
return False
# If either cost is non-zero, return False
if input_cost > 0 or output_cost > 0:
verbose_proxy_logger.debug(
f"Model {model_name} has non-zero cost (input: {input_cost}, output: {output_cost})"
)
return False
# This model has zero cost explicitly configured
verbose_proxy_logger.debug(
f"Model {model_name} has zero cost explicitly configured (input: {input_cost}, output: {output_cost})"
)
except Exception as e:
# If we can't determine the cost, assume it has cost (conservative approach)
verbose_proxy_logger.debug(
f"Error checking cost for model {model_name}: {str(e)}, assuming it has cost"
)
return False
# All models checked have zero cost
return True
async def _run_project_checks(
project_object: Optional[LiteLLM_ProjectTableCachedObj],
_model: Optional[Union[str, List[str]]],
llm_router: Optional[Router],
skip_budget_checks: bool,
valid_token: Optional[UserAPIKeyAuth],
proxy_logging_obj: ProxyLogging,
) -> None:
"""
Run all project-level checks: blocked, model access, budget, soft budget.
Extracted from common_checks() to keep statement count manageable.
"""
if project_object is None:
return
# 1.1. If project is blocked
if project_object.blocked is True:
raise Exception(
f"Project={project_object.project_id} is blocked. Update via `/project/update` if you're an admin."
)
# 2.2 If project can call model
if _model and len(project_object.models) > 0:
can_project_access_model(
model=_model,
project_object=project_object,
llm_router=llm_router,
)
if not skip_budget_checks:
# 3.0.2. If project is in budget
await _project_max_budget_check(
project_object=project_object,
valid_token=valid_token,
proxy_logging_obj=proxy_logging_obj,
)
# 3.0.3. If project is over soft budget (alert only, doesn't block)
await _project_soft_budget_check(
project_object=project_object,
valid_token=valid_token,
proxy_logging_obj=proxy_logging_obj,
)
def _enforce_user_param_check(
general_settings: dict, request: Request, request_body: dict, route: str
) -> None:
if not general_settings.get("enforce_user_param", False):
return
http_method = request.method if hasattr(request, "method") else None
is_post_method = http_method and http_method.upper() == "POST"
is_openai_route = RouteChecks.is_llm_api_route(route=route)
is_mcp_route = (
route in LiteLLMRoutes.mcp_routes.value
or RouteChecks.check_route_access(
route=route, allowed_routes=LiteLLMRoutes.mcp_routes.value
)
)
if (
is_post_method
and is_openai_route
and not is_mcp_route
and "user" not in request_body
):
raise Exception(
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
)
def _reject_clientside_metadata_tags_check(
general_settings: dict, request_body: dict, route: str
) -> None:
if not general_settings.get("reject_clientside_metadata_tags", False):
return
if (
RouteChecks.is_llm_api_route(route=route)
and "metadata" in request_body
and isinstance(request_body["metadata"], dict)
and "tags" in request_body["metadata"]
):
raise ProxyException(
message=f"Client-side 'metadata.tags' not allowed in request. 'reject_clientside_metadata_tags'={general_settings['reject_clientside_metadata_tags']}. Tags can only be set via API key metadata.",
type=ProxyErrorTypes.bad_request_error,
param="metadata.tags",
code=status.HTTP_400_BAD_REQUEST,
)
def _global_proxy_budget_check(
global_proxy_spend: Optional[float], skip_budget_checks: bool, route: str
) -> None:
if (
litellm.max_budget > 0
and not skip_budget_checks
and global_proxy_spend is not None
and RouteChecks.is_llm_api_route(route=route)
and route != "/v1/models"
and route != "/models"
):
if global_proxy_spend > litellm.max_budget:
raise litellm.BudgetExceededError(
current_cost=global_proxy_spend, max_budget=litellm.max_budget
)
def _guardrail_modification_check(
request_body: dict, team_object: Optional[LiteLLM_TeamTable]
) -> None:
_request_metadata: dict = request_body.get("metadata", {}) or {}
if not _request_metadata.get("guardrails"):
return
from litellm.proxy.guardrails.guardrail_helpers import can_modify_guardrails
if not can_modify_guardrails(team_object):
raise HTTPException(
status_code=403,
detail={
"error": "Your team does not have permission to modify guardrails."
},
)
async def check_tools_allowlist(
request_body: dict,
valid_token: Optional[UserAPIKeyAuth],
team_object: Optional[LiteLLM_TeamTable],
route: str,
) -> None:
"""
Enforce key/team tool allowlist (metadata.allowed_tools). No DB in hot path —
effective allowlist is read from valid_token.metadata and valid_token.team_metadata.
Raises ProxyException with tool_access_denied if a tool is not allowed.
"""
from litellm.litellm_core_utils.api_route_to_call_types import (
get_call_types_for_route,
)
if valid_token is None:
return
call_types = get_call_types_for_route(route)
if not call_types or not any(
ct.value in TOOL_CAPABLE_CALL_TYPES for ct in call_types
):
return
tool_names = extract_request_tool_names(route, request_body)
if not tool_names:
return
key_meta = (
(valid_token.metadata or {}) if isinstance(valid_token.metadata, dict) else {}
)
team_meta = (
(valid_token.team_metadata or {})
if isinstance(valid_token.team_metadata, dict)
else {}
)
key_allowed = key_meta.get("allowed_tools")
team_allowed = team_meta.get("allowed_tools")
effective = (
key_allowed
if (isinstance(key_allowed, list) and len(key_allowed) > 0)
else team_allowed
)
if not isinstance(effective, list) or len(effective) == 0:
return
allowed_set = {str(t) for t in effective}
disallowed = [n for n in tool_names if n not in allowed_set]
if disallowed:
raise ProxyException(
message=f"Tool(s) {disallowed} are not in the allowed tools list for this key/team.",
type=ProxyErrorTypes.tool_access_denied,
param="tools",
code=status.HTTP_403_FORBIDDEN,
)
async def common_checks( # noqa: PLR0915
request_body: dict,
team_object: Optional[LiteLLM_TeamTable],
user_object: Optional[LiteLLM_UserTable],
end_user_object: Optional[LiteLLM_EndUserTable],
global_proxy_spend: Optional[float],
general_settings: dict,
route: str,
llm_router: Optional[Router],
proxy_logging_obj: ProxyLogging,
valid_token: Optional[UserAPIKeyAuth],
request: Request,
skip_budget_checks: bool = False,
project_object: Optional[LiteLLM_ProjectTableCachedObj] = None,
) -> bool:
"""
Common checks across jwt + key-based auth.
1. If team is blocked
1.1. If project is blocked
2. If team can call model
2.2 If project can call model
3. If team is in budget
3.0.2. If project is in budget
3.0.3. If project is over soft budget (alert only)
4. If user passed in (JWT or key.user_id) - is in budget
5. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget
6. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
8. [OPTIONAL] If guardrails modified - is request allowed to change this
9. Check if request body is safe
10. [OPTIONAL] Organization checks - is user_object.organization_id is set, run these checks
11. [OPTIONAL] Vector store checks - is the object allowed to access the vector store
"""
from litellm.proxy.proxy_server import prisma_client, user_api_key_cache
_model: Optional[Union[str, List[str]]] = get_model_from_request(
request_body, route
)
# 1. If team is blocked
if team_object is not None and team_object.blocked is True:
raise Exception(
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if you're an admin."
)
# 2. If team can call model
if _model and team_object:
with tracer.trace("litellm.proxy.auth.common_checks.can_team_access_model"):
# can_team_access_model returns Literal[True] or raises ProxyException
await can_team_access_model(
model=_model,
team_object=team_object,
llm_router=llm_router,
team_model_aliases=valid_token.team_model_aliases
if valid_token
else None,
valid_token=valid_token,
)
# Require trace id for agent keys when agent has require_trace_id_on_calls_by_agent
if valid_token is not None and valid_token.agent_id:
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
from litellm.proxy.litellm_pre_call_utils import get_chain_id_from_headers
agent = global_agent_registry.get_agent_by_id(agent_id=valid_token.agent_id)
if agent is not None:
require_trace_id = (agent.litellm_params or {}).get(
"require_trace_id_on_calls_by_agent"
)
if require_trace_id:
headers_dict = dict(request.headers)
trace_id = get_chain_id_from_headers(headers_dict)
if not trace_id:
raise ProxyException(
message="Requests made with this agent's key must include the x-litellm-trace-id header.",
type=ProxyErrorTypes.bad_request_error,
param=None,
code=status.HTTP_400_BAD_REQUEST,
)
## 2.1 If user can call model (if personal key)
if _model and team_object is None and user_object is not None:
with tracer.trace("litellm.proxy.auth.common_checks.can_user_call_model"):
await can_user_call_model(
model=_model,
llm_router=llm_router,
user_object=user_object,
)
# 1.1 - 2.2 - 3.0.2 - 3.0.3: Project checks (blocked, model access, budget)
with tracer.trace("litellm.proxy.auth.common_checks.run_project_checks"):
await _run_project_checks(
project_object=project_object,
_model=_model,
llm_router=llm_router,
skip_budget_checks=skip_budget_checks,
valid_token=valid_token,
proxy_logging_obj=proxy_logging_obj,
)
# If this is a free model, skip all budget checks
if not skip_budget_checks:
# 3. If team is in budget
with tracer.trace("litellm.proxy.auth.common_checks.team_max_budget_check"):
await _team_max_budget_check(
team_object=team_object,
proxy_logging_obj=proxy_logging_obj,
valid_token=valid_token,
)
# 3.0.5. If team is over soft budget (alert only, doesn't block)
with tracer.trace("litellm.proxy.auth.common_checks.team_soft_budget_check"):
await _team_soft_budget_check(
team_object=team_object,
proxy_logging_obj=proxy_logging_obj,
valid_token=valid_token,
)
# 3.1. If organization is in budget
with tracer.trace(
"litellm.proxy.auth.common_checks.organization_max_budget_check"
):
await _organization_max_budget_check(
valid_token=valid_token,
team_object=team_object,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
with tracer.trace("litellm.proxy.auth.common_checks.tag_max_budget_check"):
await _tag_max_budget_check(
request_body=request_body,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
valid_token=valid_token,
)
# 4. If user is in budget
## 4.1 check personal budget, if personal key
if (
(team_object is None or team_object.team_id is None)
and user_object is not None
and user_object.max_budget is not None
):
user_budget = user_object.max_budget
if user_budget < user_object.spend:
raise litellm.BudgetExceededError(
current_cost=user_object.spend,
max_budget=user_budget,
message=f"ExceededBudget: User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_budget}",
)
## 4.2 check team member budget, if team key
with tracer.trace("litellm.proxy.auth.common_checks.check_team_member_budget"):
await _check_team_member_budget(
team_object=team_object,
user_object=user_object,
valid_token=valid_token,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
# 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
if (
end_user_object is not None
and end_user_object.litellm_budget_table is not None
):
end_user_budget = end_user_object.litellm_budget_table.max_budget
if end_user_budget is not None and end_user_object.spend > end_user_budget:
raise litellm.BudgetExceededError(
current_cost=end_user_object.spend,
max_budget=end_user_budget,
message=f"ExceededBudget: End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}",
)
_enforce_user_param_check(general_settings, request, request_body, route)
_reject_clientside_metadata_tags_check(general_settings, request_body, route)
_global_proxy_budget_check(global_proxy_spend, skip_budget_checks, route)
_guardrail_modification_check(request_body, team_object)
# 10 [OPTIONAL] Organization RBAC checks
organization_role_based_access_check(
user_object=user_object, route=route, request_body=request_body
)
token_team = getattr(valid_token, "team_id", None)
token_type: Literal["ui", "api"] = (
"ui" if token_team is not None and token_team == "litellm-dashboard" else "api"
)
_is_route_allowed = _is_allowed_route(
route=route,
token_type=token_type,
user_obj=user_object,
request=request,
request_data=request_body,
valid_token=valid_token,
)
# 11. [OPTIONAL] Vector store checks - is the object allowed to access the vector store
with tracer.trace("litellm.proxy.auth.common_checks.vector_store_access_check"):
await vector_store_access_check(
request_body=request_body,
team_object=team_object,
valid_token=valid_token,
)
# 12. [OPTIONAL] Tool allowlist - key/team allowed_tools (no DB in hot path)
with tracer.trace("litellm.proxy.auth.common_checks.check_tools_allowlist"):
await check_tools_allowlist(
request_body=request_body,
valid_token=valid_token,
team_object=team_object,
route=route,
)
return True
def _is_ui_route(
route: str,
user_obj: Optional[LiteLLM_UserTable] = None,
) -> bool:
"""
- Check if the route is a UI used route
"""
# this token is only used for managing the ui
allowed_routes = LiteLLMRoutes.ui_routes.value
# check if the current route startswith any of the allowed routes
if (
route is not None
and isinstance(route, str)
and any(route.startswith(allowed_route) for allowed_route in allowed_routes)
):
# Do something if the current route starts with any of the allowed routes
return True
elif any(
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
for allowed_route in allowed_routes
):
return True
return False
def _get_user_role(
user_obj: Optional[LiteLLM_UserTable],
) -> Optional[LitellmUserRoles]:
if user_obj is None:
return None
_user = user_obj
_user_role = _user.user_role
try:
role = LitellmUserRoles(_user_role)
except ValueError:
return LitellmUserRoles.INTERNAL_USER
return role
def _is_api_route_allowed(
route: str,
request: Request,
request_data: dict,
valid_token: Optional[UserAPIKeyAuth],
user_obj: Optional[LiteLLM_UserTable] = None,
) -> bool:
"""
- Route b/w api token check and normal token check
"""
_user_role = _get_user_role(user_obj=user_obj)
if valid_token is None:
raise Exception("Invalid proxy server token passed. valid_token=None.")
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
RouteChecks.non_proxy_admin_allowed_routes_check(
user_obj=user_obj,
_user_role=_user_role,
route=route,
request=request,
request_data=request_data,
valid_token=valid_token,
)
return True
def _is_user_proxy_admin(user_obj: Optional[LiteLLM_UserTable]):
if user_obj is None:
return False
if (
user_obj.user_role is not None
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
):
return True
if (
user_obj.user_role is not None
and user_obj.user_role == LitellmUserRoles.PROXY_ADMIN.value
):
return True
return False
def _is_allowed_route(
route: str,
token_type: Literal["ui", "api"],
request: Request,
request_data: dict,
valid_token: Optional[UserAPIKeyAuth],
user_obj: Optional[LiteLLM_UserTable] = None,
) -> bool:
"""
- Route b/w ui token check and normal token check
"""
if token_type == "ui" and _is_ui_route(route=route, user_obj=user_obj):
return True
else:
return _is_api_route_allowed(
route=route,
request=request,
request_data=request_data,
valid_token=valid_token,
user_obj=user_obj,
)
def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
"""
Return if a user is allowed to access route. Helper function for `allowed_routes_check`.
Parameters:
- user_route: str - the route the user is trying to call
- allowed_routes: List[str|LiteLLMRoutes] - the list of allowed routes for the user.
"""
from starlette.routing import compile_path
for allowed_route in allowed_routes:
if allowed_route in LiteLLMRoutes.__members__:
for template in LiteLLMRoutes[allowed_route].value:
regex, _, _ = compile_path(template)
if regex.match(user_route):
return True
elif allowed_route == user_route:
return True
return False
def allowed_routes_check(
user_role: LitellmUserRoles,
user_route: str,
litellm_proxy_roles: LiteLLM_JWTAuth,
) -> bool:
"""
Check if user -> not admin - allowed to access these routes
"""
if user_role == LitellmUserRoles.PROXY_ADMIN:
is_allowed = _allowed_routes_check(
user_route=user_route,
allowed_routes=litellm_proxy_roles.admin_allowed_routes,
)
return is_allowed
elif user_role == LitellmUserRoles.TEAM:
if litellm_proxy_roles.team_allowed_routes is None:
"""
By default allow a team to call openai + info routes
"""
is_allowed = _allowed_routes_check(
user_route=user_route, allowed_routes=["openai_routes", "info_routes"]
)
return is_allowed
elif litellm_proxy_roles.team_allowed_routes is not None:
is_allowed = _allowed_routes_check(
user_route=user_route,
allowed_routes=litellm_proxy_roles.team_allowed_routes,
)
return is_allowed
return False
def allowed_route_check_inside_route(
user_api_key_dict: UserAPIKeyAuth,
requested_user_id: Optional[str],
) -> bool:
ret_val = True
if (
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY
):
ret_val = False
if requested_user_id is not None and user_api_key_dict.user_id is not None:
if user_api_key_dict.user_id == requested_user_id:
ret_val = True
return ret_val
def get_actual_routes(allowed_routes: list) -> list:
actual_routes: list = []
for route_name in allowed_routes:
try:
route_value = LiteLLMRoutes[route_name].value
if isinstance(route_value, set):
actual_routes.extend(list(route_value))
else:
actual_routes.extend(route_value)
except KeyError:
actual_routes.append(route_name)
return actual_routes
async def get_default_end_user_budget(
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
) -> Optional[LiteLLM_BudgetTable]:
"""
Fetches the default end user budget from the database if litellm.max_end_user_budget_id is configured.
This budget is applied to end users who don't have an explicit budget_id set.
Results are cached for performance.
Args:
prisma_client: Database client instance
user_api_key_cache: Cache for storing/retrieving budget data
parent_otel_span: Optional OpenTelemetry span for tracing
Returns:
LiteLLM_BudgetTable if configured and found, None otherwise
"""
if prisma_client is None or litellm.max_end_user_budget_id is None:
return None
cache_key = f"default_end_user_budget:{litellm.max_end_user_budget_id}"
# Check cache first
cached_budget = await user_api_key_cache.async_get_cache(key=cache_key)
if cached_budget is not None:
return LiteLLM_BudgetTable(**cached_budget)
# Fetch from database
try:
budget_record = await prisma_client.db.litellm_budgettable.find_unique(
where={"budget_id": litellm.max_end_user_budget_id}
)
if budget_record is None:
verbose_proxy_logger.warning(
f"Default end user budget not found in database: {litellm.max_end_user_budget_id}"
)
return None
# Cache the budget for 60 seconds
await user_api_key_cache.async_set_cache(
key=cache_key,
value=budget_record.dict(),
ttl=DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL,
)
return LiteLLM_BudgetTable(**budget_record.dict())
except Exception as e:
verbose_proxy_logger.error(f"Error fetching default end user budget: {str(e)}")
return None
async def _apply_default_budget_to_end_user(
end_user_obj: LiteLLM_EndUserTable,
prisma_client: PrismaClient,
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
) -> LiteLLM_EndUserTable:
"""
Helper function to apply default budget to end user if they don't have a budget assigned.
Args:
end_user_obj: The end user object to potentially apply default budget to
prisma_client: Database client instance
user_api_key_cache: Cache for storing/retrieving data
parent_otel_span: Optional OpenTelemetry span for tracing
Returns:
Updated end user object with default budget applied if applicable
"""
# If end user already has a budget assigned, no need to apply default
if end_user_obj.litellm_budget_table is not None:
return end_user_obj
# If no default budget configured, return as-is
if litellm.max_end_user_budget_id is None:
return end_user_obj
# Fetch and apply default budget
default_budget = await get_default_end_user_budget(
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
)
if default_budget is not None:
# Apply default budget to end user object
end_user_obj.litellm_budget_table = default_budget
verbose_proxy_logger.debug(
f"Applied default budget {litellm.max_end_user_budget_id} to end user {end_user_obj.user_id}"
)
return end_user_obj
def _check_end_user_budget(
end_user_obj: LiteLLM_EndUserTable,
route: str,
) -> None:
"""
Check if end user is within their budget limit.
Args:
end_user_obj: The end user object to check
route: The request route
Raises:
litellm.BudgetExceededError: If end user has exceeded their budget
"""
if RouteChecks.is_info_route(route):
return
if end_user_obj.litellm_budget_table is None:
return
end_user_budget = end_user_obj.litellm_budget_table.max_budget
if end_user_budget is not None and end_user_obj.spend > end_user_budget:
raise litellm.BudgetExceededError(
current_cost=end_user_obj.spend,
max_budget=end_user_budget,
message=f"ExceededBudget: End User={end_user_obj.user_id} over budget. Spend={end_user_obj.spend}, Budget={end_user_budget}",
)
@log_db_metrics
async def get_end_user_object(
end_user_id: Optional[str],
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
route: str,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> Optional[LiteLLM_EndUserTable]:
"""
Returns end user object from database or cache.
If end user exists but has no budget_id, applies the default budget
(if configured via litellm.max_end_user_budget_id).
Args:
end_user_id: The ID of the end user
prisma_client: Database client instance
user_api_key_cache: Cache for storing/retrieving data
route: The request route
parent_otel_span: Optional OpenTelemetry span for tracing
proxy_logging_obj: Optional proxy logging object
Returns:
LiteLLM_EndUserTable if found, None otherwise
"""
if prisma_client is None:
raise Exception("No db connected")
if end_user_id is None:
return None
_key = "end_user_id:{}".format(end_user_id)
# Check cache first
cached_user_obj = await user_api_key_cache.async_get_cache(key=_key)
if cached_user_obj is not None:
return_obj = LiteLLM_EndUserTable(**cached_user_obj)
# Apply default budget if needed
return_obj = await _apply_default_budget_to_end_user(
end_user_obj=return_obj,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
)
# Check budget limits
_check_end_user_budget(end_user_obj=return_obj, route=route)
return return_obj
# Fetch from database
try:
response = await prisma_client.db.litellm_endusertable.find_unique(
where={"user_id": end_user_id},
include={"litellm_budget_table": True, "object_permission": True},
)
if response is None:
raise Exception
# Convert to LiteLLM_EndUserTable object
_response = LiteLLM_EndUserTable(**response.dict())
# Apply default budget if needed
_response = await _apply_default_budget_to_end_user(
end_user_obj=_response,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
)
# Save to cache (always store as dict for consistency)
await user_api_key_cache.async_set_cache(
key="end_user_id:{}".format(end_user_id), value=_response.dict()
)
# Check budget limits
_check_end_user_budget(end_user_obj=_response, route=route)
return _response
except Exception as e:
if isinstance(e, litellm.BudgetExceededError):
raise e
return None
@log_db_metrics