diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 7026a2e8fb162..a49d39e2217df 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -37,6 +37,7 @@ from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import ValidationError as ValidationErrorV1 from pydantic.v1 import validate_arguments as validate_arguments_v1 +from pydantic_core import SchemaError as PydanticSchemaError from typing_extensions import override from langchain_core.callbacks import ( @@ -326,7 +327,27 @@ def create_schema_from_function( # This code should be re-written to simply construct a Pydantic model # using inspect.signature and create_model. warnings.simplefilter("ignore", category=PydanticDeprecationWarning) - validated = validate_arguments(func, config=_SchemaConfig) # type: ignore[operator] + try: + validated = validate_arguments(func, config=_SchemaConfig) # type: ignore[operator] + except PydanticSchemaError: + # Injected arg types (e.g. Protocol unions like StateLike) may + # not be valid for Pydantic schema creation. Replace them with + # Any and retry, since they are filtered from the final schema. + _saved: dict[str, Any] = {} + for param_name, param in sig.parameters.items(): + if _is_injected_arg_type(param.annotation): + _saved[param_name] = func.__annotations__.get(param_name) + func.__annotations__[param_name] = Annotated[ + Any, InjectedToolArg() + ] + try: + validated = validate_arguments(func, config=_SchemaConfig) # type: ignore[operator] + finally: + for param_name, original in _saved.items(): + if original is not None: + func.__annotations__[param_name] = original + else: + func.__annotations__.pop(param_name, None) # Let's ignore `self` and `cls` arguments for class and instance methods # If qualified name has a ".", then it likely belongs in a class namespace diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 1067ecfa69097..68372a4ba8a02 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -14,8 +14,10 @@ from typing import ( Annotated, Any, + ClassVar, Generic, Literal, + Protocol, TypeVar, cast, get_type_hints, @@ -3636,3 +3638,98 @@ def config_tool(name: str, *, enabled: bool, count: int, prefix: str) -> str: # Invoke with only required argument - falsy defaults should be applied result = config_tool.invoke({"name": "test"}) assert result == "name=test, enabled=False, count=0, prefix=''" + + +def test_injected_arg_with_protocol_union_type() -> None: + """Test that InjectedToolArg with a Protocol/Union type doesn't break schema. + + Regression test for https://github.com/langchain-ai/langchain/issues/32067. + When a tool parameter is annotated with InjectedToolArg and typed as a Union + of Protocol types (like StateLike), Pydantic cannot build a validator for it. + The fix ensures such parameters are excluded from the Pydantic schema entirely. + """ + + class TypedDictLike(Protocol): + __required_keys__: ClassVar[frozenset[str]] + __optional_keys__: ClassVar[frozenset[str]] + + class DataclassLike(Protocol): + __dataclass_fields__: ClassVar[dict[str, Any]] + + # A union of Protocol types — not valid for Pydantic isinstance checks + state_like_type = TypedDictLike | DataclassLike | BaseModel + + @tool + def tool_with_protocol_state( + query: str, + state: Annotated[state_like_type, InjectedToolArg], + ) -> str: + """Search with injected state. + + Args: + query: The search query. + state: The injected state. + """ + return f"query={query}" + + # get_input_schema includes injected args (used for runtime validation) + schema = _schema(tool_with_protocol_state.get_input_schema()) + assert "query" in schema["properties"] + assert "state" in schema["properties"] + + # tool_call_schema excludes injected args (sent to LLM) + tc_schema = _schema(tool_with_protocol_state.tool_call_schema) + assert "state" not in tc_schema.get("properties", {}) + + # Tool should be invokable with injected state value + result = tool_with_protocol_state.invoke( + {"query": "hello", "state": {"messages": []}} + ) + assert result == "query=hello" + + +def test_injected_arg_filtered_from_schema_at_creation() -> None: + """Test that all injected arg types are excluded from the generated schema. + + Verifies that _filter_schema_args filters InjectedToolArg-annotated params, + _DirectlyInjectedToolArg subclass params, and regular params correctly. + """ + + @dataclass + class CustomRuntime(_DirectlyInjectedToolArg): + data: dict[str, Any] + + @tool + def multi_injected_tool( + query: str, + injected_ann: Annotated[Any, InjectedToolArg], + injected_direct: CustomRuntime, + ) -> str: + """Tool with multiple injected args. + + Args: + query: The search query. + injected_ann: Annotated injected arg. + injected_direct: Directly injected arg. + """ + return query + + # get_input_schema includes injected args (used for runtime validation) + schema = _schema(multi_injected_tool.get_input_schema()) + assert "query" in schema["properties"] + assert "injected_ann" in schema["properties"] + assert "injected_direct" in schema["properties"] + + # tool_call_schema excludes injected args (sent to LLM) + tc_schema = _schema(multi_injected_tool.tool_call_schema) + assert list(tc_schema["properties"].keys()) == ["query"] + + # Invocation should work + result = multi_injected_tool.invoke( + { + "query": "test", + "injected_ann": "some_value", + "injected_direct": CustomRuntime(data={"key": "val"}), + } + ) + assert result == "test"