Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion libs/core/langchain_core/utils/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
Annotated,
Any,
Literal,
NotRequired,
Required,
Union,
cast,
get_args,
Expand Down Expand Up @@ -280,7 +282,23 @@
docstring, list(annotations_)
)
fields: dict = {}
optional_keys = getattr(typed_dict, "__optional_keys__", frozenset())
for arg, arg_type in annotations_.items():
# Unwrap NotRequired[X] / Required[X] wrappers that appear when
# get_type_hints(include_extras=True) is used on a TypedDict.
# NotRequired marks a field as not required (has a default of None),
# Required marks a field as required (overrides total=False).
is_not_required = get_origin(arg_type) in {
NotRequired,
typing_extensions.NotRequired,
}
is_required_wrapper = get_origin(arg_type) in {
Required,
typing_extensions.Required,
}
if is_not_required or is_required_wrapper:
arg_type = get_args(arg_type)[0]

Check failure on line 300 in libs/core/langchain_core/utils/function_calling.py

View workflow job for this annotation

GitHub Actions / lint (libs/core, 3.12) / Python 3.12

ruff (PLW2901)

langchain_core/utils/function_calling.py:300:17: PLW2901 `for` loop variable `arg_type` overwritten by assignment target

Check failure on line 300 in libs/core/langchain_core/utils/function_calling.py

View workflow job for this annotation

GitHub Actions / lint (libs/core, 3.13) / Python 3.13

ruff (PLW2901)

langchain_core/utils/function_calling.py:300:17: PLW2901 `for` loop variable `arg_type` overwritten by assignment target

if get_origin(arg_type) in {Annotated, typing_extensions.Annotated}:
annotated_args = get_args(arg_type)
new_arg_type = _convert_any_typed_dicts_to_pydantic(
Expand All @@ -300,12 +318,22 @@
raise ValueError(msg)
if arg_desc := arg_descriptions.get(arg):
field_kwargs["description"] = arg_desc
# If the field was originally NotRequired and no explicit default
# was provided via Annotated metadata, mark it as optional.
if is_not_required and "default" not in field_kwargs:
field_kwargs["default"] = None
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
else:
new_arg_type = _convert_any_typed_dicts_to_pydantic(
arg_type, depth=depth + 1, visited=visited
)
field_kwargs = {"default": ...}
# NotRequired fields (or fields on total=False TypedDicts) have no
# mandatory value — use None as the default so Pydantic treats them
# as optional in the generated JSON schema.
if is_not_required or arg in optional_keys:
field_kwargs: dict = {"default": None}
else:
field_kwargs = {"default": ...}
if arg_desc := arg_descriptions.get(arg):
field_kwargs["description"] = arg_desc
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
Expand All @@ -321,6 +349,17 @@
_convert_any_typed_dicts_to_pydantic(arg, depth=depth + 1, visited=visited)
for arg in type_args
)
# NotRequired and Required only accept a single type argument; passing a
# tuple causes "NotRequired accepts only a single type. Got (<class ...>,)".
# Unwrap the single-element tuple in those cases.
_single_arg_origins = {
NotRequired,
typing_extensions.NotRequired,
Required,
typing_extensions.Required,
}
if origin in _single_arg_origins and len(type_args) == 1:
return cast("type", subscriptable_origin[type_args[0]]) # type: ignore[index]
return cast("type", subscriptable_origin[type_args]) # type: ignore[index]
return type_

Expand Down
79 changes: 79 additions & 0 deletions libs/core/tests/unit_tests/utils/test_function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import (
Any,
Literal,
NotRequired,
Required,
TypeAlias,
)
from typing import TypedDict as TypingTypedDict
Expand Down Expand Up @@ -1242,3 +1244,80 @@ def test_convert_to_openai_function_json_schema_missing_title_includes_schema()
}
with pytest.raises(ValueError, match="my_field"):
convert_to_openai_function(schema_without_title)


@pytest.mark.parametrize("typed_dict", [ExtensionsTypedDict, TypingTypedDict])
def test_convert_typed_dict_notrequired_fields(typed_dict: type) -> None:
"""NotRequired fields in TypedDict must not raise TypeError (issue #34085).

Required fields must appear in the ``required`` list; NotRequired fields
must be absent from ``required`` and still appear as schema properties.
"""

class Tool(typed_dict): # type: ignore[misc]
"""A tool with optional fields."""

required_field: str
optional_field: NotRequired[str]
optional_int: NotRequired[int]

result = _convert_typed_dict_to_openai_function(Tool)

assert result["name"] == "Tool"
params = result["parameters"]
props = params["properties"]

# Both fields must appear in properties
assert "required_field" in props
assert "optional_field" in props
assert "optional_int" in props

assert props["required_field"] == {"type": "string"}
assert props["optional_field"] == {"type": "string"}
assert props["optional_int"] == {"type": "integer"}

# Only the truly required field is in the required list
assert params["required"] == ["required_field"]


@pytest.mark.parametrize("typed_dict", [ExtensionsTypedDict, TypingTypedDict])
def test_convert_typed_dict_required_fields_total_false(typed_dict: type) -> None:
"""Required[] on a total=False TypedDict forces a field into ``required``."""

class Tool(typed_dict, total=False): # type: ignore[misc]
"""A total=False tool."""

required_field: Required[str]
optional_field: str

result = _convert_typed_dict_to_openai_function(Tool)

params = result["parameters"]
props = params["properties"]

assert "required_field" in props
assert "optional_field" in props

# Required[] field must appear in the required list; optional_field must not
assert "required_field" in params.get("required", [])
assert "optional_field" not in params.get("required", [])


@pytest.mark.parametrize("typed_dict", [ExtensionsTypedDict, TypingTypedDict])
def test_convert_typed_dict_notrequired_nested_type(typed_dict: type) -> None:
"""NotRequired wrapping a generic type (e.g. list[str]) must not raise."""

class Tool(typed_dict): # type: ignore[misc]
"""Tool with nested NotRequired."""

items: NotRequired[list[str]]
count: int

result = _convert_typed_dict_to_openai_function(Tool)

params = result["parameters"]
props = params["properties"]

assert props["items"] == {"type": "array", "items": {"type": "string"}}
assert props["count"] == {"type": "integer"}
assert params["required"] == ["count"]
Loading