Skip to content

Commit ab95d90

Browse files
committed
feat: pass decoded_token to llm and retrievers
1 parent f4ab85a commit ab95d90

File tree

9 files changed

+75
-25
lines changed

9 files changed

+75
-25
lines changed

application/agents/base.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,21 @@
99

1010

1111
class BaseAgent:
12-
def __init__(self, endpoint, llm_name, gpt_model, api_key, user_api_key=None):
12+
def __init__(
13+
self,
14+
endpoint,
15+
llm_name,
16+
gpt_model,
17+
api_key,
18+
user_api_key=None,
19+
decoded_token=None,
20+
):
1321
self.endpoint = endpoint
1422
self.llm = LLMCreator.create_llm(
15-
llm_name, api_key=api_key, user_api_key=user_api_key
23+
llm_name,
24+
api_key=api_key,
25+
user_api_key=user_api_key,
26+
decoded_token=decoded_token,
1627
)
1728
self.llm_handler = get_llm_handler(llm_name)
1829
self.gpt_model = gpt_model

application/agents/classic_agent.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@ def __init__(
1717
user_api_key=None,
1818
prompt="",
1919
chat_history=None,
20+
decoded_token=None,
2021
):
21-
super().__init__(endpoint, llm_name, gpt_model, api_key, user_api_key)
22+
super().__init__(
23+
endpoint, llm_name, gpt_model, api_key, user_api_key, decoded_token
24+
)
25+
self.user = decoded_token.get("sub")
2226
self.prompt = prompt
2327
self.chat_history = chat_history if chat_history is not None else []
2428

@@ -73,7 +77,7 @@ def _gen_inner(
7377
)
7478
messages_combine.append({"role": "user", "content": query})
7579

76-
tools_dict = self._get_user_tools()
80+
tools_dict = self._get_user_tools(self.user)
7781
self._prepare_tools(tools_dict)
7882

7983
resp = self._llm_gen(messages_combine, log_context)

application/api/answer/routes.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,10 @@ def complete_stream(
264264
doc["source"] = "None"
265265

266266
llm = LLMCreator.create_llm(
267-
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
267+
settings.LLM_NAME,
268+
api_key=settings.API_KEY,
269+
user_api_key=user_api_key,
270+
decoded_token=decoded_token,
268271
)
269272

270273
if should_save_conversation:
@@ -420,6 +423,7 @@ def post(self):
420423
user_api_key=user_api_key,
421424
prompt=prompt,
422425
chat_history=history,
426+
decoded_token=decoded_token,
423427
)
424428

425429
retriever = RetrieverCreator.create_retriever(
@@ -431,6 +435,7 @@ def post(self):
431435
token_limit=token_limit,
432436
gpt_model=gpt_model,
433437
user_api_key=user_api_key,
438+
decoded_token=decoded_token,
434439
)
435440

436441
return Response(
@@ -565,6 +570,7 @@ def post(self):
565570
user_api_key=user_api_key,
566571
prompt=prompt,
567572
chat_history=history,
573+
decoded_token=decoded_token,
568574
)
569575

570576
retriever = RetrieverCreator.create_retriever(
@@ -576,6 +582,7 @@ def post(self):
576582
token_limit=token_limit,
577583
gpt_model=gpt_model,
578584
user_api_key=user_api_key,
585+
decoded_token=decoded_token,
579586
)
580587

581588
response_full = ""
@@ -623,7 +630,10 @@ def post(self):
623630
doc["source"] = "None"
624631

625632
llm = LLMCreator.create_llm(
626-
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
633+
settings.LLM_NAME,
634+
api_key=settings.API_KEY,
635+
user_api_key=user_api_key,
636+
decoded_token=decoded_token,
627637
)
628638

629639
result = {"answer": response_full, "sources": source_log_docs}
@@ -743,6 +753,7 @@ def post(self):
743753
token_limit=token_limit,
744754
gpt_model=gpt_model,
745755
user_api_key=user_api_key,
756+
decoded_token=decoded_token,
746757
)
747758

748759
docs = retriever.search(question)

application/llm/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66

77
class BaseLLM(ABC):
8-
def __init__(self):
8+
def __init__(self, decoded_token):
9+
self.decoded_token = decoded_token
910
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
1011

1112
def _apply_decorator(self, method, decorators, *args, **kwargs):

application/llm/llm_creator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from application.llm.google_ai import GoogleLLM
1010
from application.llm.novita import NovitaLLM
1111

12+
1213
class LLMCreator:
1314
llms = {
1415
"openai": OpenAILLM,
@@ -21,12 +22,14 @@ class LLMCreator:
2122
"premai": PremAILLM,
2223
"groq": GroqLLM,
2324
"google": GoogleLLM,
24-
"novita": NovitaLLM
25+
"novita": NovitaLLM,
2526
}
2627

2728
@classmethod
28-
def create_llm(cls, type, api_key, user_api_key, *args, **kwargs):
29+
def create_llm(cls, type, api_key, user_api_key, decoded_token, *args, **kwargs):
2930
llm_class = cls.llms.get(type.lower())
3031
if not llm_class:
3132
raise ValueError(f"No LLM class found for type {type}")
32-
return llm_class(api_key, user_api_key, *args, **kwargs)
33+
return llm_class(
34+
api_key, user_api_key, decoded_token=decoded_token, *args, **kwargs
35+
)

