1- from typing import AsyncGenerator , Generic , Optional , List , Any
1+ from typing import AsyncIterator , Generic , Optional , List , Any
22from contextvars import ContextVar
33from contextlib import asynccontextmanager
44from collections .abc import Sequence
4141 from pydantic_ai import models , _system_prompt
4242 from pydantic_ai .output import OutputDataT , OutputSpec
4343 from pydantic .json_schema import GenerateJsonSchema
44+ from pydantic_ai .result import StreamedRunResult
4445
4546 from deepeval .integrations .pydantic_ai .utils import create_patched_tool , update_trace_context , patch_llm_model
4647
@@ -59,8 +60,17 @@ def pydantic_ai_installed():
5960
6061_IS_RUN_SYNC = ContextVar ("deepeval_is_run_sync" , default = False )
6162
63+ try :
64+ from typing import TypeVar
65+ AgentDepsT = TypeVar ('AgentDepsT' , default = None , covariant = True )
66+ OutputDataT = TypeVar ('OutputDataT' , default = str , covariant = True )
67+ except TypeError :
68+ from typing_extensions import TypeVar
69+ AgentDepsT = TypeVar ('AgentDepsT' , default = None , covariant = True )
70+ OutputDataT = TypeVar ('OutputDataT' , default = str , covariant = True )
71+
6272class DeepEvalPydanticAIAgent (
63- Agent [ AgentDepsT , OutputDataT ] ,
73+ Agent ,
6474 Generic [AgentDepsT , OutputDataT ], # make subclass generic
6575):
6676
@@ -186,7 +196,7 @@ async def run(
186196 thread_id : Optional [str ] = None ,
187197 metrics : Optional [List [BaseMetric ]] = None ,
188198 metric_collection : Optional [str ] = None ,
189- ) -> AgentRunResult [Any ]:
199+ ) -> AgentRunResult [OutputDataT ]:
190200 input = user_prompt
191201
192202 agent_name = super ().name if super ().name is not None else "Agent"
@@ -272,7 +282,7 @@ def run_sync(
272282 user_id : Optional [str ] = None ,
273283 metric_collection : Optional [str ] = None ,
274284 metrics : Optional [List [BaseMetric ]] = None ,
275- ) -> AgentRunResult [Any ]:
285+ ) -> AgentRunResult [OutputDataT ]:
276286 input = user_prompt
277287
278288 token = _IS_RUN_SYNC .set (True )
@@ -364,7 +374,7 @@ async def run_stream(
364374 user_id : Optional [str ] = None ,
365375 metric_collection : Optional [str ] = None ,
366376 metrics : Optional [List [BaseMetric ]] = None ,
367- ) -> AsyncGenerator [ AgentRunResult [ Any ], None ]:
377+ ) -> AsyncIterator [ StreamedRunResult [ AgentDepsT , OutputDataT ] ]:
368378 input = user_prompt
369379
370380 agent_name = super ().name if super ().name is not None else "Agent"
0 commit comments