Skip to content

Add blocking progress mode to Python async #116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6dae328
Expose `ucxx::Worker` epoll file descriptor getter
pentschev Nov 2, 2023
155ab58
Add blocking progress mode to Python async
pentschev Nov 2, 2023
1c1ab87
Test blocking mode in CI
pentschev Nov 2, 2023
e5f4a40
Disable Python future on `blocking` mode testing
pentschev Nov 6, 2023
7952ef1
Merge remote-tracking branch 'upstream/branch-0.36' into python-async…
pentschev Jan 15, 2024
edd3192
Add timeout to Python's async blocking progress mode
pentschev Jan 17, 2024
e3f9cc3
Support blocking mode in 'send_recv` Python benchmark
pentschev Jan 17, 2024
b5f95f0
Schedule cancelation in `ProgressTask` deleter
pentschev Jan 17, 2024
60e49d1
Rerun CI
pentschev Jan 17, 2024
91ab7bf
Revert accidental CI script changes
pentschev Jan 17, 2024
26480a5
Disable blocking progress mode delayed submission benchmarks
pentschev Jan 18, 2024
6da6c5b
Merge remote-tracking branch 'upstream/branch-0.36' into python-async…
pentschev Jan 18, 2024
80d7e14
Merge remote-tracking branch 'upstream/branch-0.36' into python-async…
pentschev Jan 30, 2024
2b7c4cf
Merge remote-tracking branch 'upstream/branch-0.36' into python-async…
pentschev Jan 30, 2024
77b3659
Remove `pytest.mark.gpu`
pentschev Jan 30, 2024
1e3c55e
Merge branch 'branch-0.40' into python-async-blocking-mode
pentschev Jul 25, 2024
755bc97
Merge remote-tracking branch 'origin/python-async-blocking-mode' into…
pentschev Jul 25, 2024
e22b227
Re-enable blocking Distributed tests in CI
pentschev Jul 25, 2024
c6f9654
Merge branch 'branch-0.40' into python-async-blocking-mode
pentschev Jul 25, 2024
d4a8795
Merge remote-tracking branch 'upstream/branch-0.40' into python-async…
pentschev Sep 10, 2024
71d3697
Merge remote-tracking branch 'origin/python-async-blocking-mode' into…
pentschev Sep 10, 2024
c800ca4
Fix progress timeout and docstrings
pentschev Sep 30, 2024
18e3cf0
Cancel progress tasks before closing of event loop
pentschev Sep 30, 2024
2ffac66
Merge remote-tracking branch 'upstream/branch-0.41' into python-async…
pentschev Sep 30, 2024
c5c2ceb
Use `partial` to rewrite `event_loop.close`
pentschev Oct 1, 2024
279cb4c
Reset `self.asyncio_task` to `None` after cancellation
pentschev Oct 1, 2024
c624898
Fix comments' phrasing
pentschev Oct 1, 2024
5769f31
Cancel `_arm_worker` instead of `sock_recv`
pentschev Oct 1, 2024
4f7a7f2
Merge branch 'branch-0.41' into python-async-blocking-mode
pentschev Oct 4, 2024
23bb0bb
Merge branch 'branch-0.41' into python-async-blocking-mode
pentschev Oct 4, 2024
01cbe8a
Merge branch 'branch-0.41' into python-async-blocking-mode
pentschev Oct 7, 2024
50d3f47
Merge remote-tracking branch 'upstream/branch-0.41' into python-async…
pentschev Oct 18, 2024
dbb6386
Revert "Resolve thread-safety issues in distributed-ucxx (#295)"
pentschev Oct 18, 2024
8e4bc38
Adjust properties and blocking progress mode initialization
pentschev Oct 22, 2024
a806e45
Merge remote-tracking branch 'origin/python-async-blocking-mode' into…
pentschev Oct 22, 2024
9e2d017
Fix unreachable test
pentschev Oct 22, 2024
94d05e9
Merge branch 'branch-0.41' into python-async-blocking-mode
pentschev Oct 22, 2024
caf67f9
Ensure writer is closed to prevent Distributed check failure
pentschev Oct 22, 2024
74913fa
Merge remote-tracking branch 'origin/python-async-blocking-mode' into…
pentschev Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions ci/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,27 @@ rapids-logger "Python Async Tests"
# run_py_tests_async PROGRESS_MODE ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE SKIP
run_py_tests_async thread 0 0 0
run_py_tests_async thread 1 1 0
run_py_tests_async blocking 0 0 0

