Skip to content

Commit 1dabfd8

Browse files
authored
Made object streams co/contravariant (#483)
Fixes #466.
1 parent add8313 commit 1dabfd8

File tree

4 files changed

+45
-24
lines changed

4 files changed

+45
-24
lines changed

docs/versionhistory.rst

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
2727
type of the value passed to ``task_status.started()``
2828
- The ``Listener`` class is now covariant in its stream type
2929
- ``create_memory_object_stream()`` now allows passing only ``item_type``
30+
- Object receive streams are now covariant and object send streams are correspondingly
31+
contravariant
3032
- Fixed ``CapacityLimiter`` on the asyncio backend to order waiting tasks in the FIFO
3133
order (instead of LIFO) (PR by Conor Stevenson)
3234
- Fixed ``CancelScope.cancel()`` not working on asyncio if called before entering the

src/anyio/abc/_streams.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
from ._tasks import TaskGroup
1111

1212
T_Item = TypeVar("T_Item")
13-
T_Stream_co = TypeVar("T_Stream_co", covariant=True)
13+
T_co = TypeVar("T_co", covariant=True)
14+
T_contra = TypeVar("T_contra", contravariant=True)
1415

1516

1617
class UnreliableObjectReceiveStream(
17-
Generic[T_Item], AsyncResource, TypedAttributeProvider
18+
Generic[T_co], AsyncResource, TypedAttributeProvider
1819
):
1920
"""
2021
An interface for receiving objects.
@@ -26,17 +27,17 @@ class UnreliableObjectReceiveStream(
2627
given type parameter.
2728
"""
2829

29-
def __aiter__(self) -> UnreliableObjectReceiveStream[T_Item]:
30+
def __aiter__(self) -> UnreliableObjectReceiveStream[T_co]:
3031
return self
3132

32-
async def __anext__(self) -> T_Item:
33+
async def __anext__(self) -> T_co:
3334
try:
3435
return await self.receive()
3536
except EndOfStream:
3637
raise StopAsyncIteration
3738

3839
@abstractmethod
39-
async def receive(self) -> T_Item:
40+
async def receive(self) -> T_co:
4041
"""
4142
Receive the next item.
4243
@@ -49,7 +50,7 @@ async def receive(self) -> T_Item:
4950

5051

5152
class UnreliableObjectSendStream(
52-
Generic[T_Item], AsyncResource, TypedAttributeProvider
53+
Generic[T_contra], AsyncResource, TypedAttributeProvider
5354
):
5455
"""
5556
An interface for sending objects.
@@ -59,7 +60,7 @@ class UnreliableObjectSendStream(
5960
"""
6061

6162
@abstractmethod
62-
async def send(self, item: T_Item) -> None:
63+
async def send(self, item: T_contra) -> None:
6364
"""
6465
Send an item to the peer(s).
6566
@@ -80,14 +81,14 @@ class UnreliableObjectStream(
8081
"""
8182

8283

83-
class ObjectReceiveStream(UnreliableObjectReceiveStream[T_Item]):
84+
class ObjectReceiveStream(UnreliableObjectReceiveStream[T_co]):
8485
"""
8586
A receive message stream which guarantees that messages are received in the same
8687
order in which they were sent, and that no messages are missed.
8788
"""
8889

8990

90-
class ObjectSendStream(UnreliableObjectSendStream[T_Item]):
91+
class ObjectSendStream(UnreliableObjectSendStream[T_contra]):
9192
"""
9293
A send message stream which guarantees that messages are delivered in the same order
9394
in which they were sent, without missing any messages in the middle.
@@ -186,12 +187,12 @@ async def send_eof(self) -> None:
186187
AnyByteStream = Union[ObjectStream[bytes], ByteStream]
187188

188189

189-
class Listener(Generic[T_Stream_co], AsyncResource, TypedAttributeProvider):
190+
class Listener(Generic[T_co], AsyncResource, TypedAttributeProvider):
190191
"""An interface for objects that let you accept incoming connections."""
191192

192193
@abstractmethod
193194
async def serve(
194-
self, handler: Callable[[T_Stream_co], Any], task_group: TaskGroup | None = None
195+
self, handler: Callable[[T_co], Any], task_group: TaskGroup | None = None
195196
) -> None:
196197
"""
197198
Accept incoming connections as they come in and start tasks to handle them.

src/anyio/streams/memory.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from ..lowlevel import checkpoint
1717

1818
T_Item = TypeVar("T_Item")
19+
T_co = TypeVar("T_co", covariant=True)
20+
T_contra = TypeVar("T_contra", contravariant=True)
1921

2022

2123
class MemoryObjectStreamStatistics(NamedTuple):
@@ -55,14 +57,14 @@ def statistics(self) -> MemoryObjectStreamStatistics:
5557

5658

5759
@dataclass(eq=False)
58-
class MemoryObjectReceiveStream(Generic[T_Item], ObjectReceiveStream[T_Item]):
59-
_state: MemoryObjectStreamState[T_Item]
60+
class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]):
61+
_state: MemoryObjectStreamState[T_co]
6062
_closed: bool = field(init=False, default=False)
6163

6264
def __post_init__(self) -> None:
6365
self._state.open_receive_channels += 1
6466

65-
def receive_nowait(self) -> T_Item:
67+
def receive_nowait(self) -> T_co:
6668
"""
6769
Receive the next item if it can be done without waiting.
6870
@@ -90,14 +92,14 @@ def receive_nowait(self) -> T_Item:
9092

9193
raise WouldBlock
9294

93-
async def receive(self) -> T_Item:
95+
async def receive(self) -> T_co:
9496
await checkpoint()
9597
try:
9698
return self.receive_nowait()
9799
except WouldBlock:
98100
# Add ourselves in the queue
99101
receive_event = Event()
100-
container: list[T_Item] = []
102+
container: list[T_co] = []
101103
self._state.waiting_receivers[receive_event] = container
102104

103105
try:
@@ -115,7 +117,7 @@ async def receive(self) -> T_Item:
115117
else:
116118
raise EndOfStream
117119

118-
def clone(self) -> MemoryObjectReceiveStream[T_Item]:
120+
def clone(self) -> MemoryObjectReceiveStream[T_co]:
119121
"""
120122
Create a clone of this receive stream.
121123
@@ -157,7 +159,7 @@ def statistics(self) -> MemoryObjectStreamStatistics:
157159
"""
158160
return self._state.statistics()
159161

160-
def __enter__(self) -> MemoryObjectReceiveStream[T_Item]:
162+
def __enter__(self) -> MemoryObjectReceiveStream[T_co]:
161163
return self
162164

163165
def __exit__(
@@ -170,14 +172,14 @@ def __exit__(
170172

171173

172174
@dataclass(eq=False)
173-
class MemoryObjectSendStream(Generic[T_Item], ObjectSendStream[T_Item]):
174-
_state: MemoryObjectStreamState[T_Item]
175+
class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]):
176+
_state: MemoryObjectStreamState[T_contra]
175177
_closed: bool = field(init=False, default=False)
176178

177179
def __post_init__(self) -> None:
178180
self._state.open_send_channels += 1
179181

180-
def send_nowait(self, item: T_Item) -> None:
182+
def send_nowait(self, item: T_contra) -> None:
181183
"""
182184
Send an item immediately if it can be done without waiting.
183185
@@ -203,7 +205,7 @@ def send_nowait(self, item: T_Item) -> None:
203205
else:
204206
raise WouldBlock
205207

206-
async def send(self, item: T_Item) -> None:
208+
async def send(self, item: T_contra) -> None:
207209
"""
208210
Send an item to the stream.
209211
@@ -236,7 +238,7 @@ async def send(self, item: T_Item) -> None:
236238
):
237239
raise BrokenResourceError
238240

239-
def clone(self) -> MemoryObjectSendStream[T_Item]:
241+
def clone(self) -> MemoryObjectSendStream[T_contra]:
240242
"""
241243
Create a clone of this send stream.
242244
@@ -279,7 +281,7 @@ def statistics(self) -> MemoryObjectStreamStatistics:
279281
"""
280282
return self._state.statistics()
281283

282-
def __enter__(self) -> MemoryObjectSendStream[T_Item]:
284+
def __enter__(self) -> MemoryObjectSendStream[T_contra]:
283285
return self
284286

285287
def __exit__(

tests/streams/test_memory.py

+16
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
fail_after,
1414
wait_all_tasks_blocked,
1515
)
16+
from anyio.abc import ObjectReceiveStream, ObjectSendStream
1617
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1718

1819
pytestmark = pytest.mark.anyio
@@ -345,3 +346,18 @@ async def test_sync_close() -> None:
345346

346347
with pytest.raises(ClosedResourceError):
347348
receive_stream.receive_nowait()
349+
350+
351+
async def test_type_variance() -> None:
352+
"""
353+
This test does not do anything at run time, but since the test suite is also checked
354+
with a static type checker, it ensures that the memory object stream
355+
co/contravariance works as intended. If it doesn't, one or both of the following
356+
reassignments will trip the type checker.
357+
358+
"""
359+
send, receive = create_memory_object_stream(item_type=float)
360+
receive1: MemoryObjectReceiveStream[complex] = receive # noqa: F841
361+
receive2: ObjectReceiveStream[complex] = receive # noqa: F841
362+
send1: MemoryObjectSendStream[int] = send # noqa: F841
363+
send2: ObjectSendStream[int] = send # noqa: F841

0 commit comments

Comments
 (0)