Skip to content

Commit c98e4cd

Browse files
authored
Model selection implemented (#309)
* Model selection implemented * Refactor: moved default model env variable to correct files
1 parent 4cf7b81 commit c98e4cd

File tree

10 files changed

+155
-12
lines changed

10 files changed

+155
-12
lines changed

deploy/docker/docker-compose.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ services:
166166
- MONGO_DB_USER=admin
167167
- MONGO_DB_PASSWORD=crapisecretpassword
168168
- MONGO_DB_NAME=crapi
169+
- DEFAULT_MODEL=gpt-4o-mini
169170
# - CHATBOT_OPENAI_API_KEY=
170171
depends_on:
171172
mongodb:

deploy/helm/templates/chatbot/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ data:
2121
MONGO_DB_PASSWORD: {{ .Values.mongodb.config.mongoPassword }}
2222
MONGO_DB_NAME: {{ .Values.mongodb.config.mongoDbName }}
2323
CHATBOT_OPENAI_API_KEY: {{ .Values.openAIApiKey }}
24+
DEFAULT_MODEL: {{ .Values.chatbot.config.defaultModel | quote }}

deploy/helm/values.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ chatbot:
151151
postgresDbDriver: postgres
152152
mongoDbDriver: mongodb
153153
secretKey: crapi
154+
defaultModel: gpt-4o-mini
154155
deploymentLabels:
155156
app: crapi-chatbot
156157
podLabels:

services/chatbot/src/chatbot/chat_api.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import logging
22
from quart import Blueprint, jsonify, request, session
33
from uuid import uuid4
4+
from .config import Config
45
from .chat_service import delete_chat_history, get_chat_history, process_user_message
56
from .session_service import (
67
delete_api_key,
78
get_api_key,
9+
get_model_name,
810
get_or_create_session_id,
911
store_api_key,
12+
store_model_name,
1013
)
1114

1215
chat_bp = Blueprint("chat", __name__, url_prefix="/genai")
@@ -34,19 +37,30 @@ async def init():
3437
await store_api_key(session_id, openai_api_key)
3538
return jsonify({"message": "Initialized"}), 200
3639

40+
@chat_bp.route("/model", methods=["POST"])
41+
async def model():
42+
session_id = await get_or_create_session_id()
43+
data = await request.get_json()
44+
model_name = Config.DEFAULT_MODEL_NAME
45+
if data and "model_name" in data and data["model_name"]:
46+
model_name = data["model_name"]
47+
logger.debug("Setting model %s for session %s", model_name, session_id)
48+
await store_model_name(session_id, model_name)
49+
return jsonify({"model_used": model_name}), 200
3750

3851
@chat_bp.route("/ask", methods=["POST"])
3952
async def chat():
4053
session_id = await get_or_create_session_id()
4154
openai_api_key = await get_api_key(session_id)
55+
model_name = await get_model_name(session_id)
4256
if not openai_api_key:
4357
return jsonify({"message": "Missing OpenAI API key. Please authenticate."}), 400
4458
data = await request.get_json()
4559
message = data.get("message", "").strip()
4660
id = data.get("id", uuid4().int & (1 << 63) - 1)
4761
if not message:
4862
return jsonify({"message": "Message is required", "id": id}), 400
49-
reply, response_id = await process_user_message(session_id, message, openai_api_key)
63+
reply, response_id = await process_user_message(session_id, message, openai_api_key, model_name)
5064
return jsonify({"id": response_id, "message": reply}), 200
5165

5266

services/chatbot/src/chatbot/chat_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ async def delete_chat_history(session_id):
2222
await db.chat_sessions.delete_one({"session_id": session_id})
2323

2424

25-
async def process_user_message(session_id, user_message, api_key):
25+
async def process_user_message(session_id, user_message, api_key, model_name):
2626
history = await get_chat_history(session_id)
2727
# generate a unique numeric id for the message that is random but unique
2828
source_message_id = uuid4().int & (1 << 63) - 1
2929
history.append({"id": source_message_id, "role": "user", "content": user_message})
3030
# Run LangGraph agent
31-
response = await execute_langgraph_agent(api_key, history, session_id)
31+
response = await execute_langgraph_agent(api_key, model_name, history, session_id)
3232
print("Response", response)
3333
reply: Messages = response.get("messages", [{}])[-1]
3434
print("Reply", reply.content)

services/chatbot/src/chatbot/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
class Config:
1111
SECRET_KEY = os.getenv("SECRET_KEY", "super-secret")
1212
MONGO_URI = MONGO_CONNECTION_URI
13+
DEFAULT_MODEL_NAME = os.getenv("DEFAULT_MODEL", "gpt-4o-mini")

services/chatbot/src/chatbot/langgraph_agent.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
from .extensions import postgresdb
2121
from .mcp_client import mcp_client
2222

23-
model_name = "gpt-4o-mini"
24-
2523

2624
async def get_retriever_tool(api_key):
2725
embeddings = OpenAIEmbeddings(api_key=api_key)
@@ -48,7 +46,7 @@ async def get_retriever_tool(api_key):
4846
return retriever_tool
4947

5048

51-
async def build_langgraph_agent(api_key):
49+
async def build_langgraph_agent(api_key, model_name):
5250
system_prompt = textwrap.dedent(
5351
"""
5452
You are crAPI Assistant — an expert agent that helps users explore and test the Completely Ridiculous API (crAPI), a vulnerable-by-design application for learning and evaluating modern API security issues.
@@ -86,7 +84,7 @@ async def build_langgraph_agent(api_key):
8684
Use the tools only if you don't know the answer.
8785
"""
8886
)
89-
llm = ChatOpenAI(api_key=api_key, model="gpt-4o-mini")
87+
llm = ChatOpenAI(api_key=api_key, model=model_name)
9088
toolkit = SQLDatabaseToolkit(db=postgresdb, llm=llm)
9189
mcp_tools = await mcp_client.get_tools()
9290
db_tools = toolkit.get_tools()
@@ -97,8 +95,8 @@ async def build_langgraph_agent(api_key):
9795
return agent_node
9896

9997

100-
async def execute_langgraph_agent(api_key, messages, session_id=None):
101-
agent = await build_langgraph_agent(api_key)
98+
async def execute_langgraph_agent(api_key, model_name, messages, session_id=None):
99+
agent = await build_langgraph_agent(api_key, model_name)
102100
print("messages", messages)
103101
print("Session ID", session_id)
104102
response = await agent.ainvoke({"messages": messages})

services/chatbot/src/chatbot/session_service.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import uuid
3-
3+
from .config import Config
44
from quart import after_this_request, request
55

66
from .extensions import db
@@ -44,3 +44,16 @@ async def delete_api_key(session_id):
4444
await db.sessions.update_one(
4545
{"session_id": session_id}, {"$unset": {"openai_api_key": ""}}
4646
)
47+
48+
async def store_model_name(session_id, model_name):
49+
await db.sessions.update_one(
50+
{"session_id": session_id}, {"$set": {"model_name": model_name}}, upsert=True
51+
)
52+
53+
async def get_model_name(session_id):
54+
doc = await db.sessions.find_one({"session_id": session_id})
55+
if not doc:
56+
return Config.DEFAULT_MODEL_NAME
57+
if "model_name" not in doc:
58+
return Config.DEFAULT_MODEL_NAME
59+
return doc["model_name"]

services/web/src/components/bot/ActionProvider.tsx

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class ActionProvider {
141141
}
142142
console.log(res);
143143
const successmessage = this.createChatBotMessage(
144-
"Chatbot initialized successfully.",
144+
"Chatbot initialized successfully. By default, GPT-4o-mini model is being used. To change chatbot's model, please type model and press enter.",
145145
Math.floor(Math.random() * 65536),
146146
{
147147
loading: true,
@@ -154,6 +154,95 @@ class ActionProvider {
154154
});
155155
};
156156

157+
handleModelSelection = (initRequired: boolean): void => {
158+
console.log("Initialization required:", initRequired);
159+
if (initRequired) {
160+
const message = this.createChatBotMessage(
161+
"Chatbot not initialized. To initialize the chatbot, please type init and press enter.",
162+
Math.floor(Math.random() * 65536),
163+
{
164+
loading: true,
165+
terminateLoading: true,
166+
role: "assistant",
167+
},
168+
);
169+
this.addMessageToState(message);
170+
} else {
171+
this.addModelSelectionToState();
172+
const message = this.createChatBotMessage(
173+
`Type one of these available options and press enter:\n\n` +
174+
`1. \`gpt-4o\` : GPT-4 Omni (fastest, multimodal, best for general use)\n\n` +
175+
`2. \`gpt-4o-mini\` : Lighter version of GPT-4o (efficient for most tasks)\n\n` +
176+
`3. \`gpt-4-turbo\` : GPT-4 Turbo (older but solid performance)\n\n` +
177+
`4. \`gpt-3.5-turbo\` : GPT-3.5 Turbo (cheaper, good for lightweight tasks)\n\n` +
178+
`5. \`gpt-3.5-turbo-16k\` : Like above but with 16k context window\n\n` +
179+
`By default, GPT-4o-mini will be used if any invalid option is entered.`,
180+
Math.floor(Math.random() * 65536),
181+
{
182+
loading: true,
183+
terminateLoading: true,
184+
role: "assistant",
185+
},
186+
);
187+
this.addMessageToState(message);
188+
}
189+
};
190+
191+
handleModelConfirmation = (model_name: string | null, accessToken: string): void => {
192+
const validModels: Record<string, string> = {
193+
"1": "gpt-4o",
194+
"2": "gpt-4o-mini",
195+
"3": "gpt-4-turbo",
196+
"4": "gpt-3.5-turbo",
197+
"5": "gpt-3.5-turbo-16k",
198+
"gpt-4o": "gpt-4o",
199+
"gpt-4o-mini": "gpt-4o-mini",
200+
"gpt-4-turbo": "gpt-4-turbo",
201+
"gpt-3.5-turbo": "gpt-3.5-turbo",
202+
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k"
203+
};
204+
const selectedModel = model_name?.trim();
205+
const modelToUse = selectedModel && validModels[selectedModel] ? validModels[selectedModel] : null;
206+
207+
const modelUrl = APIService.CHATBOT_SERVICE + "genai/model";
208+
superagent
209+
.post(modelUrl)
210+
.send({ model_name: modelToUse })
211+
.set("Accept", "application/json")
212+
.set("Content-Type", "application/json")
213+
.set("Authorization", `Bearer ${accessToken}`)
214+
.end((err, res) => {
215+
if (err) {
216+
console.log(err);
217+
const errormessage = this.createChatBotMessage(
218+
"Failed to set model. Please try again.",
219+
Math.floor(Math.random() * 65536),
220+
{
221+
loading: true,
222+
terminateLoading: true,
223+
role: "assistant",
224+
},
225+
);
226+
this.addMessageToState(errormessage);
227+
return;
228+
}
229+
230+
console.log(res);
231+
const currentModel = res.body?.model_used || modelToUse;
232+
const successmessage = this.createChatBotMessage(
233+
`Model has been successfully set to ${currentModel}. You can now start chatting.`,
234+
Math.floor(Math.random() * 65536),
235+
{
236+
loading: true,
237+
terminateLoading: true,
238+
role: "assistant",
239+
},
240+
);
241+
this.addMessageToState(successmessage);
242+
this.addModelConfirmationToState();
243+
});
244+
};
245+
157246
handleChat = (message: string, accessToken: string): void => {
158247
const chatUrl = APIService.CHATBOT_SERVICE + "genai/ask";
159248
console.log("Chat message:", message);
@@ -223,7 +312,7 @@ class ActionProvider {
223312
this.addMessageToState(message);
224313
} else {
225314
const message = this.createChatBotMessage(
226-
"Chat with the bot and exploit it.",
315+
"Chat with the bot and exploit it. To change chatbot's model, please type model and press enter.",
227316
Math.floor(Math.random() * 65536),
228317
{
229318
loading: true,
@@ -303,6 +392,20 @@ class ActionProvider {
303392
}));
304393
};
305394

395+
addModelSelectionToState = (): void => {
396+
this.setState((state) => ({
397+
...state,
398+
modelSelection: true,
399+
}));
400+
};
401+
402+
addModelConfirmationToState = (): void => {
403+
this.setState((state) => ({
404+
...state,
405+
modelSelection: false,
406+
}));
407+
};
408+
306409
clearMessages = (): void => {
307410
this.setState((state) => ({
308411
...state,

services/web/src/components/bot/MessageParser.tsx

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import request from "superagent";
1818
interface State {
1919
initializationRequired?: boolean;
2020
initializing?: boolean;
21+
modelSelection?: boolean;
2122
accessToken: string;
2223
chatHistory: ChatMessage[];
2324
}
@@ -40,6 +41,8 @@ interface ActionProvider {
4041
chatHistory: ChatMessage[],
4142
) => void;
4243
handleNotInitialized: () => void;
44+
handleModelSelection: (initRequired: boolean) => void;
45+
handleModelConfirmation: (message: string, accessToken: string) => void;
4346
handleChat: (message: string, accessToken: string) => void;
4447
}
4548

@@ -107,6 +110,12 @@ class MessageParser {
107110
return this.actionProvider.handleInitialize(
108111
this.state.initializationRequired,
109112
);
113+
} else if (message_l === "model" || message_l === "models") {
114+
const [initRequired, chatHistory] = await this.initializationRequired();
115+
this.state.initializationRequired = initRequired;
116+
this.state.chatHistory = chatHistory;
117+
console.log("State help:", this.state);
118+
return this.actionProvider.handleModelSelection(this.state.initializationRequired);
110119
} else if (
111120
message_l === "clear" ||
112121
message_l === "reset" ||
@@ -121,6 +130,8 @@ class MessageParser {
121130
);
122131
} else if (this.state.initializationRequired) {
123132
return this.actionProvider.handleNotInitialized();
133+
} else if (this.state.modelSelection) {
134+
return this.actionProvider.handleModelConfirmation(message, this.state.accessToken);
124135
}
125136

126137
return this.actionProvider.handleChat(message, this.state.accessToken);

0 commit comments

Comments
 (0)