Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion instructor/dsl/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from jiter import from_json
from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined

from instructor.mode import Mode
from instructor.utils import extract_json_from_stream, extract_json_from_stream_async
Expand Down Expand Up @@ -171,14 +172,19 @@ def _build_partial_object(
else:
result[field_name] = field_value

# Set missing fields to None or empty nested models
# Set missing fields to their default value (if one exists), None, or empty nested models
for field_name, field_info in model.model_fields.items():
if field_name not in result:
field_type = field_info.annotation
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
result[field_name] = _build_partial_object(
{}, field_type, tracker, "", **kwargs
)
elif field_info.default is not PydanticUndefined:
# Use the field's default value (e.g. Literal["Person"] = "Person")
result[field_name] = field_info.default
elif field_info.default_factory is not None:
result[field_name] = field_info.default_factory()
else:
result[field_name] = None

Expand Down
130 changes: 130 additions & 0 deletions tests/dsl/test_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,3 +1129,133 @@ class Container(BaseModel):
assert result.content[0].text == "Introduction paragraph"
assert result.content[1].title == "Section 1.1"
assert result.content[1].content[1].title == "Subsection 1.1.1"


class TestDefaultValuesInPartialStreaming:
"""Tests for fields with default values during partial streaming.

Fields with default values (e.g., Literal["Person"] = "Person") should
use their default in partial responses rather than being set to None.

Fixes: https://github.com/instructor-ai/instructor/issues/2054
"""

def test_literal_default_present_from_first_chunk(self):
"""Literal field with default should appear immediately during streaming."""

class Person(BaseModel):
type: Literal["Person"] = "Person"
name: str
age: int

PartialModel = Partial[Person]

# Simulate streaming: first chunk has partial data, type not yet in JSON
chunks = ['{"name": "Jo', 'hn", "age": 25}']

results = list(PartialModel.model_from_chunks(iter(chunks)))
assert len(results) >= 1

# Even on the first partial chunk, type should be "Person" (the default)
assert results[0].type == "Person"
assert results[0].name == "Jo"

# Final result should also have the default
assert results[-1].type == "Person"
assert results[-1].name == "John"
assert results[-1].age == 25

def test_multiple_literal_defaults(self):
"""Multiple Literal fields with defaults should all appear immediately."""

class Event(BaseModel):
kind: Literal["event"] = "event"
source: Literal["api"] = "api"
message: str

PartialModel = Partial[Event]

chunks = ['{"message": "hel', 'lo"}']

results = list(PartialModel.model_from_chunks(iter(chunks)))
assert len(results) >= 1

# Both literal defaults should be present from the start
assert results[0].kind == "event"
assert results[0].source == "api"

def test_default_value_non_literal(self):
"""Non-Literal fields with defaults should also use their default."""

class Config(BaseModel):
retries: int = 3
name: str

PartialModel = Partial[Config]

chunks = ['{"name": "tes', 't"}']

results = list(PartialModel.model_from_chunks(iter(chunks)))
assert len(results) >= 1

# retries should be 3 (the default), not None
assert results[0].retries == 3

def test_default_factory_field(self):
"""Fields with default_factory should use the factory value."""

class Tags(BaseModel):
items: list[str] = Field(default_factory=list)
name: str

PartialModel = Partial[Tags]

chunks = ['{"name": "tes', 't"}']

results = list(PartialModel.model_from_chunks(iter(chunks)))
assert len(results) >= 1

# items should be [] (from default_factory), not None
assert results[0].items == []

def test_default_not_overridden_when_field_present(self):
"""When the field is present in the JSON, the streamed value should be used."""

class Person(BaseModel):
type: Literal["Person"] = "Person"
name: str

PartialModel = Partial[Person]

# JSON includes the type field explicitly
chunks = ['{"type": "Person", "name": "Alice"}']

results = list(PartialModel.model_from_chunks(iter(chunks)))
assert results[-1].type == "Person"
assert results[-1].name == "Alice"

@pytest.mark.asyncio
async def test_literal_default_present_in_async_streaming(self):
"""Async streaming should also use defaults for missing Literal fields."""

class Person(BaseModel):
type: Literal["Person"] = "Person"
name: str
age: int

PartialModel = Partial[Person]

async def async_chunks():
yield '{"name": "Jo'
yield 'hn", "age": 25}'

results = []
async for obj in PartialModel.model_from_chunks_async(async_chunks()):
results.append(obj)

assert len(results) >= 1
# type should be "Person" from the very first partial result
assert results[0].type == "Person"
assert results[-1].type == "Person"
assert results[-1].name == "John"
assert results[-1].age == 25