-
-
Notifications
You must be signed in to change notification settings - Fork 78
Expand file tree
/
Copy pathtest_chatbot.py
More file actions
204 lines (153 loc) · 8.18 KB
/
test_chatbot.py
File metadata and controls
204 lines (153 loc) · 8.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""Integration Tests for the chatbot."""
from pydantic import ValidationError
import pytest
from api.models.schemas import ChatResponse
from api.services import memory
@pytest.fixture(autouse=True)
def reset_memory_sessions():
"""Executed before any test to reset the _sessions across the tests."""
memory.reset_sessions()
def test_create_session(client):
"""Should create a new chat session and return session ID and location header."""
response = client.post("/sessions")
assert response.status_code == 201
data = response.json()
assert "session_id" in data
assert isinstance(data["session_id"], str)
assert response.headers["Location"] == f"/sessions/{data['session_id']}/message"
def test_reply_to_existing_session(client, mock_llm_provider, mock_get_relevant_documents):
"""Should return a chatbot reply for a valid session and input message."""
create_resp = client.post("/sessions")
session_id = create_resp.json()["session_id"]
mock_llm_provider.generate.return_value = "LLM answers to the query"
mock_get_relevant_documents.return_value = get_relevant_documents_output()
payload = {"message": "Hello"}
response = client.post(f"/sessions/{session_id}/message", data=payload)
assert response.status_code == 200
try:
chat_response = ChatResponse.model_validate(response.json())
except ValidationError as e:
assert False, f"Response did not match the expected schema: {e}"
assert chat_response.reply == "LLM answers to the query"
def test_reply_to_nonexistent_session(client):
"""Should return 404 when replying to a non-existent session."""
payload = {"message": "Hello"}
response = client.post("/sessions/nonexistent-session/message", data=payload)
assert response.status_code == 404
assert response.json() == {"detail": "Session not found."}
def test_delete_existing_session(client):
"""Should delete an existing session and confirm deletion message."""
create_resp = client.post("/sessions")
session_id = create_resp.json()["session_id"]
response = client.delete(f"/sessions/{session_id}")
assert response.status_code == 200
assert response.json() == {"message": f"Session {session_id} deleted."}
def test_delete_nonexistent_session(client):
"""Should return 404 when trying to delete a non-existent session."""
response = client.delete("/sessions/invalid-session")
assert response.status_code == 404
assert response.json() == {"detail": "Session not found."}
def test_reply_after_session_deleted(client):
"""Should return 404 when replying to a session that was deleted."""
create_resp = client.post("/sessions")
session_id = create_resp.json()["session_id"]
client.delete(f"/sessions/{session_id}")
payload = {"message": "Is anyone there?"}
response = client.post(f"/sessions/{session_id}/message", data=payload)
assert response.status_code == 404
assert response.json() == {"detail": "Session not found."}
def test_reply_with_empty_message(client):
"""Should return 422 when sending an empty message."""
create_resp = client.post("/sessions")
session_id = create_resp.json()["session_id"]
payload = {"message": " "}
response = client.post(f"/sessions/{session_id}/message", data=payload)
assert response.status_code == 422
assert response.json()["detail"] == "Either a message or at least one file must be provided."
def test_full_chat_lifecycle(client, mock_llm_provider, mock_get_relevant_documents):
"""Test the complete flow: create, send message, delete a chat session."""
mock_llm_provider.generate.return_value = "Hello from the bot!"
mock_get_relevant_documents.return_value = get_relevant_documents_output()
create_resp = client.post("/sessions")
assert create_resp.status_code == 201
session_id = create_resp.json()["session_id"]
payload = {"message": "Hello"}
reply_resp = client.post(f"/sessions/{session_id}/message", data=payload)
assert reply_resp.status_code == 200
assert reply_resp.json()["reply"] == "Hello from the bot!"
delete_resp = client.delete(f"/sessions/{session_id}")
assert delete_resp.status_code == 200
assert delete_resp.json()["message"] == f"Session {session_id} deleted."
def test_multiple_messages_in_session(client, mock_llm_provider, mock_get_relevant_documents):
"""Ensure multiple consecutive messages are handled in the same session."""
mock_llm_provider.generate.side_effect = [
"Reply 1", "Reply 2", "Reply 3"
]
mock_get_relevant_documents.side_effect = [
get_relevant_documents_output(),
get_relevant_documents_output(),
get_relevant_documents_output()
]
session_id = client.post("/sessions").json()["session_id"]
for i in range(3):
resp = client.post(f"/sessions/{session_id}/message", data={"message": f"Msg {i+1}"})
assert resp.status_code == 200
assert resp.json()["reply"] == f"Reply {i+1}"
def test_multiple_sessions_are_isolated(client, mock_llm_provider, mock_get_relevant_documents):
"""Ensure messages in different sessions don't interfere with each other."""
mock_llm_provider.generate.return_value = "LLM response"
mock_get_relevant_documents.return_value = get_relevant_documents_output()
active_session = client.post("/sessions").json()["session_id"]
deleted_session = client.post("/sessions").json()["session_id"]
client.post(f"/sessions/{active_session}/message", data={"message": "Hi A"})
client.post(f"/sessions/{deleted_session}/message", data={"message": "Hi B"})
client.delete(f"/sessions/{deleted_session}")
response_active_session = client.post(f"/sessions/{active_session}/message",
data={"message": "Message again"})
response_deleted_session = client.post(f"/sessions/{deleted_session}/message",
data={"message": "Should be off"})
assert response_active_session.status_code == 200
assert response_deleted_session.status_code == 404
assert response_deleted_session.json() == {"detail": "Session not found."}
def test_get_history_empty_session(client):
"""Should return an empty message list for a newly created session."""
session_id = client.post("/sessions").json()["session_id"]
response = client.get(f"/sessions/{session_id}/message")
assert response.status_code == 200
data = response.json()
assert data["session_id"] == session_id
assert data["messages"] == []
def test_get_history_with_messages(client, mock_llm_provider, mock_get_relevant_documents):
"""Should return the conversation history after exchanging messages."""
mock_llm_provider.generate.return_value = "Bot reply"
mock_get_relevant_documents.return_value = get_relevant_documents_output()
session_id = client.post("/sessions").json()["session_id"]
client.post(f"/sessions/{session_id}/message", data={"message": "Hello"})
response = client.get(f"/sessions/{session_id}/message")
assert response.status_code == 200
data = response.json()
assert data["session_id"] == session_id
assert len(data["messages"]) == 2
assert data["messages"][0]["role"] == "human"
assert data["messages"][0]["content"] == "Hello"
assert data["messages"][1]["role"] == "ai"
assert data["messages"][1]["content"] == "Bot reply"
def test_get_history_nonexistent_session(client):
"""Should return 404 when retrieving history of a non-existent session."""
response = client.get("/sessions/nonexistent-session/message")
assert response.status_code == 404
assert response.json() == {"detail": "Session not found."}
def test_get_history_deleted_session(client):
"""Should return 404 when retrieving history of a deleted session."""
session_id = client.post("/sessions").json()["session_id"]
client.delete(f"/sessions/{session_id}")
response = client.get(f"/sessions/{session_id}/message")
assert response.status_code == 404
assert response.json() == {"detail": "Session not found."}
def get_relevant_documents_output():
"""Utility to return the output of the mock of get_relevant_documents."""
return ([
{
"id": "docid",
"chunk_text": "Relevant chunk text."
}],[0.84])