Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 5 additions & 21 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import weakref
from collections import defaultdict, deque
from collections.abc import (
Awaitable,
Callable,
Container,
Coroutine,
Expand Down Expand Up @@ -560,33 +559,18 @@ def start_periodic_callbacks(self):
if not pc.is_running():
pc.start()

def _stop_listeners(self) -> asyncio.Future:
listeners_to_stop: set[Awaitable] = set()

def _stop_listeners(self) -> None:
for listener in self.listeners:
future = listener.stop()
if inspect.isawaitable(future):
warnings.warn(
f"{type(listener)} is using an asynchronous `stop` method. "
"Support for asynchronous `Listener.stop` has been deprecated and "
"will be removed in a future version",
DeprecationWarning,
)
listeners_to_stop.add(future)
elif hasattr(listener, "abort_handshaking_comms"):
listener.stop()
if hasattr(listener, "abort_handshaking_comms"):
listener.abort_handshaking_comms()

return asyncio.gather(*listeners_to_stop)

def stop(self) -> None:
if self.__stopped:
return
self.__stopped = True
self.monitor.close()
if not (stop_listeners := self._stop_listeners()).done():
self._ongoing_background_tasks.call_soon(
asyncio.wait_for(stop_listeners, timeout=None) # type: ignore[arg-type]
)
self._stop_listeners()
if self._workdir is not None:
self._workdir.release()

Expand Down Expand Up @@ -935,7 +919,7 @@ async def close(self, timeout: float | None = None, reason: str = "") -> None:

self.__stopped = True
self.monitor.close()
await self._stop_listeners()
self._stop_listeners()

# TODO: Deal with exceptions
await self._ongoing_background_tasks.stop()
Expand Down
36 changes: 4 additions & 32 deletions distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@
import sys
import threading
import weakref
from contextlib import asynccontextmanager
from unittest import mock

import pytest
from tornado.ioloop import IOLoop

import dask

from distributed import core
from distributed.batched import BatchedSend
from distributed.comm.core import CommClosedError, FatalCommClosedError
from distributed.comm.registry import backends
from distributed.comm.tcp import TCPBackend, TCPListener
from distributed.comm.tcp import TCPBackend, TCPConnector
from distributed.core import (
ConnectionPool,
RPCClosed,
Expand Down Expand Up @@ -171,13 +173,11 @@ class MyServer(Server):
default_port = 8756


@pytest.mark.slow
@gen_test()
async def test_server_listen():
"""
Test various Server.listen() arguments and their effect.
"""
import socket

try:
EXTERNAL_IP4 = get_ip()
Expand All @@ -186,8 +186,6 @@ async def test_server_listen():
except socket.gaierror:
pytest.skip("no network access")

from contextlib import asynccontextmanager

@asynccontextmanager
async def listen_on(cls, *args, **kwargs):
server = cls({})
Expand Down Expand Up @@ -631,8 +629,6 @@ async def test_connection_pool_close_while_connecting(monkeypatch):
Ensure a closed connection pool guarantees to have no connections left open
even if it is closed mid-connecting
"""
from distributed.comm.registry import backends
from distributed.comm.tcp import TCPBackend, TCPConnector

class SlowConnector(TCPConnector):
async def connect(self, address, deserialize, **connection_args):
Expand Down Expand Up @@ -672,8 +668,6 @@ async def connect_to_server():
@gen_test()
async def test_connection_pool_outside_cancellation(monkeypatch):
# Ensure cancellation errors are properly reraised
from distributed.comm.registry import backends
from distributed.comm.tcp import TCPBackend, TCPConnector

class SlowConnector(TCPConnector):
async def connect(self, address, deserialize, **connection_args):
Expand Down Expand Up @@ -707,11 +701,9 @@ async def connect_to_server():
assert all(t.cancelled() for t in tasks)


@pytest.mark.slow
@gen_test()
async def test_connection_pool_catch_all_cancellederrors(monkeypatch):
from distributed.comm.registry import backends
from distributed.comm.tcp import TCPBackend, TCPConnector

in_connect = asyncio.Event()
block_connect = asyncio.Event()

Expand Down Expand Up @@ -922,7 +914,6 @@ async def test_ticks(s, a, b):
@gen_cluster(config={"distributed.admin.tick.interval": "20 ms"})
async def test_tick_logging(s, a, b):
pytest.importorskip("crick")
from distributed import core

old = core.tick_maximum_delay
core.tick_maximum_delay = 0.001
Expand Down Expand Up @@ -1289,25 +1280,6 @@ def stream_not_leading_position(self, other, stream): ...
assert not _expects_comm(instance.stream_not_leading_position)


class AsyncStopTCPListener(TCPListener):
async def stop(self):
await asyncio.sleep(0)
super().stop()


class TCPAsyncListenerBackend(TCPBackend):
_listener_class = AsyncStopTCPListener


@gen_test()
async def test_async_listener_stop(monkeypatch):
monkeypatch.setitem(backends, "tcp", TCPAsyncListenerBackend())
with pytest.warns(DeprecationWarning):
async with Server({}) as s:
await s.listen(0)
assert s.listeners


@gen_test()
async def test_messages_are_ordered_bsend():
ledger = []
Expand Down
2 changes: 1 addition & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,7 +1620,7 @@ async def close( # type: ignore
# otherwise
c.close()

await self._stop_listeners()
self._stop_listeners()
await self.rpc.close()

# Give some time for a UCX scheduler to complete closing endpoints
Expand Down
Loading