diff --git a/src/nccl/safe.rs b/src/nccl/safe.rs index 51f7dd33..8b176274 100644 --- a/src/nccl/safe.rs +++ b/src/nccl/safe.rs @@ -1,5 +1,7 @@ use super::{result, sys}; -use crate::driver::{CudaContext, CudaStream, DevicePtr, DevicePtrMut}; +use crate::driver::{ + CudaContext, CudaStream, CudaView, CudaViewMut, DevicePtr, DevicePtrMut, SyncOnDrop, +}; use std::{mem::MaybeUninit, sync::Arc, vec, vec::Vec}; pub use result::{group_end, group_start}; @@ -98,16 +100,16 @@ impl Comm { /// let n_devices = CudaDevice::count().unwrap() as usize; /// let devices : Vec<_> = (0..n_devices).flat_map(CudaDevice::new).collect(); /// let comms = Comm::from_devices(devices).unwrap(); - /// group_start().unwrap(); + /// let mut group = comms.group(); /// (0..n_devices).map(|i| { /// let comm = &comms[i]; /// let dev = comm.device(); /// let slice = dev.htod_copy(vec![(i + 1) as f32 * 1.0; n]).unwrap(); /// let mut slice_receive = dev.alloc_zeros::(n).unwrap(); - /// comm.all_reduce(&slice, &mut slice_receive, &ReduceOp::Sum) + /// group.all_reduce(&slice, &mut slice_receive, &ReduceOp::Sum) /// .unwrap(); /// }); - /// group_start().unwrap(); + /// drop(group); /// ``` pub fn from_devices(streams: Vec>) -> Result, result::NcclError> { let n_streams = streams.len(); @@ -445,6 +447,325 @@ impl Comm { } } +/// An NCCL Group. Calls [group_start()] via [Comm::group()], and [group_end()] on Drop. +/// +/// Works with the event tracking system in [CudaContext] by delaying the drop of [SyncOnDrop] of all +/// [CudaView]/[CudaViewMut] until **after** [group_end()] is called on drop. +/// +/// Note that the main difference between the calls on [Group] vs [Comm] is that group **requires** +/// [CudaView]/[CudaViewMut]. This is because we need to enforce that the view's lifetimes outlive the +/// group's lifetime. This is not necessarily possible with [DevicePtr]/[DevicePtrMut] because they capture +/// the &self lifetime of the borrow instead of the view's original data lifetime. +/// +/// When using [Group], you will likely need to create all the views you intend to use within the group +/// **before** starting the group. +/// +/// ```ignore +/// let send_view: CudaView<'_, u8> = ...; +/// let recv_view: CudaViewMut<'_, u8> = ...; +/// let mut group = comm.group(); +/// group.all_gather(send_view, recv_view); +/// ``` +/// +/// If you create the views after the group is created, rust will complain about the views not outliving the +/// group lifetime. +/// +/// See [nvidia docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/groups.html) +#[derive(Debug)] +pub struct Group<'a> { + comm: &'a Comm, + syncs: Vec>, +} + +impl<'a> Drop for Group<'a> { + fn drop(&mut self) { + group_end().unwrap(); + } +} + +impl Comm { + /// Initializes a new group call: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/groups.html + pub fn group(&self) -> Group<'_> { + group_start().unwrap(); + Group { + comm: self, + syncs: Vec::new(), + } + } +} + +impl<'g> Group<'g> { + /// The underlying [Comm] object. + pub fn comm(&self) -> &'g Comm { + self.comm + } + + /// Send data to one peer, see [cuda docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsend) + pub fn send<'s: 'g, T: NcclType>( + &mut self, + data: CudaView<'s, T>, + peer: i32, + ) -> Result<(), result::NcclError> { + let count = data.len(); + let (src, record_src) = data.view_ptr(&self.comm.stream); + unsafe { + result::send( + src as _, + count, + T::as_nccl_type(), + peer, + self.comm.comm, + self.comm.stream.cu_stream as _, + ) + }?; + self.syncs.push(record_src); + Ok(()) + } + + /// Receive data from one peer, see [cuda docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclrecv) + pub fn recv<'r: 'g, T: NcclType>( + &mut self, + buff: CudaViewMut<'r, T>, + peer: i32, + ) -> Result { + let count = buff.len(); + let (dst, record_dst) = buff.view_ptr_mut(&self.comm.stream); + let status = unsafe { + result::recv( + dst as _, + count, + T::as_nccl_type(), + peer, + self.comm.comm, + self.comm.stream.cu_stream as _, + ) + }?; + self.syncs.push(record_dst); + Ok(status) + } + + /// Broadcasts a value from `root` rank to every other ranks `recvbuff`. + /// sendbuff is ignored on ranks other than `root`, so you can pass `None` + /// on non-root ranks. + /// + /// sendbuff must be Some on root rank! + /// + /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#broadcast) + pub fn broadcast<'s: 'g, 'r: 'g, T: NcclType>( + &mut self, + sendbuff: Option>, + recvbuff: CudaViewMut<'r, T>, + root: i32, + ) -> Result { + debug_assert!(sendbuff.is_some() || self.comm.rank != root as usize); + let count = recvbuff.len(); + let (src, record_src) = sendbuff.map(|b| b.view_ptr(&self.comm.stream)).unzip(); + let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream); + let status = unsafe { + result::broadcast( + src.map(|ptr| ptr as _).unwrap_or(std::ptr::null()), + dst as _, + count, + T::as_nccl_type(), + root, + self.comm.comm, + self.comm.stream.cu_stream as _, + ) + }?; + if let Some(record_src) = record_src { + self.syncs.push(record_src); + } + self.syncs.push(record_dst); + Ok(status) + } + + /// In place version of [Comm::broadcast()]. + /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#broadcast) + pub fn broadcast_in_place<'r: 'g, T: NcclType>( + &mut self, + recvbuff: CudaViewMut<'r, T>, + root: i32, + ) -> Result { + let count = recvbuff.len(); + let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream); + let status = unsafe { + result::broadcast( + dst as _, + dst as _, + count, + T::as_nccl_type(), + root, + self.comm.comm, + self.comm.stream.cu_stream as _, + ) + }?; + self.syncs.push(record_dst); + Ok(status) + } + + /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather) + pub fn all_gather<'s: 'g, 'r: 'g, T: NcclType>( + &mut self, + sendbuff: CudaView<'s, T>, + recvbuff: CudaViewMut<'r, T>, + ) -> Result { + let sendcount = sendbuff.len(); + let (src, record_src) = sendbuff.view_ptr(&self.comm.stream); + let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream); + let status = unsafe { + result::all_gather( + src as _, + dst as _, + sendcount, + T::as_nccl_type(), + self.comm.comm, + self.comm.stream.cu_stream as _, + ) + }?; + self.syncs.push(record_src); + self.syncs.push(record_dst); + Ok(status) + } + + /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce) + pub fn all_reduce<'s: 'g, 'r: 'g, T: NcclType>( + &mut self, + sendbuff: CudaView<'s, T>, + recvbuff: CudaViewMut<'r, T>, + reduce_op: &ReduceOp, + ) -> Result { + let count = sendbuff.len(); + let (src, record_src) = sendbuff.view_ptr(&self.comm.stream); + let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream); + let status = unsafe { + result::all_reduce( + src as _, + dst as _, + count, + T::as_nccl_type(), + convert_to_nccl_reduce_op(reduce_op), + self.comm.comm, + self.comm.stream.cu_stream as _, + ) + }?; + self.syncs.push(record_src); + self.syncs.push(record_dst); + Ok(status) + } + + /// In place version of [Comm::all_reduce()]. + /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce) + pub fn all_reduce_in_place<'r: 'g, T: NcclType>( + &mut self, + buff: CudaViewMut<'r, T>, + reduce_op: &ReduceOp, + ) -> Result { + let count = buff.len(); + let (dst, record_dst) = buff.view_ptr_mut(&self.comm.stream); + let status = unsafe { + result::all_reduce( + dst as _, + dst as _, + count, + T::as_nccl_type(), + convert_to_nccl_reduce_op(reduce_op), + self.comm.comm, + self.comm.stream.cu_stream as _, + ) + }?; + self.syncs.push(record_dst); + Ok(status) + } + + /// Reduces the sendbuff from all ranks into the recvbuff on the + /// `root` rank. + /// + /// recvbuff must be Some on the root rank! + /// + /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reduce) + pub fn reduce<'s: 'g, 'r: 'g, T: NcclType>( + &mut self, + sendbuff: CudaView<'s, T>, + recvbuff: Option>, + reduce_op: &ReduceOp, + root: i32, + ) -> Result { + debug_assert!(recvbuff.is_some() || self.comm.rank != root as usize); + let count = sendbuff.len(); + let (src, record_src) = sendbuff.view_ptr(&self.comm.stream); + let (dst, record_dst) = recvbuff.map(|b| b.view_ptr_mut(&self.comm.stream)).unzip(); + let status = unsafe { + result::reduce( + src as _, + dst.map(|ptr| ptr as _).unwrap_or(std::ptr::null_mut()), + count, + T::as_nccl_type(), + convert_to_nccl_reduce_op(reduce_op), + root, + self.comm.comm, + self.comm.stream.cu_stream as _, + ) + }?; + self.syncs.push(record_src); + if let Some(record_dst) = record_dst { + self.syncs.push(record_dst); + } + Ok(status) + } + + /// In place version of [Comm::reduce()]. + /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reduce) + pub fn reduce_in_place<'s: 'g, T: NcclType>( + &mut self, + recvbuff: CudaViewMut<'s, T>, + reduce_op: &ReduceOp, + root: i32, + ) -> Result { + let count = recvbuff.len(); + let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream); + let status = unsafe { + result::reduce( + dst as _, + dst as _, + count, + T::as_nccl_type(), + convert_to_nccl_reduce_op(reduce_op), + root, + self.comm.comm, + self.comm.stream.cu_stream as _, + ) + }?; + self.syncs.push(record_dst); + Ok(status) + } + + /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter) + pub fn reduce_scatter<'s: 'g, 'r: 'g, T: NcclType>( + &mut self, + sendbuff: CudaView<'s, T>, + recvbuff: CudaViewMut<'r, T>, + reduce_op: &ReduceOp, + ) -> Result { + let count = recvbuff.len(); + let (src, record_src) = sendbuff.view_ptr(&self.comm.stream); + let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream); + let status = unsafe { + result::reduce_scatter( + src as _, + dst as _, + count, + T::as_nccl_type(), + convert_to_nccl_reduce_op(reduce_op), + self.comm.comm, + self.comm.stream.cu_stream as _, + ) + }?; + self.syncs.push(record_src); + self.syncs.push(record_dst); + Ok(status) + } +} + #[macro_export] macro_rules! group { ($x:block) => {