Skip to content

Added the receive_nowait() method to all streams #487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
and ``anyio.wait_writable`` before closing a socket. Among other things,
this prevents an OSError on the ``ProactorEventLoop``.
(`#896 <https://github.com/agronholm/anyio/pull/896>`_; PR by @graingert)
- Added the ``receive_nowait()`` method to the entire stream class hierarchy
- Fixed ``anyio.Path.copy()`` and ``anyio.Path.copy_into()`` failing on Python 3.14.0a7
- Fixed return annotation of ``__aexit__`` on async context managers. CMs which can
suppress exceptions should return ``bool``, or ``None`` otherwise.
Expand Down
70 changes: 70 additions & 0 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,16 @@ def _spawn_task_from_thread(
class StreamReaderWrapper(abc.ByteReceiveStream):
_stream: asyncio.StreamReader

def receive_nowait(self, max_bytes: int = 65536) -> bytes:
if self._stream.exception():
raise self._stream.exception()
elif not self._stream._buffer: # type: ignore[attr-defined]
raise WouldBlock

data = self._stream._buffer[:max_bytes] # type: ignore[attr-defined]
del self._stream._buffer[:max_bytes] # type: ignore[attr-defined]
return data

async def receive(self, max_bytes: int = 65536) -> bytes:
data = await self._stream.read(max_bytes)
if data:
Expand Down Expand Up @@ -1594,6 +1604,20 @@ async def aclose(self) -> None:
self._closed = True
self._transport.close()

def receive_nowait(self) -> tuple[bytes, IPSockAddrType]:
with self._receive_guard:
# If the buffer is empty, raise WouldBlock
if not self._protocol.read_queue and not self._transport.is_closing():
raise WouldBlock

try:
return self._protocol.read_queue.popleft()
except IndexError:
if self._closed:
raise ClosedResourceError from None
else:
raise BrokenResourceError from None

async def receive(self) -> tuple[bytes, IPSockAddrType]:
with self._receive_guard:
await AsyncIOBackend.checkpoint()
Expand Down Expand Up @@ -1642,6 +1666,22 @@ async def aclose(self) -> None:
self._closed = True
self._transport.close()

def receive_nowait(self) -> bytes:
with self._receive_guard:
# If the buffer is empty, raise WouldBlock
if not self._protocol.read_queue and not self._transport.is_closing():
raise WouldBlock

try:
packet = self._protocol.read_queue.popleft()
except IndexError:
if self._closed:
raise ClosedResourceError from None
else:
raise BrokenResourceError from None

return packet[0]

async def receive(self) -> bytes:
with self._receive_guard:
await AsyncIOBackend.checkpoint()
Expand Down Expand Up @@ -1674,6 +1714,21 @@ async def send(self, item: bytes) -> None:


class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
def receive_nowait(self) -> UNIXDatagramPacketType:
with self._receive_guard:
while True:
try:
data = self._raw_socket.recvfrom(65536)
except BlockingIOError:
raise WouldBlock from None
except OSError as exc:
if self._closing:
raise ClosedResourceError from None
else:
raise BrokenResourceError from exc
else:
return data

async def receive(self) -> UNIXDatagramPacketType:
loop = get_running_loop()
await AsyncIOBackend.checkpoint()
Expand Down Expand Up @@ -1710,6 +1765,21 @@ async def send(self, item: UNIXDatagramPacketType) -> None:


class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
def receive_nowait(self) -> bytes:
with self._receive_guard:
while True:
try:
data = self._raw_socket.recv(65536)
except BlockingIOError:
raise WouldBlock from None
except OSError as exc:
if self._closing:
raise ClosedResourceError from None
else:
raise BrokenResourceError from exc
else:
return data

async def receive(self) -> bytes:
loop = get_running_loop()
await AsyncIOBackend.checkpoint()
Expand Down
43 changes: 43 additions & 0 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,8 @@ def _convert_socket_error(self, exc: BaseException) -> NoReturn:
raise ClosedResourceError from exc
elif self._trio_socket.fileno() < 0 and self._closed:
raise ClosedResourceError from None
elif isinstance(exc, BlockingIOError):
raise WouldBlock from exc
elif isinstance(exc, OSError):
raise BrokenResourceError from exc
else:
Expand All @@ -419,6 +421,18 @@ def __init__(self, trio_socket: TrioSocketType) -> None:
self._receive_guard = ResourceGuard("reading from")
self._send_guard = ResourceGuard("writing to")

def receive_nowait(self, max_bytes: int = 65536) -> bytes:
with self._receive_guard:
try:
data = self._raw_socket.recv(max_bytes)
except BaseException as exc:
self._convert_socket_error(exc)

if data:
return data
else:
raise EndOfStream

async def receive(self, max_bytes: int = 65536) -> bytes:
with self._receive_guard:
try:
Expand Down Expand Up @@ -550,6 +564,14 @@ def __init__(self, trio_socket: TrioSocketType) -> None:
self._receive_guard = ResourceGuard("reading from")
self._send_guard = ResourceGuard("writing to")

def receive_nowait(self) -> tuple[bytes, IPSockAddrType]:
with self._receive_guard:
try:
data, addr = self._raw_socket.recvfrom(65536)
return data, convert_ipv6_sockaddr(addr)
except BaseException as exc:
self._convert_socket_error(exc)

async def receive(self) -> tuple[bytes, IPSockAddrType]:
with self._receive_guard:
try:
Expand All @@ -572,6 +594,13 @@ def __init__(self, trio_socket: TrioSocketType) -> None:
self._receive_guard = ResourceGuard("reading from")
self._send_guard = ResourceGuard("writing to")

def receive_nowait(self) -> bytes:
with self._receive_guard:
try:
return self._raw_socket.recv(65536)
except BaseException as exc:
self._convert_socket_error(exc)

async def receive(self) -> bytes:
with self._receive_guard:
try:
Expand All @@ -593,6 +622,13 @@ def __init__(self, trio_socket: TrioSocketType) -> None:
self._receive_guard = ResourceGuard("reading from")
self._send_guard = ResourceGuard("writing to")

def receive_nowait(self) -> UNIXDatagramPacketType:
with self._receive_guard:
try:
return self._raw_socket.recvfrom(65536)
except BaseException as exc:
self._convert_socket_error(exc)

async def receive(self) -> UNIXDatagramPacketType:
with self._receive_guard:
try:
Expand All @@ -617,6 +653,13 @@ def __init__(self, trio_socket: TrioSocketType) -> None:
self._receive_guard = ResourceGuard("reading from")
self._send_guard = ResourceGuard("writing to")

def receive_nowait(self) -> bytes:
with self._receive_guard:
try:
return self._raw_socket.recv(65536)
except BaseException as exc:
self._convert_socket_error(exc)

async def receive(self) -> bytes:
with self._receive_guard:
try:
Expand Down
42 changes: 41 additions & 1 deletion src/anyio/abc/_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from abc import abstractmethod
from collections.abc import Callable
from typing import Any, Generic, TypeVar, Union
from warnings import warn

from .._core._exceptions import EndOfStream
from .._core._exceptions import EndOfStream, WouldBlock
from .._core._typedattr import TypedAttributeProvider
from ._resources import AsyncResource
from ._tasks import TaskGroup
Expand All @@ -27,6 +28,16 @@ class UnreliableObjectReceiveStream(
given type parameter.
"""

def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if cls.receive_nowait is UnreliableObjectReceiveStream.receive_nowait:
warn(
f"{cls.__qualname__} does not implement receive_nowait(). In v5.0, "
f"receive_nowait() will become an abstract method and an exception "
f"will be raised if not implemented in a stream class.",
DeprecationWarning,
)

def __aiter__(self) -> UnreliableObjectReceiveStream[T_co]:
return self

Expand All @@ -36,6 +47,20 @@ async def __anext__(self) -> T_co:
except EndOfStream:
raise StopAsyncIteration

def receive_nowait(self) -> T_co:
"""
Receive the next item if it can be done without waiting.

:raises ~anyio.ClosedResourceError: if the receive stream has been explicitly
closed
:raises ~anyio.EndOfStream: if this stream has been closed from the other end
:raises ~anyio.BrokenResourceError: if this stream has been rendered unusable
due to external causes
:raises ~anyio.WouldBlock: if there is no item immediately available

"""
raise NotImplementedError

@abstractmethod
async def receive(self) -> T_co:
"""
Expand Down Expand Up @@ -132,6 +157,21 @@ async def __anext__(self) -> bytes:
except EndOfStream:
raise StopAsyncIteration

def receive_nowait(self, max_bytes: int = 65536) -> bytes:
"""
Receive at most ``max_bytes`` bytes from the peer, if it can be done without
blocking.

.. note:: Implementers of this interface should not return an empty
:class:`bytes` object, and users should ignore them.

:param max_bytes: maximum number of bytes to receive
:return: the received bytes
:raises ~anyio.EndOfStream: if this stream has been closed from the other end
:raises ~anyio.WouldBlock: if there is no data waiting to be received
"""
raise WouldBlock

@abstractmethod
async def receive(self, max_bytes: int = 65536) -> bytes:
"""
Expand Down
21 changes: 21 additions & 0 deletions src/anyio/streams/buffered.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ def buffer(self) -> bytes:
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return self.receive_stream.extra_attributes

def receive_nowait(self, max_bytes: int = 65536) -> bytes:
if self._closed:
raise ClosedResourceError

if self._buffer:
chunk = bytes(self._buffer[:max_bytes])
del self._buffer[:max_bytes]
return chunk
elif isinstance(self.receive_stream, ByteReceiveStream):
return self.receive_stream.receive_nowait(max_bytes)
else:
# With a bytes-oriented object stream, we need to handle any surplus bytes
# we get from the receive_nowait() call
chunk = self.receive_stream.receive_nowait()
if len(chunk) > max_bytes:
# Save the surplus bytes in the buffer
self._buffer.extend(chunk[max_bytes:])
return chunk[:max_bytes]
else:
return chunk

async def receive(self, max_bytes: int = 65536) -> bytes:
if self._closed:
raise ClosedResourceError
Expand Down
11 changes: 0 additions & 11 deletions src/anyio/streams/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,6 @@ def __post_init__(self) -> None:
self._state.open_receive_channels += 1

def receive_nowait(self) -> T_co:
"""
Receive the next item if it can be done without waiting.

:return: the received item
:raises ~anyio.ClosedResourceError: if this send stream has been closed
:raises ~anyio.EndOfStream: if the buffer is empty and this stream has been
closed from the sending end
:raises ~anyio.WouldBlock: if there are no items in the buffer and no tasks
waiting to send

"""
if self._closed:
raise ClosedResourceError

Expand Down
3 changes: 3 additions & 0 deletions src/anyio/streams/stapled.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class StapledByteStream(ByteStream):
send_stream: ByteSendStream
receive_stream: ByteReceiveStream

def receive_nowait(self, max_bytes: int = 65536) -> bytes:
return self.receive_stream.receive_nowait(max_bytes)

async def receive(self, max_bytes: int = 65536) -> bytes:
return await self.receive_stream.receive(max_bytes)

Expand Down
9 changes: 9 additions & 0 deletions src/anyio/streams/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def __post_init__(self, encoding: str, errors: str) -> None:
decoder_class = codecs.getincrementaldecoder(encoding)
self._decoder = decoder_class(errors=errors)

def receive_nowait(self) -> str:
while True:
chunk = self.transport_stream.receive_nowait()
if decoded := self._decoder.decode(chunk):
return decoded

async def receive(self) -> str:
while True:
chunk = await self.transport_stream.receive()
Expand Down Expand Up @@ -126,6 +132,9 @@ def __post_init__(self, encoding: str, errors: str) -> None:
self.transport_stream, encoding=encoding, errors=errors
)

def receive_nowait(self) -> str:
return self._receive_stream.receive_nowait()

async def receive(self) -> str:
return await self._receive_stream.receive()

Expand Down
Loading