Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
329 changes: 325 additions & 4 deletions src/nccl/safe.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use super::{result, sys};
use crate::driver::{CudaContext, CudaStream, DevicePtr, DevicePtrMut};
use crate::driver::{
CudaContext, CudaStream, CudaView, CudaViewMut, DevicePtr, DevicePtrMut, SyncOnDrop,
};
use std::{mem::MaybeUninit, sync::Arc, vec, vec::Vec};

pub use result::{group_end, group_start};
Expand Down Expand Up @@ -98,16 +100,16 @@ impl Comm {
/// let n_devices = CudaDevice::count().unwrap() as usize;
/// let devices : Vec<_> = (0..n_devices).flat_map(CudaDevice::new).collect();
/// let comms = Comm::from_devices(devices).unwrap();
/// group_start().unwrap();
/// let mut group = comms.group();
Comment thread
chelsea0x3b marked this conversation as resolved.
/// (0..n_devices).map(|i| {
/// let comm = &comms[i];
/// let dev = comm.device();
/// let slice = dev.htod_copy(vec![(i + 1) as f32 * 1.0; n]).unwrap();
/// let mut slice_receive = dev.alloc_zeros::<f32>(n).unwrap();
/// comm.all_reduce(&slice, &mut slice_receive, &ReduceOp::Sum)
/// group.all_reduce(&slice, &mut slice_receive, &ReduceOp::Sum)
/// .unwrap();
/// });
/// group_start().unwrap();
/// drop(group);
/// ```
pub fn from_devices(streams: Vec<Arc<CudaStream>>) -> Result<Vec<Self>, result::NcclError> {
let n_streams = streams.len();
Expand Down Expand Up @@ -445,6 +447,325 @@ impl Comm {
}
}

/// An NCCL Group. Calls [group_start()] via [Comm::group()], and [group_end()] on Drop.
///
/// Works with the event tracking system in [CudaContext] by delaying the drop of [SyncOnDrop] of all
/// [CudaView]/[CudaViewMut] until **after** [group_end()] is called on drop.
///
/// Note that the main difference between the calls on [Group] vs [Comm] is that group **requires**
/// [CudaView]/[CudaViewMut]. This is because we need to enforce that the view's lifetimes outlive the
/// group's lifetime. This is not necessarily possible with [DevicePtr]/[DevicePtrMut] because they capture
/// the &self lifetime of the borrow instead of the view's original data lifetime.
///
/// When using [Group], you will likely need to create all the views you intend to use within the group
/// **before** starting the group.
///
/// ```ignore
/// let send_view: CudaView<'_, u8> = ...;
/// let recv_view: CudaViewMut<'_, u8> = ...;
/// let mut group = comm.group();
/// group.all_gather(send_view, recv_view);
/// ```
///
/// If you create the views after the group is created, rust will complain about the views not outliving the
/// group lifetime.
///
/// See [nvidia docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/groups.html)
#[derive(Debug)]
pub struct Group<'a> {
comm: &'a Comm,
syncs: Vec<SyncOnDrop<'a>>,
}
Comment on lines +475 to +478

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Group must not be Send because NCCL groups are thread-local. According to NCCL documentation, ncclGroupEnd must be called by the same thread that called ncclGroupStart. Since Group implements RAII for these calls, moving a Group to another thread would result in undefined behavior when it is dropped. Adding a PhantomData<*const ()> marker will prevent the struct from being Send.

pub struct Group<'a> {
    comm: &'a Comm,
    syncs: Vec<SyncOnDrop<'a>>,
    _marker: std::marker::PhantomData<*const ()>,
}


impl<'a> Drop for Group<'a> {
fn drop(&mut self) {
group_end().unwrap();
}
}

impl Comm {
/// Initializes a new group call: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/groups.html
pub fn group(&self) -> Group<'_> {
group_start().unwrap();
Group {
comm: self,
syncs: Vec::new(),
}
}
Comment on lines +488 to +494

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

group_start() can fail (e.g., if the maximum number of nested groups is reached). This method should return a Result to allow the caller to handle such errors gracefully instead of panicking via unwrap().

    pub fn group(&self) -> Result<Group<'_>, result::NcclError> {
        group_start()?;
        Ok(Group {
            comm: self,
            syncs: Vec::new(),
            _marker: std::marker::PhantomData,
        })
    }

}

