Skip to content

Commit 69f5461

Browse files
authored
Add user usage model and tracking endpoints (#443)
Signed-off-by: Trevor Grant <[email protected]>
1 parent 9112242 commit 69f5461

File tree

7 files changed

+428
-19
lines changed

7 files changed

+428
-19
lines changed

webapp/packages/api/user-service/agent_factory/__init__.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
what_to_do_prompt_template, \
66
how_to_use_swagger_tools, \
77
how_to_use_gofannon_agents
8-
from litellm import acompletion
98
from models.agent import GenerateCodeRequest, GenerateCodeResponse, Agent
109
import json
1110
import asyncio
@@ -183,16 +182,19 @@ async def code_gen_with_thoughts():
183182
if provider == "openai":
184183
name_doc_config['response_format'] = { "type": "json_object" }
185184

186-
name_doc_gen_task = acompletion(
187-
model=f"{provider}/{model}",
188-
messages=name_doc_messages,
189-
**name_doc_config
190-
)
191-
185+
async def name_doc_generation():
186+
content, _ = await call_llm(
187+
provider=provider,
188+
model=model,
189+
messages=name_doc_messages,
190+
parameters=name_doc_config,
191+
)
192+
return content
193+
192194
# ---- Run tasks concurrently ----
193-
(code_body, thoughts), name_doc_response = await asyncio.gather(
195+
(code_body, thoughts), name_doc_content = await asyncio.gather(
194196
code_gen_with_thoughts(),
195-
name_doc_gen_task
197+
name_doc_generation()
196198
)
197199

198200
# ---- Process Code Generation Response ----
@@ -215,7 +217,6 @@ async def run(input_dict, tools):
215217
full_code = f"{header}\n{indented_body}"
216218

217219
# ---- Process Name and Docstring Response ----
218-
name_doc_content = name_doc_response.choices[0].message.content
219220
try:
220221
# Clean up potential markdown
221222
if name_doc_content.strip().startswith("```json"):

webapp/packages/api/user-service/main.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from fastapi.middleware.cors import CORSMiddleware
55
from typing import Annotated, Optional, Dict, Any, List
66
from pydantic import BaseModel, Field
7+
from pydantic.config import ConfigDict
78
from datetime import datetime
89
import uuid
910
import json
@@ -29,18 +30,20 @@
2930
)
3031

3132
from services.llm_service import call_llm
33+
from services.user_service import get_user_service, UserService
3234

3335
# Import the shared provider configuration
3436
from config.provider_config import PROVIDER_CONFIG as APP_PROVIDER_CONFIG
3537
from config.routes_config import RouterConfig, resolve_router_configs
3638
from models.agent import (
37-
GenerateCodeRequest, GenerateCodeResponse, RunCodeRequest,
39+
GenerateCodeRequest, GenerateCodeResponse, RunCodeRequest,
3840
RunCodeResponse, Agent, CreateAgentRequest, Deployment, DeployedApi
3941
)
4042
from models.demo import (
4143
GenerateDemoCodeRequest, GenerateDemoCodeResponse,
4244
CreateDemoAppRequest, DemoApp
4345
)
46+
from models.user import User
4447

4548
from agent_factory.remote_mcp_client import RemoteMCPClient
4649

@@ -122,6 +125,31 @@ class ClientLogPayload(BaseModel):
122125
class FetchSpecRequest(BaseModel):
123126
url: str
124127

128+
129+
class UpdateMonthlyAllowanceRequest(BaseModel):
130+
monthly_allowance: float = Field(..., alias="monthlyAllowance")
131+
132+
model_config = ConfigDict(populate_by_name=True)
133+
134+
135+
class UpdateResetDateRequest(BaseModel):
136+
allowance_reset_date: float = Field(..., alias="allowanceResetDate")
137+
138+
model_config = ConfigDict(populate_by_name=True)
139+
140+
141+
class UpdateSpendRemainingRequest(BaseModel):
142+
spend_remaining: float = Field(..., alias="spendRemaining")
143+
144+
model_config = ConfigDict(populate_by_name=True)
145+
146+
147+
class AddUsageRequest(BaseModel):
148+
response_cost: float = Field(..., alias="responseCost")
149+
metadata: Optional[Dict[str, Any]] = None
150+
151+
model_config = ConfigDict(populate_by_name=True)
152+
125153
# Import models after defining local ones to avoid circular dependencies
126154
from models.chat import ChatRequest, ChatMessage, ChatResponse, ProviderConfig, SessionData
127155

@@ -134,11 +162,19 @@ def get_logger() -> ObservabilityService:
134162
"""Dependency to get the observability service instance."""
135163
return get_observability_service()
136164

165+
def get_user_service_dep(db: DatabaseService = Depends(get_db)) -> UserService:
166+
return get_user_service(db)
167+
137168
# Background task for LLM processing
138169
async def process_chat(ticket_id: str, request: ChatRequest, user: dict, req: Request):
139170
# Background tasks don't have access to dependency injection, so we get service instances directly
140171
db_service = get_database_service(settings)
172+
user_service = get_user_service(db_service)
141173
logger = get_observability_service()
174+
user_basic_info = {
175+
"email": user.get("email"),
176+
"name": user.get("name") or user.get("displayName")
177+
}
142178
try:
143179
# Update ticket status
144180
ticket_data = {
@@ -207,8 +243,11 @@ async def process_chat(ticket_id: str, request: ChatRequest, user: dict, req: Re
207243
model=request.model,
208244
messages=messages,
209245
parameters=request.parameters,
210-
tools=built_in_tools if built_in_tools else None
211-
)
246+
tools=built_in_tools if built_in_tools else None,
247+
user_service=user_service,
248+
user_id=user.get("uid"),
249+
user_basic_info=user_basic_info,
250+
)
212251

213252
# Update ticket with success
214253
ticket_data.update({
@@ -349,6 +388,52 @@ def get_model_config(provider: str, model: str):
349388
raise HTTPException(status_code=404, detail="Model not found")
350389
return available_providers[provider]["models"][model]
351390

391+
392+
@router.get("/users/me", response_model=User)
393+
def get_current_user_profile(user: dict = Depends(get_current_user), user_service: UserService = Depends(get_user_service_dep)):
394+
return user_service.get_user(user.get("uid", "anonymous"), user)
395+
396+
397+
@router.put("/users/me/monthly-allowance", response_model=User)
398+
def set_monthly_allowance(
399+
request: UpdateMonthlyAllowanceRequest,
400+
user: dict = Depends(get_current_user),
401+
user_service: UserService = Depends(get_user_service_dep),
402+
):
403+
return user_service.set_monthly_allowance(user.get("uid", "anonymous"), request.monthly_allowance, user)
404+
405+
406+
@router.put("/users/me/allowance-reset-date", response_model=User)
407+
def set_allowance_reset_date(
408+
request: UpdateResetDateRequest,
409+
user: dict = Depends(get_current_user),
410+
user_service: UserService = Depends(get_user_service_dep),
411+
):
412+
return user_service.set_reset_date(user.get("uid", "anonymous"), request.allowance_reset_date, user)
413+
414+
415+
@router.post("/users/me/reset-allowance", response_model=User)
416+
def reset_allowance(user: dict = Depends(get_current_user), user_service: UserService = Depends(get_user_service_dep)):
417+
return user_service.reset_allowance(user.get("uid", "anonymous"), user)
418+
419+
420+
@router.put("/users/me/spend-remaining", response_model=User)
421+
def update_spend_remaining(
422+
request: UpdateSpendRemainingRequest,
423+
user: dict = Depends(get_current_user),
424+
user_service: UserService = Depends(get_user_service_dep),
425+
):
426+
return user_service.update_spend_remaining(user.get("uid", "anonymous"), request.spend_remaining, user)
427+
428+
429+
@router.post("/users/me/usage", response_model=User)
430+
def add_usage_entry(
431+
request: AddUsageRequest,
432+
user: dict = Depends(get_current_user),
433+
user_service: UserService = Depends(get_user_service_dep),
434+
):
435+
return user_service.add_usage(user.get("uid", "anonymous"), request.response_cost, request.metadata, user)
436+
352437
@router.post("/chat")
353438
async def chat(request: ChatRequest, req: Request, background_tasks: BackgroundTasks, user: dict = Depends(get_current_user)):
354439
"""Submit a chat request and get a ticket ID"""
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from datetime import datetime
2+
from typing import List, Optional, Any
3+
4+
from pydantic import BaseModel, Field
5+
from pydantic.alias_generators import to_camel
6+
from pydantic.config import ConfigDict
7+
8+
9+
class UsageEntry(BaseModel):
10+
timestamp: datetime = Field(default_factory=datetime.utcnow)
11+
response_cost: float = Field(alias="responseCost")
12+
metadata: Optional[Any] = None
13+
14+
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
15+
16+
17+
class UsageInfo(BaseModel):
18+
monthly_allowance: float = Field(default=100.0, alias="monthlyAllowance")
19+
allowance_reset_date: float = Field(default=0.0, alias="allowanceResetDate")
20+
spend_remaining: float = Field(default=100.0, alias="spendRemaining")
21+
usage: List[UsageEntry] = Field(default_factory=list)
22+
23+
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
24+
25+
26+
class BillingInfo(BaseModel):
27+
plan: Optional[str] = None
28+
status: Optional[str] = None
29+
30+
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
31+
32+
33+
class BasicInfo(BaseModel):
34+
display_name: Optional[str] = Field(default=None, alias="displayName")
35+
email: Optional[str] = None
36+
37+
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
38+
39+
40+
class User(BaseModel):
41+
id: str = Field(alias="_id")
42+
rev: Optional[str] = Field(default=None, alias="_rev")
43+
created_at: datetime = Field(default_factory=datetime.utcnow, alias="createdAt")
44+
updated_at: datetime = Field(default_factory=datetime.utcnow, alias="updatedAt")
45+
basic_info: BasicInfo = Field(default_factory=BasicInfo, alias="basicInfo")
46+
billing_info: BillingInfo = Field(default_factory=BillingInfo, alias="billingInfo")
47+
usage_info: UsageInfo = Field(default_factory=UsageInfo, alias="usageInfo")
48+
49+
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)

webapp/packages/api/user-service/services/llm_service.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,54 @@
11
import asyncio
22
import json
3+
from config import settings
34
from config.provider_config import PROVIDER_CONFIG
5+
from services.database_service import get_database_service
6+
from services.user_service import get_user_service
47
from typing import Any, Dict, List, Tuple, Optional
58
import litellm
69

710
from services.litellm_logger import ensure_litellm_logging
11+
from services.user_service import UserService
812

913
ensure_litellm_logging()
1014

15+
def _extract_response_cost(response_obj: Any) -> Optional[float]:
16+
standard_logging = None
17+
if hasattr(response_obj, "_hidden_params") and isinstance(response_obj._hidden_params, dict):
18+
standard_logging = response_obj._hidden_params.get("standard_logging_object")
19+
if isinstance(standard_logging, dict):
20+
try:
21+
cost_value = standard_logging.get("response_cost")
22+
if cost_value is not None:
23+
return float(cost_value)
24+
except Exception:
25+
pass
26+
usage = getattr(response_obj, "usage", None)
27+
if usage and getattr(usage, "total_cost", None) is not None:
28+
try:
29+
return float(getattr(usage, "total_cost"))
30+
except Exception:
31+
return None
32+
return None
33+
34+
1135
async def call_llm(
1236
provider: str,
1337
model: str,
1438
messages: List[Dict[str, Any]],
1539
parameters: Dict[str, Any],
1640
tools: Optional[List[Dict[str, Any]]] = None,
41+
user_service: Optional[UserService] = None,
42+
user_id: Optional[str] = None,
43+
user_basic_info: Optional[Dict[str, Any]] = None,
1744
) -> Tuple[str, Any]:
1845
"""
1946
Calls the specified language model using litellm, handling different API styles.
2047
Returns a tuple of (content, thoughts).
2148
"""
2249
model_config = PROVIDER_CONFIG.get(provider, {}).get("models", {}).get(model, {})
2350
api_style = model_config.get("api_style")
24-
51+
2552
model_string = f"{provider}/{model}"
2653

2754
thoughts = None
@@ -38,6 +65,14 @@ async def call_llm(
3865
if tools:
3966
kwargs["tools"] = tools
4067

68+
if user_service is None:
69+
user_service = get_user_service(get_database_service(settings))
70+
if user_id is None:
71+
user_id = "anonymous"
72+
73+
if user_service and user_id:
74+
user_service.require_allowance(user_id, basic_info=user_basic_info)
75+
4176
if api_style == "responses":
4277
# Use aresponses and aget_responses for OpenAI's special tools like built-in web search
4378
kwargs.pop('messages', None) # aresponses uses 'input' not 'messages'
@@ -121,4 +156,13 @@ async def call_llm(
121156
if thoughts is not None:
122157
thoughts = json.loads(json.dumps(thoughts, default=str))
123158

159+
if user_service and user_id:
160+
response_cost = None
161+
try:
162+
response_cost = _extract_response_cost(final_response if api_style == "responses" else response)
163+
except Exception:
164+
response_cost = None
165+
if response_cost is not None:
166+
user_service.add_usage(user_id, response_cost, basic_info=user_basic_info)
167+
124168
return content, thoughts

0 commit comments

Comments
 (0)