From 59d76c8d8fa849dad13d6db22a32313226a2c784 Mon Sep 17 00:00:00 2001 From: nu1lx Date: Mon, 13 Apr 2026 08:58:34 +0000 Subject: [PATCH] fix(langchain): handle done markers after stdout without trailing newline --- .../langchain/agents/middleware/shell_tool.py | 59 +++++++++++-------- .../implementations/test_shell_tool.py | 55 +++++++++++++++++ 2 files changed, 89 insertions(+), 25 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py index 93fd978ca4fa0..ea009cae71de6 100644 --- a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py +++ b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py @@ -237,6 +237,14 @@ def execute(self, command: str, *, timeout: float) -> CommandExecutionResult: return self._collect_output(marker, deadline, timeout) + @staticmethod + def _split_stdout_marker(data: str, marker: str) -> tuple[str, str | None]: + """Split real stdout from an appended completion marker.""" + marker_index = data.find(marker) + if marker_index == -1: + return data, None + return data[:marker_index], data[marker_index:] + def _collect_output( self, marker: str, @@ -265,8 +273,32 @@ def _collect_output( if data is None: continue - if source == "stdout" and data.startswith(marker): - _, _, status = data.partition(" ") + marker_data: str | None = None + if source == "stdout": + data, marker_data = self._split_stdout_marker(data, marker) + + if data: + total_lines += 1 + encoded = data.encode("utf-8", "replace") + total_bytes += len(encoded) + + if total_lines > self._policy.max_output_lines: + truncated_by_lines = True + elif ( + self._policy.max_output_bytes is not None + and total_bytes > self._policy.max_output_bytes + ): + truncated_by_bytes = True + elif source == "stderr": + stripped = data.rstrip("\n") + collected.append(f"[stderr] {stripped}") + if data.endswith("\n"): + collected.append("\n") + else: + collected.append(data) + + if marker_data is not None: + _, _, status = marker_data.partition(" ") exit_code = self._safe_int(status.strip()) # Drain any remaining stderr that may have arrived concurrently. # The stderr reader thread runs independently, so output might @@ -274,29 +306,6 @@ def _collect_output( self._drain_remaining_stderr(collected, deadline) break - total_lines += 1 - encoded = data.encode("utf-8", "replace") - total_bytes += len(encoded) - - if total_lines > self._policy.max_output_lines: - truncated_by_lines = True - continue - - if ( - self._policy.max_output_bytes is not None - and total_bytes > self._policy.max_output_bytes - ): - truncated_by_bytes = True - continue - - if source == "stderr": - stripped = data.rstrip("\n") - collected.append(f"[stderr] {stripped}") - if data.endswith("\n"): - collected.append("\n") - else: - collected.append(data) - if timed_out: LOGGER.warning( "Command timed out after %.2f seconds; restarting shell session.", diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_shell_tool.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_shell_tool.py index 776b01a6e7a06..37534344875ca 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_shell_tool.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_shell_tool.py @@ -394,6 +394,61 @@ def test_empty_output_replaced_with_no_output(tmp_path: Path) -> None: middleware.after_agent(state, runtime) +def test_stdout_without_trailing_newline_does_not_timeout(tmp_path: Path) -> None: + """Test stdout is preserved when the done marker shares the same line.""" + policy = HostExecutionPolicy(command_timeout=1.0) + middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy) + runtime = Runtime() + state = _empty_state() + try: + updates = middleware.before_agent(state, runtime) + if updates: + state.update(cast("ShellToolState", updates)) + resources = middleware._get_or_create_resources(state) + + result = middleware._run_shell_tool( + resources, + {"command": "printf 'hello without newline'"}, + tool_call_id="test-id", + ) + + assert isinstance(result, ToolMessage) + assert result.status == "success" + assert result.content == "hello without newline" + assert result.artifact["timed_out"] is False + assert result.artifact["exit_code"] == 0 + finally: + middleware.after_agent(state, runtime) + + +def test_truncated_stdout_without_trailing_newline_does_not_timeout(tmp_path: Path) -> None: + """Test truncation does not prevent marker detection on the same line.""" + policy = HostExecutionPolicy(command_timeout=1.0, max_output_bytes=5) + middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy) + runtime = Runtime() + state = _empty_state() + try: + updates = middleware.before_agent(state, runtime) + if updates: + state.update(cast("ShellToolState", updates)) + resources = middleware._get_or_create_resources(state) + + result = middleware._run_shell_tool( + resources, + {"command": "printf 'hello without newline'"}, + tool_call_id="test-id", + ) + + assert isinstance(result, ToolMessage) + assert result.status == "success" + assert "truncated at 5 bytes" in result.content.lower() + assert result.artifact["timed_out"] is False + assert result.artifact["exit_code"] == 0 + assert result.artifact["truncated_by_bytes"] is True + finally: + middleware.after_agent(state, runtime) + + def test_stderr_output_labeling(tmp_path: Path) -> None: """Test that stderr output is properly labeled.""" middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")