diff --git a/src/anyio/__init__.py b/src/anyio/__init__.py index 578cda6f..8c79b411 100644 --- a/src/anyio/__init__.py +++ b/src/anyio/__init__.py @@ -40,6 +40,9 @@ from ._core._sockets import wait_socket_writable as wait_socket_writable from ._core._sockets import wait_writable as wait_writable from ._core._streams import create_memory_object_stream as create_memory_object_stream +from ._core._streams import ( + create_priority_memory_object_stream as create_priority_memory_object_stream, +) from ._core._subprocesses import open_process as open_process from ._core._subprocesses import run_process as run_process from ._core._synchronization import CapacityLimiter as CapacityLimiter diff --git a/src/anyio/_core/_streams.py b/src/anyio/_core/_streams.py index 6a9814e5..c59cc447 100644 --- a/src/anyio/_core/_streams.py +++ b/src/anyio/_core/_streams.py @@ -8,6 +8,7 @@ MemoryObjectReceiveStream, MemoryObjectSendStream, MemoryObjectStreamState, + PriorityMemoryObjectStreamState, ) T_Item = TypeVar("T_Item") @@ -50,3 +51,27 @@ def __new__( # type: ignore[misc] state = MemoryObjectStreamState[T_Item](max_buffer_size) return (MemoryObjectSendStream(state), MemoryObjectReceiveStream(state)) + + +class create_priority_memory_object_stream( + tuple[MemoryObjectSendStream[T_Item], MemoryObjectReceiveStream[T_Item]], +): + """ + Create a memory object stream. + + The stream's item type can be annotated like + :func:`create_priority_memory_object_stream[T_Item]`. + + :param max_buffer_size: number of items held in the buffer until ``send()`` starts + blocking + """ + + def __new__( # type: ignore[misc] + cls, max_buffer_size: float = 0 + ) -> tuple[MemoryObjectSendStream[T_Item], MemoryObjectReceiveStream[T_Item]]: + if max_buffer_size != math.inf and not isinstance(max_buffer_size, int): + raise ValueError("max_buffer_size must be either an integer or math.inf") + if max_buffer_size < 0: + raise ValueError("max_buffer_size cannot be negative") + state = PriorityMemoryObjectStreamState[T_Item](max_buffer_size) + return (MemoryObjectSendStream(state), MemoryObjectReceiveStream(state)) diff --git a/src/anyio/streams/memory.py b/src/anyio/streams/memory.py index 83bf1d97..e301281a 100644 --- a/src/anyio/streams/memory.py +++ b/src/anyio/streams/memory.py @@ -1,10 +1,11 @@ from __future__ import annotations +import heapq import warnings from collections import OrderedDict, deque from dataclasses import dataclass, field from types import TracebackType -from typing import Generic, NamedTuple, TypeVar +from typing import Generic, NamedTuple, Protocol, TypeVar from .. import ( BrokenResourceError, @@ -21,6 +22,26 @@ T_contra = TypeVar("T_contra", contravariant=True) +class Buffer(Protocol, Generic[T_Item]): + def append(self, /, v: T_Item) -> None: ... + def popleft(self) -> T_Item: ... + def __len__(self) -> int: ... + + +@dataclass(eq=False) +class HeapQ(Buffer[T_Item]): + items: list[T_Item] = field(default_factory=list, init=False) + + def append(self, v: T_Item) -> None: + heapq.heappush(self.items, v) + + def popleft(self) -> T_Item: + return heapq.heappop(self.items) + + def __len__(self) -> int: + return len(self.items) + + class MemoryObjectStreamStatistics(NamedTuple): current_buffer_used: int #: number of items stored in the buffer #: maximum number of items that can be stored on this stream (or :data:`math.inf`) @@ -48,7 +69,7 @@ def __repr__(self) -> str: @dataclass(eq=False) class MemoryObjectStreamState(Generic[T_Item]): max_buffer_size: float = field() - buffer: deque[T_Item] = field(init=False, default_factory=deque) + buffer: Buffer[T_Item] = field(init=False, default_factory=deque) open_send_channels: int = field(init=False, default=0) open_receive_channels: int = field(init=False, default=0) waiting_receivers: OrderedDict[Event, MemoryObjectItemReceiver[T_Item]] = field( @@ -69,6 +90,11 @@ def statistics(self) -> MemoryObjectStreamStatistics: ) +@dataclass(eq=False) +class PriorityMemoryObjectStreamState(MemoryObjectStreamState[T_Item]): + buffer: HeapQ[T_Item] = field(init=False, default_factory=HeapQ) + + @dataclass(eq=False) class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]): _state: MemoryObjectStreamState[T_co]