Skip to content

Commit d374d77

Browse files
authored
User context provided (#310)
* Model selection implemented * Refactor: moved default model env variable to correct files * User context provided
1 parent c98e4cd commit d374d77

File tree

5 files changed

+30
-19
lines changed

5 files changed

+30
-19
lines changed

services/chatbot/src/chatbot/chat_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
get_or_create_session_id,
1111
store_api_key,
1212
store_model_name,
13+
get_user_jwt
1314
)
1415

1516
chat_bp = Blueprint("chat", __name__, url_prefix="/genai")
@@ -53,14 +54,15 @@ async def chat():
5354
session_id = await get_or_create_session_id()
5455
openai_api_key = await get_api_key(session_id)
5556
model_name = await get_model_name(session_id)
57+
user_jwt = await get_user_jwt()
5658
if not openai_api_key:
5759
return jsonify({"message": "Missing OpenAI API key. Please authenticate."}), 400
5860
data = await request.get_json()
5961
message = data.get("message", "").strip()
6062
id = data.get("id", uuid4().int & (1 << 63) - 1)
6163
if not message:
6264
return jsonify({"message": "Message is required", "id": id}), 400
63-
reply, response_id = await process_user_message(session_id, message, openai_api_key, model_name)
65+
reply, response_id = await process_user_message(session_id, message, openai_api_key, model_name, user_jwt)
6466
return jsonify({"id": response_id, "message": reply}), 200
6567

6668

services/chatbot/src/chatbot/chat_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ async def delete_chat_history(session_id):
2222
await db.chat_sessions.delete_one({"session_id": session_id})
2323

2424

25-
async def process_user_message(session_id, user_message, api_key, model_name):
25+
async def process_user_message(session_id, user_message, api_key, model_name, user_jwt):
2626
history = await get_chat_history(session_id)
2727
# generate a unique numeric id for the message that is random but unique
2828
source_message_id = uuid4().int & (1 << 63) - 1
2929
history.append({"id": source_message_id, "role": "user", "content": user_message})
3030
# Run LangGraph agent
31-
response = await execute_langgraph_agent(api_key, model_name, history, session_id)
31+
response = await execute_langgraph_agent(api_key, model_name, history, user_jwt, session_id)
3232
print("Response", response)
3333
reply: Messages = response.get("messages", [{}])[-1]
3434
print("Reply", reply.content)

services/chatbot/src/chatbot/langgraph_agent.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from langgraph.prebuilt import create_react_agent
1919

2020
from .extensions import postgresdb
21-
from .mcp_client import mcp_client
21+
from .mcp_client import get_mcp_client
2222

2323

2424
async def get_retriever_tool(api_key):
@@ -46,7 +46,7 @@ async def get_retriever_tool(api_key):
4646
return retriever_tool
4747

4848

49-
async def build_langgraph_agent(api_key, model_name):
49+
async def build_langgraph_agent(api_key, model_name, user_jwt):
5050
system_prompt = textwrap.dedent(
5151
"""
5252
You are crAPI Assistant — an expert agent that helps users explore and test the Completely Ridiculous API (crAPI), a vulnerable-by-design application for learning and evaluating modern API security issues.
@@ -86,6 +86,7 @@ async def build_langgraph_agent(api_key, model_name):
8686
)
8787
llm = ChatOpenAI(api_key=api_key, model=model_name)
8888
toolkit = SQLDatabaseToolkit(db=postgresdb, llm=llm)
89+
mcp_client = get_mcp_client(user_jwt)
8990
mcp_tools = await mcp_client.get_tools()
9091
db_tools = toolkit.get_tools()
9192
tools = mcp_tools + db_tools
@@ -95,8 +96,8 @@ async def build_langgraph_agent(api_key, model_name):
9596
return agent_node
9697

9798

98-
async def execute_langgraph_agent(api_key, model_name, messages, session_id=None):
99-
agent = await build_langgraph_agent(api_key, model_name)
99+
async def execute_langgraph_agent(api_key, model_name, messages, user_jwt, session_id=None):
100+
agent = await build_langgraph_agent(api_key, model_name, user_jwt)
100101
print("messages", messages)
101102
print("Session ID", session_id)
102103
response = await agent.ainvoke({"messages": messages})
Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
import asyncio
2-
import os
3-
41
from langchain_mcp_adapters.client import MultiServerMCPClient
52

6-
mcp_client = MultiServerMCPClient(
7-
{
8-
"crapi": {
9-
"transport": "streamable_http",
10-
"url": "http://localhost:5500/mcp/",
11-
"headers": {},
12-
},
13-
}
14-
)
3+
def get_mcp_client(user_jwt: str | None) -> MultiServerMCPClient:
4+
headers = {}
5+
if user_jwt:
6+
headers["Authorization"] = f"Bearer {user_jwt}"
7+
8+
return MultiServerMCPClient(
9+
{
10+
"crapi": {
11+
"transport": "streamable_http",
12+
"url": "http://localhost:5500/mcp/",
13+
"headers": headers,
14+
}
15+
}
16+
)

services/chatbot/src/chatbot/session_service.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,9 @@ async def get_model_name(session_id):
5757
if "model_name" not in doc:
5858
return Config.DEFAULT_MODEL_NAME
5959
return doc["model_name"]
60+
61+
async def get_user_jwt() -> str | None:
62+
auth = request.headers.get("Authorization", "")
63+
if auth.startswith("Bearer "):
64+
return auth.replace("Bearer ", "")
65+
return None

0 commit comments

Comments
 (0)