diff --git a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py index 93fd978ca4fa0..b59df79a83621 100644 --- a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py +++ b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py @@ -403,11 +403,18 @@ def _kill_process(self) -> None: return if hasattr(os, "killpg"): - with contextlib.suppress(ProcessLookupError): - os.killpg(os.getpgid(self._process.pid), signal.SIGKILL) - else: # pragma: no cover - with contextlib.suppress(ProcessLookupError): - self._process.kill() + try: + child_pgid = os.getpgid(self._process.pid) + # Only send a group kill when the child has a dedicated process group. + # If the child shares our group, killpg would terminate the caller too. + if child_pgid != os.getpgrp(): + os.killpg(child_pgid, signal.SIGKILL) + return + except ProcessLookupError: + return + + with contextlib.suppress(ProcessLookupError): + self._process.kill() def _enqueue_stream(self, stream: Any, label: str) -> None: for line in iter(stream.readline, ""): diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_shell_tool.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_shell_tool.py index 776b01a6e7a06..dcefe24ce092a 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_shell_tool.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_shell_tool.py @@ -1,10 +1,13 @@ from __future__ import annotations import gc +import os +import signal import tempfile import time from pathlib import Path from typing import cast +from unittest.mock import Mock import pytest from langchain_core.messages import ToolMessage @@ -14,6 +17,7 @@ from langchain.agents.middleware.shell_tool import ( HostExecutionPolicy, RedactionRule, + ShellSession, ShellToolMiddleware, ShellToolState, _SessionResources, @@ -546,3 +550,43 @@ def test_get_or_create_resources_reuses_existing(tmp_path: Path) -> None: # Clean up resources1.finalizer() + + +def test_kill_process_avoids_group_kill_for_shared_process_group( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Avoid `killpg` when child shares the caller process group.""" + session = ShellSession(tmp_path, HostExecutionPolicy(), ("/bin/bash",), {}) + process = Mock() + process.pid = 1234 + session._process = process # type: ignore[assignment] + + killpg_mock = Mock() + monkeypatch.setattr(os, "killpg", killpg_mock, raising=False) + monkeypatch.setattr(os, "getpgid", lambda _pid: 1000) + monkeypatch.setattr(os, "getpgrp", lambda: 1000) + + session._kill_process() + + killpg_mock.assert_not_called() + process.kill.assert_called_once_with() + + +def test_kill_process_uses_group_kill_for_dedicated_process_group( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Keep process-group kill behavior when child runs in a separate group.""" + session = ShellSession(tmp_path, HostExecutionPolicy(), ("/bin/bash",), {}) + process = Mock() + process.pid = 5678 + session._process = process # type: ignore[assignment] + + killpg_mock = Mock() + monkeypatch.setattr(os, "killpg", killpg_mock, raising=False) + monkeypatch.setattr(os, "getpgid", lambda _pid: 2000) + monkeypatch.setattr(os, "getpgrp", lambda: 1000) + + session._kill_process() + + killpg_mock.assert_called_once_with(2000, signal.SIGKILL) + process.kill.assert_not_called()