|
1 | 1 | import inspect |
2 | 2 | from typing import Optional, List |
3 | 3 | from contextvars import ContextVar |
| 4 | +from contextlib import asynccontextmanager |
4 | 5 |
|
5 | 6 | from deepeval.prompt import Prompt |
6 | 7 | from deepeval.tracing.types import AgentSpan |
@@ -96,10 +97,12 @@ async def run( |
96 | 97 | bound = sig.bind_partial(*args, **kwargs) |
97 | 98 | bound.apply_defaults() |
98 | 99 | input = bound.arguments.get("user_prompt", None) |
| 100 | + |
| 101 | + agent_name = super().name if super().name is not None else "Agent" |
99 | 102 |
|
100 | 103 | with Observer( |
101 | 104 | span_type="agent" if not _IS_RUN_SYNC.get() else "custom", |
102 | | - func_name="Agent" if not _IS_RUN_SYNC.get() else "run", |
| 105 | + func_name=agent_name if not _IS_RUN_SYNC.get() else "run", |
103 | 106 | function_kwargs={"input": input}, |
104 | 107 | metrics=self.agent_metrics if not _IS_RUN_SYNC.get() else None, |
105 | 108 | metric_collection=self.agent_metric_collection if not _IS_RUN_SYNC.get() else None, |
@@ -149,9 +152,11 @@ def run_sync( |
149 | 152 |
|
150 | 153 | token = _IS_RUN_SYNC.set(True) |
151 | 154 |
|
| 155 | + agent_name = super().name if super().name is not None else "Agent" |
| 156 | + |
152 | 157 | with Observer( |
153 | 158 | span_type="agent", |
154 | | - func_name="Agent", |
| 159 | + func_name=agent_name, |
155 | 160 | function_kwargs={"input": input}, |
156 | 161 | metrics=self.agent_metrics, |
157 | 162 | metric_collection=self.agent_metric_collection, |
@@ -187,7 +192,64 @@ def run_sync( |
187 | 192 |
|
188 | 193 | return result |
189 | 194 |
|
190 | | - |
| 195 | + @asynccontextmanager |
| 196 | + async def run_stream( |
| 197 | + self, |
| 198 | + *args, |
| 199 | + name: Optional[str] = None, |
| 200 | + tags: Optional[List[str]] = None, |
| 201 | + metadata: Optional[dict] = None, |
| 202 | + thread_id: Optional[str] = None, |
| 203 | + user_id: Optional[str] = None, |
| 204 | + metric_collection: Optional[str] = None, |
| 205 | + metrics: Optional[List[BaseMetric]] = None, |
| 206 | + **kwargs |
| 207 | + ): |
| 208 | + sig = inspect.signature(super().run_stream) |
| 209 | + super_params = sig.parameters |
| 210 | + super_kwargs = {k: v for k, v in kwargs.items() if k in super_params} |
| 211 | + bound = sig.bind_partial(*args, **super_kwargs) |
| 212 | + bound.apply_defaults() |
| 213 | + input = bound.arguments.get("user_prompt", None) |
| 214 | + |
| 215 | + agent_name = super().name if super().name is not None else "Agent" |
| 216 | + |
| 217 | + with Observer( |
| 218 | + span_type="agent", |
| 219 | + func_name=agent_name, |
| 220 | + function_kwargs={"input": input}, |
| 221 | + metrics=self.agent_metrics, |
| 222 | + metric_collection=self.agent_metric_collection, |
| 223 | + ) as observer: |
| 224 | + final_result = None |
| 225 | + async with super().run_stream(*args, **super_kwargs) as result: |
| 226 | + try: |
| 227 | + yield result |
| 228 | + finally: |
| 229 | + try: |
| 230 | + final_result = await result.get_output() |
| 231 | + observer.result = final_result |
| 232 | + except Exception: |
| 233 | + pass |
| 234 | + |
| 235 | + update_trace_context( |
| 236 | + trace_name=name if name is not None else self.trace_name, |
| 237 | + trace_tags=tags if tags is not None else self.trace_tags, |
| 238 | + trace_metadata=metadata if metadata is not None else self.trace_metadata, |
| 239 | + trace_thread_id=thread_id if thread_id is not None else self.trace_thread_id, |
| 240 | + trace_user_id=user_id if user_id is not None else self.trace_user_id, |
| 241 | + trace_metric_collection=metric_collection if metric_collection is not None else self.trace_metric_collection, |
| 242 | + trace_metrics=metrics if metrics is not None else self.trace_metrics, |
| 243 | + trace_input=input, |
| 244 | + trace_output=(final_result if final_result is not None else None), |
| 245 | + ) |
| 246 | + agent_span: AgentSpan = current_span_context.get() |
| 247 | + try: |
| 248 | + if final_result is not None: |
| 249 | + agent_span.tools_called = extract_tools_called(final_result) |
| 250 | + except: |
| 251 | + pass |
| 252 | + |
191 | 253 | def tool( |
192 | 254 | self, |
193 | 255 | *args, |
|
0 commit comments