diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index b60d57d844bd4..f97e53ba7af59 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -1076,6 +1076,35 @@ def merge_message_runs( return merged +def _is_per_message_counter(fn: Callable) -> bool: + """Determine if a callable is a per-message token counter. + + Returns True if the callable appears to accept a single BaseMessage + rather than a list of BaseMessages. Handles lambdas (no annotation), + postponed annotations (``from __future__ import annotations``), + and subclass annotations (e.g. ``HumanMessage``). + """ + try: + sig = inspect.signature(fn) + except (ValueError, TypeError): + return False + params = list(sig.parameters.values()) + if not params: + return False + ann = params[0].annotation + # No annotation (lambda, bare ``def``) — assume per-message counter + if ann is inspect.Parameter.empty: + return True + # Postponed / stringified annotation + if isinstance(ann, str): + return "message" in ann.lower() + # Direct type annotation — check if it's BaseMessage or a subclass + if isinstance(ann, type) and issubclass(ann, BaseMessage): + return True + return False + + + # TODO: Update so validation errors (for token_counter, for example) are raised on # init not at runtime. @_runnable_support @@ -1417,12 +1446,7 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int: if hasattr(actual_token_counter, "get_num_tokens_from_messages"): list_token_counter = actual_token_counter.get_num_tokens_from_messages elif callable(actual_token_counter): - if ( - next( - iter(inspect.signature(actual_token_counter).parameters.values()) - ).annotation - is BaseMessage - ): + if _is_per_message_counter(actual_token_counter): def list_token_counter(messages: Sequence[BaseMessage]) -> int: return sum(actual_token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]