Skip to content

Commit 7cfce44

Browse files
feat: move guardrail state to inner_state; add inputs to guardrail state (#420)
1 parent 2db250c commit 7cfce44

12 files changed

Lines changed: 178 additions & 58 deletions

File tree

src/uipath_langchain/agent/guardrails/actions/escalate_action.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ async def _node(
8383
"GuardrailDescription": guardrail.description,
8484
"Component": scope.name.lower(),
8585
"ExecutionStage": _execution_stage_to_string(execution_stage),
86-
"GuardrailResult": state.guardrail_validation_result,
86+
"GuardrailResult": state.inner_state.guardrail_validation_result,
8787
}
8888

8989
# Add tenant and trace URL if base_url is configured
@@ -275,7 +275,7 @@ def _process_agent_escalation_response(
275275
return Command(update={"messages": msgs})
276276

277277
# POST_EXECUTION: update agent_result
278-
return Command(update={"agent_result": parsed})
278+
return Command(update={"inner_state": {"agent_result": parsed}})
279279
except Exception as e:
280280
raise AgentTerminationException(
281281
code=UiPathErrorCode.EXECUTION_ERROR,
@@ -519,7 +519,7 @@ def _extract_agent_escalation_content(
519519
if execution_stage == ExecutionStage.PRE_EXECUTION:
520520
return get_message_content(cast(AnyMessage, message))
521521

522-
output_content = state.agent_result or ""
522+
output_content = state.inner_state.agent_result or ""
523523
return json.dumps(output_content)
524524

525525

src/uipath_langchain/agent/guardrails/actions/log_action.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def action_node(
4949
async def _node(_state: AgentGuardrailsGraphState) -> dict[str, Any]:
5050
message = (
5151
self.message
52-
or f"Guardrail [{guardrail.name}] validation failed for [{scope.name}] [{execution_stage.name}] with the following reason: {_state.guardrail_validation_result}"
52+
or f"Guardrail [{guardrail.name}] validation failed for [{scope.name}] [{execution_stage.name}] with the following reason: {_state.inner_state.guardrail_validation_result}"
5353
)
5454

5555
logger.log(self.level, message)

src/uipath_langchain/agent/guardrails/guardrail_nodes.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,17 @@ def _create_validation_command(
106106
AgentTerminationException: If the result is neither PASSED nor VALIDATION_FAILED.
107107
"""
108108
if guardrail_result.result == GuardrailValidationResultType.PASSED:
109-
return Command(goto=success_node, update={"guardrail_validation_result": None})
109+
return Command(
110+
goto=success_node,
111+
update={"inner_state": {"guardrail_validation_result": None}},
112+
)
110113

111114
if guardrail_result.result == GuardrailValidationResultType.VALIDATION_FAILED:
112115
return Command(
113116
goto=failure_node,
114-
update={"guardrail_validation_result": guardrail_result.reason},
117+
update={
118+
"inner_state": {"guardrail_validation_result": guardrail_result.reason}
119+
},
115120
)
116121

117122
# For other results (FEATURE_DISABLED, ENTITLEMENTS_MISSING, etc.), interrupt execution
@@ -260,7 +265,7 @@ def create_agent_terminate_guardrail_node(
260265
failure_node: str,
261266
) -> tuple[str, Callable[[AgentGuardrailsGraphState], Any]]:
262267
def _payload_generator(state: AgentGuardrailsGraphState) -> str:
263-
return str(state.agent_result)
268+
return str(state.inner_state.agent_result)
264269

265270
return _create_guardrail_node(
266271
guardrail,

src/uipath_langchain/agent/react/agent.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Callable, Sequence, Type, TypeVar, cast
2+
from typing import Callable, Sequence, Type, TypeVar
33

44
from langchain_core.language_models import BaseChatModel
55
from langchain_core.messages import HumanMessage, SystemMessage
@@ -30,23 +30,17 @@
3030
create_terminate_node,
3131
)
3232
from .tools import create_flow_control_tools
33-
from .types import AgentGraphConfig, AgentGraphNode, AgentGraphState
33+
from .types import (
34+
AgentGraphConfig,
35+
AgentGraphNode,
36+
AgentGraphState,
37+
)
38+
from .utils import create_state_with_input
3439

3540
InputT = TypeVar("InputT", bound=BaseModel)
3641
OutputT = TypeVar("OutputT", bound=BaseModel)
3742

3843

39-
def create_state_with_input(input_schema: Type[InputT]):
40-
CompleteAgentGraphState = type(
41-
"CompleteAgentGraphState",
42-
(AgentGraphState, input_schema),
43-
{},
44-
)
45-
46-
cast(type[BaseModel], CompleteAgentGraphState).model_rebuild()
47-
return CompleteAgentGraphState
48-
49-
5044
def create_agent(
5145
model: BaseChatModel,
5246
tools: Sequence[BaseTool],
@@ -84,7 +78,7 @@ def create_agent(
8478

8579
tool_nodes = create_tool_node(agent_tools)
8680
tool_nodes_with_guardrails = create_tools_guardrails_subgraph(
87-
tool_nodes, guardrails
81+
tool_nodes, guardrails, input_schema=input_schema
8882
)
8983
terminate_node = create_terminate_node(output_schema, config.is_conversational)
9084

@@ -98,6 +92,7 @@ def create_agent(
9892
init_with_guardrails_subgraph = create_agent_init_guardrails_subgraph(
9993
(AgentGraphNode.GUARDED_INIT, init_node),
10094
guardrails,
95+
input_schema=input_schema,
10196
)
10297
builder.add_node(AgentGraphNode.INIT, init_with_guardrails_subgraph)
10398

@@ -107,6 +102,7 @@ def create_agent(
107102
terminate_with_guardrails_subgraph = create_agent_terminate_guardrails_subgraph(
108103
(AgentGraphNode.GUARDED_TERMINATE, terminate_node),
109104
guardrails,
105+
input_schema=input_schema,
110106
)
111107
builder.add_node(AgentGraphNode.TERMINATE, terminate_with_guardrails_subgraph)
112108

@@ -116,7 +112,7 @@ def create_agent(
116112
model, llm_tools, config.thinking_messages_limit, config.is_conversational
117113
)
118114
llm_with_guardrails_subgraph = create_llm_guardrails_subgraph(
119-
(AgentGraphNode.LLM, llm_node), guardrails
115+
(AgentGraphNode.LLM, llm_node), guardrails, input_schema=input_schema
120116
)
121117
builder.add_node(AgentGraphNode.AGENT, llm_with_guardrails_subgraph)
122118
builder.add_edge(AgentGraphNode.INIT, AgentGraphNode.AGENT)

src/uipath_langchain/agent/react/guardrails/guardrails_subgraph.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from functools import partial
2-
from typing import Any, Callable, Mapping, Sequence
2+
from typing import Any, Callable, Mapping, Sequence, TypeVar
33

44
from langgraph._internal._runnable import RunnableCallable
55
from langgraph.constants import END, START
66
from langgraph.graph import StateGraph
7+
from pydantic import BaseModel
78
from uipath.core.guardrails import DeterministicGuardrail
89
from uipath.platform.guardrails import (
910
BaseGuardrail,
@@ -26,6 +27,7 @@
2627
AgentGraphState,
2728
AgentGuardrailsGraphState,
2829
)
30+
from uipath_langchain.agent.react.utils import create_guardrails_state_with_input
2931

3032
_VALIDATOR_ALLOWED_STAGES = {
3133
"prompt_injection": {ExecutionStage.PRE_EXECUTION},
@@ -65,6 +67,7 @@ def _create_guardrails_subgraph(
6567
],
6668
GuardrailActionNode,
6769
] = create_llm_guardrail_node,
70+
input_schema: type[BaseModel] | None = None,
6871
):
6972
"""Build a subgraph that enforces guardrails around an inner node.
7073
@@ -83,7 +86,8 @@ def _create_guardrails_subgraph(
8386
"""
8487
inner_name, inner_node = main_inner_node
8588

86-
subgraph = StateGraph(AgentGuardrailsGraphState)
89+
CompleteAgentGuardrailsGraphState = create_guardrails_state_with_input(input_schema)
90+
subgraph = StateGraph(CompleteAgentGuardrailsGraphState)
8791

8892
subgraph.add_node(inner_name, inner_node)
8993

@@ -203,12 +207,14 @@ def _build_guardrail_node_chain(
203207
def create_llm_guardrails_subgraph(
204208
llm_node: tuple[str, Any],
205209
guardrails: Sequence[tuple[BaseGuardrail, GuardrailAction]] | None,
210+
input_schema: type[BaseModel] | None = None,
206211
):
207212
"""Create a guarded LLM node.
208213
209214
Args:
210215
llm_node: Tuple of (node_name, node_callable) for the LLM node.
211216
guardrails: Optional sequence of (guardrail, action) tuples.
217+
input_schema: Optional input schema to include in state.
212218
213219
Returns:
214220
Either the original node callable (if no applicable guardrails) or a compiled
@@ -229,17 +235,20 @@ def create_llm_guardrails_subgraph(
229235
scope=GuardrailScope.LLM,
230236
execution_stages=[ExecutionStage.PRE_EXECUTION, ExecutionStage.POST_EXECUTION],
231237
node_factory=create_llm_guardrail_node,
238+
input_schema=input_schema,
232239
)
233240

234241

235242
def create_tools_guardrails_subgraph(
236243
tool_nodes: Mapping[str, RunnableCallable],
237244
guardrails: Sequence[tuple[BaseGuardrail, GuardrailAction]] | None,
245+
input_schema: type[BaseModel] | None = None,
238246
) -> dict[str, RunnableCallable]:
239247
"""Create tool nodes with guardrails applied.
240248
Args:
241249
tool_nodes: Mapping of tool name to a LangGraph `ToolNode`.
242250
guardrails: Optional sequence of (guardrail, action) tuples.
251+
input_schema: Optional input schema to include in state.
243252
244253
Returns:
245254
A mapping of tool name to either the original `ToolNode` or a compiled subgraph
@@ -250,6 +259,7 @@ def create_tools_guardrails_subgraph(
250259
subgraph = create_tool_guardrails_subgraph(
251260
(tool_name, tool_node),
252261
guardrails,
262+
input_schema=input_schema,
253263
)
254264
result[tool_name] = subgraph
255265

@@ -259,6 +269,7 @@ def create_tools_guardrails_subgraph(
259269
def create_agent_init_guardrails_subgraph(
260270
init_node: tuple[str, Any],
261271
guardrails: Sequence[tuple[BaseGuardrail, GuardrailAction]] | None,
272+
input_schema: type[BaseModel] | None = None,
262273
) -> Any:
263274
"""Create a subgraph for the INIT node and apply AGENT guardrails after INIT.
264275
@@ -269,6 +280,7 @@ def create_agent_init_guardrails_subgraph(
269280
Args:
270281
init_node: Tuple of (node_name, node_callable) for the INIT node.
271282
guardrails: Optional sequence of (guardrail, action) tuples.
283+
input_schema: Optional input schema to include in state.
272284
273285
Returns:
274286
Either the original node callable (if no applicable guardrails) or a compiled
@@ -287,7 +299,8 @@ def create_agent_init_guardrails_subgraph(
287299
return init_node[1]
288300

289301
inner_name, inner_node = init_node
290-
subgraph = StateGraph(AgentGuardrailsGraphState)
302+
CompleteAgentGuardrailsGraphState = create_guardrails_state_with_input(input_schema)
303+
subgraph = StateGraph(CompleteAgentGuardrailsGraphState)
291304
subgraph.add_node(inner_name, inner_node)
292305
subgraph.add_edge(START, inner_name)
293306

@@ -307,15 +320,15 @@ def create_agent_init_guardrails_subgraph(
307320
def create_agent_terminate_guardrails_subgraph(
308321
terminate_node: tuple[str, Any],
309322
guardrails: Sequence[tuple[BaseGuardrail, GuardrailAction]] | None,
323+
input_schema: type[BaseModel] | None = None,
310324
):
311325
"""Create a subgraph for TERMINATE node that applies guardrails on the agent result."""
312326
node_name, node_func = terminate_node
313327

314328
def terminate_wrapper(state: Any) -> dict[str, Any]:
315329
# Call original terminate node
316330
result = node_func(state)
317-
# Store result in state
318-
return {"agent_result": result, "messages": state.messages}
331+
return {"inner_state": {"agent_result": result}}
319332

320333
applicable_guardrails = [
321334
(guardrail, _)
@@ -332,26 +345,31 @@ def terminate_wrapper(state: Any) -> dict[str, Any]:
332345
scope=GuardrailScope.AGENT,
333346
execution_stages=[ExecutionStage.POST_EXECUTION],
334347
node_factory=create_agent_terminate_guardrail_node,
348+
input_schema=input_schema,
335349
)
336350

351+
StateT = TypeVar("StateT", bound=AgentGraphState)
352+
337353
async def run_terminate_subgraph(
338-
state: AgentGraphState,
354+
state: StateT,
339355
) -> dict[str, Any]:
340356
result_state = await subgraph.ainvoke(state)
341-
return result_state["agent_result"]
357+
return result_state["inner_state"].agent_result
342358

343359
return run_terminate_subgraph
344360

345361

346362
def create_tool_guardrails_subgraph(
347363
tool_node: tuple[str, Any],
348364
guardrails: Sequence[tuple[BaseGuardrail, GuardrailAction]] | None,
365+
input_schema: type[BaseModel] | None = None,
349366
):
350367
"""Create a guarded tool node.
351368
352369
Args:
353370
tool_node: Tuple of (tool_name, tool_node_callable).
354371
guardrails: Optional sequence of (guardrail, action) tuples.
372+
input_schema: Optional input schema to include in state.
355373
356374
Returns:
357375
Either the original tool node callable (if no matching guardrails) or a compiled
@@ -374,4 +392,5 @@ def create_tool_guardrails_subgraph(
374392
scope=GuardrailScope.TOOL,
375393
execution_stages=[ExecutionStage.PRE_EXECUTION, ExecutionStage.POST_EXECUTION],
376394
node_factory=partial(create_tool_guardrail_node, tool_name=tool_name),
395+
input_schema=input_schema,
377396
)

src/uipath_langchain/agent/react/types.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ class InnerAgentGraphState(BaseModel):
2626
termination: AgentTermination | None = None
2727

2828

29+
class InnerAgentGuardrailsGraphState(InnerAgentGraphState):
30+
"""Extended inner state for guardrails subgraph."""
31+
32+
guardrail_validation_result: Optional[str] = None
33+
agent_result: Optional[dict[str, Any]] = None
34+
35+
2936
class AgentGraphState(BaseModel):
3037
"""Agent Graph state for standard loop execution."""
3138

@@ -38,8 +45,9 @@ class AgentGraphState(BaseModel):
3845
class AgentGuardrailsGraphState(AgentGraphState):
3946
"""Agent Guardrails Graph state for guardrail subgraph."""
4047

41-
guardrail_validation_result: Optional[str] = None
42-
agent_result: Optional[dict[str, Any]] = None
48+
inner_state: Annotated[InnerAgentGuardrailsGraphState, merge_objects] = Field(
49+
default_factory=InnerAgentGuardrailsGraphState
50+
)
4351

4452

4553
class AgentGraphNode(StrEnum):

src/uipath_langchain/agent/react/utils.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
"""ReAct Agent loop utilities."""
22

3-
from typing import Any, Sequence
3+
from typing import Any, Sequence, TypeVar, cast
44

55
from langchain_core.messages import AIMessage, BaseMessage
66
from pydantic import BaseModel
77
from uipath.agent.react import END_EXECUTION_TOOL
88

99
from uipath_langchain.agent.react.jsonschema_pydantic_converter import create_model
10+
from uipath_langchain.agent.react.types import (
11+
AgentGraphState,
12+
AgentGuardrailsGraphState,
13+
)
1014

1115

1216
def resolve_input_model(
@@ -48,3 +52,41 @@ def count_consecutive_thinking_messages(messages: Sequence[BaseMessage]) -> int:
4852
count += 1
4953

5054
return count
55+
56+
57+
InputT = TypeVar("InputT", bound=BaseModel)
58+
GraphStateT = TypeVar("GraphStateT", bound=BaseModel)
59+
60+
61+
def _create_state_model_with_input(
62+
state_model: type[GraphStateT],
63+
input_schema: type[InputT] | None,
64+
model_name: str = "CompleteStateModel",
65+
) -> type[GraphStateT]:
66+
if input_schema is None:
67+
return state_model
68+
69+
CompleteStateModel = type(
70+
model_name,
71+
(state_model, input_schema),
72+
{},
73+
)
74+
75+
cast(type[GraphStateT], CompleteStateModel).model_rebuild()
76+
return CompleteStateModel
77+
78+
79+
def create_state_with_input(input_schema: type[InputT] | None) -> type[AgentGraphState]:
80+
return _create_state_model_with_input(
81+
AgentGraphState, input_schema, model_name="CompleteAgentGraphState"
82+
)
83+
84+
85+
def create_guardrails_state_with_input(
86+
input_schema: type[InputT] | None,
87+
) -> type[AgentGuardrailsGraphState]:
88+
return _create_state_model_with_input(
89+
AgentGuardrailsGraphState,
90+
input_schema,
91+
model_name="CompleteAgentGuardrailsGraphState",
92+
)

0 commit comments

Comments
 (0)