From a5d236eb7846c18d1e12c95a9826f88ae5539700 Mon Sep 17 00:00:00 2001 From: Giulio Leone <6887247+giulio-leone@users.noreply.github.com> Date: Mon, 9 Mar 2026 11:32:57 +0100 Subject: [PATCH 1/6] fix(core): prevent output corruption in RunnableRetry.batch when partial retries succeed MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After retries, the final assembly used result.pop(0) to fill positions not in results_map. But result still contained successfully-retried values alongside exceptions, so the pop consumed the wrong elements — replacing exceptions with stale success values. Replace the pop-based assembly with an index-mapped lookup using last_remaining_indices so each original position maps to its correct result from the last retry batch. Fixes langchain-ai#35475 --- libs/core/langchain_core/runnables/retry.py | 16 ++++- .../unit_tests/runnables/test_runnable.py | 71 +++++++++++++++++++ 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/runnables/retry.py b/libs/core/langchain_core/runnables/retry.py index 6b9f5fef2de7d..72cf0c585d78c 100644 --- a/libs/core/langchain_core/runnables/retry.py +++ b/libs/core/langchain_core/runnables/retry.py @@ -235,6 +235,7 @@ def _batch( not_set: list[Output] = [] result = not_set + last_remaining_indices: list[int] = list(range(len(inputs))) try: for attempt in self._sync_retrying(): with attempt: @@ -245,6 +246,7 @@ def _batch( ] if not remaining_indices: break + last_remaining_indices = remaining_indices pending_inputs = [inputs[i] for i in remaining_indices] pending_configs = [config[i] for i in remaining_indices] pending_run_managers = [run_manager[i] for i in remaining_indices] @@ -279,12 +281,16 @@ def _batch( if result is not_set: result = cast("list[Output]", [e] * len(inputs)) + # Map last retry results back to original indices so that + # successfully-retried values don't overwrite unrelated positions. + last_result_map = dict(zip(last_remaining_indices, result)) + outputs: list[Output | Exception] = [] for idx in range(len(inputs)): if idx in results_map: outputs.append(results_map[idx]) else: - outputs.append(result.pop(0)) + outputs.append(last_result_map[idx]) return outputs @override @@ -311,6 +317,7 @@ async def _abatch( not_set: list[Output] = [] result = not_set + last_remaining_indices: list[int] = list(range(len(inputs))) try: async for attempt in self._async_retrying(): with attempt: @@ -321,6 +328,7 @@ async def _abatch( ] if not remaining_indices: break + last_remaining_indices = remaining_indices pending_inputs = [inputs[i] for i in remaining_indices] pending_configs = [config[i] for i in remaining_indices] pending_run_managers = [run_manager[i] for i in remaining_indices] @@ -354,12 +362,16 @@ async def _abatch( if result is not_set: result = cast("list[Output]", [e] * len(inputs)) + # Map last retry results back to original indices so that + # successfully-retried values don't overwrite unrelated positions. + last_result_map = dict(zip(last_remaining_indices, result)) + outputs: list[Output | Exception] = [] for idx in range(len(inputs)): if idx in results_map: outputs.append(results_map[idx]) else: - outputs.append(result.pop(0)) + outputs.append(last_result_map[idx]) return outputs @override diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index b6848c2cd85b3..50e87065caf8e 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -3979,6 +3979,77 @@ def sometimes_fail(x: int) -> int: # pragma: no cover - trivial assert results == [0, 1, 2] +def test_retry_batch_no_corruption_on_partial_retry() -> None: + """Regression: items that still fail after retries must stay as exceptions. + + When one item succeeds on retry while another permanently fails, the stale + index mapping caused the permanent failure to be replaced by the retried + success value (GH-35475). + """ + failed_once = False + + def process(name: str) -> str: + nonlocal failed_once + if name == "ok": + return "ok-result" + if name == "retry_then_ok": + if not failed_once: + failed_once = True + msg = "transient" + raise ValueError(msg) + return "retry-result" + msg = "permanent" + raise ValueError(msg) + + runnable = RunnableLambda(process).with_retry( + stop_after_attempt=2, + wait_exponential_jitter=False, + retry_if_exception_type=(ValueError,), + ) + + result = runnable.batch( + ["ok", "retry_then_ok", "always_fail"], + return_exceptions=True, + ) + + assert result[0] == "ok-result" + assert result[1] == "retry-result" + assert isinstance(result[2], Exception) + + +async def test_async_retry_batch_no_corruption_on_partial_retry() -> None: + """Async variant of the partial-retry corruption regression test.""" + failed_once = False + + def process(name: str) -> str: + nonlocal failed_once + if name == "ok": + return "ok-result" + if name == "retry_then_ok": + if not failed_once: + failed_once = True + msg = "transient" + raise ValueError(msg) + return "retry-result" + msg = "permanent" + raise ValueError(msg) + + runnable = RunnableLambda(process).with_retry( + stop_after_attempt=2, + wait_exponential_jitter=False, + retry_if_exception_type=(ValueError,), + ) + + result = await runnable.abatch( + ["ok", "retry_then_ok", "always_fail"], + return_exceptions=True, + ) + + assert result[0] == "ok-result" + assert result[1] == "retry-result" + assert isinstance(result[2], Exception) + + async def test_async_retrying(mocker: MockerFixture) -> None: def _lambda(x: int) -> int: if x == 1: From 8a4472dee74a02eb9d622ff08a2bbc32d1b68491 Mon Sep 17 00:00:00 2001 From: Giulio Leone <6887247+giulio-leone@users.noreply.github.com> Date: Mon, 9 Mar 2026 15:02:31 +0100 Subject: [PATCH 2/6] fix: add strict=True to zip() calls to satisfy ruff B905 --- libs/core/langchain_core/runnables/retry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/runnables/retry.py b/libs/core/langchain_core/runnables/retry.py index 72cf0c585d78c..26a173b453061 100644 --- a/libs/core/langchain_core/runnables/retry.py +++ b/libs/core/langchain_core/runnables/retry.py @@ -283,7 +283,7 @@ def _batch( # Map last retry results back to original indices so that # successfully-retried values don't overwrite unrelated positions. - last_result_map = dict(zip(last_remaining_indices, result)) + last_result_map = dict(zip(last_remaining_indices, result, strict=True)) outputs: list[Output | Exception] = [] for idx in range(len(inputs)): @@ -364,7 +364,7 @@ async def _abatch( # Map last retry results back to original indices so that # successfully-retried values don't overwrite unrelated positions. - last_result_map = dict(zip(last_remaining_indices, result)) + last_result_map = dict(zip(last_remaining_indices, result, strict=True)) outputs: list[Output | Exception] = [] for idx in range(len(inputs)): From 562c2bdec7d46928dcedfa148f91c9f2afb96b7a Mon Sep 17 00:00:00 2001 From: giulio-leone Date: Wed, 11 Mar 2026 16:50:10 +0100 Subject: [PATCH 3/6] docs(core): clarify retry result mapping invariant Clarify why last_result_map is only consulted for indices that still need fallback values after retries. --- libs/core/langchain_core/runnables/retry.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/libs/core/langchain_core/runnables/retry.py b/libs/core/langchain_core/runnables/retry.py index 26a173b453061..ea5df6db34de0 100644 --- a/libs/core/langchain_core/runnables/retry.py +++ b/libs/core/langchain_core/runnables/retry.py @@ -283,6 +283,9 @@ def _batch( # Map last retry results back to original indices so that # successfully-retried values don't overwrite unrelated positions. + # If the final retry attempt succeeded for every remaining input, those + # indices are already in results_map, so last_result_map is never + # consulted for them. last_result_map = dict(zip(last_remaining_indices, result, strict=True)) outputs: list[Output | Exception] = [] @@ -364,6 +367,9 @@ async def _abatch( # Map last retry results back to original indices so that # successfully-retried values don't overwrite unrelated positions. + # If the final retry attempt succeeded for every remaining input, those + # indices are already in results_map, so last_result_map is never + # consulted for them. last_result_map = dict(zip(last_remaining_indices, result, strict=True)) outputs: list[Output | Exception] = [] From 1b426b6ddcf31161f205e1e712268b8bf5ad6a9d Mon Sep 17 00:00:00 2001 From: giulio-leone Date: Fri, 13 Mar 2026 02:24:54 +0100 Subject: [PATCH 4/6] docs: clarify all-succeed-first-try edge case invariant Add detailed comment explaining why last_result_map is safe when every item succeeds on the first attempt, as suggested by @gambletan in review. --- libs/core/langchain_core/runnables/retry.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/libs/core/langchain_core/runnables/retry.py b/libs/core/langchain_core/runnables/retry.py index ea5df6db34de0..57bcb444d29e8 100644 --- a/libs/core/langchain_core/runnables/retry.py +++ b/libs/core/langchain_core/runnables/retry.py @@ -283,9 +283,12 @@ def _batch( # Map last retry results back to original indices so that # successfully-retried values don't overwrite unrelated positions. - # If the final retry attempt succeeded for every remaining input, those - # indices are already in results_map, so last_result_map is never - # consulted for them. + # Note: if all items succeed on the very first attempt, remaining_indices + # becomes empty on the second iteration → break before updating + # last_remaining_indices. In that case last_remaining_indices == + # range(len(inputs)) and result == the full first-attempt output, so the + # zip still pairs correctly. However, every idx will already be in + # results_map, so last_result_map is never actually consulted. last_result_map = dict(zip(last_remaining_indices, result, strict=True)) outputs: list[Output | Exception] = [] @@ -367,9 +370,12 @@ async def _abatch( # Map last retry results back to original indices so that # successfully-retried values don't overwrite unrelated positions. - # If the final retry attempt succeeded for every remaining input, those - # indices are already in results_map, so last_result_map is never - # consulted for them. + # Note: if all items succeed on the very first attempt, remaining_indices + # becomes empty on the second iteration → break before updating + # last_remaining_indices. In that case last_remaining_indices == + # range(len(inputs)) and result == the full first-attempt output, so the + # zip still pairs correctly. However, every idx will already be in + # results_map, so last_result_map is never actually consulted. last_result_map = dict(zip(last_remaining_indices, result, strict=True)) outputs: list[Output | Exception] = [] From 253949a5a6de4adf1515a822419290821f47f7bc Mon Sep 17 00:00:00 2001 From: Giulio Leone Date: Sat, 14 Mar 2026 20:10:46 +0100 Subject: [PATCH 5/6] fix: drop zip(strict=True) for Python 3.9 compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit zip(strict=True) was introduced in Python 3.10, but langchain-core supports Python >=3.9. Remove the strict parameter — the invariant (equal-length lists) is guaranteed by the retry loop logic. --- libs/core/langchain_core/runnables/retry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/runnables/retry.py b/libs/core/langchain_core/runnables/retry.py index 57bcb444d29e8..a41ba6c04f602 100644 --- a/libs/core/langchain_core/runnables/retry.py +++ b/libs/core/langchain_core/runnables/retry.py @@ -289,7 +289,7 @@ def _batch( # range(len(inputs)) and result == the full first-attempt output, so the # zip still pairs correctly. However, every idx will already be in # results_map, so last_result_map is never actually consulted. - last_result_map = dict(zip(last_remaining_indices, result, strict=True)) + last_result_map = dict(zip(last_remaining_indices, result)) outputs: list[Output | Exception] = [] for idx in range(len(inputs)): @@ -376,7 +376,7 @@ async def _abatch( # range(len(inputs)) and result == the full first-attempt output, so the # zip still pairs correctly. However, every idx will already be in # results_map, so last_result_map is never actually consulted. - last_result_map = dict(zip(last_remaining_indices, result, strict=True)) + last_result_map = dict(zip(last_remaining_indices, result)) outputs: list[Output | Exception] = [] for idx in range(len(inputs)): From 616e9d6c97d6ad2f23e10765b59b9b822ddc8fc4 Mon Sep 17 00:00:00 2001 From: giulio-leone Date: Sat, 21 Mar 2026 11:48:17 +0100 Subject: [PATCH 6/6] fix(core): add explicit zip strict flags Rebasing onto current master surfaced two ruff B905 failures from the new retry result mapping logic. Add explicit strict=False to keep the small fix lint-clean without changing runtime behavior. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- libs/core/langchain_core/runnables/retry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/runnables/retry.py b/libs/core/langchain_core/runnables/retry.py index a41ba6c04f602..f50f32fec3dc3 100644 --- a/libs/core/langchain_core/runnables/retry.py +++ b/libs/core/langchain_core/runnables/retry.py @@ -289,7 +289,7 @@ def _batch( # range(len(inputs)) and result == the full first-attempt output, so the # zip still pairs correctly. However, every idx will already be in # results_map, so last_result_map is never actually consulted. - last_result_map = dict(zip(last_remaining_indices, result)) + last_result_map = dict(zip(last_remaining_indices, result, strict=False)) outputs: list[Output | Exception] = [] for idx in range(len(inputs)): @@ -376,7 +376,7 @@ async def _abatch( # range(len(inputs)) and result == the full first-attempt output, so the # zip still pairs correctly. However, every idx will already be in # results_map, so last_result_map is never actually consulted. - last_result_map = dict(zip(last_remaining_indices, result)) + last_result_map = dict(zip(last_remaining_indices, result, strict=False)) outputs: list[Output | Exception] = [] for idx in range(len(inputs)):