-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathruntime.py
More file actions
572 lines (489 loc) · 22.1 KB
/
runtime.py
File metadata and controls
572 lines (489 loc) · 22.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
import logging
import os
from typing import Any, AsyncGenerator
from uuid import uuid4
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.runnables.config import RunnableConfig
from langchain_core.tools import BaseTool
from langgraph.errors import EmptyInputError, GraphRecursionError, InvalidUpdateError
from langgraph.graph.state import CompiledStateGraph
from langgraph.types import Command, Interrupt, StateSnapshot
from uipath.runtime import (
UiPathBreakpointResult,
UiPathExecuteOptions,
UiPathRuntimeResult,
UiPathRuntimeStatus,
UiPathRuntimeStorageProtocol,
UiPathStreamOptions,
)
from uipath.runtime.errors import (
UiPathBaseRuntimeError,
UiPathErrorCategory,
UiPathErrorCode,
)
from uipath.runtime.events import (
UiPathRuntimeEvent,
UiPathRuntimeMessageEvent,
UiPathRuntimeStateEvent,
UiPathRuntimeStatePhase,
)
from uipath.runtime.schema import UiPathRuntimeSchema
from uipath_langchain.agent.tools.tool_node import RunnableCallableWithTool
from uipath_langchain.chat.hitl import get_confirmation_schema
from uipath_langchain.runtime.errors import LangGraphErrorCode, LangGraphRuntimeError
from uipath_langchain.runtime.messages import UiPathChatMessagesMapper
from uipath_langchain.runtime.schema import get_entrypoints_schema, get_graph_schema
from ._serialize import serialize_output
logger = logging.getLogger(__name__)
class UiPathLangGraphRuntime:
"""
A runtime class for executing LangGraph graphs within the UiPath framework.
"""
def __init__(
self,
graph: CompiledStateGraph[Any, Any, Any, Any],
runtime_id: str | None = None,
entrypoint: str | None = None,
callbacks: list[BaseCallbackHandler] | None = None,
storage: UiPathRuntimeStorageProtocol | None = None,
):
"""
Initialize the runtime.
Args:
graph: The CompiledStateGraph to execute
runtime_id: Unique identifier for this runtime instance
entrypoint: Optional entrypoint name (for schema generation)
"""
self.graph: CompiledStateGraph[Any, Any, Any, Any] = graph
self.runtime_id: str = runtime_id or "default"
self.entrypoint: str | None = entrypoint
self.callbacks: list[BaseCallbackHandler] = callbacks or []
self.chat = UiPathChatMessagesMapper(self.runtime_id, storage)
self.chat.tools_requiring_confirmation = self._get_tool_confirmation_info()
self._middleware_node_names: set[str] = self._detect_middleware_nodes()
async def execute(
self,
input: dict[str, Any] | None = None,
options: UiPathExecuteOptions | None = None,
) -> UiPathRuntimeResult:
"""Execute the graph with the provided input and configuration."""
try:
graph_input = await self._get_graph_input(input, options)
graph_config = self._get_graph_config()
# Execute without streaming
graph_output = await self.graph.ainvoke(
graph_input,
graph_config,
interrupt_before=options.breakpoints if options else None,
)
# Get final state and create result
result = await self._create_runtime_result(graph_config, graph_output)
return result
except Exception as e:
raise self.create_runtime_error(e) from e
async def stream(
self,
input: dict[str, Any] | None = None,
options: UiPathStreamOptions | None = None,
) -> AsyncGenerator[UiPathRuntimeEvent, None]:
"""
Stream graph execution events in real-time.
Yields UiPath UiPathRuntimeEvent instances (thin wrappers around framework data),
then yields the final UiPathRuntimeResult as the last item.
Yields:
- UiPathRuntimeMessageEvent: Wraps framework messages (BaseMessage, chunks, etc.)
- UiPathRuntimeStateEvent: Wraps framework state updates
- Final event: UiPathRuntimeResult or UiPathBreakpointResult
Example:
async for event in runtime.stream():
if isinstance(event, UiPathRuntimeResult):
# Last event is the result
print(f"Final result: {event}")
elif isinstance(event, UiPathRuntimeMessageEvent):
# Access framework-specific message
message = event.payload # BaseMessage or AIMessageChunk
print(f"Message: {message.content}")
elif isinstance(event, UiPathRuntimeStateEvent):
# Access framework-specific state
state = event.payload
print(f"Node {event.node_name} updated: {state}")
Raises:
LangGraphRuntimeError: If execution fails
"""
try:
graph_input = await self._get_graph_input(input, options)
graph_config = self._get_graph_config()
# Track final chunk for result creation
final_chunk: dict[Any, Any] | None = None
# Stream events from graph
async for stream_chunk in self.graph.astream(
graph_input,
graph_config,
interrupt_before=options.breakpoints if options else None,
stream_mode=["messages", "updates", "tasks"],
subgraphs=True,
):
namespace, chunk_type, data = stream_chunk
# Emit UiPathRuntimeMessageEvent for messages
if chunk_type == "messages":
if isinstance(data, tuple):
message, _ = data
try:
events = await self.chat.map_event(message)
except Exception as e:
logger.warning(f"Error mapping message event: {e}")
events = None
if events:
for mapped_event in events:
event = UiPathRuntimeMessageEvent(
payload=mapped_event,
)
yield event
# Emit UiPathRuntimeStateEvent for state updates
elif chunk_type == "updates":
if isinstance(data, dict):
filtered_data = {
node_name: agent_data
for node_name, agent_data in data.items()
if not self._is_middleware_node(node_name)
}
if filtered_data:
final_chunk = filtered_data
# Emit state update event for each node
for node_name, agent_data in data.items():
if node_name in ("__metadata__",):
continue
state_event = UiPathRuntimeStateEvent(
payload=serialize_output(agent_data)
if isinstance(agent_data, dict)
else {},
node_name=node_name,
qualified_node_name=self._build_node_name(
namespace,
node_name,
),
)
yield state_event
elif chunk_type == "tasks":
if isinstance(data, dict):
task_name = data.get("name", "")
if "input" in data:
phase = UiPathRuntimeStatePhase.STARTED
elif "result" in data:
phase = (
UiPathRuntimeStatePhase.FAULTED
if data.get("error")
else UiPathRuntimeStatePhase.COMPLETED
)
else:
phase = None
if phase is not None:
state_event = UiPathRuntimeStateEvent(
payload=serialize_output(data),
node_name=task_name,
qualified_node_name=self._build_node_name(
namespace,
task_name,
),
phase=phase,
)
yield state_event
# Extract output from final chunk
graph_output = self._extract_graph_result(final_chunk)
# Get final state and create result
result = await self._create_runtime_result(graph_config, graph_output)
# Yield the final result as last event
yield result
except Exception as e:
raise self.create_runtime_error(e) from e
async def get_schema(self) -> UiPathRuntimeSchema:
"""Get schema for this LangGraph runtime."""
schema_details = get_entrypoints_schema(self.graph)
return UiPathRuntimeSchema(
filePath=self.entrypoint,
uniqueId=str(uuid4()),
type="agent",
input=schema_details.schema["input"],
output=schema_details.schema["output"],
graph=get_graph_schema(self.graph, xray=1),
)
# This can be overriden by subclasses working with custom exception hierarchies
def create_runtime_error(self, e: Exception) -> UiPathBaseRuntimeError:
"""Handle execution errors and create appropriate LangGraphRuntimeError."""
if isinstance(e, LangGraphRuntimeError):
return e
detail = f"Error: {str(e)}"
if isinstance(e, GraphRecursionError):
return LangGraphRuntimeError(
LangGraphErrorCode.GRAPH_LOAD_ERROR,
"Graph recursion limit exceeded",
detail,
UiPathErrorCategory.USER,
)
if isinstance(e, InvalidUpdateError):
return LangGraphRuntimeError(
LangGraphErrorCode.GRAPH_INVALID_UPDATE,
str(e),
detail,
UiPathErrorCategory.USER,
)
if isinstance(e, EmptyInputError):
return LangGraphRuntimeError(
LangGraphErrorCode.GRAPH_EMPTY_INPUT,
"The input data is empty",
detail,
UiPathErrorCategory.USER,
)
return LangGraphRuntimeError(
UiPathErrorCode.EXECUTION_ERROR,
"Graph execution failed",
detail,
UiPathErrorCategory.USER,
)
def _get_graph_config(self) -> RunnableConfig:
"""Build graph execution configuration."""
graph_config: RunnableConfig = {
"configurable": {"thread_id": self.runtime_id},
"callbacks": self.callbacks,
}
# Add optional config from environment
recursion_limit = os.environ.get("LANGCHAIN_RECURSION_LIMIT")
max_concurrency = os.environ.get("LANGCHAIN_MAX_CONCURRENCY")
if recursion_limit is not None:
graph_config["recursion_limit"] = int(recursion_limit)
if max_concurrency is not None:
graph_config["max_concurrency"] = int(max_concurrency)
return graph_config
async def _get_graph_input(
self,
input: dict[str, Any] | None,
options: UiPathExecuteOptions | None,
) -> Any:
"""Process and return graph input."""
graph_input = input or {}
if isinstance(graph_input, dict):
messages = graph_input.get("messages", None)
if messages and isinstance(messages, list):
graph_input["messages"] = self.chat.map_messages(messages)
if options and options.resume:
return Command(resume=graph_input)
return graph_input
async def _get_graph_state(
self,
graph_config: RunnableConfig,
) -> StateSnapshot | None:
"""Get final graph state."""
try:
return await self.graph.aget_state(graph_config)
except Exception:
return None
def _extract_graph_result(self, final_chunk: Any) -> Any:
"""
Extract the result from a LangGraph output chunk according to the graph's output channels.
Args:
final_chunk: The final chunk from graph.astream()
output_channels: The graph's output channel configuration
Returns:
The extracted result according to the graph's output_channels configuration
"""
# Unwrap from subgraph tuple format if needed
if isinstance(final_chunk, tuple) and len(final_chunk) == 2:
final_chunk = final_chunk[1]
# If the result isn't a dict or graph doesn't define output channels, return as is
if not isinstance(final_chunk, dict):
return final_chunk
output_channels = self.graph.output_channels
# Case 1: Single output channel as string
if isinstance(output_channels, str):
return final_chunk.get(output_channels, final_chunk)
# Case 2: Multiple output channels as sequence
elif hasattr(output_channels, "__iter__") and not isinstance(
output_channels, str
):
# Check which channels are present
available_channels = [ch for ch in output_channels if ch in final_chunk]
# If no available channels, output may contain the last_node name as key
unwrapped_final_chunk = {}
if not available_channels and len(final_chunk) == 1:
potential_unwrap = next(iter(final_chunk.values()))
if isinstance(potential_unwrap, dict):
unwrapped_final_chunk = potential_unwrap
available_channels = [
ch for ch in output_channels if ch in unwrapped_final_chunk
]
if available_channels:
# Create a dict with the available channels
return {
channel: final_chunk.get(channel)
or unwrapped_final_chunk.get(channel)
for channel in available_channels
}
# Fallback for any other case
return final_chunk
def _is_interrupted(self, state: StateSnapshot) -> bool:
"""Check if execution was interrupted (static or dynamic)."""
# An execution is considered interrupted if there are any next nodes (static interrupt)
# or if there are any dynamic interrupts present
return bool(state.next) or bool(state.interrupts)
async def _create_runtime_result(
self,
graph_config: RunnableConfig,
graph_output: Any,
) -> UiPathRuntimeResult:
"""
Get final graph state and create the execution result.
Args:
graph_config: The graph execution configuration
graph_output: The graph execution output
"""
# Get the final state
graph_state = await self._get_graph_state(graph_config)
# Check if execution was interrupted (static or dynamic)
if graph_state and self._is_interrupted(graph_state):
return await self._create_suspended_result(graph_state)
else:
# Normal completion
return self._create_success_result(graph_output)
async def _create_suspended_result(
self,
graph_state: StateSnapshot,
) -> UiPathRuntimeResult:
"""Create result for suspended execution."""
interrupt_map: dict[str, Any] = {}
if graph_state.interrupts:
for interrupt in graph_state.interrupts:
if isinstance(interrupt, Interrupt):
# Find which task this interrupt belongs to
for task in graph_state.tasks:
if task.interrupts and interrupt in task.interrupts:
# Only include if this task is still waiting for interrupt resolution
if task.interrupts and not task.result:
interrupt_map[interrupt.id] = interrupt.value
break
# If we have dynamic interrupts, return suspended with interrupt map
# The output is used to create the resume triggers
if interrupt_map:
return UiPathRuntimeResult(
output=interrupt_map,
status=UiPathRuntimeStatus.SUSPENDED,
)
else:
# Static interrupt (breakpoint)
return self._create_breakpoint_result(graph_state)
def _create_breakpoint_result(
self,
graph_state: StateSnapshot,
) -> UiPathBreakpointResult:
"""Create result for execution paused at a breakpoint."""
# Get next nodes - these are the nodes that will execute when resumed
next_nodes = list(graph_state.next)
# Determine breakpoint type and node
if next_nodes:
# Breakpoint is BEFORE these nodes (interrupt_before)
breakpoint_type = "before"
breakpoint_node = ", ".join(next_nodes)
else:
# Breakpoint is AFTER the last executed node (interrupt_after)
# Get the last executed node from tasks
breakpoint_type = "after"
if graph_state.tasks:
# Tasks contain the nodes that just executed
# Get the last task's name
breakpoint_node = graph_state.tasks[-1].name
else:
# Fallback if no tasks (shouldn't happen)
breakpoint_node = "unknown"
return UiPathBreakpointResult(
breakpoint_node=breakpoint_node,
breakpoint_type=breakpoint_type,
current_state=serialize_output(graph_state.values) or {},
next_nodes=next_nodes,
)
def _create_success_result(self, output: Any) -> UiPathRuntimeResult:
"""Create result for successful completion."""
return UiPathRuntimeResult(
output=serialize_output(output) or {},
status=UiPathRuntimeStatus.SUCCESSFUL,
)
def _detect_middleware_nodes(self) -> set[str]:
"""
Detect middleware nodes by their naming pattern.
Middleware nodes always contain both:
1. "Middleware" in the name (by convention)
2. A dot "." separator (MiddlewareName.hook_name)
Returns:
Set of middleware node names
"""
middleware_nodes: set[str] = set()
for node_name in self.graph.nodes.keys():
if "." in node_name and "Middleware" in node_name:
middleware_nodes.add(node_name)
return middleware_nodes
def _get_tool_confirmation_info(self) -> dict[str, Any]:
"""Build {tool_name: input_schema} for tools requiring confirmation.
Walks compiled graph nodes once at runtime init. This is needed because coded agents
(create_agent) export a compiled graph as the only artifact — there's no side channel
to pass confirmation metadata from the build step to the runtime.
"""
schemas: dict[str, Any] = {}
for node_spec in self.graph.nodes.values():
bound = getattr(node_spec, "bound", None)
if bound is None:
continue
# Coded agents: one tool per node
if isinstance(bound, RunnableCallableWithTool):
schema = get_confirmation_schema(bound.tool)
if schema is not None:
schemas[bound.tool.name] = schema
continue
# Low-code agents: multiple tools in one node
tools_by_name = getattr(bound, "tools_by_name", None)
if isinstance(tools_by_name, dict):
for tool in tools_by_name.values():
if not isinstance(tool, BaseTool):
continue
schema = get_confirmation_schema(tool)
if schema is not None:
schemas[tool.name] = schema
return schemas
def _is_middleware_node(self, node_name: str) -> bool:
"""Check if a node name represents a middleware node."""
return node_name in self._middleware_node_names
def _build_node_name(self, namespace: Any, node_name: str) -> str:
"""Build a fully qualified node name with subgraph prefix from the namespace.
When streaming with ``subgraphs=True``, LangGraph provides a namespace
tuple that identifies the subgraph hierarchy a node belongs to. This
method extracts the subgraph names and prepends them to the node name.
Args:
namespace: A tuple representing the subgraph hierarchy.
- () for the root graph.
- ("subgraph_name:node_id",) for a single-level subgraph.
- ("subgraph_name:node_id", "nested:node_id") for nested subgraphs.
node_name: The name of the node within its graph.
Returns:
The fully qualified node name. For example:
- "agent" when called from the root graph.
- "coder:generate" when called from the *coder* subgraph.
- "coder:debugger:analyze" when called from *debugger* nested inside *coder*.
"""
if not namespace:
return node_name
if not isinstance(namespace, (tuple, list)):
return node_name
parts = []
for ns in namespace:
if not isinstance(ns, str):
continue
if not ns:
continue
# Extract subgraph name (part before ':'), fall back to full string
part = ns.split(":")[0] if ":" in ns else ns
if part:
parts.append(part)
if not parts:
return node_name
prefix = ":".join(parts)
return f"{prefix}:{node_name}"
async def dispose(self) -> None:
"""Cleanup runtime resources."""
pass