Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion chatbot-core/api/models/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@

logger = LoggerFactory.instance().get_logger("api")

EMBEDDING_MODEL = load_embedding_model(CONFIG["retrieval"]["embedding_model_name"], logger)
EMBEDDING_MODEL = load_embedding_model(
CONFIG["retrieval"]["embedding_model_name"])
20 changes: 17 additions & 3 deletions chatbot-core/api/routes/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,19 @@ async def chatbot_stream(websocket: WebSocket, session_id: str):
message_data = json.loads(data)
user_message = message_data.get("message", "")

if len(user_message) > 2000:
logger.warning(
"Truncated massive payload from session %s", session_id)
user_message = user_message[:2000]

if not user_message:
continue

async for token in get_chatbot_reply_stream(
session_id,
user_message,
):

await websocket.send_text(
json.dumps({"token": token})
)
Expand Down Expand Up @@ -168,6 +174,7 @@ def start_chat(response: Response):
)
return SessionResponse(session_id=session_id)


@router.delete(
"/sessions/{session_id}",
response_model=DeleteResponse,
Expand Down Expand Up @@ -228,7 +235,6 @@ def get_chat_history(session_id: str):
# Chat Endpoint
@router.post("/sessions/{session_id}/message", response_model=ChatResponse)
def chatbot_reply(session_id: str, request: ChatRequest, _background_tasks: BackgroundTasks):

"""
POST endpoint to handle chatbot replies.

Expand All @@ -247,11 +253,16 @@ def chatbot_reply(session_id: str, request: ChatRequest, _background_tasks: Back
status_code=404,
detail="Session not found.",
)
reply = get_chatbot_reply(session_id, request.message)

if len(request.message) > 2000:
logger.warning("Truncated massive payload from session %s", session_id)
request.message = request.message[:2000]

reply = get_chatbot_reply(session_id, request.message)
_background_tasks.add_task(
persist_session,
session_id,
)
)

return reply

Expand Down Expand Up @@ -301,6 +312,9 @@ async def chatbot_reply_with_files(
status_code=422,
detail="Either message or files must be provided.",
)
if has_message and len(message) > 2000:
logger.warning("Truncated massive payload from session %s", session_id)
message = message[:2000]

# Process uploaded files
processed_files: List[FileAttachment] = []
Expand Down
8 changes: 5 additions & 3 deletions chatbot-core/api/services/chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def get_chatbot_reply(

memory = get_session(session_id)
if memory is None:
raise RuntimeError(f"Session '{session_id}' not found in the memory store.")
raise RuntimeError(
f"Session '{session_id}' not found in the memory store.")

context = retrieve_context(user_input)
logger.debug("Context retrieved: %s", _sanitize_log_payload(context))
Expand Down Expand Up @@ -341,7 +342,8 @@ def _execute_search_tools(tool_calls) -> str:
})

return "\n\n".join(
f"[Result of the search tool {res['tool']}]:\n{res.get('output', '')}".strip()
f"[Result of the search tool {res['tool']}]:\n{res.get('output', '')}".strip(
)
for res in retrieved_results
)

Expand Down Expand Up @@ -389,7 +391,6 @@ def retrieve_context(user_input: str) -> str:
data_retrieved, _ = get_relevant_documents(
user_input,
EMBEDDING_MODEL,
logger=logger,
source_name="plugins",
top_k=retrieval_config["top_k"]
)
Expand Down Expand Up @@ -548,6 +549,7 @@ def _extract_relevance_score(response: str) -> str:

return relevance_score


def _generate_search_query_from_logs(log_text: str) -> str:
"""
Uses the LLM to extract a concise error signature from the logs
Expand Down
26 changes: 14 additions & 12 deletions chatbot-core/api/services/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from datetime import datetime, timedelta
from threading import Lock
from typing import Optional
from langchain.memory import ConversationBufferMemory
from langchain.memory import ConversationBufferWindowMemory
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from api.config.loader import CONFIG
from api.services.sessionmanager import(
from api.services.sessionmanager import (
delete_session_file,
load_session,
session_exists_in_json,
Expand Down Expand Up @@ -41,13 +41,13 @@ def init_session() -> str:
session_id = str(uuid.uuid4())
with _lock:
_sessions[session_id] = {
"memory": ConversationBufferMemory(return_messages=True),
"memory": ConversationBufferWindowMemory(k=10, return_messages=True),
"last_accessed": datetime.now()
}
return session_id


def _restore_persisted_message(memory: ConversationBufferMemory, message: object) -> None:
def _restore_persisted_message(memory: ConversationBufferWindowMemory, message: object) -> None:
"""
Restore one persisted message into LangChain memory.

Expand All @@ -71,7 +71,7 @@ def _restore_persisted_message(memory: ConversationBufferMemory, message: object
memory.chat_memory.add_message(message_class(content=content))


def get_session(session_id: str) -> Optional[ConversationBufferMemory]:
def get_session(session_id: str) -> Optional[ConversationBufferWindowMemory]:
"""
Retrieve the chat session memory for the given session ID.
Lazily restores from disk if missing in memory.
Expand All @@ -80,22 +80,22 @@ def get_session(session_id: str) -> Optional[ConversationBufferMemory]:
session_id (str): The session identifier.

Returns:
Optional[ConversationBufferMemory]: The memory object if found, else None.
Optional[ConversationBufferWindowMemory]: The memory object if found, else None.
"""

with _lock:

session_data = _sessions.get(session_id)

if session_data :
if session_data:
session_data["last_accessed"] = datetime.now()
return session_data["memory"]

history = load_session(session_id)
if not history:
return None

memory = ConversationBufferMemory(return_messages=True)
memory = ConversationBufferWindowMemory(k=10, return_messages=True)
for msg in history:
_restore_persisted_message(memory, msg)

Expand All @@ -106,14 +106,15 @@ def get_session(session_id: str) -> Optional[ConversationBufferMemory]:

return memory

async def get_session_async(session_id: str) -> Optional[ConversationBufferMemory]:

async def get_session_async(session_id: str) -> Optional[ConversationBufferWindowMemory]:
"""
Async wrapper for get_session to prevent event loop blocking.
"""
return await asyncio.to_thread(get_session, session_id)


def persist_session(session_id: str)-> None:
def persist_session(session_id: str) -> None:
"""
Persist the current session messages to disk.

Expand All @@ -129,7 +130,6 @@ def persist_session(session_id: str)-> None:
append_message(session_id, messages)



def delete_session(session_id: str) -> bool:
"""
Delete a chat session and its persisted data.
Expand Down Expand Up @@ -207,9 +207,9 @@ def get_last_accessed(session_id: str) -> Optional[datetime]:
if not history:
return None


return history["last_accessed"]


def set_last_accessed(session_id: str, timestamp: datetime) -> bool:
"""
Set the last accessed timestamp for a given session (for testing purposes).
Expand All @@ -236,6 +236,7 @@ def set_last_accessed(session_id: str, timestamp: datetime) -> bool:

return False


def get_session_count() -> int:
"""
Get the total number of active sessions (for testing purposes).
Expand All @@ -246,6 +247,7 @@ def get_session_count() -> int:
with _lock:
return len(_sessions)


def cleanup_expired_sessions() -> int:
"""
Remove sessions that have not been accessed within the configured timeout period.
Expand Down
34 changes: 25 additions & 9 deletions chatbot-core/api/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"search_community_threads": {"query": str},
})


def get_default_tools_call(query: str):
"""
Returns a default list of tool calls using the user query,
Expand Down Expand Up @@ -62,6 +63,7 @@ def get_default_tools_call(query: str):
}
]


def validate_tool_calls(tool_calls_parsed: list, logger) -> bool:
"""
Validates that each tool call has a valid tool name and matching params.
Expand All @@ -85,16 +87,19 @@ def validate_tool_calls(tool_calls_parsed: list, logger) -> bool:

for param_name, param_type in expected_params.items():
if param_name not in params:
logger.warning("Tool: %s: Param %s is not expected.", tool, param_name)
logger.warning(
"Tool: %s: Param %s is not expected.", tool, param_name)
valid = False
if not isinstance(params[param_name], param_type):
logger.warning("Tool: %s: Param %s is not of the expected type %s.",
tool, param_name, param_type.__name__)
tool, param_name, param_type.__name__)
valid = False

