11from dataclasses import dataclass
2- from typing import Any , AsyncGenerator , Callable , Generic , List , TypeVar
2+ from typing import Any , AsyncGenerator , Callable , Generic , List , Optional , TypeVar
33
44import ijson
55from pydantic import BaseModel
@@ -73,11 +73,13 @@ def __init__(
7373 prefix : str ,
7474 streamed_item_cls : type [StreamedItem ],
7575 remainder_cls : type [Remainder ] = _NoopRemainder ,
76+ validation_context : Optional [object ] = None ,
7677 ):
7778 self .input = input
7879 self .prefix = prefix
7980 self .streamed_item_cls = streamed_item_cls
8081 self .remainder_cls = remainder_cls
82+ self .validation_context = validation_context
8183 self .parser = ijson .parse_async (
8284 _AsyncStreamWrapper (input ),
8385 multiple_values = True ,
@@ -102,7 +104,7 @@ async def __anext__(self) -> StreamedItem:
102104 self .remainder_builder .event (event , value )
103105
104106 if event not in ("start_map" , "start_array" ):
105- return self .streamed_item_cls .model_validate (value )
107+ return self .streamed_item_cls .model_validate (value , context = self . validation_context )
106108
107109 depth = 1
108110 obj = _ObjectBuilder ()
@@ -118,7 +120,7 @@ async def __anext__(self) -> StreamedItem:
118120 elif event in ("end_map" , "end_array" ):
119121 depth -= 1
120122
121- return self .streamed_item_cls .model_validate (obj .get_value ())
123+ return self .streamed_item_cls .model_validate (obj .get_value (), context = self . validation_context )
122124
123125 def get_remainder (self ) -> Remainder :
124126 if not self .done :
0 commit comments