application/retriever/brave_search.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(
1717
token_limit=150,
1818
gpt_model="docsgpt",
1919
user_api_key=None,
20+
decoded_token=None,
2021
):
2122
self.question = question
2223
self.source = source
@@ -35,6 +36,7 @@ def __init__(
3536
)
3637
)
3738
self.user_api_key = user_api_key
39+
self.decoded_token = decoded_token
3840

3941
def _get_data(self):
4042
if self.chunks == 0:
@@ -81,7 +83,10 @@ def gen(self):
8183
messages_combine.append({"role": "user", "content": self.question})
8284

8385
llm = LLMCreator.create_llm(
84-
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
86+
settings.LLM_NAME,
87+
api_key=settings.API_KEY,
88+
user_api_key=self.user_api_key,
89+
decoded_token=self.decoded_token,
8590
)
8691

8792
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
@@ -100,5 +105,5 @@ def get_params(self):
100105
"chunks": self.chunks,
101106
"token_limit": self.token_limit,
102107
"gpt_model": self.gpt_model,
103-
"user_api_key": self.user_api_key
108+
"user_api_key": self.user_api_key,
104109
}

application/retriever/classic_rag.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(
1717
user_api_key=None,
1818
llm_name=settings.LLM_NAME,
1919
api_key=settings.API_KEY,
20+
decoded_token=None,
2021
):
2122
self.original_question = ""
2223
self.chat_history = chat_history if chat_history is not None else []
@@ -37,10 +38,14 @@ def __init__(
3738
self.llm_name = llm_name
3839
self.api_key = api_key
3940
self.llm = LLMCreator.create_llm(
40-
self.llm_name, api_key=self.api_key, user_api_key=self.user_api_key
41+
self.llm_name,
42+
api_key=self.api_key,
43+
user_api_key=self.user_api_key,
44+
decoded_token=decoded_token,
4145
)
4246
self.question = self._rephrase_query()
4347
self.vectorstore = source["active_docs"] if "active_docs" in source else None
48+
self.decoded_token = decoded_token
4449

4550
def _rephrase_query(self):
4651
if (

application/retriever/duckduck_search.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(
1717
token_limit=150,
1818
gpt_model="docsgpt",
1919
user_api_key=None,
20+
decoded_token=None,
2021
):
2122
self.question = question
2223
self.source = source
@@ -35,6 +36,7 @@ def __init__(
3536
)
3637
)
3738
self.user_api_key = user_api_key
39+
self.decoded_token = decoded_token
3840

3941
def _parse_lang_string(self, input_string):
4042
result = []
@@ -88,17 +90,20 @@ def gen(self):
8890
for doc in docs:
8991
yield {"source": doc}
9092

91-
if len(self.chat_history) > 0:
93+
if len(self.chat_history) > 0:
9294
for i in self.chat_history:
93-
if "prompt" in i and "response" in i:
94-
messages_combine.append({"role": "user", "content": i["prompt"]})
95-
messages_combine.append(
96-
{"role": "assistant", "content": i["response"]}
97-
)
95+
if "prompt" in i and "response" in i:
96+
messages_combine.append({"role": "user", "content": i["prompt"]})
97+
messages_combine.append(
98+
{"role": "assistant", "content": i["response"]}
99+
)
98100
messages_combine.append({"role": "user", "content": self.question})
99101

100102
llm = LLMCreator.create_llm(
101-
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
103+
settings.LLM_NAME,
104+
api_key=settings.API_KEY,
105+
user_api_key=self.user_api_key,
106+
decoded_token=self.decoded_token,
102107
)
103108

104109
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
@@ -107,7 +112,7 @@ def gen(self):
107112

108113
def search(self):
109114
return self._get_data()
110-
115+
111116
def get_params(self):
112117
return {
113118
"question": self.question,
@@ -117,5 +122,5 @@ def get_params(self):
117122
"chunks": self.chunks,
118123
"token_limit": self.token_limit,
119124
"gpt_model": self.gpt_model,
120-
"user_api_key": self.user_api_key
125+
"user_api_key": self.user_api_key,
121126
}

application/usage.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,15 @@
99
usage_collection = db["token_usage"]
1010

1111

12-
def update_token_usage(user_api_key, token_usage):
12+
def update_token_usage(decoded_token, user_api_key, token_usage):
1313
if "pytest" in sys.modules:
1414
return
15+
if decoded_token:
16+
user_id = decoded_token["sub"]
17+
else:
18+
user_id = None
1519
usage_data = {
20+
"user_id": user_id,
1621
"api_key": user_api_key,
1722
"prompt_tokens": token_usage["prompt_tokens"],
1823
"generated_tokens": token_usage["generated_tokens"],
@@ -35,7 +40,7 @@ def wrapper(self, model, messages, stream, tools, **kwargs):
3540
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(
3641
result
3742
)
38-
update_token_usage(self.user_api_key, self.token_usage)
43+
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
3944
return result
4045

4146
return wrapper
@@ -54,6 +59,6 @@ def wrapper(self, model, messages, stream, tools, **kwargs):
5459
yield r
5560
for line in batch:
5661
self.token_usage["generated_tokens"] += num_tokens_from_string(line)
57-
update_token_usage(self.user_api_key, self.token_usage)
62+
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
5863

5964
return wrapper

0 commit comments

Comments
 (0)