diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index ea3d0f26231b0..96007325f1492 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -13,6 +13,8 @@ Annotated, Any, Literal, + NotRequired, + Required, Union, cast, get_args, @@ -280,7 +282,23 @@ def _convert_any_typed_dicts_to_pydantic( docstring, list(annotations_) ) fields: dict = {} + optional_keys = getattr(typed_dict, "__optional_keys__", frozenset()) for arg, arg_type in annotations_.items(): + # Unwrap NotRequired[X] / Required[X] wrappers that appear when + # get_type_hints(include_extras=True) is used on a TypedDict. + # NotRequired marks a field as not required (has a default of None), + # Required marks a field as required (overrides total=False). + is_not_required = get_origin(arg_type) in { + NotRequired, + typing_extensions.NotRequired, + } + is_required_wrapper = get_origin(arg_type) in { + Required, + typing_extensions.Required, + } + if is_not_required or is_required_wrapper: + arg_type = get_args(arg_type)[0] + if get_origin(arg_type) in {Annotated, typing_extensions.Annotated}: annotated_args = get_args(arg_type) new_arg_type = _convert_any_typed_dicts_to_pydantic( @@ -300,12 +318,22 @@ def _convert_any_typed_dicts_to_pydantic( raise ValueError(msg) if arg_desc := arg_descriptions.get(arg): field_kwargs["description"] = arg_desc + # If the field was originally NotRequired and no explicit default + # was provided via Annotated metadata, mark it as optional. + if is_not_required and "default" not in field_kwargs: + field_kwargs["default"] = None fields[arg] = (new_arg_type, Field_v1(**field_kwargs)) else: new_arg_type = _convert_any_typed_dicts_to_pydantic( arg_type, depth=depth + 1, visited=visited ) - field_kwargs = {"default": ...} + # NotRequired fields (or fields on total=False TypedDicts) have no + # mandatory value — use None as the default so Pydantic treats them + # as optional in the generated JSON schema. + if is_not_required or arg in optional_keys: + field_kwargs: dict = {"default": None} + else: + field_kwargs = {"default": ...} if arg_desc := arg_descriptions.get(arg): field_kwargs["description"] = arg_desc fields[arg] = (new_arg_type, Field_v1(**field_kwargs)) @@ -321,6 +349,17 @@ def _convert_any_typed_dicts_to_pydantic( _convert_any_typed_dicts_to_pydantic(arg, depth=depth + 1, visited=visited) for arg in type_args ) + # NotRequired and Required only accept a single type argument; passing a + # tuple causes "NotRequired accepts only a single type. Got (,)". + # Unwrap the single-element tuple in those cases. + _single_arg_origins = { + NotRequired, + typing_extensions.NotRequired, + Required, + typing_extensions.Required, + } + if origin in _single_arg_origins and len(type_args) == 1: + return cast("type", subscriptable_origin[type_args[0]]) # type: ignore[index] return cast("type", subscriptable_origin[type_args]) # type: ignore[index] return type_ diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index e56ae47e221d5..c5e59e701b533 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -4,6 +4,8 @@ from typing import ( Any, Literal, + NotRequired, + Required, TypeAlias, ) from typing import TypedDict as TypingTypedDict @@ -1242,3 +1244,80 @@ def test_convert_to_openai_function_json_schema_missing_title_includes_schema() } with pytest.raises(ValueError, match="my_field"): convert_to_openai_function(schema_without_title) + + +@pytest.mark.parametrize("typed_dict", [ExtensionsTypedDict, TypingTypedDict]) +def test_convert_typed_dict_notrequired_fields(typed_dict: type) -> None: + """NotRequired fields in TypedDict must not raise TypeError (issue #34085). + + Required fields must appear in the ``required`` list; NotRequired fields + must be absent from ``required`` and still appear as schema properties. + """ + + class Tool(typed_dict): # type: ignore[misc] + """A tool with optional fields.""" + + required_field: str + optional_field: NotRequired[str] + optional_int: NotRequired[int] + + result = _convert_typed_dict_to_openai_function(Tool) + + assert result["name"] == "Tool" + params = result["parameters"] + props = params["properties"] + + # Both fields must appear in properties + assert "required_field" in props + assert "optional_field" in props + assert "optional_int" in props + + assert props["required_field"] == {"type": "string"} + assert props["optional_field"] == {"type": "string"} + assert props["optional_int"] == {"type": "integer"} + + # Only the truly required field is in the required list + assert params["required"] == ["required_field"] + + +@pytest.mark.parametrize("typed_dict", [ExtensionsTypedDict, TypingTypedDict]) +def test_convert_typed_dict_required_fields_total_false(typed_dict: type) -> None: + """Required[] on a total=False TypedDict forces a field into ``required``.""" + + class Tool(typed_dict, total=False): # type: ignore[misc] + """A total=False tool.""" + + required_field: Required[str] + optional_field: str + + result = _convert_typed_dict_to_openai_function(Tool) + + params = result["parameters"] + props = params["properties"] + + assert "required_field" in props + assert "optional_field" in props + + # Required[] field must appear in the required list; optional_field must not + assert "required_field" in params.get("required", []) + assert "optional_field" not in params.get("required", []) + + +@pytest.mark.parametrize("typed_dict", [ExtensionsTypedDict, TypingTypedDict]) +def test_convert_typed_dict_notrequired_nested_type(typed_dict: type) -> None: + """NotRequired wrapping a generic type (e.g. list[str]) must not raise.""" + + class Tool(typed_dict): # type: ignore[misc] + """Tool with nested NotRequired.""" + + items: NotRequired[list[str]] + count: int + + result = _convert_typed_dict_to_openai_function(Tool) + + params = result["parameters"] + props = params["properties"] + + assert props["items"] == {"type": "array", "items": {"type": "string"}} + assert props["count"] == {"type": "integer"} + assert params["required"] == ["count"]