Skip to content

Commit 785d54e

Browse files
authored
Fix to the workflow. Old flow did not make sense with our current logic. (#48)
1 parent 259a3ce commit 785d54e

File tree

2 files changed

+29
-33
lines changed

2 files changed

+29
-33
lines changed

RAGManager/app/agents/graph.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
context_builder,
88
fallback_final,
99
fallback_inicial,
10-
generator,
1110
guard_final,
1211
guard_inicial,
1312
parafraseo,
@@ -18,25 +17,24 @@
1817
route_after_guard_inicial,
1918
)
2019
from app.agents.state import AgentState
21-
from app.agents.routing import route_after_guard
2220

2321
def create_agent_graph() -> StateGraph:
2422
"""
2523
Create and configure the LangGraph agent graph.
2624
2725
The graph implements the following flow:
2826
1. START -> agent_host (Nodo 1) - Prepares state, no DB operations
29-
2. agent_host -> guard (Nodo 2) - Validates for malicious content
30-
3. guard -> [conditional]:
31-
- malicious -> fallback -> END (stops processing, no DB save)
27+
2. agent_host -> guard_inicial (Nodo 2) - Validates for malicious content
28+
3. guard_inicial -> [conditional]:
29+
- malicious -> fallback_inicial -> END (stops processing, no DB save)
3230
- continue -> parafraseo (Nodo 4)
3331
4. parafraseo -> Saves message to DB, retrieves chat history, paraphrases
34-
5. parafraseo -> retriever (Nodo 5)
35-
6. retriever -> context_builder (Nodo 6)
36-
7. context_builder -> guard (validates response)
37-
8. guard -> [conditional]:
38-
- malicious -> fallback -> END
39-
- continue -> END (success)
32+
5. parafraseo -> retriever (Nodo 5) - Retrieves relevant chunks from vector DB
33+
6. retriever -> context_builder (Nodo 6) - Builds enriched query and generates response
34+
7. context_builder -> guard_final (Nodo 8) - Validates response for risky content
35+
8. guard_final -> [conditional]:
36+
- risky -> fallback_final -> END
37+
- continue -> END (success, returns generated message)
4038
4139
Returns:
4240
Configured StateGraph instance ready for execution
@@ -51,7 +49,6 @@ def create_agent_graph() -> StateGraph:
5149
workflow.add_node("parafraseo", parafraseo)
5250
workflow.add_node("retriever", retriever)
5351
workflow.add_node("context_builder", context_builder)
54-
workflow.add_node("generator", generator)
5552
workflow.add_node("guard_final", guard_final)
5653
workflow.add_node("fallback_final", fallback_final)
5754

@@ -81,11 +78,8 @@ def create_agent_graph() -> StateGraph:
8178
# retriever -> context_builder
8279
workflow.add_edge("retriever", "context_builder")
8380

84-
# context_builder -> guard
85-
workflow.add_edge("context_builder", "guard")
86-
87-
# generator -> guard_final
88-
workflow.add_edge("generator", "guard_final")
81+
# context_builder -> guard_final
82+
workflow.add_edge("context_builder", "guard_final")
8983

9084
# guard_final -> conditional routing
9185
workflow.add_conditional_edges(

RAGManager/app/agents/nodes/retriever.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -105,33 +105,35 @@ def retriever(state: AgentState) -> AgentState:
105105
Retriever node - Performs semantic search in vector database using LangChain PGVector.
106106
107107
This node:
108-
1. Takes 3 paraphrased phrases from parafraseo node
109-
2. For each phrase, uses PGVector's similarity_search to query the database
110-
3. Retrieves top 3 most relevant chunks per phrase
108+
1. Takes 3 paraphrased statements from parafraseo node
109+
2. For each statement, uses PGVector's similarity_search to query the database
110+
3. Retrieves top 3 most relevant chunks per statement
111111
4. Creates a unique union of all retrieved chunks (no duplicates)
112112
5. Stores the chunk contents in relevant_chunks
113113
114114
Args:
115-
state: Agent state containing paraphrased_phrases (list of 3 phrases)
115+
state: Agent state containing paraphrased_statements (list of 3 statements)
116116
117117
Returns:
118118
Updated state with relevant_chunks set (list of unique chunk contents)
119119
"""
120120
updated_state = state.copy()
121121

122-
# Get the 3 paraphrased phrases from state
123-
paraphrased_phrases = state.get("paraphrased_phrases")
122+
# Get the 3 paraphrased statements from state
123+
paraphrased_statements = state.get("paraphrased_statements")
124124

125-
if not paraphrased_phrases or len(paraphrased_phrases) == 0:
125+
if not paraphrased_statements or len(paraphrased_statements) == 0:
126126
logger.warning(
127-
"No paraphrased phrases found in state. Parafraseo node may not have been executed yet."
127+
"No paraphrased statements found in state. Parafraseo node may not have been executed yet."
128128
)
129129
updated_state["relevant_chunks"] = []
130130
return updated_state
131131

132-
# Ensure we have exactly 3 phrases (or at least handle what we get)
133-
phrases_to_process = paraphrased_phrases[:3] if len(paraphrased_phrases) >= 3 else paraphrased_phrases
134-
logger.info(f"Retrieving documents for {len(phrases_to_process)} phrases")
132+
# Ensure we have exactly 3 statements (or at least handle what we get)
133+
statements_to_process = (
134+
paraphrased_statements[:3] if len(paraphrased_statements) >= 3 else paraphrased_statements
135+
)
136+
logger.info(f"Retrieving documents for {len(statements_to_process)} statements")
135137

136138
# Use a set to track unique chunk IDs to avoid duplicates
137139
seen_chunk_ids: set[str] = set()
@@ -141,10 +143,10 @@ def retriever(state: AgentState) -> AgentState:
141143
# Get PGVector instance
142144
vector_store = _get_vector_store()
143145

144-
# For each phrase, retrieve top 3 chunks using PGVector's similarity_search
145-
for phrase in phrases_to_process:
146-
logger.debug(f"Retrieving chunks for phrase: {phrase[:50]}...")
147-
chunks = _retrieve_chunks_for_phrase(vector_store, phrase, top_k=3)
146+
# For each statement, retrieve top 3 chunks using PGVector's similarity_search
147+
for statement in statements_to_process:
148+
logger.debug(f"Retrieving chunks for statement: {statement[:50]}...")
149+
chunks = _retrieve_chunks_for_phrase(vector_store, statement, top_k=3)
148150

149151
# Add chunks to unique list (avoiding duplicates by chunk ID)
150152
for chunk_id, chunk_content in chunks:
@@ -153,7 +155,7 @@ def retriever(state: AgentState) -> AgentState:
153155
unique_chunks.append(chunk_content)
154156
logger.debug(f"Added chunk {chunk_id} to results")
155157

156-
logger.info(f"Retrieved {len(unique_chunks)} unique chunks from {len(phrases_to_process)} phrases")
158+
logger.info(f"Retrieved {len(unique_chunks)} unique chunks from {len(statements_to_process)} statements")
157159
except Exception as e:
158160
logger.error(f"Error during retrieval: {e}", exc_info=True)
159161
unique_chunks = []

0 commit comments

Comments
 (0)