Skip to content

Commit 7188224

Browse files
committed
fix: register sdk interrupt models with msgpack allowlist (#1500)
1 parent a3c7f53 commit 7188224

5 files changed

Lines changed: 93 additions & 2 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "uipath-langchain"
3-
version = "0.10.10"
3+
version = "0.10.11"
44
description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform"
55
readme = { file = "README.md", content-type = "text/markdown" }
66
requires-python = ">=3.11"

src/uipath_langchain/runtime/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
UiPathRuntimeFactoryRegistry,
55
)
66

7+
from uipath_langchain.runtime._msgpack_registry import (
8+
register_uipath_interrupt_models_with_msgpack,
9+
)
710
from uipath_langchain.runtime.factory import UiPathLangGraphRuntimeFactory
811
from uipath_langchain.runtime.runtime import UiPathLangGraphRuntime
912
from uipath_langchain.runtime.schema import (
@@ -14,6 +17,7 @@
1417

1518
def register_runtime_factory() -> None:
1619
"""Register the LangGraph factory. Called automatically via entry point."""
20+
register_uipath_interrupt_models_with_msgpack()
1721

1822
def create_factory(
1923
context: UiPathRuntimeContext | None = None,
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""Add UiPath SDK interrupt models to langgraph's msgpack safe-types set."""
2+
3+
from __future__ import annotations
4+
5+
import inspect
6+
7+
8+
def register_uipath_interrupt_models_with_msgpack() -> None:
9+
"""Allowlist every class in `uipath.platform.common.interrupt_models`.
10+
11+
Without this, langgraph emits an `unregistered type` warning whenever a
12+
checkpoint containing an SDK interrupt model (e.g. `CreateTask`) is loaded.
13+
"""
14+
from langgraph.checkpoint.serde import _msgpack as _lg_msgpack
15+
from uipath.platform.common import interrupt_models
16+
17+
extras = {
18+
(cls.__module__, cls.__name__)
19+
for _, cls in inspect.getmembers(interrupt_models, inspect.isclass)
20+
if cls.__module__ == interrupt_models.__name__
21+
}
22+
_lg_msgpack.SAFE_MSGPACK_TYPES = frozenset(_lg_msgpack.SAFE_MSGPACK_TYPES | extras)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Verify SDK interrupt models are registered with langgraph's msgpack allowlist."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
7+
from langgraph.checkpoint.serde import _msgpack as _lg_msgpack
8+
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
9+
from uipath.platform.common.interrupt_models import CreateTask
10+
11+
from uipath_langchain.runtime._msgpack_registry import (
12+
register_uipath_interrupt_models_with_msgpack,
13+
)
14+
15+
16+
def test_create_task_is_in_safe_msgpack_types() -> None:
17+
register_uipath_interrupt_models_with_msgpack()
18+
19+
assert (
20+
"uipath.platform.common.interrupt_models",
21+
"CreateTask",
22+
) in _lg_msgpack.SAFE_MSGPACK_TYPES
23+
24+
25+
def test_all_interrupt_models_registered() -> None:
26+
"""Every public class in interrupt_models should be auto-registered."""
27+
register_uipath_interrupt_models_with_msgpack()
28+
29+
from uipath.platform.common import interrupt_models
30+
31+
for name in [
32+
"CreateTask",
33+
"CreateEscalation",
34+
"WaitTask",
35+
"InvokeProcess",
36+
"WaitJob",
37+
"CreateBatchTransform",
38+
]:
39+
assert hasattr(interrupt_models, name)
40+
assert (
41+
interrupt_models.__name__,
42+
name,
43+
) in _lg_msgpack.SAFE_MSGPACK_TYPES
44+
45+
46+
def test_round_trip_create_task_emits_no_warning(
47+
caplog: logging.LogRecord,
48+
) -> None:
49+
"""Deserializing a CreateTask checkpoint must not warn about unregistered types."""
50+
register_uipath_interrupt_models_with_msgpack()
51+
52+
serde = JsonPlusSerializer()
53+
task = CreateTask(title="hello", data={})
54+
type_, payload = serde.dumps_typed(task)
55+
56+
with caplog.at_level( # type: ignore[attr-defined]
57+
logging.WARNING, logger="langgraph.checkpoint.serde.jsonplus"
58+
):
59+
restored = serde.loads_typed((type_, payload))
60+
61+
assert isinstance(restored, CreateTask)
62+
assert all(
63+
"Deserializing unregistered type" not in record.message
64+
for record in caplog.records # type: ignore[attr-defined]
65+
)

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)