|
1 | | -from typing import TYPE_CHECKING |
| 1 | +from typing import Iterable |
2 | 2 | from unittest.mock import patch |
3 | 3 |
|
4 | 4 | import pytest |
5 | 5 | from langchain_core.language_models import ModelProfile |
6 | 6 | from langchain_core.language_models.chat_models import BaseChatModel |
7 | | -from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage |
| 7 | +from langchain_core.messages import ( |
| 8 | + AIMessage, |
| 9 | + AnyMessage, |
| 10 | + HumanMessage, |
| 11 | + MessageLikeRepresentation, |
| 12 | + RemoveMessage, |
| 13 | + ToolMessage, |
| 14 | +) |
8 | 15 | from langchain_core.outputs import ChatGeneration, ChatResult |
9 | 16 | from langgraph.graph.message import REMOVE_ALL_MESSAGES |
10 | 17 |
|
@@ -316,7 +323,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None: |
316 | 323 |
|
317 | 324 |
|
318 | 325 | def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> None: |
319 | | - """Ensure token retention keeps pairs together even if exceeding target tokens.""" |
| 326 | + """Ensure token retention never splits tool pairs while enforcing hard caps.""" |
320 | 327 |
|
321 | 328 | def token_counter(messages: list[AnyMessage]) -> int: |
322 | 329 | return sum(len(getattr(message, "content", "")) for message in messages) |
@@ -344,13 +351,13 @@ def token_counter(messages: list[AnyMessage]) -> int: |
344 | 351 | assert result is not None |
345 | 352 |
|
346 | 353 | preserved_messages = result["messages"][2:] |
347 | | - assert preserved_messages == messages[1:] |
| 354 | + assert preserved_messages == messages[3:] |
| 355 | + assert token_counter(preserved_messages) <= 500 |
| 356 | + assert not any(isinstance(msg, (AIMessage, ToolMessage)) for msg in preserved_messages) |
348 | 357 |
|
349 | | - target_token_count = int(1000 * 0.5) |
350 | 358 | preserved_tokens = middleware.token_counter(preserved_messages) |
351 | | - |
352 | | - # Tool pair retention can exceed the target token count but should keep the pair intact. |
353 | | - assert preserved_tokens > target_token_count |
| 359 | + target_token_count = int(1000 * 0.5) |
| 360 | + assert preserved_tokens <= target_token_count |
354 | 361 |
|
355 | 362 |
|
356 | 363 | def test_summarization_middleware_missing_profile() -> None: |
@@ -783,6 +790,45 @@ def test_summarization_middleware_tool_call_in_search_range() -> None: |
783 | 790 | assert middleware._is_safe_cutoff_point(messages, 1) |
784 | 791 |
|
785 | 792 |
|
| 793 | +def test_summarization_middleware_results_under_window() -> None: |
| 794 | + """Ensure automatic profile inference triggers summarization when limits are exceeded.""" |
| 795 | + |
| 796 | + def _token_counter(messages: Iterable[MessageLikeRepresentation]) -> int: |
| 797 | + count = 0 |
| 798 | + for message in messages: |
| 799 | + if isinstance(message, ToolMessage): |
| 800 | + count = count + 500 |
| 801 | + else: |
| 802 | + count = count + 100 |
| 803 | + return count |
| 804 | + |
| 805 | + state = { |
| 806 | + "messages": [ |
| 807 | + HumanMessage(content="Message 1"), |
| 808 | + AIMessage( |
| 809 | + content="Message 2", |
| 810 | + tool_calls=[ |
| 811 | + {"name": "test", "args": {}, "id": "call-1"}, |
| 812 | + {"name": "test", "args": {}, "id": "call-2"}, |
| 813 | + ], |
| 814 | + ), |
| 815 | + ToolMessage(content="Result 2-1", tool_call_id="call-1"), |
| 816 | + ToolMessage(content="Result 2-2", tool_call_id="call-2"), |
| 817 | + ] |
| 818 | + } |
| 819 | + |
| 820 | + middleware = SummarizationMiddleware( |
| 821 | + model=ProfileChatModel(), |
| 822 | + trigger=("fraction", 0.80), |
| 823 | + keep=("fraction", 0.5), |
| 824 | + token_counter=_token_counter, |
| 825 | + ) |
| 826 | + result = middleware.before_model(state, None) |
| 827 | + assert result is not None |
| 828 | + count_after_summarization = _token_counter(result["messages"]) |
| 829 | + assert count_after_summarization <= 1000 # max_input_tokens of ProfileChatModel |
| 830 | + |
| 831 | + |
786 | 832 | def test_summarization_middleware_zero_and_negative_target_tokens() -> None: |
787 | 833 | """Test handling of edge cases with target token calculations.""" |
788 | 834 | # Test with very small fraction that rounds to zero |
|
0 commit comments