Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 11 additions & 17 deletions RAGManager/app/agents/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
context_builder,
fallback_final,
fallback_inicial,
generator,
guard_final,
guard_inicial,
parafraseo,
Expand All @@ -18,25 +17,24 @@
route_after_guard_inicial,
)
from app.agents.state import AgentState
from app.agents.routing import route_after_guard

def create_agent_graph() -> StateGraph:
"""
Create and configure the LangGraph agent graph.

The graph implements the following flow:
1. START -> agent_host (Nodo 1) - Prepares state, no DB operations
2. agent_host -> guard (Nodo 2) - Validates for malicious content
3. guard -> [conditional]:
- malicious -> fallback -> END (stops processing, no DB save)
2. agent_host -> guard_inicial (Nodo 2) - Validates for malicious content
3. guard_inicial -> [conditional]:
- malicious -> fallback_inicial -> END (stops processing, no DB save)
- continue -> parafraseo (Nodo 4)
4. parafraseo -> Saves message to DB, retrieves chat history, paraphrases
5. parafraseo -> retriever (Nodo 5)
6. retriever -> context_builder (Nodo 6)
7. context_builder -> guard (validates response)
8. guard -> [conditional]:
- malicious -> fallback -> END
- continue -> END (success)
5. parafraseo -> retriever (Nodo 5) - Retrieves relevant chunks from vector DB
6. retriever -> context_builder (Nodo 6) - Builds enriched query and generates response
7. context_builder -> guard_final (Nodo 8) - Validates response for risky content
8. guard_final -> [conditional]:
- risky -> fallback_final -> END
- continue -> END (success, returns generated message)

Returns:
Configured StateGraph instance ready for execution
Expand All @@ -51,7 +49,6 @@ def create_agent_graph() -> StateGraph:
workflow.add_node("parafraseo", parafraseo)
workflow.add_node("retriever", retriever)
workflow.add_node("context_builder", context_builder)
workflow.add_node("generator", generator)
workflow.add_node("guard_final", guard_final)
workflow.add_node("fallback_final", fallback_final)

Expand Down Expand Up @@ -81,11 +78,8 @@ def create_agent_graph() -> StateGraph:
# retriever -> context_builder
workflow.add_edge("retriever", "context_builder")

# context_builder -> guard
workflow.add_edge("context_builder", "guard")

# generator -> guard_final
workflow.add_edge("generator", "guard_final")
# context_builder -> guard_final
workflow.add_edge("context_builder", "guard_final")

# guard_final -> conditional routing
workflow.add_conditional_edges(
Expand Down
34 changes: 18 additions & 16 deletions RAGManager/app/agents/nodes/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,33 +105,35 @@ def retriever(state: AgentState) -> AgentState:
Retriever node - Performs semantic search in vector database using LangChain PGVector.

This node:
1. Takes 3 paraphrased phrases from parafraseo node
2. For each phrase, uses PGVector's similarity_search to query the database
3. Retrieves top 3 most relevant chunks per phrase
1. Takes 3 paraphrased statements from parafraseo node
2. For each statement, uses PGVector's similarity_search to query the database
3. Retrieves top 3 most relevant chunks per statement
4. Creates a unique union of all retrieved chunks (no duplicates)
5. Stores the chunk contents in relevant_chunks

Args:
state: Agent state containing paraphrased_phrases (list of 3 phrases)
state: Agent state containing paraphrased_statements (list of 3 statements)

Returns:
Updated state with relevant_chunks set (list of unique chunk contents)
"""
updated_state = state.copy()

# Get the 3 paraphrased phrases from state
paraphrased_phrases = state.get("paraphrased_phrases")
# Get the 3 paraphrased statements from state
paraphrased_statements = state.get("paraphrased_statements")

if not paraphrased_phrases or len(paraphrased_phrases) == 0:
if not paraphrased_statements or len(paraphrased_statements) == 0:
logger.warning(
"No paraphrased phrases found in state. Parafraseo node may not have been executed yet."
"No paraphrased statements found in state. Parafraseo node may not have been executed yet."
)
updated_state["relevant_chunks"] = []
return updated_state

# Ensure we have exactly 3 phrases (or at least handle what we get)
phrases_to_process = paraphrased_phrases[:3] if len(paraphrased_phrases) >= 3 else paraphrased_phrases
logger.info(f"Retrieving documents for {len(phrases_to_process)} phrases")
# Ensure we have exactly 3 statements (or at least handle what we get)
statements_to_process = (
paraphrased_statements[:3] if len(paraphrased_statements) >= 3 else paraphrased_statements
)
logger.info(f"Retrieving documents for {len(statements_to_process)} statements")

# Use a set to track unique chunk IDs to avoid duplicates
seen_chunk_ids: set[str] = set()
Expand All @@ -141,10 +143,10 @@ def retriever(state: AgentState) -> AgentState:
# Get PGVector instance
vector_store = _get_vector_store()

# For each phrase, retrieve top 3 chunks using PGVector's similarity_search
for phrase in phrases_to_process:
logger.debug(f"Retrieving chunks for phrase: {phrase[:50]}...")
chunks = _retrieve_chunks_for_phrase(vector_store, phrase, top_k=3)
# For each statement, retrieve top 3 chunks using PGVector's similarity_search
for statement in statements_to_process:
logger.debug(f"Retrieving chunks for statement: {statement[:50]}...")
chunks = _retrieve_chunks_for_phrase(vector_store, statement, top_k=3)

# Add chunks to unique list (avoiding duplicates by chunk ID)
for chunk_id, chunk_content in chunks:
Expand All @@ -153,7 +155,7 @@ def retriever(state: AgentState) -> AgentState:
unique_chunks.append(chunk_content)
logger.debug(f"Added chunk {chunk_id} to results")

logger.info(f"Retrieved {len(unique_chunks)} unique chunks from {len(phrases_to_process)} phrases")
logger.info(f"Retrieved {len(unique_chunks)} unique chunks from {len(statements_to_process)} statements")
except Exception as e:
logger.error(f"Error during retrieval: {e}", exc_info=True)
unique_chunks = []
Expand Down