Skip to content

Commit fda666b

Browse files
committed
Use reference counting for storing inherited run trees to support garbage collection
1 parent 9f232ca commit fda666b

File tree

5 files changed

+78
-1
lines changed

5 files changed

+78
-1
lines changed

libs/core/langchain_core/callbacks/manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2480,7 +2480,12 @@ def _configure(
24802480
run_tree.trace_id,
24812481
run_tree.dotted_order,
24822482
)
2483-
handler.run_map[str(run_tree.id)] = run_tree
2483+
run_id_str = str(run_tree.id)
2484+
if run_id_str not in handler.run_map:
2485+
handler.run_map[run_id_str] = run_tree
2486+
handler._external_run_ids.setdefault( # noqa: SLF001
2487+
run_id_str, 0
2488+
)
24842489
for var, inheritable, handler_class, env_var in _configure_hooks:
24852490
create_one = (
24862491
env_var is not None

libs/core/langchain_core/tracers/base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ def _end_trace(self, run: Run) -> None:
4747
if not run.parent_run_id:
4848
self._persist_run(run)
4949
self.run_map.pop(str(run.id))
50+
# If this run's parent was injected from an external tracing context
51+
# (e.g. a langsmith @traceable), decrement its child refcount and
52+
# remove it from run_map once the last child is done.
53+
parent_id = str(run.parent_run_id) if run.parent_run_id else None
54+
if parent_id and parent_id in self._external_run_ids:
55+
self._external_run_ids[parent_id] -= 1
56+
if self._external_run_ids[parent_id] <= 0:
57+
self.run_map.pop(parent_id, None)
58+
del self._external_run_ids[parent_id]
5059
self._on_run_update(run)
5160

5261
def on_chat_model_start(
@@ -568,6 +577,15 @@ async def _end_trace(self, run: Run) -> None:
568577
if not run.parent_run_id:
569578
await self._persist_run(run)
570579
self.run_map.pop(str(run.id))
580+
# If this run's parent was injected from an external tracing context
581+
# (e.g. a langsmith @traceable), decrement its child refcount and
582+
# remove it from run_map once the last child is done.
583+
parent_id = str(run.parent_run_id) if run.parent_run_id else None
584+
if parent_id and parent_id in self._external_run_ids:
585+
self._external_run_ids[parent_id] -= 1
586+
if self._external_run_ids[parent_id] <= 0:
587+
self.run_map.pop(parent_id, None)
588+
del self._external_run_ids[parent_id]
571589
await self._on_run_update(run)
572590

573591
@override

libs/core/langchain_core/tracers/core.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
] = "original",
5454
run_map: dict[str, Run] | None = None,
5555
order_map: dict[UUID, tuple[UUID, str]] | None = None,
56+
_external_run_ids: dict[str, int] | None = None,
5657
**kwargs: Any,
5758
) -> None:
5859
"""Initialize the tracer.
@@ -74,6 +75,7 @@ def __init__(
7475
it does NOT raise an attribute error `on_chat_model_start`
7576
run_map: Optional shared map of run ID to run.
7677
order_map: Optional shared map of run ID to trace ordering data.
78+
_external_run_ids: Optional shared set of externally injected run IDs.
7779
**kwargs: Additional keyword arguments that will be passed to the
7880
superclass.
7981
"""
@@ -87,6 +89,16 @@ def __init__(
8789
self.order_map = order_map if order_map is not None else {}
8890
"""Map of run ID to (trace_id, dotted_order). Cleared when tracer GCed."""
8991

92+
self._external_run_ids: dict[str, int] = (
93+
_external_run_ids if _external_run_ids is not None else {}
94+
)
95+
"""Refcount of active children per externally-injected run ID.
96+
97+
These runs are added to `run_map` so child runs can find their parent,
98+
but they are not managed by the tracer's callback lifecycle. When
99+
the last child finishes the entry is evicted to avoid memory leaks.
100+
"""
101+
90102
@abstractmethod
91103
def _persist_run(self, run: Run) -> Coroutine[Any, Any, None] | None:
92104
"""Persist a run."""
@@ -117,6 +129,9 @@ def _start_trace(self, run: Run) -> Coroutine[Any, Any, None] | None: # type: i
117129
run.dotted_order += "." + current_dotted_order
118130
if parent_run := self.run_map.get(str(run.parent_run_id)):
119131
self._add_child_run(parent_run, run)
132+
parent_key = str(run.parent_run_id)
133+
if parent_key in self._external_run_ids:
134+
self._external_run_ids[parent_key] += 1
120135
else:
121136
if self.log_missing_parent:
122137
logger.debug(

libs/core/langchain_core/tracers/langchain.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def copy_with_metadata_defaults(
189189
metadata=merged_metadata,
190190
run_map=self.run_map,
191191
order_map=self.order_map,
192+
_external_run_ids=self._external_run_ids,
192193
)
193194

194195
def _start_trace(self, run: Run) -> None:

libs/core/tests/unit_tests/runnables/test_tracing_interops.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,44 @@ def parent(_: Any) -> str:
555555
assert kitten_run.dotted_order.startswith(grandchild_run.dotted_order)
556556

557557

558+
def test_traceable_parent_run_map_cleanup() -> None:
559+
"""External RunTree injected into run_map is cleaned up when its child ends.
560+
561+
When a `@traceable` function invokes a LangChain `Runnable`, the
562+
`RunTree` is added to the tracer's `run_map` so child runs can
563+
reference it. Previously the entry was never removed, causing a
564+
memory leak that grew with every call.
565+
"""
566+
mock_session = MagicMock()
567+
mock_client_ = Client(
568+
session=mock_session, api_key="test", auto_batch_tracing=False
569+
)
570+
571+
@RunnableLambda
572+
def child(x: str) -> str:
573+
return x
574+
575+
with tracing_context(client=mock_client_, enabled=True):
576+
577+
@traceable
578+
def parent(x: str) -> str:
579+
return child.invoke(x)
580+
581+
parent("hello")
582+
583+
# All LangChainTracer instances created during the call should have an
584+
# empty run_map after the call completes.
585+
import gc # noqa: PLC0415
586+
587+
gc.collect()
588+
tracers = [o for o in gc.get_objects() if isinstance(o, LangChainTracer)]
589+
for tracer in tracers:
590+
assert tracer.run_map == {}, (
591+
f"run_map should be empty but contains: "
592+
f"{[getattr(v, 'name', k) for k, v in tracer.run_map.items()]}"
593+
)
594+
595+
558596
class TestTracerMetadataThroughInvoke:
559597
"""Tests for tracer metadata merging through invoke calls."""
560598

0 commit comments

Comments
 (0)