Skip to content

Commit 45c0f3f

Browse files
authored
Merge pull request #259 from madsbk/branch-25.08-merge-25.06
Branch 25.08 merge 25.06
2 parents e4d1c1f + 7b5bfcd commit 45c0f3f

12 files changed

Lines changed: 209 additions & 59 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ build
66
.mypy_cache
77
.hypothesis
88
__pycache__
9+
compile_commands.json

conda/environments/all_cuda-128_arch-aarch64.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ dependencies:
3838
- pytest
3939
- python>=3.10,<3.13
4040
- rapids-build-backend>=0.3.0,<0.4.0.dev0
41-
- ray-default==2.42.*,>=0.0.0a0
4241
- rmm==25.8.*,>=0.0.0a0
4342
- scikit-build-core>=0.10.0
4443
- spdlog

cpp/include/rapidsmpf/progress_thread.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,15 @@ class ProgressThread {
132132
*
133133
* @param logger The logger instance to use.
134134
* @param statistics The statistics instance to use (disabled by default).
135+
* @param sleep The duration to sleep between each progress loop iteration.
136+
* If 0, the thread yields execution instead of sleeping. Anecdotally, a 1 us
137+
* sleep time (the default) is sufficient to avoid starvation and get smooth
138+
* progress.
135139
*/
136140
ProgressThread(
137141
Communicator::Logger& logger,
138-
std::shared_ptr<Statistics> statistics = std::make_shared<Statistics>(false)
142+
std::shared_ptr<Statistics> statistics = std::make_shared<Statistics>(false),
143+
Duration sleep = std::chrono::microseconds{1}
139144
);
140145

141146
~ProgressThread();

cpp/include/rapidsmpf/shuffler/shuffler.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ class Shuffler {
215215
*
216216
* @param chunk The chunk to insert.
217217
*/
218-
void insert_into_outbox(detail::Chunk&& chunk);
218+
void insert_into_ready_postbox(detail::Chunk&& chunk);
219219

220220
/// @brief Get an new unique chunk ID.
221221
[[nodiscard]] detail::ChunkID get_new_cid();
@@ -255,10 +255,10 @@ class Shuffler {
255255
rmm::cuda_stream_view stream_;
256256
BufferResource* br_;
257257
bool active_{true};
258-
detail::PostBox<Rank>
259-
outgoing_chunks_; ///< Outgoing chunks, that are ready to be sent to other ranks.
260-
detail::PostBox<PartID> received_chunks_; ///< Received chunks, that are ready to be
261-
///< extracted by the user.
258+
detail::PostBox<Rank> outgoing_postbox_; ///< Postbox for outgoing chunks, that are
259+
///< ready to be sent to other ranks.
260+
detail::PostBox<PartID> ready_postbox_; ///< Postbox for received chunks, that are
261+
///< ready to be extracted by the user.
262262

263263
std::shared_ptr<Communicator> comm_;
264264
std::shared_ptr<ProgressThread> progress_thread_;

cpp/src/progress_thread.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,20 @@ void ProgressThread::FunctionState::operator()() {
3737
}
3838

3939
ProgressThread::ProgressThread(
40-
Communicator::Logger& logger, std::shared_ptr<Statistics> statistics
40+
Communicator::Logger& logger, std::shared_ptr<Statistics> statistics, Duration sleep
4141
)
42-
: thread_([this]() {
43-
if (!is_thread_initialized_) {
44-
// This thread needs to have a cuda context associated with it.
45-
// For now, do so by calling cudaFree to initialise the driver.
46-
RAPIDSMPF_CUDA_TRY(cudaFree(nullptr));
47-
is_thread_initialized_ = true;
48-
}
49-
return event_loop();
50-
}),
42+
: thread_(
43+
[this]() {
44+
if (!is_thread_initialized_) {
45+
// This thread needs to have a cuda context associated with it.
46+
// For now, do so by calling cudaFree to initialise the driver.
47+
RAPIDSMPF_CUDA_TRY(cudaFree(nullptr));
48+
is_thread_initialized_ = true;
49+
}
50+
return event_loop();
51+
},
52+
sleep
53+
),
5154
logger_(logger),
5255
statistics_(std::move(statistics)) {
5356
RAPIDSMPF_EXPECTS(statistics_ != nullptr, "the statistics pointer cannot be NULL");

cpp/src/shuffler/shuffler.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ class Shuffler::Progress {
175175

176176
// Check for new chunks in the inbox and send off their metadata.
177177
auto const t0_send_metadata = Clock::now();
178-
for (auto&& chunk : shuffler_.outgoing_chunks_.extract_all_ready()) {
178+
for (auto&& chunk : shuffler_.outgoing_postbox_.extract_all_ready()) {
179179
auto dst = shuffler_.partition_owner(shuffler_.comm_, chunk.pid);
180180
log.trace("send metadata to ", dst, ": ", chunk);
181181
RAPIDSMPF_EXPECTS(
@@ -282,7 +282,7 @@ class Shuffler::Progress {
282282
chunk.gpu_data =
283283
std::move(allocate_buffer(0, shuffler_.stream_, shuffler_.br_));
284284
}
285-
shuffler_.insert_into_outbox(std::move(chunk));
285+
shuffler_.insert_into_ready_postbox(std::move(chunk));
286286
}
287287
}
288288

@@ -325,7 +325,7 @@ class Shuffler::Progress {
325325
auto chunk = extract_value(in_transit_chunks_, cid);
326326
auto future = extract_value(in_transit_futures_, cid);
327327
chunk.gpu_data = shuffler_.comm_->get_gpu_data(std::move(future));
328-
shuffler_.insert_into_outbox(std::move(chunk));
328+
shuffler_.insert_into_ready_postbox(std::move(chunk));
329329
}
330330
}
331331

@@ -356,7 +356,7 @@ class Shuffler::Progress {
356356
|| !(
357357
fire_and_forget_.empty() && incoming_chunks_.empty()
358358
&& outgoing_chunks_.empty() && in_transit_chunks_.empty()
359-
&& in_transit_futures_.empty() && shuffler_.outgoing_chunks_.empty()
359+
&& in_transit_futures_.empty() && shuffler_.outgoing_postbox_.empty()
360360
))
361361
? ProgressThread::ProgressState::InProgress
362362
: ProgressThread::ProgressState::Done;
@@ -404,13 +404,13 @@ Shuffler::Shuffler(
404404
partition_owner{partition_owner},
405405
stream_{stream},
406406
br_{br},
407-
outgoing_chunks_{
407+
outgoing_postbox_{
408408
[this](PartID pid) -> Rank {
409409
return this->partition_owner(this->comm_, pid);
410410
}, // extract Rank from pid
411411
static_cast<std::size_t>(comm->nranks())
412412
},
413-
received_chunks_{
413+
ready_postbox_{
414414
[](PartID pid) -> PartID { return pid; }, // identity mapping
415415
static_cast<std::size_t>(total_num_partitions),
416416
},
@@ -459,14 +459,14 @@ void Shuffler::shutdown() {
459459
}
460460
}
461461

462-
void Shuffler::insert_into_outbox(detail::Chunk&& chunk) {
462+
void Shuffler::insert_into_ready_postbox(detail::Chunk&& chunk) {
463463
auto& log = comm_->logger();
464464
log.trace("insert_into_outbox: ", chunk);
465465
auto pid = chunk.pid;
466466
if (chunk.expected_num_chunks) {
467467
finish_counter_.move_goalpost(chunk.pid, chunk.expected_num_chunks);
468468
} else {
469-
received_chunks_.insert(std::move(chunk));
469+
ready_postbox_.insert(std::move(chunk));
470470
}
471471
finish_counter_.add_finished_chunk(pid);
472472
}
@@ -481,9 +481,9 @@ void Shuffler::insert(detail::Chunk&& chunk) {
481481
statistics_->add_bytes_stat("shuffle-payload-send", chunk.gpu_data->size);
482482
statistics_->add_bytes_stat("shuffle-payload-recv", chunk.gpu_data->size);
483483
}
484-
insert_into_outbox(std::move(chunk));
484+
insert_into_ready_postbox(std::move(chunk));
485485
} else {
486-
outgoing_chunks_.insert(std::move(chunk));
486+
outgoing_postbox_.insert(std::move(chunk));
487487
}
488488
}
489489

@@ -555,7 +555,7 @@ std::vector<PackedData> Shuffler::extract(PartID pid) {
555555
// Protect the chunk extraction to make sure we don't get a chunk
556556
// `Shuffler::spill` is in the process of spilling.
557557
std::unique_lock<std::mutex> lock(outbox_spilling_mutex_);
558-
auto chunks = received_chunks_.extract(pid);
558+
auto chunks = ready_postbox_.extract(pid);
559559
lock.unlock();
560560
std::vector<PackedData> ret;
561561
ret.reserve(chunks.size());
@@ -610,7 +610,7 @@ std::size_t Shuffler::spill(std::optional<std::size_t> amount) {
610610
if (spill_need > 0) {
611611
std::lock_guard<std::mutex> lock(outbox_spilling_mutex_);
612612
spilled =
613-
postbox_spilling(br_, comm_->logger(), stream_, received_chunks_, spill_need);
613+
postbox_spilling(br_, comm_->logger(), stream_, ready_postbox_, spill_need);
614614
}
615615
return spilled;
616616
}
@@ -625,7 +625,7 @@ detail::ChunkID Shuffler::get_new_cid() {
625625

626626
std::string Shuffler::str() const {
627627
std::stringstream ss;
628-
ss << "Shuffler(outgoing=" << outgoing_chunks_ << ", received=" << received_chunks_
628+
ss << "Shuffler(outgoing=" << outgoing_postbox_ << ", received=" << ready_postbox_
629629
<< ", " << finish_counter_;
630630
return ss.str();
631631
}

dependencies.yaml

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ channels:
4949
dependencies:
5050
build-universal:
5151
common:
52-
- output_types: [conda, pyproject]
52+
- output_types: [conda, pyproject, requirements]
5353
packages:
5454
- &cmake_ver cmake>=3.26.4,!=3.30.0
5555
- ninja
@@ -87,18 +87,18 @@ dependencies:
8787
- cuda-nvcc
8888
rapids_build_skbuild:
8989
common:
90-
- output_types: [conda, pyproject]
90+
- output_types: [conda, pyproject, requirements]
9191
packages:
9292
- &rapids_build_backend rapids-build-backend>=0.3.0,<0.4.0.dev0
9393
- output_types: conda
9494
packages:
9595
- scikit-build-core>=0.10.0
96-
- output_types: [pyproject]
96+
- output_types: [pyproject, requirements]
9797
packages:
9898
- scikit-build-core[pyproject]>=0.10.0
9999
build-python:
100100
common:
101-
- output_types: [conda, pyproject]
101+
- output_types: [conda, pyproject, requirements]
102102
packages:
103103
- cython>=3.0.3
104104
- *rmm_unsuffixed
@@ -176,12 +176,14 @@ dependencies:
176176
- rapidsmpf==25.8.*,>=0.0.0a0
177177
test_python:
178178
common:
179-
- output_types: [conda, pyproject]
179+
- output_types: conda
180+
packages:
181+
- gdb
182+
- output_types: [conda, pyproject, requirements]
180183
packages:
181184
- cudf==25.8.*,>=0.0.0a0
182185
- dask-cuda==25.8.*,>=0.0.0a0
183186
- dask-cudf==25.8.*,>=0.0.0a0
184-
- gdb
185187
- psutil
186188
- pytest
187189
- ucxx==0.45.*,>=0.0.0a0
@@ -207,8 +209,17 @@ dependencies:
207209
- myst-parser
208210
- numpydoc
209211
- pydata-sphinx-theme
210-
- ray-default==2.42.*,>=0.0.0a0
211212
- sphinx
212213
- sphinx-autobuild
213214
- sphinx-copybutton
214215
- ucxx==0.45.*,>=0.0.0a0
216+
specific:
217+
- output_types: conda
218+
matrices:
219+
- matrix:
220+
arch: x86_64
221+
packages:
222+
- ray-default==2.42.*,>=0.0.0a0
223+
- matrix:
224+
arch: aarch64
225+
packages:

python/rapidsmpf/rapidsmpf/integrations/dask/shuffler.py

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from __future__ import annotations
66

7+
import threading
78
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
89

910
from distributed import get_worker
@@ -22,24 +23,62 @@
2223
if TYPE_CHECKING:
2324
from collections.abc import Callable, MutableMapping, Sequence
2425

25-
from distributed import Worker
26+
from distributed import Client, Worker
2627

2728

28-
_shuffle_counter: int = 0
29+
# Set of available shuffle IDs
30+
_shuffle_id_vacancy: set[int] = set(range(Shuffler.max_concurrent_shuffles))
31+
_shuffle_id_vacancy_lock: threading.Lock = threading.Lock()
2932

3033

31-
def get_shuffle_id() -> int:
34+
def _get_new_shuffle_id(client: Client) -> int:
3235
"""
33-
Return the unique id for a new shuffle.
36+
Get a new available shuffle ID.
37+
38+
Since RapidsMPF only supports a limited number of shuffler instances at
39+
any given time, this function maintains a shared pool of shuffle IDs.
40+
41+
If no IDs are available locally, it queries all workers for IDs in use,
42+
updates the vacancy set accordingly, and retries. If all IDs are in use
43+
across the cluster, an error is raised.
44+
45+
Parameters
46+
----------
47+
client
48+
A Dask distributed client used to query workers for active shuffle IDs.
3449
3550
Returns
3651
-------
37-
The enumerated integer id for the current shuffle.
52+
A unique shuffle ID not currently in use.
53+
54+
Raises
55+
------
56+
ValueError
57+
If all shuffle IDs are currently in use across the cluster.
3858
"""
39-
global _shuffle_counter # noqa: PLW0603
59+
global _shuffle_id_vacancy # noqa: PLW0603
4060

41-
_shuffle_counter += 1
42-
return _shuffle_counter
61+
with _shuffle_id_vacancy_lock:
62+
if not _shuffle_id_vacancy:
63+
64+
def get_occupied_ids(dask_worker: Worker) -> set[int]:
65+
ctx = get_worker_context(dask_worker)
66+
with ctx.lock:
67+
return set(ctx.shufflers.keys())
68+
69+
# We start with setting all IDs as vacant and then subtract all
70+
# IDs occupied on any one worker.
71+
_shuffle_id_vacancy = set(range(Shuffler.max_concurrent_shuffles))
72+
_shuffle_id_vacancy.difference_update(
73+
*client.run(get_occupied_ids).values()
74+
)
75+
if not _shuffle_id_vacancy:
76+
raise ValueError(
77+
f"Cannot shuffle more than {Shuffler.max_concurrent_shuffles} "
78+
"times in a single Dask compute."
79+
)
80+
81+
return _shuffle_id_vacancy.pop()
4382

4483

4584
def get_shuffler(
@@ -280,13 +319,17 @@ def _extract_partition(
280319
-------
281320
Extracted DataFrame partition.
282321
"""
283-
if callback is None:
284-
raise ValueError("Missing callback in _extract_partition.")
285-
return callback(
286-
partition_id,
287-
column_names,
288-
get_shuffler(shuffle_id),
289-
)
322+
shuffler = get_shuffler(shuffle_id)
323+
try:
324+
return callback(
325+
partition_id,
326+
column_names,
327+
shuffler,
328+
)
329+
finally:
330+
if shuffler.finished():
331+
ctx = get_worker_context()
332+
del ctx.shufflers[shuffle_id]
290333

291334

292335
def rapidsmpf_shuffle_graph(
@@ -383,7 +426,7 @@ def rapidsmpf_shuffle_graph(
383426
"""
384427
# Get the shuffle id
385428
client = get_dask_client()
386-
shuffle_id = get_shuffle_id()
429+
shuffle_id = _get_new_shuffle_id(client)
387430

388431
# Check integration argument
389432
if not isinstance(integration, DaskIntegration):

python/rapidsmpf/rapidsmpf/shuffler.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from libc.stdint cimport uint16_t, uint32_t
4+
from libc.stdint cimport uint8_t, uint32_t
55
from libcpp cimport bool
66
from libcpp.memory cimport shared_ptr, unique_ptr
77
from libcpp.string cimport string
@@ -38,7 +38,7 @@ cdef extern from "<rapidsmpf/shuffler/shuffler.hpp>" nogil:
3838
cpp_Shuffler(
3939
shared_ptr[cpp_Communicator] comm,
4040
shared_ptr[cpp_ProgressThread] comm,
41-
uint16_t op_id,
41+
uint8_t op_id,
4242
uint32_t total_num_partitions,
4343
cuda_stream_view stream,
4444
cpp_BufferResource *br,

python/rapidsmpf/rapidsmpf/shuffler.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def unpack_and_concat(
3434
) -> Table: ...
3535

3636
class Shuffler:
37+
max_concurrent_shuffles: int
3738
def __init__(
3839
self,
3940
comm: Communicator,

0 commit comments

Comments
 (0)