-
Notifications
You must be signed in to change notification settings - Fork 794
/
Copy path_utils.py
294 lines (217 loc) · 10 KB
/
_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
from __future__ import annotations as _annotations
import asyncio
import contextvars
import time
import uuid
from collections.abc import AsyncIterable, AsyncIterator, Iterator
from contextlib import asynccontextmanager, suppress
from dataclasses import dataclass, is_dataclass
from datetime import datetime, timezone
from functools import partial
from types import GenericAlias
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
from pydantic import BaseModel
from pydantic.json_schema import JsonSchemaValue
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
from pydantic_graph._utils import AbstractSpan
AbstractSpan = AbstractSpan
if TYPE_CHECKING:
from pydantic_ai.agent import AgentRun, AgentRunResult
from pydantic_graph import GraphRun, GraphRunResult
from . import messages as _messages
from .tools import ObjectJsonSchema
_P = ParamSpec('_P')
_R = TypeVar('_R')
async def run_in_executor(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs) -> _R:
loop = asyncio.get_running_loop()
ctx = contextvars.copy_context() # copy the current context to the new thread
func_call = partial(func, *args, **kwargs)
return await loop.run_in_executor(None, ctx.run, func_call)
def is_model_like(type_: Any) -> bool:
"""Check if something is a pydantic model, dataclass or typedict.
These should all generate a JSON Schema with `{"type": "object"}` and therefore be usable directly as
function parameters.
"""
return (
isinstance(type_, type)
and not isinstance(type_, GenericAlias)
and (issubclass(type_, BaseModel) or is_dataclass(type_) or is_typeddict(type_)) # pyright: ignore[reportUnknownArgumentType]
)
def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
from .exceptions import UserError
if schema.get('type') == 'object':
return schema
elif schema.get('$ref') is not None:
maybe_result = schema.get('$defs', {}).get(schema['$ref'][8:]) # This removes the initial "#/$defs/".
if "'$ref': '#/$defs/" in str(maybe_result):
return schema # We can't remove the $defs because the schema contains other references
return maybe_result
else:
raise UserError('Schema must be an object')
T = TypeVar('T')
@dataclass
class Some(Generic[T]):
"""Analogous to Rust's `Option::Some` type."""
value: T
Option: TypeAlias = Union[Some[T], None]
"""Analogous to Rust's `Option` type, usage: `Option[Thing]` is equivalent to `Some[Thing] | None`."""
class Unset:
"""A singleton to represent an unset value."""
pass
UNSET = Unset()
def is_set(t_or_unset: T | Unset) -> TypeGuard[T]:
return t_or_unset is not UNSET
@asynccontextmanager
async def group_by_temporal(
aiterable: AsyncIterable[T], soft_max_interval: float | None
) -> AsyncIterator[AsyncIterable[list[T]]]:
"""Group items from an async iterable into lists based on time interval between them.
Effectively, this debounces the iterator.
This returns a context manager usable as an iterator so any pending tasks can be cancelled if an error occurs
during iteration.
Usage:
```python
async with group_by_temporal(yield_groups(), 0.1) as groups_iter:
async for groups in groups_iter:
print(groups)
```
Args:
aiterable: The async iterable to group.
soft_max_interval: Maximum interval over which to group items, this should avoid a trickle of items causing
a group to never be yielded. It's a soft max in the sense that once we're over this time, we yield items
as soon as `aiter.__anext__()` returns. If `None`, no grouping/debouncing is performed
Returns:
A context manager usable as an async iterable of lists of items produced by the input async iterable.
"""
if soft_max_interval is None:
async def async_iter_groups_noop() -> AsyncIterator[list[T]]:
async for item in aiterable:
yield [item]
yield async_iter_groups_noop()
return
# we might wait for the next item more than once, so we store the task to await next time
task: asyncio.Task[T] | None = None
async def async_iter_groups() -> AsyncIterator[list[T]]:
nonlocal task
assert soft_max_interval is not None and soft_max_interval >= 0, 'soft_max_interval must be a positive number'
buffer: list[T] = []
group_start_time = time.monotonic()
aiterator = aiterable.__aiter__()
while True:
if group_start_time is None:
# group hasn't started, we just wait for the maximum interval
wait_time = soft_max_interval
else:
# wait for the time remaining in the group
wait_time = soft_max_interval - (time.monotonic() - group_start_time)
# if there's no current task, we get the next one
if task is None:
# aiter.__anext__() returns an Awaitable[T], not a Coroutine which asyncio.create_task expects
# so far, this doesn't seem to be a problem
task = asyncio.create_task(aiterator.__anext__()) # pyright: ignore[reportArgumentType]
# we use asyncio.wait to avoid cancelling the coroutine if it's not done
done, _ = await asyncio.wait((task,), timeout=wait_time)
if done:
# the one task we waited for completed
try:
item = done.pop().result()
except StopAsyncIteration:
# if the task raised StopAsyncIteration, we're done iterating
if buffer:
yield buffer
task = None
break
else:
# we got an item, add it to the buffer and set task to None to get the next item
buffer.append(item)
task = None
# if this is the first item in the group, set the group start time
if group_start_time is None:
group_start_time = time.monotonic()
elif buffer:
# otherwise if the task timeout expired and we have items in the buffer, yield the buffer
yield buffer
# clear the buffer and reset the group start time ready for the next group
buffer = []
group_start_time = None
try:
yield async_iter_groups()
finally: # pragma: no cover
# after iteration if a tasks still exists, cancel it, this will only happen if an error occurred
if task:
task.cancel('Cancelling due to error in iterator')
with suppress(asyncio.CancelledError):
await task
def sync_anext(iterator: Iterator[T]) -> T:
"""Get the next item from a sync iterator, raising `StopAsyncIteration` if it's exhausted.
Useful when iterating over a sync iterator in an async context.
"""
try:
return next(iterator)
except StopIteration as e:
raise StopAsyncIteration() from e
def now_utc() -> datetime:
return datetime.now(tz=timezone.utc)
def guard_tool_call_id(t: _messages.ToolCallPart | _messages.ToolReturnPart | _messages.RetryPromptPart) -> str:
"""Type guard that either returns the tool call id or generates a new one if it's None."""
return t.tool_call_id or generate_tool_call_id()
def generate_tool_call_id() -> str:
"""Generate a tool call id.
Ensure that the tool call id is unique.
"""
return f'pyd_ai_{uuid.uuid4().hex}'
class PeekableAsyncStream(Generic[T]):
"""Wraps an async iterable of type T and allows peeking at the *next* item without consuming it.
We only buffer one item at a time (the next item). Once that item is yielded, it is discarded.
This is a single-pass stream.
"""
def __init__(self, source: AsyncIterable[T]):
self._source = source
self._source_iter: AsyncIterator[T] | None = None
self._buffer: T | Unset = UNSET
self._exhausted = False
async def peek(self) -> T | Unset:
"""Returns the next item that would be yielded without consuming it.
Returns None if the stream is exhausted.
"""
if self._exhausted:
return UNSET
# If we already have a buffered item, just return it.
if not isinstance(self._buffer, Unset):
return self._buffer
# Otherwise, we need to fetch the next item from the underlying iterator.
if self._source_iter is None:
self._source_iter = self._source.__aiter__()
try:
self._buffer = await self._source_iter.__anext__()
except StopAsyncIteration:
self._exhausted = True
return UNSET
return self._buffer
async def is_exhausted(self) -> bool:
"""Returns True if the stream is exhausted, False otherwise."""
return isinstance(await self.peek(), Unset)
def __aiter__(self) -> AsyncIterator[T]:
# For a single-pass iteration, we can return self as the iterator.
return self
async def __anext__(self) -> T:
"""Yields the buffered item if present, otherwise fetches the next item from the underlying source.
Raises StopAsyncIteration if the stream is exhausted.
"""
if self._exhausted:
raise StopAsyncIteration
# If we have a buffered item, yield it.
if not isinstance(self._buffer, Unset):
item = self._buffer
self._buffer = UNSET
return item
# Otherwise, fetch the next item from the source.
if self._source_iter is None:
self._source_iter = self._source.__aiter__()
try:
return await self._source_iter.__anext__()
except StopAsyncIteration:
self._exhausted = True
raise
def get_traceparent(x: AgentRun | AgentRunResult | GraphRun | GraphRunResult) -> str:
return x._traceparent(required=False) or '' # type: ignore[reportPrivateUsage]