Skip to content

Commit 60ba9da

Browse files
.
1 parent 388bd2d commit 60ba9da

1 file changed

Lines changed: 15 additions & 5 deletions

File tree

  • deepeval/integrations/pydantic_ai

deepeval/integrations/pydantic_ai/agent.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import AsyncGenerator, Generic, Optional, List, Any
1+
from typing import AsyncIterator, Generic, Optional, List, Any
22
from contextvars import ContextVar
33
from contextlib import asynccontextmanager
44
from collections.abc import Sequence
@@ -41,6 +41,7 @@
4141
from pydantic_ai import models, _system_prompt
4242
from pydantic_ai.output import OutputDataT, OutputSpec
4343
from pydantic.json_schema import GenerateJsonSchema
44+
from pydantic_ai.result import StreamedRunResult
4445

4546
from deepeval.integrations.pydantic_ai.utils import create_patched_tool, update_trace_context, patch_llm_model
4647

@@ -59,8 +60,17 @@ def pydantic_ai_installed():
5960

6061
_IS_RUN_SYNC = ContextVar("deepeval_is_run_sync", default=False)
6162

63+
try:
64+
from typing import TypeVar
65+
AgentDepsT = TypeVar('AgentDepsT', default=None, covariant=True)
66+
OutputDataT = TypeVar('OutputDataT', default=str, covariant=True)
67+
except TypeError:
68+
from typing_extensions import TypeVar
69+
AgentDepsT = TypeVar('AgentDepsT', default=None, covariant=True)
70+
OutputDataT = TypeVar('OutputDataT', default=str, covariant=True)
71+
6272
class DeepEvalPydanticAIAgent(
63-
Agent[AgentDepsT, OutputDataT],
73+
Agent,
6474
Generic[AgentDepsT, OutputDataT], # make subclass generic
6575
):
6676

@@ -186,7 +196,7 @@ async def run(
186196
thread_id: Optional[str] = None,
187197
metrics: Optional[List[BaseMetric]] = None,
188198
metric_collection: Optional[str] = None,
189-
) -> AgentRunResult[Any]:
199+
) -> AgentRunResult[OutputDataT]:
190200
input = user_prompt
191201

192202
agent_name = super().name if super().name is not None else "Agent"
@@ -272,7 +282,7 @@ def run_sync(
272282
user_id: Optional[str] = None,
273283
metric_collection: Optional[str] = None,
274284
metrics: Optional[List[BaseMetric]] = None,
275-
) -> AgentRunResult[Any]:
285+
) -> AgentRunResult[OutputDataT]:
276286
input = user_prompt
277287

278288
token = _IS_RUN_SYNC.set(True)
@@ -364,7 +374,7 @@ async def run_stream(
364374
user_id: Optional[str] = None,
365375
metric_collection: Optional[str] = None,
366376
metrics: Optional[List[BaseMetric]] = None,
367-
) -> AsyncGenerator[AgentRunResult[Any], None]:
377+
) -> AsyncIterator[StreamedRunResult[AgentDepsT, OutputDataT]]:
368378
input = user_prompt
369379

370380
agent_name = super().name if super().name is not None else "Agent"

0 commit comments

Comments
 (0)