diff --git a/RAGManager/app/agents/graph.py b/RAGManager/app/agents/graph.py index 9c869bd..52825a5 100644 --- a/RAGManager/app/agents/graph.py +++ b/RAGManager/app/agents/graph.py @@ -5,11 +5,18 @@ from app.agents.nodes import ( agent_host, context_builder, - fallback, - guard, + fallback_final, + fallback_inicial, + generator, + guard_final, + guard_inicial, parafraseo, retriever, ) +from app.agents.routing import ( + route_after_guard_final, + route_after_guard_inicial, +) from app.agents.state import AgentState from app.agents.routing import route_after_guard @@ -39,29 +46,35 @@ def create_agent_graph() -> StateGraph: # Add nodes workflow.add_node("agent_host", agent_host) - workflow.add_node("guard", guard) - workflow.add_node("fallback", fallback) + workflow.add_node("guard_inicial", guard_inicial) + workflow.add_node("fallback_inicial", fallback_inicial) workflow.add_node("parafraseo", parafraseo) workflow.add_node("retriever", retriever) workflow.add_node("context_builder", context_builder) + workflow.add_node("generator", generator) + workflow.add_node("guard_final", guard_final) + workflow.add_node("fallback_final", fallback_final) # Define edges # Start -> agent_host workflow.add_edge(START, "agent_host") - # agent_host -> guard - workflow.add_edge("agent_host", "guard") + # agent_host -> guard_inicial + workflow.add_edge("agent_host", "guard_inicial") - # guard -> conditional routing + # guard_inicial -> conditional routing workflow.add_conditional_edges( - "guard", - route_after_guard, + "guard_inicial", + route_after_guard_inicial, { - "malicious": "fallback", # go to fallback if malicious - "continue": "parafraseo", # Continue to parafraseo if valid + "malicious": "fallback_inicial", # Exception path: malicious content detected + "continue": "parafraseo", # Normal path: continue processing }, ) + # fallback_inicial -> END (stop flow with error message) + workflow.add_edge("fallback_inicial", END) + # parafraseo -> retriever workflow.add_edge("parafraseo", "retriever") @@ -71,15 +84,21 @@ def create_agent_graph() -> StateGraph: # context_builder -> guard workflow.add_edge("context_builder", "guard") - # guard -> conditional routing + # generator -> guard_final + workflow.add_edge("generator", "guard_final") + + # guard_final -> conditional routing workflow.add_conditional_edges( - "guard", - route_after_guard, + "guard_final", + route_after_guard_final, { - "malicious": "fallback", # go to fallback if malicious - "continue": END, # if there's no error ends + "risky": "fallback_final", # Exception path: risky content detected + "continue": END, # Normal path: end successfully }, ) - workflow.add_edge("fallback", END) + + # fallback_final -> END (stop flow with error message) + workflow.add_edge("fallback_final", END) + # Compile the graph return workflow.compile() diff --git a/RAGManager/app/agents/nodes/__init__.py b/RAGManager/app/agents/nodes/__init__.py index aeeefa8..47d0708 100644 --- a/RAGManager/app/agents/nodes/__init__.py +++ b/RAGManager/app/agents/nodes/__init__.py @@ -2,15 +2,19 @@ from app.agents.nodes.agent_host import agent_host from app.agents.nodes.context_builder import context_builder -from app.agents.nodes.fallback import fallback -from app.agents.nodes.guard import guard +from app.agents.nodes.fallback_final import fallback_final +from app.agents.nodes.fallback_inicial import fallback_inicial +from app.agents.nodes.generator import generator +from app.agents.nodes.guard_final import guard_final +from app.agents.nodes.guard_inicial import guard_inicial from app.agents.nodes.parafraseo import parafraseo from app.agents.nodes.retriever import retriever __all__ = [ "agent_host", - "guard", - "fallback", + "guard_inicial", + "guard_final", + "fallback_inicial", "parafraseo", "retriever", "context_builder", diff --git a/RAGManager/app/agents/nodes/fallback_final.py b/RAGManager/app/agents/nodes/fallback_final.py new file mode 100644 index 0000000..d6b1d73 --- /dev/null +++ b/RAGManager/app/agents/nodes/fallback_final.py @@ -0,0 +1,30 @@ +"""Nodo 8: Fallback Final - Stops processing when risky content is detected.""" + +import logging + +from app.agents.state import AgentState + +logger = logging.getLogger(__name__) + + +def fallback_final(state: AgentState) -> AgentState: + """ + Fallback Final node - Stops processing when risky content is detected. + + This node: + 1. Sets error message indicating that the information requested is classified or not free to know + 2. Stops the flow by routing to END + + Args: + state: Agent state containing the response flagged as risky + + Returns: + Updated state with error_message set, ready to route to END + """ + updated_state = state.copy() + + # Set error message for risky content + updated_state["error_message"] = "The information requested is classified or not free to know." + logger.warning("Risky content detected. Stopping processing. Response content not logged for security.") + + return updated_state diff --git a/RAGManager/app/agents/nodes/fallback_inicial.py b/RAGManager/app/agents/nodes/fallback_inicial.py new file mode 100644 index 0000000..e687bf8 --- /dev/null +++ b/RAGManager/app/agents/nodes/fallback_inicial.py @@ -0,0 +1,30 @@ +"""Nodo 3: Fallback Inicial - Stops processing when malicious content is detected.""" + +import logging + +from app.agents.state import AgentState + +logger = logging.getLogger(__name__) + + +def fallback_inicial(state: AgentState) -> AgentState: + """ + Fallback Inicial node - Stops processing when malicious content is detected. + + This node: + 1. Sets error message indicating that the user's intentions break the chatbot's rules + 2. Stops the flow by routing to END + + Args: + state: Agent state containing the prompt flagged as malicious + + Returns: + Updated state with error_message set, ready to route to END + """ + updated_state = state.copy() + + # Set error message for malicious content + updated_state["error_message"] = "The user's intentions break the chatbot's rules." + logger.warning("Malicious content detected. Stopping processing. Prompt content not logged for security.") + + return updated_state diff --git a/RAGManager/app/agents/nodes/guard_final.py b/RAGManager/app/agents/nodes/guard_final.py new file mode 100644 index 0000000..3e017c3 --- /dev/null +++ b/RAGManager/app/agents/nodes/guard_final.py @@ -0,0 +1,74 @@ +"""Nodo Guard Final - Validates generated response for PII (risky information detection).""" + +import logging + +from guardrails import Guard +from guardrails.hub import DetectPII + +from app.agents.state import AgentState +from app.core.config import settings + +logger = logging.getLogger(__name__) + +# Initialize Guard with DetectPII validator +# Note: The validator must be installed via: guardrails hub install hub://guardrails/detect_pii +_guard_final = Guard().use( + DetectPII( + pii_entities=settings.guardrails_pii_entities, + on_fail="noop", # Don't raise exceptions, handle via state flags + ) +) + + +def guard_final(state: AgentState) -> AgentState: + """ + Guard final node - Validates generated response for PII using Guardrails DetectPII. + + This node: + 1. Validates the generated_response using Guardrails DetectPII validator + 2. Sets is_risky flag if PII is detected + 3. Sets error_message if risky content is detected + + Args: + state: Agent state containing the generated_response + + Returns: + Updated state with is_risky and error_message set + """ + updated_state = state.copy() + generated_response = state.get("generated_response", "") + + if not generated_response: + # Empty response is considered safe + updated_state["is_risky"] = False + updated_state["error_message"] = None + return updated_state + + try: + # Validate the generated response using Guardrails + validation_result = _guard_final.validate(generated_response) + + # Check if validation passed + # The validator returns ValidationResult with outcome + # If validation fails, outcome will indicate failure + if validation_result.validation_passed: + updated_state["is_risky"] = False + updated_state["error_message"] = None + logger.debug("Generated response passed PII detection") + else: + # PII detected + updated_state["is_risky"] = True + updated_state["error_message"] = ( + "PII detected in generated response. The information requested is classified or not free to know." + ) + logger.warning("PII detected in generated response. Response content not logged for security.") + + except Exception as e: + # If validation fails due to error, log it but don't block the request + # This is a safety measure - if Guardrails fails, we allow the request + # but log the error for monitoring + logger.error(f"Error during PII detection: {e}") + updated_state["is_risky"] = False + updated_state["error_message"] = None + + return updated_state diff --git a/RAGManager/app/agents/nodes/guard.py b/RAGManager/app/agents/nodes/guard_inicial.py similarity index 84% rename from RAGManager/app/agents/nodes/guard.py rename to RAGManager/app/agents/nodes/guard_inicial.py index a681769..c6080f3 100644 --- a/RAGManager/app/agents/nodes/guard.py +++ b/RAGManager/app/agents/nodes/guard_inicial.py @@ -1,4 +1,4 @@ -"""Nodo 2: Guard - Validates for malicious content.""" +"""Nodo 2: Guard Inicial - Validates for malicious content (jailbreak detection).""" import logging @@ -12,7 +12,7 @@ # Initialize Guard with DetectJailbreak validator # Note: The validator must be installed via: guardrails hub install hub://guardrails/detect_jailbreak -_guard = Guard().use( +_guard_inicial = Guard().use( DetectJailbreak( threshold=settings.guardrails_jailbreak_threshold, device=settings.guardrails_device, @@ -21,9 +21,9 @@ ) -def guard(state: AgentState) -> AgentState: +def guard_inicial(state: AgentState) -> AgentState: """ - Guard node - Validates user input for malicious content using Guardrails DetectJailbreak. + Guard inicial node - Validates user input for jailbreak attempts using Guardrails DetectJailbreak. This node: 1. Validates the prompt using Guardrails DetectJailbreak validator @@ -49,7 +49,7 @@ def guard(state: AgentState) -> AgentState: try: # Validate the prompt using Guardrails - validation_result = _guard.validate(prompt) + validation_result = _guard_inicial.validate(prompt) # Check if validation passed # The validator returns ValidationResult with outcome @@ -64,7 +64,7 @@ def guard(state: AgentState) -> AgentState: updated_state["error_message"] = ( "Jailbreak attempt detected. Your request contains content that violates security policies." ) - logger.warning("Jailbreak attempt detected in prompt (len=%d)", len(prompt)) + logger.warning("Jailbreak attempt detected. Prompt content not logged for security.") except Exception as e: # If validation fails due to error, log it but don't block the request diff --git a/RAGManager/app/agents/routing.py b/RAGManager/app/agents/routing.py index 3b04236..5807313 100644 --- a/RAGManager/app/agents/routing.py +++ b/RAGManager/app/agents/routing.py @@ -3,9 +3,9 @@ from app.agents.state import AgentState -def route_after_guard(state: AgentState) -> str: +def route_after_guard_inicial(state: AgentState) -> str: """ - Route after Guard node (Nodo 2) validation. + Route after Guard Inicial node validation. Determines the next step based on whether the prompt was flagged as malicious. @@ -18,3 +18,20 @@ def route_after_guard(state: AgentState) -> str: if state.get("is_malicious", False): return "malicious" return "continue" + + +def route_after_guard_final(state: AgentState) -> str: + """ + Route after Guard Final node validation. + + Determines the next step based on whether the response was flagged as risky. + + Args: + state: Current agent state + + Returns: + "risky" if the response is risky, "continue" otherwise + """ + if state.get("is_risky", False): + return "risky" + return "continue" diff --git a/RAGManager/app/core/config.py b/RAGManager/app/core/config.py index d9e5a7d..b96d119 100644 --- a/RAGManager/app/core/config.py +++ b/RAGManager/app/core/config.py @@ -47,6 +47,19 @@ class Settings(BaseSettings): default="cpu", description="Device for model inference.", ) + guardrails_pii_entities: list[str] = Field( + default=[ + "EMAIL_ADDRESS", + "PHONE_NUMBER", + "CREDIT_CARD", + "SSN", + "US_PASSPORT", + "US_DRIVER_LICENSE", + "IBAN_CODE", + "IP_ADDRESS", + ], + description="List of PII entity types to detect using DetectPII validator.", + ) model_config = SettingsConfigDict( env_file=".env", env_file_encoding="utf-8", diff --git a/RAGManager/pyproject.toml b/RAGManager/pyproject.toml index 284bbb7..e7191ac 100644 --- a/RAGManager/pyproject.toml +++ b/RAGManager/pyproject.toml @@ -17,7 +17,9 @@ dependencies = [ "pydantic-settings>=2.0.0", "typing-extensions>=4.15.0", "uvicorn>=0.38.0", - "guardrails-ai>=0.5.10", + "guardrails-ai>=0.6.2", + "presidio-analyzer>=2.2.360", + "presidio-anonymizer>=2.2.360", ] [project.optional-dependencies]