|
14 | 14 | Any, |
15 | 15 | Callable, |
16 | 16 | Generic, |
| 17 | + Union, |
17 | 18 | TypeVar, |
18 | 19 | cast, |
19 | 20 | get_args, |
@@ -595,22 +596,66 @@ def prepare_response_model(response_model: type[T] | None) -> type[T] | None: |
595 | 596 | if response_model is None: |
596 | 597 | return None |
597 | 598 |
|
598 | | - if is_simple_type(response_model): |
599 | | - from instructor.dsl.simple_type import ModelAdapter |
| 599 | + # `list[int | str]` and similar scalar lists are treated as simple types and should |
| 600 | + # be adapted, not converted into an IterableModel. |
| 601 | + origin = get_origin(response_model) |
| 602 | + if origin is list and is_simple_type(response_model): |
| 603 | + args = get_args(response_model) |
| 604 | + |
| 605 | + def _is_model_type(t: Any) -> bool: |
| 606 | + if inspect.isclass(t) and issubclass(t, BaseModel): |
| 607 | + return True |
| 608 | + return get_origin(t) is Union and all( |
| 609 | + inspect.isclass(m) and issubclass(m, BaseModel) for m in get_args(t) |
| 610 | + ) |
600 | 611 |
|
601 | | - response_model = ModelAdapter[response_model] |
| 612 | + # If the list element is a Pydantic model (or union of models), this is a |
| 613 | + # structured "iterable extraction" response model, not a simple scalar list. |
| 614 | + if args and _is_model_type(args[0]): |
| 615 | + origin = None |
| 616 | + else: |
| 617 | + from instructor.dsl.simple_type import ModelAdapter |
| 618 | + |
| 619 | + response_model = ModelAdapter[response_model] # type: ignore[invalid-type-form] |
| 620 | + origin = get_origin(response_model) |
602 | 621 |
|
603 | 622 | if is_typed_dict(response_model): |
604 | | - response_model: BaseModel = create_model( |
605 | | - response_model.__name__, |
606 | | - **{k: (v, ...) for k, v in response_model.__annotations__.items()}, |
| 623 | + response_model = cast( |
| 624 | + type[BaseModel], |
| 625 | + create_model( |
| 626 | + response_model.__name__, |
| 627 | + **{k: (v, ...) for k, v in response_model.__annotations__.items()}, |
| 628 | + ), |
607 | 629 | ) |
608 | 630 |
|
609 | | - if get_origin(response_model) is Iterable: |
| 631 | + # Recompute after potential wrapping/conversion above. |
| 632 | + origin = get_origin(response_model) |
| 633 | + if origin in {Iterable, list}: |
610 | 634 | from instructor.dsl.iterable import IterableModel |
611 | 635 |
|
612 | | - iterable_element_class = get_args(response_model)[0] |
613 | | - response_model = cast(BaseModel, IterableModel(iterable_element_class)) # type: ignore |
| 636 | + args = get_args(response_model) |
| 637 | + if not args or args[0] is None: |
| 638 | + raise ValueError( |
| 639 | + "response_model must be parameterized, e.g. list[User] or Iterable[User]" |
| 640 | + ) |
| 641 | + iterable_element_class = args[0] |
| 642 | + if is_typed_dict(iterable_element_class): |
| 643 | + iterable_element_class = cast( |
| 644 | + type[BaseModel], |
| 645 | + create_model( |
| 646 | + iterable_element_class.__name__, |
| 647 | + **{ |
| 648 | + k: (v, ...) |
| 649 | + for k, v in iterable_element_class.__annotations__.items() |
| 650 | + }, |
| 651 | + ), |
| 652 | + ) |
| 653 | + response_model = IterableModel(cast(type[BaseModel], iterable_element_class)) |
| 654 | + |
| 655 | + if is_simple_type(response_model): |
| 656 | + from instructor.dsl.simple_type import ModelAdapter |
| 657 | + |
| 658 | + response_model = ModelAdapter[response_model] # type: ignore[invalid-type-form] |
614 | 659 |
|
615 | 660 | # Import here to avoid circular dependency |
616 | 661 | from ..processing.function_calls import OpenAISchema, openai_schema |
|
0 commit comments