Skip to content

Commit 122d2f4

Browse files
authored
Add blocking progress mode to Python async (#116)
Implements the blocking progress mode (UCX-Py default), which was still not implemented in UCXX. Authors: - Peter Andreas Entschev (https://github.com/pentschev) Approvers: - Lawrence Mitchell (https://github.com/wence-) - AJ Schmidt (https://github.com/ajschmidt8) - Ray Douglass (https://github.com/raydouglass) URL: #116
1 parent cfe9008 commit 122d2f4

File tree

12 files changed

+271
-62
lines changed

12 files changed

+271
-62
lines changed

ci/test_python.sh

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,27 @@ rapids-logger "Python Async Tests"
4242
# run_py_tests_async PROGRESS_MODE ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE SKIP
4343
run_py_tests_async thread 0 0 0
4444
run_py_tests_async thread 1 1 0
45+
run_py_tests_async blocking 0 0 0
4546

4647
rapids-logger "Python Benchmarks"
4748
# run_py_benchmark BACKEND PROGRESS_MODE ASYNCIO_WAIT ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE NBUFFERS SLOW
4849
run_py_benchmark ucxx-core thread 0 0 0 1 0
4950
run_py_benchmark ucxx-core thread 1 0 0 1 0
5051

51-
for nbuf in 1 8; do
52-
if [[ ! $RAPIDS_CUDA_VERSION =~ 11.2.* ]]; then
53-
# run_py_benchmark BACKEND PROGRESS_MODE ASYNCIO_WAIT ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE NBUFFERS SLOW
54-
run_py_benchmark ucxx-async thread 0 0 0 ${nbuf} 0
55-
run_py_benchmark ucxx-async thread 0 0 1 ${nbuf} 0
56-
run_py_benchmark ucxx-async thread 0 1 0 ${nbuf} 0
57-
run_py_benchmark ucxx-async thread 0 1 1 ${nbuf} 0
58-
fi
52+
for progress_mode in "blocking" "thread"; do
53+
for nbuf in 1 8; do
54+
if [[ ! $RAPIDS_CUDA_VERSION =~ 11.2.* ]]; then
55+
# run_py_benchmark BACKEND PROGRESS_MODE ASYNCIO_WAIT ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE NBUFFERS SLOW
56+
run_py_benchmark ucxx-async ${progress_mode} 0 0 0 ${nbuf} 0
57+
run_py_benchmark ucxx-async ${progress_mode} 0 0 1 ${nbuf} 0
58+
if [[ ${progress_mode} != "blocking" ]]; then
59+
# Delayed submission isn't support by blocking progress mode
60+
# run_py_benchmark BACKEND PROGRESS_MODE ASYNCIO_WAIT ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE NBUFFERS SLOW
61+
run_py_benchmark ucxx-async ${progress_mode} 0 1 0 ${nbuf} 0
62+
run_py_benchmark ucxx-async ${progress_mode} 0 1 1 ${nbuf} 0
63+
fi
64+
fi
65+
done
5966
done
6067

6168
rapids-logger "C++ future -> Python future notifier example"

ci/test_python_distributed.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ print_ucx_config
3737

3838
rapids-logger "Run distributed-ucxx tests with conda package"
3939
# run_distributed_ucxx_tests PROGRESS_MODE ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE
40+
run_distributed_ucxx_tests blocking 0 0
4041
run_distributed_ucxx_tests polling 0 0
4142
run_distributed_ucxx_tests thread 0 0
4243
run_distributed_ucxx_tests thread 0 1
@@ -46,6 +47,7 @@ run_distributed_ucxx_tests thread 1 1
4647
install_distributed_dev_mode
4748

4849
# run_distributed_ucxx_tests_internal PROGRESS_MODE ENABLE_DELAYED_SUBMISSION ENABLE_PYTHON_FUTURE
50+
run_distributed_ucxx_tests_internal blocking 0 0
4951
run_distributed_ucxx_tests_internal polling 0 0
5052
run_distributed_ucxx_tests_internal thread 0 0
5153
run_distributed_ucxx_tests_internal thread 0 1

cpp/include/ucxx/worker.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,23 @@ class Worker : public Component {
253253
*/
254254
void initBlockingProgressMode();
255255

256+
/**
257+
* @brief Get the epoll file descriptor associated with the worker.
258+
*
259+
* Get the epoll file descriptor associated with the worker when running in blocking mode.
260+
* The worker only has an associated epoll file descriptor after
261+
* `initBlockingProgressMode()` is executed.
262+
*
263+
* The file descriptor is destroyed as part of the `ucxx::Worker` destructor, thus any
264+
* reference to it shall not be used after that.
265+
*
266+
* @throws std::runtime_error if `initBlockingProgressMode()` was not executed to run the
267+
* worker in blocking progress mode.
268+
*
269+
* @returns the file descriptor.
270+
*/
271+
int getEpollFileDescriptor();
272+
256273
/**
257274
* @brief Arm the UCP worker.
258275
*

cpp/src/worker.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,14 @@ void Worker::initBlockingProgressMode()
220220
}
221221
}
222222

223+
int Worker::getEpollFileDescriptor()
224+
{
225+
if (_epollFileDescriptor == 0)
226+
throw std::runtime_error("Worker not running in blocking progress mode");
227+
228+
return _epollFileDescriptor;
229+
}
230+
223231
bool Worker::arm()
224232
{
225233
ucs_status_t status = ucp_worker_arm(_handle);

python/distributed-ucxx/distributed_ucxx/tests/test_ucxx.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,10 @@ async def test_comm_closed_on_read_error():
411411
with pytest.raises((asyncio.TimeoutError, CommClosedError)):
412412
await wait_for(reader.read(), 0.01)
413413

414+
await writer.close()
415+
414416
assert reader.closed()
417+
assert writer.closed()
415418

416419

417420
@pytest.mark.flaky(

python/distributed-ucxx/distributed_ucxx/ucxx.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import struct
1515
import weakref
1616
from collections.abc import Awaitable, Callable, Collection
17-
from threading import Lock
1817
from typing import TYPE_CHECKING, Any
1918
from unittest.mock import patch
2019

@@ -50,13 +49,6 @@
5049
pre_existing_cuda_context = False
5150
cuda_context_created = False
5251
multi_buffer = None
53-
# Lock protecting access to _resources dict
54-
_resources_lock = Lock()
55-
# Mapping from UCXX context handles to sets of registered dask resource IDs
56-
# Used to track when there are no more users of the context, at which point
57-
# its progress task and notification thread can be shut down.
58-
# See _register_dask_resource and _deregister_dask_resource.
59-
_resources = dict()
6052

6153

6254
_warning_suffix = (
@@ -103,13 +95,13 @@ def make_register():
10395
count = itertools.count()
10496

10597
def register() -> int:
106-
"""Register a Dask resource with the resource tracker.
98+
"""Register a Dask resource with the UCXX context.
10799
108-
Generate a unique ID for the resource and register it with the resource
109-
tracker. The resource ID is later used to deregister the resource from
110-
the tracker calling `_deregister_dask_resource(resource_id)`, which
111-
stops the notifier thread and progress tasks when no more UCXX resources
112-
are alive.
100+
Register a Dask resource with the UCXX context and keep track of it with the
101+
use of a unique ID for the resource. The resource ID is later used to
102+
deregister the resource from the UCXX context calling
103+
`_deregister_dask_resource(resource_id)`, which stops the notifier thread
104+
and progress tasks when no more UCXX resources are alive.
113105
114106
Returns
115107
-------
@@ -118,13 +110,9 @@ def register() -> int:
118110
`_deregister_dask_resource` during stop/destruction of the resource.
119111
"""
120112
ctx = ucxx.core._get_ctx()
121-
handle = ctx.context.handle
122-
with _resources_lock:
123-
if handle not in _resources:
124-
_resources[handle] = set()
125-
113+
with ctx._dask_resources_lock:
126114
resource_id = next(count)
127-
_resources[handle].add(resource_id)
115+
ctx._dask_resources.add(resource_id)
128116
ctx.start_notifier_thread()
129117
ctx.continuous_ucx_progress()
130118
return resource_id
@@ -138,11 +126,11 @@ def register() -> int:
138126

139127

140128
def _deregister_dask_resource(resource_id):
141-
"""Deregister a Dask resource from the resource tracker.
129+
"""Deregister a Dask resource with the UCXX context.
142130
143-
Deregister a Dask resource from the resource tracker with given ID, and if
144-
no resources remain after deregistration, stop the notifier thread and
145-
progress tasks.
131+
Deregister a Dask resource from the UCXX context with given ID, and if no
132+
resources remain after deregistration, stop the notifier thread and progress
133+
tasks.
146134
147135
Parameters
148136
----------
@@ -156,22 +144,40 @@ def _deregister_dask_resource(resource_id):
156144
return
157145

158146
ctx = ucxx.core._get_ctx()
159-
handle = ctx.context.handle
160147

161148
# Check if the attribute exists first, in tests the UCXX context may have
162149
# been reset before some resources are deregistered.
163-
with _resources_lock:
164-
try:
165-
_resources[handle].remove(resource_id)
166-
except KeyError:
167-
pass
150+
if hasattr(ctx, "_dask_resources_lock"):
151+
with ctx._dask_resources_lock:
152+
try:
153+
ctx._dask_resources.remove(resource_id)
154+
except KeyError:
155+
pass
156+
157+
# Stop notifier thread and progress tasks if no Dask resources using
158+
# UCXX communicators are running anymore.
159+
if len(ctx._dask_resources) == 0:
160+
ctx.stop_notifier_thread()
161+
ctx.progress_tasks.clear()
168162

169-
# Stop notifier thread and progress tasks if no Dask resources using
170-
# UCXX communicators are running anymore.
171-
if handle in _resources and len(_resources[handle]) == 0:
172-
ctx.stop_notifier_thread()
173-
ctx.progress_tasks.clear()
174-
del _resources[handle]
163+
164+
def _allocate_dask_resources_tracker() -> None:
165+
"""Allocate Dask resources tracker.
166+
167+
Allocate a Dask resources tracker in the UCXX context. This is useful to
168+
track Distributed communicators so that progress and notifier threads can
169+
be cleanly stopped when no UCXX communicators are alive anymore.
170+
"""
171+
ctx = ucxx.core._get_ctx()
172+
if not hasattr(ctx, "_dask_resources"):
173+
# TODO: Move the `Lock` to a file/module-level variable for true
174+
# lock-safety. The approach implemented below could cause race
175+
# conditions if this function is called simultaneously by multiple
176+
# threads.
177+
from threading import Lock
178+
179+
ctx._dask_resources = set()
180+
ctx._dask_resources_lock = Lock()
175181

176182

177183
def init_once():
@@ -181,6 +187,11 @@ def init_once():
181187
global multi_buffer
182188

183189
if ucxx is not None:
190+
# Ensure reallocation of Dask resources tracker if the UCXX context was
191+
# reset since the previous `init_once()` call. This may happen in tests,
192+
# where the `ucxx_loop` fixture will reset the context after each test.
193+
_allocate_dask_resources_tracker()
194+
184195
return
185196

186197
# remove/process dask.ucx flags for valid ucx options
@@ -243,6 +254,7 @@ def init_once():
243254
# environment, so the user's external environment can safely
244255
# override things here.
245256
ucxx.init(options=ucx_config, env_takes_precedence=True)
257+
_allocate_dask_resources_tracker()
246258

247259
pool_size_str = dask.config.get("distributed.rmm.pool-size")
248260

python/ucxx/ucxx/_lib/libucxx.pyx

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,23 @@ cdef class UCXWorker():
617617
with nogil:
618618
self._worker.get().initBlockingProgressMode()
619619

620+
def arm(self) -> bool:
621+
cdef bint armed
622+
623+
with nogil:
624+
armed = self._worker.get().arm()
625+
626+
return armed
627+
628+
@property
629+
def epoll_file_descriptor(self) -> int:
630+
cdef int epoll_file_descriptor = 0
631+
632+
with nogil:
633+
epoll_file_descriptor = self._worker.get().getEpollFileDescriptor()
634+
635+
return epoll_file_descriptor
636+
620637
def progress(self) -> None:
621638
with nogil:
622639
self._worker.get().progress()

python/ucxx/ucxx/_lib/ucxx_api.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ cdef extern from "<ucxx/api.h>" namespace "ucxx" nogil:
229229
uint16_t port, ucp_listener_conn_callback_t callback, void *callback_args
230230
) except +raise_py_error
231231
void initBlockingProgressMode() except +raise_py_error
232+
int getEpollFileDescriptor()
233+
bint arm() except +raise_py_error
232234
void progress()
233235
bint progressOnce()
234236
void progressWorkerEvent(int epoll_timeout)

python/ucxx/ucxx/_lib_async/application_context.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ucxx.exceptions import UCXMessageTruncatedError
1414
from ucxx.types import Tag
1515

16-
from .continuous_ucx_progress import PollingMode, ThreadMode
16+
from .continuous_ucx_progress import BlockingMode, PollingMode, ThreadMode
1717
from .endpoint import Endpoint
1818
from .exchange_peer_info import exchange_peer_info
1919
from .listener import ActiveClients, Listener, _listener_handler
@@ -56,8 +56,8 @@ def __init__(
5656
self.context = ucx_api.UCXContext(config_dict)
5757
self.worker = ucx_api.UCXWorker(
5858
self.context,
59-
enable_delayed_submission=self._enable_delayed_submission,
60-
enable_python_future=self._enable_python_future,
59+
enable_delayed_submission=self.enable_delayed_submission,
60+
enable_python_future=self.enable_python_future,
6161
)
6262

6363
self.start_notifier_thread()
@@ -82,12 +82,12 @@ def progress_mode(self, progress_mode):
8282
else:
8383
progress_mode = "thread"
8484

85-
valid_progress_modes = ["polling", "thread", "thread-polling"]
85+
valid_progress_modes = ["blocking", "polling", "thread", "thread-polling"]
8686
if not isinstance(progress_mode, str) or not any(
8787
progress_mode == m for m in valid_progress_modes
8888
):
8989
raise ValueError(
90-
f"Unknown progress mode {progress_mode}, valid modes are: "
90+
f"Unknown progress mode '{progress_mode}', valid modes are: "
9191
"'blocking', 'polling', 'thread' or 'thread-polling'"
9292
)
9393

@@ -121,8 +121,9 @@ def enable_delayed_submission(self, enable_delayed_submission):
121121
and explicit_enable_delayed_submission
122122
):
123123
raise ValueError(
124-
f"Delayed submission requested, but {self.progress_mode} does not "
125-
"support it, 'thread' or 'thread-polling' progress mode required."
124+
f"Delayed submission requested, but '{self.progress_mode}' does "
125+
"not support it, 'thread' or 'thread-polling' progress mode "
126+
"required."
126127
)
127128

128129
self._enable_delayed_submission = explicit_enable_delayed_submission
@@ -153,7 +154,7 @@ def enable_python_future(self, enable_python_future):
153154
and explicit_enable_python_future
154155
):
155156
logger.warning(
156-
f"Notifier thread requested, but {self.progress_mode} does not "
157+
f"Notifier thread requested, but '{self.progress_mode}' does not "
157158
"support it, using Python wait_yield()."
158159
)
159160
explicit_enable_python_future = False
@@ -464,6 +465,8 @@ def continuous_ucx_progress(self, event_loop=None):
464465
task = ThreadMode(self.worker, loop, polling_mode=True)
465466
elif self.progress_mode == "polling":
466467
task = PollingMode(self.worker, loop)
468+
elif self.progress_mode == "blocking":
469+
task = BlockingMode(self.worker, loop)
467470

468471
self.progress_tasks[loop] = task
469472

0 commit comments

Comments
 (0)