Skip to content

Commit 1dbcbe5

Browse files
authored
Separated guard and fallback logic into initial and final. (#32)
* Separated guard and fallback logic into initial and final. * Edited logging of PII containing responses and Jailbreak attempting prompts.
1 parent 8584c86 commit 1dbcbe5

File tree

9 files changed

+219
-30
lines changed

9 files changed

+219
-30
lines changed

RAGManager/app/agents/graph.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,18 @@
55
from app.agents.nodes import (
66
agent_host,
77
context_builder,
8-
fallback,
9-
guard,
8+
fallback_final,
9+
fallback_inicial,
10+
generator,
11+
guard_final,
12+
guard_inicial,
1013
parafraseo,
1114
retriever,
1215
)
16+
from app.agents.routing import (
17+
route_after_guard_final,
18+
route_after_guard_inicial,
19+
)
1320
from app.agents.state import AgentState
1421
from app.agents.routing import route_after_guard
1522

@@ -39,29 +46,35 @@ def create_agent_graph() -> StateGraph:
3946

4047
# Add nodes
4148
workflow.add_node("agent_host", agent_host)
42-
workflow.add_node("guard", guard)
43-
workflow.add_node("fallback", fallback)
49+
workflow.add_node("guard_inicial", guard_inicial)
50+
workflow.add_node("fallback_inicial", fallback_inicial)
4451
workflow.add_node("parafraseo", parafraseo)
4552
workflow.add_node("retriever", retriever)
4653
workflow.add_node("context_builder", context_builder)
54+
workflow.add_node("generator", generator)
55+
workflow.add_node("guard_final", guard_final)
56+
workflow.add_node("fallback_final", fallback_final)
4757

4858
# Define edges
4959
# Start -> agent_host
5060
workflow.add_edge(START, "agent_host")
5161

52-
# agent_host -> guard
53-
workflow.add_edge("agent_host", "guard")
62+
# agent_host -> guard_inicial
63+
workflow.add_edge("agent_host", "guard_inicial")
5464

55-
# guard -> conditional routing
65+
# guard_inicial -> conditional routing
5666
workflow.add_conditional_edges(
57-
"guard",
58-
route_after_guard,
67+
"guard_inicial",
68+
route_after_guard_inicial,
5969
{
60-
"malicious": "fallback", # go to fallback if malicious
61-
"continue": "parafraseo", # Continue to parafraseo if valid
70+
"malicious": "fallback_inicial", # Exception path: malicious content detected
71+
"continue": "parafraseo", # Normal path: continue processing
6272
},
6373
)
6474

75+
# fallback_inicial -> END (stop flow with error message)
76+
workflow.add_edge("fallback_inicial", END)
77+
6578
# parafraseo -> retriever
6679
workflow.add_edge("parafraseo", "retriever")
6780

@@ -71,15 +84,21 @@ def create_agent_graph() -> StateGraph:
7184
# context_builder -> guard
7285
workflow.add_edge("context_builder", "guard")
7386

74-
# guard -> conditional routing
87+
# generator -> guard_final
88+
workflow.add_edge("generator", "guard_final")
89+
90+
# guard_final -> conditional routing
7591
workflow.add_conditional_edges(
76-
"guard",
77-
route_after_guard,
92+
"guard_final",
93+
route_after_guard_final,
7894
{
79-
"malicious": "fallback", # go to fallback if malicious
80-
"continue": END, # if there's no error ends
95+
"risky": "fallback_final", # Exception path: risky content detected
96+
"continue": END, # Normal path: end successfully
8197
},
8298
)
83-
workflow.add_edge("fallback", END)
99+
100+
# fallback_final -> END (stop flow with error message)
101+
workflow.add_edge("fallback_final", END)
102+
84103
# Compile the graph
85104
return workflow.compile()

RAGManager/app/agents/nodes/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@
22

33
from app.agents.nodes.agent_host import agent_host
44
from app.agents.nodes.context_builder import context_builder
5-
from app.agents.nodes.fallback import fallback
6-
from app.agents.nodes.guard import guard
5+
from app.agents.nodes.fallback_final import fallback_final
6+
from app.agents.nodes.fallback_inicial import fallback_inicial
7+
from app.agents.nodes.generator import generator
8+
from app.agents.nodes.guard_final import guard_final
9+
from app.agents.nodes.guard_inicial import guard_inicial
710
from app.agents.nodes.parafraseo import parafraseo
811
from app.agents.nodes.retriever import retriever
912

1013
__all__ = [
1114
"agent_host",
12-
"guard",
13-
"fallback",
15+
"guard_inicial",
16+
"guard_final",
17+
"fallback_inicial",
1418
"parafraseo",
1519
"retriever",
1620
"context_builder",
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Nodo 8: Fallback Final - Stops processing when risky content is detected."""
2+
3+
import logging
4+
5+
from app.agents.state import AgentState
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def fallback_final(state: AgentState) -> AgentState:
11+
"""
12+
Fallback Final node - Stops processing when risky content is detected.
13+
14+
This node:
15+
1. Sets error message indicating that the information requested is classified or not free to know
16+
2. Stops the flow by routing to END
17+
18+
Args:
19+
state: Agent state containing the response flagged as risky
20+
21+
Returns:
22+
Updated state with error_message set, ready to route to END
23+
"""
24+
updated_state = state.copy()
25+
26+
# Set error message for risky content
27+
updated_state["error_message"] = "The information requested is classified or not free to know."
28+
logger.warning("Risky content detected. Stopping processing. Response content not logged for security.")
29+
30+
return updated_state
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Nodo 3: Fallback Inicial - Stops processing when malicious content is detected."""
2+
3+
import logging
4+
5+
from app.agents.state import AgentState
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def fallback_inicial(state: AgentState) -> AgentState:
11+
"""
12+
Fallback Inicial node - Stops processing when malicious content is detected.
13+
14+
This node:
15+
1. Sets error message indicating that the user's intentions break the chatbot's rules
16+
2. Stops the flow by routing to END
17+
18+
Args:
19+
state: Agent state containing the prompt flagged as malicious
20+
21+
Returns:
22+
Updated state with error_message set, ready to route to END
23+
"""
24+
updated_state = state.copy()
25+
26+
# Set error message for malicious content
27+
updated_state["error_message"] = "The user's intentions break the chatbot's rules."
28+
logger.warning("Malicious content detected. Stopping processing. Prompt content not logged for security.")
29+
30+
return updated_state
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Nodo Guard Final - Validates generated response for PII (risky information detection)."""
2+
3+
import logging
4+
5+
from guardrails import Guard
6+
from guardrails.hub import DetectPII
7+
8+
from app.agents.state import AgentState
9+
from app.core.config import settings
10+
11+
logger = logging.getLogger(__name__)
12+
13+
# Initialize Guard with DetectPII validator
14+
# Note: The validator must be installed via: guardrails hub install hub://guardrails/detect_pii
15+
_guard_final = Guard().use(
16+
DetectPII(
17+
pii_entities=settings.guardrails_pii_entities,
18+
on_fail="noop", # Don't raise exceptions, handle via state flags
19+
)
20+
)
21+
22+
23+
def guard_final(state: AgentState) -> AgentState:
24+
"""
25+
Guard final node - Validates generated response for PII using Guardrails DetectPII.
26+
27+
This node:
28+
1. Validates the generated_response using Guardrails DetectPII validator
29+
2. Sets is_risky flag if PII is detected
30+
3. Sets error_message if risky content is detected
31+
32+
Args:
33+
state: Agent state containing the generated_response
34+
35+
Returns:
36+
Updated state with is_risky and error_message set
37+
"""
38+
updated_state = state.copy()
39+
generated_response = state.get("generated_response", "")
40+
41+
if not generated_response:
42+
# Empty response is considered safe
43+
updated_state["is_risky"] = False
44+
updated_state["error_message"] = None
45+
return updated_state
46+
47+
try:
48+
# Validate the generated response using Guardrails
49+
validation_result = _guard_final.validate(generated_response)
50+
51+
# Check if validation passed
52+
# The validator returns ValidationResult with outcome
53+
# If validation fails, outcome will indicate failure
54+
if validation_result.validation_passed:
55+
updated_state["is_risky"] = False
56+
updated_state["error_message"] = None
57+
logger.debug("Generated response passed PII detection")
58+
else:
59+
# PII detected
60+
updated_state["is_risky"] = True
61+
updated_state["error_message"] = (
62+
"PII detected in generated response. The information requested is classified or not free to know."
63+
)
64+
logger.warning("PII detected in generated response. Response content not logged for security.")
65+
66+
except Exception as e:
67+
# If validation fails due to error, log it but don't block the request
68+
# This is a safety measure - if Guardrails fails, we allow the request
69+
# but log the error for monitoring
70+
logger.error(f"Error during PII detection: {e}")
71+
updated_state["is_risky"] = False
72+
updated_state["error_message"] = None
73+
74+
return updated_state

RAGManager/app/agents/nodes/guard.py renamed to RAGManager/app/agents/nodes/guard_inicial.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Nodo 2: Guard - Validates for malicious content."""
1+
"""Nodo 2: Guard Inicial - Validates for malicious content (jailbreak detection)."""
22

33
import logging
44

@@ -12,7 +12,7 @@
1212

1313
# Initialize Guard with DetectJailbreak validator
1414
# Note: The validator must be installed via: guardrails hub install hub://guardrails/detect_jailbreak
15-
_guard = Guard().use(
15+
_guard_inicial = Guard().use(
1616
DetectJailbreak(
1717
threshold=settings.guardrails_jailbreak_threshold,
1818
device=settings.guardrails_device,
@@ -21,9 +21,9 @@
2121
)
2222

2323

24-
def guard(state: AgentState) -> AgentState:
24+
def guard_inicial(state: AgentState) -> AgentState:
2525
"""
26-
Guard node - Validates user input for malicious content using Guardrails DetectJailbreak.
26+
Guard inicial node - Validates user input for jailbreak attempts using Guardrails DetectJailbreak.
2727
2828
This node:
2929
1. Validates the prompt using Guardrails DetectJailbreak validator
@@ -49,7 +49,7 @@ def guard(state: AgentState) -> AgentState:
4949

5050
try:
5151
# Validate the prompt using Guardrails
52-
validation_result = _guard.validate(prompt)
52+
validation_result = _guard_inicial.validate(prompt)
5353

5454
# Check if validation passed
5555
# The validator returns ValidationResult with outcome
@@ -64,7 +64,7 @@ def guard(state: AgentState) -> AgentState:
6464
updated_state["error_message"] = (
6565
"Jailbreak attempt detected. Your request contains content that violates security policies."
6666
)
67-
logger.warning("Jailbreak attempt detected in prompt (len=%d)", len(prompt))
67+
logger.warning("Jailbreak attempt detected. Prompt content not logged for security.")
6868

6969
except Exception as e:
7070
# If validation fails due to error, log it but don't block the request

RAGManager/app/agents/routing.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from app.agents.state import AgentState
44

55

6-
def route_after_guard(state: AgentState) -> str:
6+
def route_after_guard_inicial(state: AgentState) -> str:
77
"""
8-
Route after Guard node (Nodo 2) validation.
8+
Route after Guard Inicial node validation.
99
1010
Determines the next step based on whether the prompt was flagged as malicious.
1111
@@ -18,3 +18,20 @@ def route_after_guard(state: AgentState) -> str:
1818
if state.get("is_malicious", False):
1919
return "malicious"
2020
return "continue"
21+
22+
23+
def route_after_guard_final(state: AgentState) -> str:
24+
"""
25+
Route after Guard Final node validation.
26+
27+
Determines the next step based on whether the response was flagged as risky.
28+
29+
Args:
30+
state: Current agent state
31+
32+
Returns:
33+
"risky" if the response is risky, "continue" otherwise
34+
"""
35+
if state.get("is_risky", False):
36+
return "risky"
37+
return "continue"

RAGManager/app/core/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,19 @@ class Settings(BaseSettings):
4747
default="cpu",
4848
description="Device for model inference.",
4949
)
50+
guardrails_pii_entities: list[str] = Field(
51+
default=[
52+
"EMAIL_ADDRESS",
53+
"PHONE_NUMBER",
54+
"CREDIT_CARD",
55+
"SSN",
56+
"US_PASSPORT",
57+
"US_DRIVER_LICENSE",
58+
"IBAN_CODE",
59+
"IP_ADDRESS",
60+
],
61+
description="List of PII entity types to detect using DetectPII validator.",
62+
)
5063
model_config = SettingsConfigDict(
5164
env_file=".env",
5265
env_file_encoding="utf-8",

RAGManager/pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ dependencies = [
1717
"pydantic-settings>=2.0.0",
1818
"typing-extensions>=4.15.0",
1919
"uvicorn>=0.38.0",
20-
"guardrails-ai>=0.5.10",
20+
"guardrails-ai>=0.6.2",
21+
"presidio-analyzer>=2.2.360",
22+
"presidio-anonymizer>=2.2.360",
2123
]
2224

2325
[project.optional-dependencies]

0 commit comments

Comments
 (0)