Skip to content

Commit 6a5b4ad

Browse files
fix al workflow (#26)
* fix al workflow: nodo condicional guard te lleva a fallback_inicial si malicious sino a parafraseo y el segundo guard si malicious te lleva a fallback_final sino termina * Logica de fallback final y fallback inicial juntada en un unico fallback, no hace falta separarlos en 2 * Estado cambiado a MessageState para usar el historial de mensajes * Copilot suggestions. --------- Co-authored-by: JPAmorin <juanpabloamorinjusto@gmail.com>
1 parent e7de51f commit 6a5b4ad

File tree

12 files changed

+102
-189
lines changed

12 files changed

+102
-189
lines changed

RAGManager/app/agents/graph.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,13 @@
55
from app.agents.nodes import (
66
agent_host,
77
context_builder,
8-
fallback_final,
9-
fallback_inicial,
10-
generator,
8+
fallback,
119
guard,
1210
parafraseo,
1311
retriever,
1412
)
15-
from app.agents.routing import route_after_fallback_final, route_after_guard
1613
from app.agents.state import AgentState
17-
14+
from app.agents.routing import route_after_guard
1815

1916
def create_agent_graph() -> StateGraph:
2017
"""
@@ -23,13 +20,13 @@ def create_agent_graph() -> StateGraph:
2320
The graph implements the following flow:
2421
1. START -> agent_host (Nodo 1)
2522
2. agent_host -> guard (Nodo 2)
26-
3. guard -> [conditional] -> fallback_inicial (Nodo 3) or END
27-
4. fallback_inicial -> parafraseo (Nodo 4)
23+
3. guard -> [conditional] -> fallback (Nodo 3) or END
24+
4. fallback -> parafraseo (Nodo 4)
2825
5. parafraseo -> retriever (Nodo 5)
2926
6. retriever -> context_builder (Nodo 6)
3027
7. context_builder -> generator (Nodo 7)
31-
8. generator -> fallback_final (Nodo 8)
32-
9. fallback_final -> [conditional] -> END (with final_response) or END (with error)
28+
8. generator -> fallback (Nodo 8)
29+
9. fallback -> [conditional] -> END (with final_response) or END (with error)
3330
3431
Returns:
3532
Configured StateGraph instance ready for execution
@@ -40,12 +37,10 @@ def create_agent_graph() -> StateGraph:
4037
# Add nodes
4138
workflow.add_node("agent_host", agent_host)
4239
workflow.add_node("guard", guard)
43-
workflow.add_node("fallback_inicial", fallback_inicial)
40+
workflow.add_node("fallback", fallback)
4441
workflow.add_node("parafraseo", parafraseo)
4542
workflow.add_node("retriever", retriever)
4643
workflow.add_node("context_builder", context_builder)
47-
workflow.add_node("generator", generator)
48-
workflow.add_node("fallback_final", fallback_final)
4944

5045
# Define edges
5146
# Start -> agent_host
@@ -59,37 +54,29 @@ def create_agent_graph() -> StateGraph:
5954
"guard",
6055
route_after_guard,
6156
{
62-
"malicious": END, # End with error if malicious
63-
"continue": "fallback_inicial", # Continue to fallback_inicial if valid
57+
"malicious": "fallback", # go to fallback if malicious
58+
"continue": "parafraseo", # Continue to parafraseo if valid
6459
},
6560
)
6661

67-
# fallback_inicial -> parafraseo
68-
workflow.add_edge("fallback_inicial", "parafraseo")
69-
7062
# parafraseo -> retriever
7163
workflow.add_edge("parafraseo", "retriever")
7264

7365
# retriever -> context_builder
7466
workflow.add_edge("retriever", "context_builder")
7567

76-
# context_builder -> generator
77-
# Note: Primary LLM is called within context_builder node
78-
workflow.add_edge("context_builder", "generator")
68+
# context_builder -> guard
69+
workflow.add_edge("context_builder", "guard")
7970

80-
# generator -> fallback_final
81-
workflow.add_edge("generator", "fallback_final")
82-
83-
# fallback_final -> conditional routing
71+
# guard -> conditional routing
8472
workflow.add_conditional_edges(
85-
"fallback_final",
86-
route_after_fallback_final,
73+
"guard",
74+
route_after_guard,
8775
{
88-
"risky": END, # End with error if risky
89-
"continue": END, # End with final_response if valid
90-
# Note: Final LLM is called within fallback_final node
76+
"malicious": "fallback", # go to fallback if malicious
77+
"continue": END, # if there's no error ends
9178
},
9279
)
93-
80+
workflow.add_edge("fallback", END)
9481
# Compile the graph
9582
return workflow.compile()

RAGManager/app/agents/nodes/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,16 @@
22

33
from app.agents.nodes.agent_host import agent_host
44
from app.agents.nodes.context_builder import context_builder
5-
from app.agents.nodes.fallback_final import fallback_final
6-
from app.agents.nodes.fallback_inicial import fallback_inicial
7-
from app.agents.nodes.generator import generator
5+
from app.agents.nodes.fallback import fallback
86
from app.agents.nodes.guard import guard
97
from app.agents.nodes.parafraseo import parafraseo
108
from app.agents.nodes.retriever import retriever
119

1210
__all__ = [
1311
"agent_host",
1412
"guard",
15-
"fallback_inicial",
13+
"fallback",
1614
"parafraseo",
1715
"retriever",
1816
"context_builder",
19-
"generator",
20-
"fallback_final",
2117
]

RAGManager/app/agents/nodes/agent_host.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def agent_host(state: AgentState) -> AgentState:
2929
Updated state with chat_session_id, chat_messages, and initial_context set
3030
"""
3131
updated_state = state.copy()
32-
prompt = state.get("prompt", "")
32+
33+
prompt = state["messages"][-1]
3334
chat_session_id = state.get("chat_session_id")
3435
user_id = state.get("user_id")
3536

RAGManager/app/agents/nodes/context_builder.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""Nodo 6: Context Builder - Enriches query with retrieved context."""
22

33
from app.agents.state import AgentState
4+
from langchain_core.messages import SystemMessage
5+
from langchain_openai import ChatOpenAI
6+
7+
llm = ChatOpenAI(model="gpt-5-nano")
48

59

610
def context_builder(state: AgentState) -> AgentState:
@@ -31,13 +35,18 @@ def context_builder(state: AgentState) -> AgentState:
3135
paraphrased = state.get("paraphrased_text", "")
3236
chunks = state.get("relevant_chunks", [])
3337

34-
# Build enriched query
35-
context_section = "\n\n".join(chunks) if chunks else ""
36-
enriched_query = f"{paraphrased}\n\nContext:\n{context_section}" if context_section else paraphrased
37-
updated_state["enriched_query"] = enriched_query
38+
# Build enriched query with context
39+
context_section = "\n\n".join(chunks) if chunks else "No relevant context found."
40+
41+
system_content = f"""You are a helpful assistant. Use the following context to answer the user's question.
42+
If the answer is not in the context, say you don't know.
43+
44+
Context:
45+
{context_section}"""
46+
47+
messages = [SystemMessage(content=system_content)] + state["messages"]
3848

39-
# TODO: Call Primary LLM here
40-
# updated_state["primary_response"] = call_primary_llm(enriched_query)
41-
updated_state["primary_response"] = None
49+
# Call Primary LLM
50+
response = llm.invoke(messages)
4251

43-
return updated_state
52+
return {"messages": [response]}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""Nodo 3: Fallback - Handles fallback processing from multiple workflow points."""
2+
3+
import logging
4+
5+
from app.agents.state import AgentState
6+
from langchain_core.messages import SystemMessage
7+
from langchain_openai import ChatOpenAI
8+
9+
logger = logging.getLogger(__name__)
10+
11+
llm = ChatOpenAI(
12+
model="gpt-5-nano",
13+
)
14+
15+
# TO DO: implementar clase nodo fallback y inicializar el llm en el init
16+
def fallback(state: AgentState) -> AgentState:
17+
"""
18+
Fallback node - Performs fallback processing.
19+
20+
This node:
21+
1. Alerts about malicious prompt
22+
2. Generates an error_message from llm to show the user
23+
24+
Args:
25+
state: Agent state containing the prompt or initial context
26+
27+
Returns:
28+
error_message
29+
"""
30+
31+
logger.warning(
32+
"Defensive check triggered: Malicious prompt detected"
33+
)
34+
35+
messages = [
36+
SystemMessage(
37+
content="Your job is to generate an error message in user's language for the user explaining the database doesn't have the information to respond what the user asked"
38+
)
39+
] + state["messages"]
40+
error_message = llm.invoke(messages)
41+
return {"messages": [error_message]}
42+

RAGManager/app/agents/nodes/fallback_final.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

RAGManager/app/agents/nodes/fallback_inicial.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

RAGManager/app/agents/nodes/generator.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

RAGManager/app/agents/nodes/guard.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def guard(state: AgentState) -> AgentState:
3737
Updated state with is_malicious and error_message set
3838
"""
3939
updated_state = state.copy()
40-
prompt = state.get("prompt", "")
40+
messages = state.get("messages", [])
41+
last_message = messages[-1] if messages else None
42+
prompt = last_message.content if last_message else ""
4143

4244
if not prompt:
4345
# Empty prompt is considered safe
@@ -62,7 +64,7 @@ def guard(state: AgentState) -> AgentState:
6264
updated_state["error_message"] = (
6365
"Jailbreak attempt detected. Your request contains content that violates security policies."
6466
)
65-
logger.warning(f"Jailbreak attempt detected in prompt: {prompt[:100]}...")
67+
logger.warning("Jailbreak attempt detected in prompt (len=%d)", len(prompt))
6668

6769
except Exception as e:
6870
# If validation fails due to error, log it but don't block the request

RAGManager/app/agents/nodes/parafraseo.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""Nodo 4: Parafraseo - Paraphrases user input."""
22

33
from app.agents.state import AgentState
4+
from langchain_core.messages import SystemMessage
5+
from langchain_openai import ChatOpenAI
6+
7+
llm = ChatOpenAI(model="gpt-5-nano")
48

59

610
def parafraseo(state: AgentState) -> AgentState:
@@ -24,9 +28,16 @@ def parafraseo(state: AgentState) -> AgentState:
2428
# 2. Improve clarity, adjust tone, or format as needed
2529
# 3. Set paraphrased_text with the result
2630

27-
# Placeholder: For now, we'll use the adjusted_text as-is
28-
updated_state = state.copy()
29-
text_to_paraphrase = state.get("adjusted_text") or state.get("prompt", "")
30-
updated_state["paraphrased_text"] = text_to_paraphrase
31+
# Paraphrase the last message using history
32+
33+
system_instruction = """You are an expert at paraphrasing user questions to be standalone and clear, given the conversation history.
34+
Reformulate the last user message to be a self-contained query that includes necessary context from previous messages.
35+
Do not answer the question, just rewrite it."""
36+
37+
messages = [SystemMessage(content=system_instruction)] + state["messages"]
38+
39+
response = llm.invoke(messages)
40+
updated_state = state.copy() # Create a copy of the state to update
41+
updated_state["paraphrased_text"] = response.content
3142

3243
return updated_state

0 commit comments

Comments
 (0)