rapids-logger "Python Benchmarks"
# run_py_benchmark BACKEND PROGRESS_MODE ASYNCIO_WAIT ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE NBUFFERS SLOW
run_py_benchmark ucxx-core thread 0 0 0 1 0
run_py_benchmark ucxx-core thread 1 0 0 1 0

for nbuf in 1 8; do
if [[ ! $RAPIDS_CUDA_VERSION =~ 11.2.* ]]; then
# run_py_benchmark BACKEND PROGRESS_MODE ASYNCIO_WAIT ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE NBUFFERS SLOW
run_py_benchmark ucxx-async thread 0 0 0 ${nbuf} 0
run_py_benchmark ucxx-async thread 0 0 1 ${nbuf} 0
run_py_benchmark ucxx-async thread 0 1 0 ${nbuf} 0
run_py_benchmark ucxx-async thread 0 1 1 ${nbuf} 0
fi
for progress_mode in "blocking" "thread"; do
for nbuf in 1 8; do
if [[ ! $RAPIDS_CUDA_VERSION =~ 11.2.* ]]; then
# run_py_benchmark BACKEND PROGRESS_MODE ASYNCIO_WAIT ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE NBUFFERS SLOW
run_py_benchmark ucxx-async ${progress_mode} 0 0 0 ${nbuf} 0
run_py_benchmark ucxx-async ${progress_mode} 0 0 1 ${nbuf} 0
if [[ ${progress_mode} != "blocking" ]]; then
# Delayed submission isn't support by blocking progress mode
# run_py_benchmark BACKEND PROGRESS_MODE ASYNCIO_WAIT ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE NBUFFERS SLOW
run_py_benchmark ucxx-async ${progress_mode} 0 1 0 ${nbuf} 0
run_py_benchmark ucxx-async ${progress_mode} 0 1 1 ${nbuf} 0
fi
fi
done
done

rapids-logger "C++ future -> Python future notifier example"
Expand Down
2 changes: 2 additions & 0 deletions ci/test_python_distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ print_ucx_config

rapids-logger "Run distributed-ucxx tests with conda package"
# run_distributed_ucxx_tests PROGRESS_MODE ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE
run_distributed_ucxx_tests blocking 0 0
run_distributed_ucxx_tests polling 0 0
run_distributed_ucxx_tests thread 0 0
run_distributed_ucxx_tests thread 0 1
Expand All @@ -46,6 +47,7 @@ run_distributed_ucxx_tests thread 1 1
install_distributed_dev_mode

