Skip to content

Commit b2db842

Browse files
committed
treat keep threshold as a hard cap
1 parent 9c21f83 commit b2db842

File tree

2 files changed

+66
-12
lines changed

2 files changed

+66
-12
lines changed

libs/langchain_v1/langchain/agents/middleware/summarization.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import uuid
44
import warnings
55
from collections.abc import Callable, Iterable, Mapping
6+
from functools import cache
67
from typing import Any, Literal, cast
78

89
from langchain_core.messages import (
@@ -259,6 +260,10 @@ def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
259260
if not messages:
260261
return 0
261262

263+
@cache
264+
def suffix_token_count(start_index: int) -> int:
265+
return self.token_counter(messages[start_index:])
266+
262267
kind, value = self.keep
263268
if kind == "fraction":
264269
max_input_tokens = self._get_profile_limits()
@@ -273,7 +278,7 @@ def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
273278
if target_token_count <= 0:
274279
target_token_count = 1
275280

276-
if self.token_counter(messages) <= target_token_count:
281+
if suffix_token_count(0) <= target_token_count:
277282
return 0
278283

279284
# Use binary search to identify the earliest message index that keeps the
@@ -286,7 +291,7 @@ def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
286291
break
287292

288293
mid = (left + right) // 2
289-
if self.token_counter(messages[mid:]) <= target_token_count:
294+
if suffix_token_count(mid) <= target_token_count:
290295
cutoff_candidate = mid
291296
right = mid
292297
else:
@@ -300,8 +305,11 @@ def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
300305
return 0
301306
cutoff_candidate = len(messages) - 1
302307

303-
for i in range(cutoff_candidate, -1, -1):
304-
if self._is_safe_cutoff_point(messages, i):
308+
for i in range(cutoff_candidate, len(messages) + 1):
309+
if (
310+
self._is_safe_cutoff_point(messages, i)
311+
and suffix_token_count(i) <= target_token_count
312+
):
305313
return i
306314

307315
return 0

libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1-
from typing import TYPE_CHECKING
1+
from typing import Iterable
22
from unittest.mock import patch
33

44
import pytest
55
from langchain_core.language_models import ModelProfile
66
from langchain_core.language_models.chat_models import BaseChatModel
7-
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
7+
from langchain_core.messages import (
8+
AIMessage,
9+
AnyMessage,
10+
HumanMessage,
11+
MessageLikeRepresentation,
12+
RemoveMessage,
13+
ToolMessage,
14+
)
815
from langchain_core.outputs import ChatGeneration, ChatResult
916
from langgraph.graph.message import REMOVE_ALL_MESSAGES
1017

@@ -316,7 +323,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
316323

317324

318325
def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> None:
319-
"""Ensure token retention keeps pairs together even if exceeding target tokens."""
326+
"""Ensure token retention never splits tool pairs while enforcing hard caps."""
320327

321328
def token_counter(messages: list[AnyMessage]) -> int:
322329
return sum(len(getattr(message, "content", "")) for message in messages)
@@ -344,13 +351,13 @@ def token_counter(messages: list[AnyMessage]) -> int:
344351
assert result is not None
345352

346353
preserved_messages = result["messages"][2:]
347-
assert preserved_messages == messages[1:]
354+
assert preserved_messages == messages[3:]
355+
assert token_counter(preserved_messages) <= 500
356+
assert not any(isinstance(msg, (AIMessage, ToolMessage)) for msg in preserved_messages)
348357

349-
target_token_count = int(1000 * 0.5)
350358
preserved_tokens = middleware.token_counter(preserved_messages)
351-
352-
# Tool pair retention can exceed the target token count but should keep the pair intact.
353-
assert preserved_tokens > target_token_count
359+
target_token_count = int(1000 * 0.5)
360+
assert preserved_tokens <= target_token_count
354361

355362

356363
def test_summarization_middleware_missing_profile() -> None:
@@ -783,6 +790,45 @@ def test_summarization_middleware_tool_call_in_search_range() -> None:
783790
assert middleware._is_safe_cutoff_point(messages, 1)
784791

785792

793+
def test_summarization_middleware_results_under_window() -> None:
794+
"""Ensure automatic profile inference triggers summarization when limits are exceeded."""
795+
796+
def _token_counter(messages: Iterable[MessageLikeRepresentation]) -> int:
797+
count = 0
798+
for message in messages:
799+
if isinstance(message, ToolMessage):
800+
count = count + 500
801+
else:
802+
count = count + 100
803+
return count
804+
805+
state = {
806+
"messages": [
807+
HumanMessage(content="Message 1"),
808+
AIMessage(
809+
content="Message 2",
810+
tool_calls=[
811+
{"name": "test", "args": {}, "id": "call-1"},
812+
{"name": "test", "args": {}, "id": "call-2"},
813+
],
814+
),
815+
ToolMessage(content="Result 2-1", tool_call_id="call-1"),
816+
ToolMessage(content="Result 2-2", tool_call_id="call-2"),
817+
]
818+
}
819+
820+
middleware = SummarizationMiddleware(
821+
model=ProfileChatModel(),
822+
trigger=("fraction", 0.80),
823+
keep=("fraction", 0.5),
824+
token_counter=_token_counter,
825+
)
826+
result = middleware.before_model(state, None)
827+
assert result is not None
828+
count_after_summarization = _token_counter(result["messages"])
829+
assert count_after_summarization <= 1000 # max_input_tokens of ProfileChatModel
830+
831+
786832
def test_summarization_middleware_zero_and_negative_target_tokens() -> None:
787833
"""Test handling of edge cases with target token calculations."""
788834
# Test with very small fraction that rounds to zero

0 commit comments

Comments
 (0)