Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from pydantic.v1 import validate_arguments as validate_arguments_v1
from pydantic_core import SchemaError as PydanticSchemaError
from typing_extensions import override

from langchain_core.callbacks import (
Expand Down Expand Up @@ -326,7 +327,27 @@ def create_schema_from_function(
# This code should be re-written to simply construct a Pydantic model
# using inspect.signature and create_model.
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore[operator]
try:
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore[operator]
except PydanticSchemaError:
# Injected arg types (e.g. Protocol unions like StateLike) may
# not be valid for Pydantic schema creation. Replace them with
# Any and retry, since they are filtered from the final schema.
_saved: dict[str, Any] = {}
for param_name, param in sig.parameters.items():
if _is_injected_arg_type(param.annotation):
_saved[param_name] = func.__annotations__.get(param_name)
func.__annotations__[param_name] = Annotated[
Any, InjectedToolArg()
]
try:
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore[operator]
finally:
for param_name, original in _saved.items():
if original is not None:
func.__annotations__[param_name] = original
else:
func.__annotations__.pop(param_name, None)

# Let's ignore `self` and `cls` arguments for class and instance methods
# If qualified name has a ".", then it likely belongs in a class namespace
Expand Down
97 changes: 97 additions & 0 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from typing import (
Annotated,
Any,
ClassVar,
Generic,
Literal,
Protocol,
TypeVar,
cast,
get_type_hints,
Expand Down Expand Up @@ -3636,3 +3638,98 @@ def config_tool(name: str, *, enabled: bool, count: int, prefix: str) -> str:
# Invoke with only required argument - falsy defaults should be applied
result = config_tool.invoke({"name": "test"})
assert result == "name=test, enabled=False, count=0, prefix=''"


def test_injected_arg_with_protocol_union_type() -> None:
"""Test that InjectedToolArg with a Protocol/Union type doesn't break schema.

Regression test for https://github.com/langchain-ai/langchain/issues/32067.
When a tool parameter is annotated with InjectedToolArg and typed as a Union
of Protocol types (like StateLike), Pydantic cannot build a validator for it.
The fix ensures such parameters are excluded from the Pydantic schema entirely.
"""

class TypedDictLike(Protocol):
__required_keys__: ClassVar[frozenset[str]]
__optional_keys__: ClassVar[frozenset[str]]

class DataclassLike(Protocol):
__dataclass_fields__: ClassVar[dict[str, Any]]

# A union of Protocol types — not valid for Pydantic isinstance checks
state_like_type = TypedDictLike | DataclassLike | BaseModel

@tool
def tool_with_protocol_state(
query: str,
state: Annotated[state_like_type, InjectedToolArg],
) -> str:
"""Search with injected state.

Args:
query: The search query.
state: The injected state.
"""
return f"query={query}"

# get_input_schema includes injected args (used for runtime validation)
schema = _schema(tool_with_protocol_state.get_input_schema())
assert "query" in schema["properties"]
assert "state" in schema["properties"]

# tool_call_schema excludes injected args (sent to LLM)
tc_schema = _schema(tool_with_protocol_state.tool_call_schema)
assert "state" not in tc_schema.get("properties", {})

# Tool should be invokable with injected state value
result = tool_with_protocol_state.invoke(
{"query": "hello", "state": {"messages": []}}
)
assert result == "query=hello"


def test_injected_arg_filtered_from_schema_at_creation() -> None:
"""Test that all injected arg types are excluded from the generated schema.

Verifies that _filter_schema_args filters InjectedToolArg-annotated params,
_DirectlyInjectedToolArg subclass params, and regular params correctly.
"""

@dataclass
class CustomRuntime(_DirectlyInjectedToolArg):
data: dict[str, Any]

@tool
def multi_injected_tool(
query: str,
injected_ann: Annotated[Any, InjectedToolArg],
injected_direct: CustomRuntime,
) -> str:
"""Tool with multiple injected args.

Args:
query: The search query.
injected_ann: Annotated injected arg.
injected_direct: Directly injected arg.
"""
return query

# get_input_schema includes injected args (used for runtime validation)
schema = _schema(multi_injected_tool.get_input_schema())
assert "query" in schema["properties"]
assert "injected_ann" in schema["properties"]
assert "injected_direct" in schema["properties"]

# tool_call_schema excludes injected args (sent to LLM)
tc_schema = _schema(multi_injected_tool.tool_call_schema)
assert list(tc_schema["properties"].keys()) == ["query"]

# Invocation should work
result = multi_injected_tool.invoke(
{
"query": "test",
"injected_ann": "some_value",
"injected_direct": CustomRuntime(data={"key": "val"}),
}
)
assert result == "test"