impl<'g> Group<'g> {
/// The underlying [Comm] object.
pub fn comm(&self) -> &'g Comm {
self.comm
}

/// Send data to one peer, see [cuda docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsend)
pub fn send<'s: 'g, T: NcclType>(
&mut self,
data: CudaView<'s, T>,
peer: i32,
) -> Result<(), result::NcclError> {
let count = data.len();
let (src, record_src) = data.view_ptr(&self.comm.stream);
unsafe {
result::send(
src as _,
count,
T::as_nccl_type(),
peer,
self.comm.comm,
self.comm.stream.cu_stream as _,
)
}?;
self.syncs.push(record_src);
Ok(())
}

/// Receive data from one peer, see [cuda docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclrecv)
pub fn recv<'r: 'g, T: NcclType>(
&mut self,
buff: CudaViewMut<'r, T>,
peer: i32,
) -> Result<result::NcclStatus, result::NcclError> {
let count = buff.len();
let (dst, record_dst) = buff.view_ptr_mut(&self.comm.stream);
let status = unsafe {
result::recv(
dst as _,
count,
T::as_nccl_type(),
peer,
self.comm.comm,
self.comm.stream.cu_stream as _,
)
}?;
self.syncs.push(record_dst);
Ok(status)
}

/// Broadcasts a value from `root` rank to every other ranks `recvbuff`.
/// sendbuff is ignored on ranks other than `root`, so you can pass `None`
/// on non-root ranks.
///
/// sendbuff must be Some on root rank!
///
/// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#broadcast)
pub fn broadcast<'s: 'g, 'r: 'g, T: NcclType>(
&mut self,
sendbuff: Option<CudaView<'s, T>>,
recvbuff: CudaViewMut<'r, T>,
root: i32,
) -> Result<result::NcclStatus, result::NcclError> {
debug_assert!(sendbuff.is_some() || self.comm.rank != root as usize);
let count = recvbuff.len();
Comment thread
chelsea0x3b marked this conversation as resolved.
let (src, record_src) = sendbuff.map(|b| b.view_ptr(&self.comm.stream)).unzip();
let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream);
let status = unsafe {
result::broadcast(
src.map(|ptr| ptr as _).unwrap_or(std::ptr::null()),
dst as _,
count,
T::as_nccl_type(),
root,
self.comm.comm,
self.comm.stream.cu_stream as _,
)
}?;
if let Some(record_src) = record_src {
self.syncs.push(record_src);
}
self.syncs.push(record_dst);
Ok(status)
}

/// In place version of [Comm::broadcast()].
/// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#broadcast)
pub fn broadcast_in_place<'r: 'g, T: NcclType>(
&mut self,
recvbuff: CudaViewMut<'r, T>,
root: i32,
) -> Result<result::NcclStatus, result::NcclError> {
let count = recvbuff.len();
let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream);
let status = unsafe {
result::broadcast(
dst as _,
dst as _,
count,
T::as_nccl_type(),
root,
self.comm.comm,
self.comm.stream.cu_stream as _,
)
}?;
self.syncs.push(record_dst);
Ok(status)
}

/// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather)
pub fn all_gather<'s: 'g, 'r: 'g, T: NcclType>(
&mut self,
sendbuff: CudaView<'s, T>,
recvbuff: CudaViewMut<'r, T>,
) -> Result<result::NcclStatus, result::NcclError> {
let sendcount = sendbuff.len();
Comment thread
chelsea0x3b marked this conversation as resolved.
let (src, record_src) = sendbuff.view_ptr(&self.comm.stream);
let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream);
let status = unsafe {
result::all_gather(
src as _,
dst as _,
sendcount,
T::as_nccl_type(),
self.comm.comm,
self.comm.stream.cu_stream as _,
)
}?;
self.syncs.push(record_src);
self.syncs.push(record_dst);
Ok(status)
}

/// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce)
pub fn all_reduce<'s: 'g, 'r: 'g, T: NcclType>(
&mut self,
sendbuff: CudaView<'s, T>,
recvbuff: CudaViewMut<'r, T>,
reduce_op: &ReduceOp,
) -> Result<result::NcclStatus, result::NcclError> {
let count = sendbuff.len();
Comment thread
chelsea0x3b marked this conversation as resolved.
let (src, record_src) = sendbuff.view_ptr(&self.comm.stream);
let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream);
let status = unsafe {
result::all_reduce(
src as _,
dst as _,
count,
T::as_nccl_type(),
convert_to_nccl_reduce_op(reduce_op),
self.comm.comm,
self.comm.stream.cu_stream as _,
)
}?;
self.syncs.push(record_src);
self.syncs.push(record_dst);
Ok(status)
}

/// In place version of [Comm::all_reduce()].
/// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce)
pub fn all_reduce_in_place<'r: 'g, T: NcclType>(
&mut self,
buff: CudaViewMut<'r, T>,
reduce_op: &ReduceOp,
) -> Result<result::NcclStatus, result::NcclError> {
let count = buff.len();
let (dst, record_dst) = buff.view_ptr_mut(&self.comm.stream);
let status = unsafe {
result::all_reduce(
dst as _,
dst as _,
count,
T::as_nccl_type(),
convert_to_nccl_reduce_op(reduce_op),
self.comm.comm,
self.comm.stream.cu_stream as _,
)
}?;
self.syncs.push(record_dst);
Ok(status)
}

/// Reduces the sendbuff from all ranks into the recvbuff on the
/// `root` rank.
///
/// recvbuff must be Some on the root rank!
///
/// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reduce)
pub fn reduce<'s: 'g, 'r: 'g, T: NcclType>(
&mut self,
sendbuff: CudaView<'s, T>,
recvbuff: Option<CudaViewMut<'r, T>>,
reduce_op: &ReduceOp,
root: i32,
) -> Result<result::NcclStatus, result::NcclError> {
debug_assert!(recvbuff.is_some() || self.comm.rank != root as usize);
let count = sendbuff.len();
Comment thread
chelsea0x3b marked this conversation as resolved.
let (src, record_src) = sendbuff.view_ptr(&self.comm.stream);
let (dst, record_dst) = recvbuff.map(|b| b.view_ptr_mut(&self.comm.stream)).unzip();
let status = unsafe {
result::reduce(
src as _,
dst.map(|ptr| ptr as _).unwrap_or(std::ptr::null_mut()),
count,
T::as_nccl_type(),
convert_to_nccl_reduce_op(reduce_op),
root,
self.comm.comm,
self.comm.stream.cu_stream as _,
)
}?;
self.syncs.push(record_src);
if let Some(record_dst) = record_dst {
self.syncs.push(record_dst);
}
Ok(status)
}

/// In place version of [Comm::reduce()].
/// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reduce)
pub fn reduce_in_place<'s: 'g, T: NcclType>(
&mut self,
recvbuff: CudaViewMut<'s, T>,
reduce_op: &ReduceOp,
root: i32,
) -> Result<result::NcclStatus, result::NcclError> {
let count = recvbuff.len();
let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream);
let status = unsafe {
result::reduce(
dst as _,
dst as _,
count,
T::as_nccl_type(),
convert_to_nccl_reduce_op(reduce_op),
root,
self.comm.comm,
self.comm.stream.cu_stream as _,
)
}?;
self.syncs.push(record_dst);
Ok(status)
}

/// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter)
pub fn reduce_scatter<'s: 'g, 'r: 'g, T: NcclType>(
&mut self,
sendbuff: CudaView<'s, T>,
recvbuff: CudaViewMut<'r, T>,
reduce_op: &ReduceOp,
) -> Result<result::NcclStatus, result::NcclError> {
let count = recvbuff.len();
Comment thread
chelsea0x3b marked this conversation as resolved.
let (src, record_src) = sendbuff.view_ptr(&self.comm.stream);
let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream);
let status = unsafe {
result::reduce_scatter(
src as _,
dst as _,
count,
T::as_nccl_type(),
convert_to_nccl_reduce_op(reduce_op),
self.comm.comm,
self.comm.stream.cu_stream as _,
)
}?;
self.syncs.push(record_src);
self.syncs.push(record_dst);
Ok(status)
}
}

#[macro_export]
macro_rules! group {
($x:block) => {
Expand Down
Loading