return valid

# pylint: disable=too-many-locals


def get_inverted_scores(
semantic_chunk_ids: List[str],
semantic_scores: List[float],
Expand Down Expand Up @@ -122,18 +127,20 @@ def get_inverted_scores(
"""
if not 0 <= semantic_weight <= 1:
semantic_weight = 0.5
semantic_map = {semantic_chunk_ids[i]:semantic_scores[i]
semantic_map = {semantic_chunk_ids[i]: semantic_scores[i]
for i in range(len(semantic_chunk_ids))}
keyword_map = {keyword_chunk_ids[i]:keyword_scores[i]
keyword_map = {keyword_chunk_ids[i]: keyword_scores[i]
for i in range(len(keyword_chunk_ids))}

all_chunk_ids = set(semantic_map.keys()).union(keyword_map.keys())

default_keyword = min(keyword_map.values()) if keyword_map else 0
default_semantic = max(semantic_map.values()) if semantic_map else 1.5

keyword_vals = [keyword_map.get(cid, default_keyword) for cid in all_chunk_ids]
semantic_vals = [semantic_map.get(cid, default_semantic) for cid in all_chunk_ids]
keyword_vals = [keyword_map.get(cid, default_keyword)
for cid in all_chunk_ids]
semantic_vals = [semantic_map.get(cid, default_semantic)
for cid in all_chunk_ids]

keyword_norm = _min_max_normalize(keyword_vals)
sem_max = max(semantic_vals) if semantic_vals else 1.0
Expand All @@ -146,6 +153,7 @@ def get_inverted_scores(
for i, cid in enumerate(all_chunk_ids)
]


def _min_max_normalize(values: List[float]) -> List[float]:
"""
Normalize a list of floats to [0, 1].
Expand All @@ -160,6 +168,7 @@ def _min_max_normalize(values: List[float]) -> List[float]:
rng = vmax - vmin
return [(v - vmin) / rng for v in values]


def extract_chunks_content(chunks: List[Dict], logger) -> str:
"""
Builds a single context string from a list of chunks by replacing code block
Expand All @@ -177,7 +186,8 @@ def extract_chunks_content(chunks: List[Dict], logger) -> str:
item_id = item.get("id", "")
text = item.get("chunk_text", "")
if not item_id:
logger.warning("Id of retrieved context not found. Skipping element.")
logger.warning(
"Id of retrieved context not found. Skipping element.")
continue
if text:
code_iter = iter(item.get("code_blocks", []))
Expand All @@ -193,6 +203,7 @@ def extract_chunks_content(chunks: List[Dict], logger) -> str:
else retrieval_config["empty_context_message"]
)


def is_valid_plugin(plugin_name: str) -> bool:
"""
Checks whether the given plugin name exists in the list of known plugin names.
Expand All @@ -217,6 +228,7 @@ def tokenize(item: str) -> str:

return False


def filter_retrieved_data(
semantic_data: List[Dict],
keyword_data: List[Dict],
Expand Down Expand Up @@ -245,6 +257,7 @@ def tokenize(item: str) -> str:

return semantic_filtered_data, keyword_filtered_data


def make_placeholder_replacer(code_iter, item_id, logger):
"""
Returns a function to replace code block placeholders in retrieved text
Expand All @@ -261,10 +274,12 @@ def replace(_match):
try:
return next(code_iter)
except StopIteration:
logger.warning("More placeholders than code blocks in chunk with ID %s", item_id)
logger.warning(
"More placeholders than code blocks in chunk with ID %s", item_id)
return "[MISSING_CODE]"
return replace


def retrieve_documents(query: str, keywords: str, logger, source_name: str, embedding_model):
"""
Retrieve documents using both semantic and keyword-based methods.
Expand All @@ -282,7 +297,6 @@ def retrieve_documents(query: str, keywords: str, logger, source_name: str, embe
data_retrieved_semantic, scores_semantic = get_relevant_documents(
query,
embedding_model,
logger=logger,
source_name=source_name,
top_k=retrieval_config["top_k_semantic"]
)
Expand All @@ -302,6 +316,8 @@ def retrieve_documents(query: str, keywords: str, logger, source_name: str, embe

# pylint: disable=too-many-arguments
# pylint: disable=too-many-positional-arguments


def extract_top_chunks(
data_retrieved_semantic,
scores_semantic,
Expand Down
2 changes: 2 additions & 0 deletions chatbot-core/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""Top-level pytest configuration and plugins."""
pytest_plugins = ["tests.unit.mocks.test_env"]
Loading
Loading