Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions libs/core/langchain_core/messages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Protocol,
TypeVar,
cast,
get_type_hints,
overload,
)
from xml.sax.saxutils import escape, quoteattr
Expand Down Expand Up @@ -1093,6 +1094,7 @@ def trim_messages(
start_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None,
include_system: bool = False,
text_splitter: Callable[[str], list[str]] | TextSplitter | None = None,
token_counter_is_per_message: bool = False,
) -> list[BaseMessage]:
r"""Trim messages to be below a token count.

Expand Down Expand Up @@ -1183,6 +1185,11 @@ def trim_messages(
splitter assumes that separators are kept, so that split contents can be
directly concatenated to recreate the original text. Defaults to splitting
on newlines.
token_counter_is_per_message: If `True`, `token_counter` is treated as a
per-message callable `(msg: BaseMessage) -> int`. Auto-detection only works
for annotated callables whose first positional parameter is typed as
`BaseMessage` or a subclass. Use this flag for lambdas or unannotated
callables, which cannot be reliably auto-detected.

Returns:
List of trimmed `BaseMessage`.
Expand Down Expand Up @@ -1417,12 +1424,30 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int:
if hasattr(actual_token_counter, "get_num_tokens_from_messages"):
list_token_counter = actual_token_counter.get_num_tokens_from_messages
elif callable(actual_token_counter):
if (
next(
iter(inspect.signature(actual_token_counter).parameters.values())
).annotation
is BaseMessage
):
try:
hints = get_type_hints(actual_token_counter)
except (NameError, AttributeError, TypeError):
# Fall back to raw annotations if type-hint resolution fails.
hints = {}
params = list(inspect.signature(actual_token_counter).parameters.values())
_positional_kinds = {
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.POSITIONAL_ONLY,
}
first_positional = next(
(p for p in params if p.kind in _positional_kinds), None
)
first_annotation = (
hints.get(first_positional.name, first_positional.annotation)
if first_positional is not None
else inspect.Parameter.empty
)
is_per_message = token_counter_is_per_message or (
first_annotation is not inspect.Parameter.empty
and isinstance(first_annotation, type)
and issubclass(first_annotation, BaseMessage)
)
if is_per_message:

def list_token_counter(messages: Sequence[BaseMessage]) -> int:
return sum(actual_token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
Expand Down
67 changes: 67 additions & 0 deletions libs/core/tests/unit_tests/messages/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,73 @@ def test_trim_messages_token_counter_shortcut_with_options() -> None:
assert messages == messages_copy


def test_trim_messages_per_message_base_message_annotation() -> None:
def counter(_msg: BaseMessage) -> int:
return 10

result = trim_messages(_MESSAGES_TO_TRIM, max_tokens=30, token_counter=counter)
assert len(result) == 3


def test_trim_messages_per_message_subclass_annotation() -> None:
def counter(_msg: HumanMessage) -> int:
return 10

result = trim_messages(_MESSAGES_TO_TRIM, max_tokens=30, token_counter=counter) # type: ignore[arg-type]
assert len(result) == 3


def test_trim_messages_per_message_string_annotation() -> None:
def counter(_msg: "BaseMessage") -> int:
return 10

result = trim_messages(_MESSAGES_TO_TRIM, max_tokens=30, token_counter=counter)
assert len(result) == 3


def test_trim_messages_per_message_lambda_with_flag() -> None:
result = trim_messages(
_MESSAGES_TO_TRIM,
max_tokens=30,
token_counter=lambda _msg: 10,
token_counter_is_per_message=True,
)
assert len(result) == 3


def test_trim_messages_per_message_unannotated_with_flag() -> None:
def counter(_msg): # type: ignore[no-untyped-def] # noqa: ANN001, ANN202
return 10

result = trim_messages(
_MESSAGES_TO_TRIM,
max_tokens=30,
token_counter=counter,
token_counter_is_per_message=True,
)
assert len(result) == 3


def test_trim_messages_list_counter_still_works() -> None:
def counter(messages: list[BaseMessage]) -> int:
return len(messages) * 10

result = trim_messages(_MESSAGES_TO_TRIM, max_tokens=30, token_counter=counter)
assert len(result) == 3


def test_trim_messages_get_num_tokens_from_messages_takes_precedence() -> None:
class CounterObj:
def __call__(self, _msg: BaseMessage) -> int:
return 10

def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
return len(messages) * 10

result = trim_messages(_MESSAGES_TO_TRIM, max_tokens=30, token_counter=CounterObj())
assert len(result) == 3


class FakeTokenCountingModel(FakeChatModel):
@override
def get_num_tokens_from_messages(
Expand Down
Loading