# run_distributed_ucxx_tests_internal PROGRESS_MODE ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE
run_distributed_ucxx_tests_internal blocking 0 0
run_distributed_ucxx_tests_internal polling 0 0
run_distributed_ucxx_tests_internal thread 0 0
run_distributed_ucxx_tests_internal thread 0 1
Expand Down
17 changes: 17 additions & 0 deletions cpp/include/ucxx/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,23 @@ class Worker : public Component {
*/
void initBlockingProgressMode();

/**
* @brief Get the epoll file descriptor associated with the worker.
*
* Get the epoll file descriptor associated with the worker when running in blocking mode.
* The worker only has an associated epoll file descriptor after
* `initBlockingProgressMode()` is executed.
*
* The file descriptor is destroyed as part of the `ucxx::Worker` destructor, thus any
* reference to it shall not be used after that.
*
* @throws std::runtime_error if `initBlockingProgressMode()` was not executed to run the
* worker in blocking progress mode.
*
* @returns the file descriptor.
*/
int getEpollFileDescriptor();

/**
* @brief Arm the UCP worker.
*
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ void Worker::initBlockingProgressMode()
}
}

int Worker::getEpollFileDescriptor()
{
if (_epollFileDescriptor == 0)
throw std::runtime_error("Worker not running in blocking progress mode");

return _epollFileDescriptor;
}

bool Worker::arm()
{
ucs_status_t status = ucp_worker_arm(_handle);
Expand Down
3 changes: 3 additions & 0 deletions python/distributed-ucxx/distributed_ucxx/tests/test_ucxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,10 @@ async def test_comm_closed_on_read_error():
with pytest.raises((asyncio.TimeoutError, CommClosedError)):
await wait_for(reader.read(), 0.01)

await writer.close()

assert reader.closed()
assert writer.closed()


@pytest.mark.flaky(
Expand Down
84 changes: 48 additions & 36 deletions python/distributed-ucxx/distributed_ucxx/ucxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import struct
import weakref
from collections.abc import Awaitable, Callable, Collection
from threading import Lock
from typing import TYPE_CHECKING, Any
from unittest.mock import patch

Expand Down Expand Up @@ -50,13 +49,6 @@
pre_existing_cuda_context = False
cuda_context_created = False
multi_buffer = None
# Lock protecting access to _resources dict
_resources_lock = Lock()
# Mapping from UCXX context handles to sets of registered dask resource IDs
# Used to track when there are no more users of the context, at which point
# its progress task and notification thread can be shut down.
# See _register_dask_resource and _deregister_dask_resource.
_resources = dict()


_warning_suffix = (
Expand Down Expand Up @@ -103,13 +95,13 @@ def make_register():
count = itertools.count()

def register() -> int:
"""Register a Dask resource with the resource tracker.
"""Register a Dask resource with the UCXX context.

Generate a unique ID for the resource and register it with the resource
tracker. The resource ID is later used to deregister the resource from
the tracker calling `_deregister_dask_resource(resource_id)`, which
stops the notifier thread and progress tasks when no more UCXX resources
are alive.
Register a Dask resource with the UCXX context and keep track of it with the
use of a unique ID for the resource. The resource ID is later used to
deregister the resource from the UCXX context calling
`_deregister_dask_resource(resource_id)`, which stops the notifier thread
and progress tasks when no more UCXX resources are alive.

Returns
-------
Expand All @@ -118,13 +110,9 @@ def register() -> int:
`_deregister_dask_resource` during stop/destruction of the resource.
"""
ctx = ucxx.core._get_ctx()
handle = ctx.context.handle
with _resources_lock:
if handle not in _resources:
_resources[handle] = set()

with ctx._dask_resources_lock:
resource_id = next(count)
_resources[handle].add(resource_id)
ctx._dask_resources.add(resource_id)
ctx.start_notifier_thread()
ctx.continuous_ucx_progress()
return resource_id
Expand All @@ -138,11 +126,11 @@ def register() -> int:


def _deregister_dask_resource(resource_id):
"""Deregister a Dask resource from the resource tracker.
"""Deregister a Dask resource with the UCXX context.

Deregister a Dask resource from the resource tracker with given ID, and if
no resources remain after deregistration, stop the notifier thread and
progress tasks.
Deregister a Dask resource from the UCXX context with given ID, and if no
resources remain after deregistration, stop the notifier thread and progress
tasks.

Parameters
----------
Expand All @@ -156,22 +144,40 @@ def _deregister_dask_resource(resource_id):
return

ctx = ucxx.core._get_ctx()
handle = ctx.context.handle

# Check if the attribute exists first, in tests the UCXX context may have
# been reset before some resources are deregistered.
with _resources_lock:
try:
_resources[handle].remove(resource_id)
except KeyError:
pass
if hasattr(ctx, "_dask_resources_lock"):
with ctx._dask_resources_lock:
try:
ctx._dask_resources.remove(resource_id)
except KeyError:
pass

# Stop notifier thread and progress tasks if no Dask resources using
# UCXX communicators are running anymore.
if len(ctx._dask_resources) == 0:
ctx.stop_notifier_thread()
ctx.progress_tasks.clear()

# Stop notifier thread and progress tasks if no Dask resources using
# UCXX communicators are running anymore.
if handle in _resources and len(_resources[handle]) == 0:
ctx.stop_notifier_thread()
ctx.progress_tasks.clear()
del _resources[handle]

def _allocate_dask_resources_tracker() -> None:
"""Allocate Dask resources tracker.

Allocate a Dask resources tracker in the UCXX context. This is useful to
track Distributed communicators so that progress and notifier threads can
be cleanly stopped when no UCXX communicators are alive anymore.
"""
ctx = ucxx.core._get_ctx()
if not hasattr(ctx, "_dask_resources"):
# TODO: Move the `Lock` to a file/module-level variable for true
# lock-safety. The approach implemented below could cause race
# conditions if this function is called simultaneously by multiple
# threads.
from threading import Lock

ctx._dask_resources = set()
ctx._dask_resources_lock = Lock()


def init_once():
Expand All @@ -181,6 +187,11 @@ def init_once():
global multi_buffer

if ucxx is not None:
# Ensure reallocation of Dask resources tracker if the UCXX context was
# reset since the previous `init_once()` call. This may happen in tests,
# where the `ucxx_loop` fixture will reset the context after each test.
_allocate_dask_resources_tracker()

return

# remove/process dask.ucx flags for valid ucx options
Expand Down Expand Up @@ -243,6 +254,7 @@ def init_once():
# environment, so the user's external environment can safely
# override things here.
ucxx.init(options=ucx_config, env_takes_precedence=True)
_allocate_dask_resources_tracker()

pool_size_str = dask.config.get("distributed.rmm.pool-size")

Expand Down
17 changes: 17 additions & 0 deletions python/ucxx/ucxx/_lib/libucxx.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,23 @@ cdef class UCXWorker():
with nogil:
self._worker.get().initBlockingProgressMode()

def arm(self) -> bool:
cdef bint armed

with nogil:
armed = self._worker.get().arm()

return armed

@property
def epoll_file_descriptor(self) -> int:
cdef int epoll_file_descriptor = 0

with nogil:
epoll_file_descriptor = self._worker.get().getEpollFileDescriptor()

return epoll_file_descriptor

def progress(self) -> None:
with nogil:
self._worker.get().progress()
Expand Down
2 changes: 2 additions & 0 deletions python/ucxx/ucxx/_lib/ucxx_api.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ cdef extern from "<ucxx/api.h>" namespace "ucxx" nogil:
uint16_t port, ucp_listener_conn_callback_t callback, void *callback_args
) except +raise_py_error
void initBlockingProgressMode() except +raise_py_error
int getEpollFileDescriptor()
bint arm() except +raise_py_error
void progress()
bint progressOnce()
void progressWorkerEvent(int epoll_timeout)
Expand Down
19 changes: 11 additions & 8 deletions python/ucxx/ucxx/_lib_async/application_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ucxx.exceptions import UCXMessageTruncatedError
from ucxx.types import Tag

from .continuous_ucx_progress import PollingMode, ThreadMode
from .continuous_ucx_progress import BlockingMode, PollingMode, ThreadMode
from .endpoint import Endpoint
from .exchange_peer_info import exchange_peer_info
from .listener import ActiveClients, Listener, _listener_handler
Expand Down Expand Up @@ -56,8 +56,8 @@ def __init__(
self.context = ucx_api.UCXContext(config_dict)
self.worker = ucx_api.UCXWorker(
self.context,
enable_delayed_submission=self._enable_delayed_submission,
enable_python_future=self._enable_python_future,
enable_delayed_submission=self.enable_delayed_submission,
enable_python_future=self.enable_python_future,
)

self.start_notifier_thread()
Expand All @@ -82,12 +82,12 @@ def progress_mode(self, progress_mode):
else:
progress_mode = "thread"

valid_progress_modes = ["polling", "thread", "thread-polling"]
valid_progress_modes = ["blocking", "polling", "thread", "thread-polling"]
if not isinstance(progress_mode, str) or not any(
progress_mode == m for m in valid_progress_modes
):
raise ValueError(
f"Unknown progress mode {progress_mode}, valid modes are: "
f"Unknown progress mode '{progress_mode}', valid modes are: "
"'blocking', 'polling', 'thread' or 'thread-polling'"
)

Expand Down Expand Up @@ -121,8 +121,9 @@ def enable_delayed_submission(self, enable_delayed_submission):
and explicit_enable_delayed_submission
):
raise ValueError(
f"Delayed submission requested, but {self.progress_mode} does not "
"support it, 'thread' or 'thread-polling' progress mode required."
f"Delayed submission requested, but '{self.progress_mode}' does "
"not support it, 'thread' or 'thread-polling' progress mode "
"required."
)

self._enable_delayed_submission = explicit_enable_delayed_submission
Expand Down Expand Up @@ -153,7 +154,7 @@ def enable_python_future(self, enable_python_future):
and explicit_enable_python_future
):
logger.warning(
f"Notifier thread requested, but {self.progress_mode} does not "
f"Notifier thread requested, but '{self.progress_mode}' does not "
"support it, using Python wait_yield()."
)
explicit_enable_python_future = False
Expand Down Expand Up @@ -464,6 +465,8 @@ def continuous_ucx_progress(self, event_loop=None):
task = ThreadMode(self.worker, loop, polling_mode=True)
elif self.progress_mode == "polling":
task = PollingMode(self.worker, loop)
elif self.progress_mode == "blocking":
task = BlockingMode(self.worker, loop)

self.progress_tasks[loop] = task

Expand Down
Loading