diff --git a/RAGManager/app/agents/graph.py b/RAGManager/app/agents/graph.py index 52825a5..8916ea5 100644 --- a/RAGManager/app/agents/graph.py +++ b/RAGManager/app/agents/graph.py @@ -7,7 +7,6 @@ context_builder, fallback_final, fallback_inicial, - generator, guard_final, guard_inicial, parafraseo, @@ -18,7 +17,6 @@ route_after_guard_inicial, ) from app.agents.state import AgentState -from app.agents.routing import route_after_guard def create_agent_graph() -> StateGraph: """ @@ -26,17 +24,17 @@ def create_agent_graph() -> StateGraph: The graph implements the following flow: 1. START -> agent_host (Nodo 1) - Prepares state, no DB operations - 2. agent_host -> guard (Nodo 2) - Validates for malicious content - 3. guard -> [conditional]: - - malicious -> fallback -> END (stops processing, no DB save) + 2. agent_host -> guard_inicial (Nodo 2) - Validates for malicious content + 3. guard_inicial -> [conditional]: + - malicious -> fallback_inicial -> END (stops processing, no DB save) - continue -> parafraseo (Nodo 4) 4. parafraseo -> Saves message to DB, retrieves chat history, paraphrases - 5. parafraseo -> retriever (Nodo 5) - 6. retriever -> context_builder (Nodo 6) - 7. context_builder -> guard (validates response) - 8. guard -> [conditional]: - - malicious -> fallback -> END - - continue -> END (success) + 5. parafraseo -> retriever (Nodo 5) - Retrieves relevant chunks from vector DB + 6. retriever -> context_builder (Nodo 6) - Builds enriched query and generates response + 7. context_builder -> guard_final (Nodo 8) - Validates response for risky content + 8. guard_final -> [conditional]: + - risky -> fallback_final -> END + - continue -> END (success, returns generated message) Returns: Configured StateGraph instance ready for execution @@ -51,7 +49,6 @@ def create_agent_graph() -> StateGraph: workflow.add_node("parafraseo", parafraseo) workflow.add_node("retriever", retriever) workflow.add_node("context_builder", context_builder) - workflow.add_node("generator", generator) workflow.add_node("guard_final", guard_final) workflow.add_node("fallback_final", fallback_final) @@ -81,11 +78,8 @@ def create_agent_graph() -> StateGraph: # retriever -> context_builder workflow.add_edge("retriever", "context_builder") - # context_builder -> guard - workflow.add_edge("context_builder", "guard") - - # generator -> guard_final - workflow.add_edge("generator", "guard_final") + # context_builder -> guard_final + workflow.add_edge("context_builder", "guard_final") # guard_final -> conditional routing workflow.add_conditional_edges( diff --git a/RAGManager/app/agents/nodes/retriever.py b/RAGManager/app/agents/nodes/retriever.py index 8c7d053..0f2ec42 100644 --- a/RAGManager/app/agents/nodes/retriever.py +++ b/RAGManager/app/agents/nodes/retriever.py @@ -105,33 +105,35 @@ def retriever(state: AgentState) -> AgentState: Retriever node - Performs semantic search in vector database using LangChain PGVector. This node: - 1. Takes 3 paraphrased phrases from parafraseo node - 2. For each phrase, uses PGVector's similarity_search to query the database - 3. Retrieves top 3 most relevant chunks per phrase + 1. Takes 3 paraphrased statements from parafraseo node + 2. For each statement, uses PGVector's similarity_search to query the database + 3. Retrieves top 3 most relevant chunks per statement 4. Creates a unique union of all retrieved chunks (no duplicates) 5. Stores the chunk contents in relevant_chunks Args: - state: Agent state containing paraphrased_phrases (list of 3 phrases) + state: Agent state containing paraphrased_statements (list of 3 statements) Returns: Updated state with relevant_chunks set (list of unique chunk contents) """ updated_state = state.copy() - # Get the 3 paraphrased phrases from state - paraphrased_phrases = state.get("paraphrased_phrases") + # Get the 3 paraphrased statements from state + paraphrased_statements = state.get("paraphrased_statements") - if not paraphrased_phrases or len(paraphrased_phrases) == 0: + if not paraphrased_statements or len(paraphrased_statements) == 0: logger.warning( - "No paraphrased phrases found in state. Parafraseo node may not have been executed yet." + "No paraphrased statements found in state. Parafraseo node may not have been executed yet." ) updated_state["relevant_chunks"] = [] return updated_state - # Ensure we have exactly 3 phrases (or at least handle what we get) - phrases_to_process = paraphrased_phrases[:3] if len(paraphrased_phrases) >= 3 else paraphrased_phrases - logger.info(f"Retrieving documents for {len(phrases_to_process)} phrases") + # Ensure we have exactly 3 statements (or at least handle what we get) + statements_to_process = ( + paraphrased_statements[:3] if len(paraphrased_statements) >= 3 else paraphrased_statements + ) + logger.info(f"Retrieving documents for {len(statements_to_process)} statements") # Use a set to track unique chunk IDs to avoid duplicates seen_chunk_ids: set[str] = set() @@ -141,10 +143,10 @@ def retriever(state: AgentState) -> AgentState: # Get PGVector instance vector_store = _get_vector_store() - # For each phrase, retrieve top 3 chunks using PGVector's similarity_search - for phrase in phrases_to_process: - logger.debug(f"Retrieving chunks for phrase: {phrase[:50]}...") - chunks = _retrieve_chunks_for_phrase(vector_store, phrase, top_k=3) + # For each statement, retrieve top 3 chunks using PGVector's similarity_search + for statement in statements_to_process: + logger.debug(f"Retrieving chunks for statement: {statement[:50]}...") + chunks = _retrieve_chunks_for_phrase(vector_store, statement, top_k=3) # Add chunks to unique list (avoiding duplicates by chunk ID) for chunk_id, chunk_content in chunks: @@ -153,7 +155,7 @@ def retriever(state: AgentState) -> AgentState: unique_chunks.append(chunk_content) logger.debug(f"Added chunk {chunk_id} to results") - logger.info(f"Retrieved {len(unique_chunks)} unique chunks from {len(phrases_to_process)} phrases") + logger.info(f"Retrieved {len(unique_chunks)} unique chunks from {len(statements_to_process)} statements") except Exception as e: logger.error(f"Error during retrieval: {e}", exc_info=True) unique_chunks = []