Skip to content
164 changes: 135 additions & 29 deletions libs/core/langchain_core/runnables/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

import ast
import asyncio
import dis
import inspect
import sys
import textwrap

# Cannot move to TYPE_CHECKING as Mapping and Sequence are needed at runtime by
# RunnableConfigurableFields.
from collections.abc import Mapping, Sequence # noqa: TC003
from contextlib import suppress
from functools import lru_cache
from inspect import signature
from itertools import groupby
Expand Down Expand Up @@ -38,6 +40,7 @@
Iterable,
)
from contextvars import Context
from types import CodeType

from langchain_core.runnables.schema import StreamEvent

Expand Down Expand Up @@ -404,6 +407,63 @@ def get_lambda_source(func: Callable) -> str | None:
return visitor.source if visitor.count == 1 else name


@lru_cache(maxsize=256)
def _nonlocal_access_plan(
code: CodeType,
) -> tuple[tuple[str, ...], tuple[tuple[str, tuple[str, ...]], ...]]:
"""Compute a nonlocal access plan from bytecode.

Args:
code: Code object to scan.

Returns:
A tuple ``(plain_roots, chains)``, where:
- plain_roots: Names loaded without any attribute access.
- chains: ``(root, path)`` pairs for attribute-access chains.
"""
root_ops = {"LOAD_GLOBAL", "LOAD_DEREF", "LOAD_NAME"}
attr_ops = {"LOAD_ATTR", "LOAD_METHOD"}

plain_roots: set[str] = set()
chains: list[tuple[str, tuple[str, ...]]] = []

base: str | None = None
attrs: list[str] = []

def flush() -> None:
nonlocal base, attrs
if base is not None:
if attrs:
chains.append((base, tuple(attrs)))
else:
plain_roots.add(base)
base = None
attrs = []

for ins in dis.get_instructions(code):
op = ins.opname
if op in root_ops and isinstance(ins.argval, str):
flush()
base = ins.argval
continue
if op in attr_ops and isinstance(ins.argval, str):
if base is not None:
attrs.append(ins.argval)
continue
flush()

flush()

deduped = []
seen = set()
for c in chains:
if c not in seen:
seen.add(c)
deduped.append(c)

return tuple(sorted(plain_roots)), tuple(deduped)


@lru_cache(maxsize=256)
def get_function_nonlocals(func: Callable) -> list[Any]:
"""Get the nonlocal variables accessed by a function.
Expand All @@ -414,37 +474,83 @@ def get_function_nonlocals(func: Callable) -> list[Any]:
Returns:
The nonlocal variables accessed by the function.
"""
try:
code = inspect.getsource(func)
tree = ast.parse(textwrap.dedent(code))
visitor = FunctionNonLocals()
visitor.visit(tree)
values: list[Any] = []
closure = (
inspect.getclosurevars(func.__wrapped__)
if hasattr(func, "__wrapped__") and callable(func.__wrapped__)
else inspect.getclosurevars(func)
)
candidates = {**closure.globals, **closure.nonlocals}
for k, v in candidates.items():
if k in visitor.nonlocals:
values.append(v)
for kk in visitor.nonlocals:
if "." in kk and kk.startswith(k):
vv = v
for part in kk.split(".")[1:]:
if vv is None:
break
try:
vv = getattr(vv, part)
except AttributeError:
break
else:
values.append(vv)
except (SyntaxError, TypeError, OSError, SystemError):
terminal_methods = {
"invoke",
"ainvoke",
"batch",
"abatch",
"stream",
"astream",
"transform",
"atransform",
}

target = func
seen_wrapped = set()
while True:
w = getattr(target, "__wrapped__", None)
if not callable(w):
break
wid = id(w)
if wid in seen_wrapped:
break
seen_wrapped.add(wid)
target = w

if getattr(target, "__code__", None) is None:
target = getattr(target, "__func__", target)

code = getattr(target, "__code__", None)
if code is None:
return []

return values
nonlocals_dict = {}
freevars = code.co_freevars
closure = getattr(target, "__closure__", None)
if closure and freevars:
for name, cell in zip(freevars, closure, strict=False):
with suppress(ValueError):
nonlocals_dict[name] = cell.cell_contents

globals_dict = getattr(target, "__globals__", {})

plain_roots, chains = _nonlocal_access_plan(code)

out = []
seen_ids = set()

def add(v: Any) -> None:
vid = id(v)
if vid not in seen_ids:
seen_ids.add(vid)
out.append(v)

def resolve_root(name: str) -> Any | None:
if name in nonlocals_dict:
return nonlocals_dict[name]
return globals_dict.get(name)

for name in plain_roots:
v = resolve_root(name)
if v is not None:
add(v)

for base, attrs in chains:
if not attrs or attrs[-1] not in terminal_methods:
continue
v = resolve_root(base)
if v is None:
continue
for a in attrs:
try:
v = getattr(v, a)
except Exception:
break
else:
if v is not None:
add(v)

return out


def indent_lines_after_first(text: str, prefix: str) -> str:
Expand Down
35 changes: 34 additions & 1 deletion libs/core/tests/unit_tests/runnables/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
from collections.abc import Callable
from typing import Any
from typing import Any, NoReturn

import pytest

Expand Down Expand Up @@ -73,3 +74,35 @@ def my_func6(value: str) -> str:
assert RunnableLambda(my_func3).deps == [agent]
assert RunnableLambda(my_func4).deps == [global_agent]
assert RunnableLambda(func).deps == [nl]


def test_deps_does_not_call_inspect_getsource() -> None:
original = inspect.getsource
error_message = "inspect.getsource was called while computing deps"

def explode(*_args: Any, **_kwargs: Any) -> NoReturn:
raise AssertionError(error_message)

inspect.getsource = explode
try:
agent: RunnableLambda[str, str] = RunnableLambda(lambda x: x)

class Box:
def __init__(self, a: RunnableLambda[str, str]) -> None:
self.agent: RunnableLambda[str, str] = a

box = Box(agent)

def my_func(x: str) -> str:
return box.agent.invoke(x)

r: RunnableLambda[str, str] = RunnableLambda(my_func)
_ = r.deps
finally:
inspect.getsource = original


def test_deps_is_cached_on_instance() -> None:
r = RunnableLambda(lambda x: x)
_ = r.deps
assert "deps" in r.__dict__
Loading