11from functools import partial
2- from typing import Any , Callable , Mapping , Sequence
2+ from typing import Any , Callable , Mapping , Sequence , TypeVar
33
44from langgraph ._internal ._runnable import RunnableCallable
55from langgraph .constants import END , START
66from langgraph .graph import StateGraph
7+ from pydantic import BaseModel
78from uipath .core .guardrails import DeterministicGuardrail
89from uipath .platform .guardrails import (
910 BaseGuardrail ,
2627 AgentGraphState ,
2728 AgentGuardrailsGraphState ,
2829)
30+ from uipath_langchain .agent .react .utils import create_guardrails_state_with_input
2931
3032_VALIDATOR_ALLOWED_STAGES = {
3133 "prompt_injection" : {ExecutionStage .PRE_EXECUTION },
@@ -65,6 +67,7 @@ def _create_guardrails_subgraph(
6567 ],
6668 GuardrailActionNode ,
6769 ] = create_llm_guardrail_node ,
70+ input_schema : type [BaseModel ] | None = None ,
6871):
6972 """Build a subgraph that enforces guardrails around an inner node.
7073
@@ -83,7 +86,8 @@ def _create_guardrails_subgraph(
8386 """
8487 inner_name , inner_node = main_inner_node
8588
86- subgraph = StateGraph (AgentGuardrailsGraphState )
89+ CompleteAgentGuardrailsGraphState = create_guardrails_state_with_input (input_schema )
90+ subgraph = StateGraph (CompleteAgentGuardrailsGraphState )
8791
8892 subgraph .add_node (inner_name , inner_node )
8993
@@ -203,12 +207,14 @@ def _build_guardrail_node_chain(
203207def create_llm_guardrails_subgraph (
204208 llm_node : tuple [str , Any ],
205209 guardrails : Sequence [tuple [BaseGuardrail , GuardrailAction ]] | None ,
210+ input_schema : type [BaseModel ] | None = None ,
206211):
207212 """Create a guarded LLM node.
208213
209214 Args:
210215 llm_node: Tuple of (node_name, node_callable) for the LLM node.
211216 guardrails: Optional sequence of (guardrail, action) tuples.
217+ input_schema: Optional input schema to include in state.
212218
213219 Returns:
214220 Either the original node callable (if no applicable guardrails) or a compiled
@@ -229,17 +235,20 @@ def create_llm_guardrails_subgraph(
229235 scope = GuardrailScope .LLM ,
230236 execution_stages = [ExecutionStage .PRE_EXECUTION , ExecutionStage .POST_EXECUTION ],
231237 node_factory = create_llm_guardrail_node ,
238+ input_schema = input_schema ,
232239 )
233240
234241
235242def create_tools_guardrails_subgraph (
236243 tool_nodes : Mapping [str , RunnableCallable ],
237244 guardrails : Sequence [tuple [BaseGuardrail , GuardrailAction ]] | None ,
245+ input_schema : type [BaseModel ] | None = None ,
238246) -> dict [str , RunnableCallable ]:
239247 """Create tool nodes with guardrails applied.
240248 Args:
241249 tool_nodes: Mapping of tool name to a LangGraph `ToolNode`.
242250 guardrails: Optional sequence of (guardrail, action) tuples.
251+ input_schema: Optional input schema to include in state.
243252
244253 Returns:
245254 A mapping of tool name to either the original `ToolNode` or a compiled subgraph
@@ -250,6 +259,7 @@ def create_tools_guardrails_subgraph(
250259 subgraph = create_tool_guardrails_subgraph (
251260 (tool_name , tool_node ),
252261 guardrails ,
262+ input_schema = input_schema ,
253263 )
254264 result [tool_name ] = subgraph
255265
@@ -259,6 +269,7 @@ def create_tools_guardrails_subgraph(
259269def create_agent_init_guardrails_subgraph (
260270 init_node : tuple [str , Any ],
261271 guardrails : Sequence [tuple [BaseGuardrail , GuardrailAction ]] | None ,
272+ input_schema : type [BaseModel ] | None = None ,
262273) -> Any :
263274 """Create a subgraph for the INIT node and apply AGENT guardrails after INIT.
264275
@@ -269,6 +280,7 @@ def create_agent_init_guardrails_subgraph(
269280 Args:
270281 init_node: Tuple of (node_name, node_callable) for the INIT node.
271282 guardrails: Optional sequence of (guardrail, action) tuples.
283+ input_schema: Optional input schema to include in state.
272284
273285 Returns:
274286 Either the original node callable (if no applicable guardrails) or a compiled
@@ -287,7 +299,8 @@ def create_agent_init_guardrails_subgraph(
287299 return init_node [1 ]
288300
289301 inner_name , inner_node = init_node
290- subgraph = StateGraph (AgentGuardrailsGraphState )
302+ CompleteAgentGuardrailsGraphState = create_guardrails_state_with_input (input_schema )
303+ subgraph = StateGraph (CompleteAgentGuardrailsGraphState )
291304 subgraph .add_node (inner_name , inner_node )
292305 subgraph .add_edge (START , inner_name )
293306
@@ -307,15 +320,15 @@ def create_agent_init_guardrails_subgraph(
307320def create_agent_terminate_guardrails_subgraph (
308321 terminate_node : tuple [str , Any ],
309322 guardrails : Sequence [tuple [BaseGuardrail , GuardrailAction ]] | None ,
323+ input_schema : type [BaseModel ] | None = None ,
310324):
311325 """Create a subgraph for TERMINATE node that applies guardrails on the agent result."""
312326 node_name , node_func = terminate_node
313327
314328 def terminate_wrapper (state : Any ) -> dict [str , Any ]:
315329 # Call original terminate node
316330 result = node_func (state )
317- # Store result in state
318- return {"agent_result" : result , "messages" : state .messages }
331+ return {"inner_state" : {"agent_result" : result }}
319332
320333 applicable_guardrails = [
321334 (guardrail , _ )
@@ -332,26 +345,31 @@ def terminate_wrapper(state: Any) -> dict[str, Any]:
332345 scope = GuardrailScope .AGENT ,
333346 execution_stages = [ExecutionStage .POST_EXECUTION ],
334347 node_factory = create_agent_terminate_guardrail_node ,
348+ input_schema = input_schema ,
335349 )
336350
351+ StateT = TypeVar ("StateT" , bound = AgentGraphState )
352+
337353 async def run_terminate_subgraph (
338- state : AgentGraphState ,
354+ state : StateT ,
339355 ) -> dict [str , Any ]:
340356 result_state = await subgraph .ainvoke (state )
341- return result_state ["agent_result" ]
357+ return result_state ["inner_state" ]. agent_result
342358
343359 return run_terminate_subgraph
344360
345361
346362def create_tool_guardrails_subgraph (
347363 tool_node : tuple [str , Any ],
348364 guardrails : Sequence [tuple [BaseGuardrail , GuardrailAction ]] | None ,
365+ input_schema : type [BaseModel ] | None = None ,
349366):
350367 """Create a guarded tool node.
351368
352369 Args:
353370 tool_node: Tuple of (tool_name, tool_node_callable).
354371 guardrails: Optional sequence of (guardrail, action) tuples.
372+ input_schema: Optional input schema to include in state.
355373
356374 Returns:
357375 Either the original tool node callable (if no matching guardrails) or a compiled
@@ -374,4 +392,5 @@ def create_tool_guardrails_subgraph(
374392 scope = GuardrailScope .TOOL ,
375393 execution_stages = [ExecutionStage .PRE_EXECUTION , ExecutionStage .POST_EXECUTION ],
376394 node_factory = partial (create_tool_guardrail_node , tool_name = tool_name ),
395+ input_schema = input_schema ,
377396 )
0 commit comments