-
-
Notifications
You must be signed in to change notification settings - Fork 159
Adding safe Group api to nccl
#578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| use super::{result, sys}; | ||
| use crate::driver::{CudaContext, CudaStream, DevicePtr, DevicePtrMut}; | ||
| use crate::driver::{ | ||
| CudaContext, CudaStream, CudaView, CudaViewMut, DevicePtr, DevicePtrMut, SyncOnDrop, | ||
| }; | ||
| use std::{mem::MaybeUninit, sync::Arc, vec, vec::Vec}; | ||
|
|
||
| pub use result::{group_end, group_start}; | ||
|
|
@@ -98,16 +100,16 @@ impl Comm { | |
| /// let n_devices = CudaDevice::count().unwrap() as usize; | ||
| /// let devices : Vec<_> = (0..n_devices).flat_map(CudaDevice::new).collect(); | ||
| /// let comms = Comm::from_devices(devices).unwrap(); | ||
| /// group_start().unwrap(); | ||
| /// let mut group = comms.group(); | ||
| /// (0..n_devices).map(|i| { | ||
| /// let comm = &comms[i]; | ||
| /// let dev = comm.device(); | ||
| /// let slice = dev.htod_copy(vec![(i + 1) as f32 * 1.0; n]).unwrap(); | ||
| /// let mut slice_receive = dev.alloc_zeros::<f32>(n).unwrap(); | ||
| /// comm.all_reduce(&slice, &mut slice_receive, &ReduceOp::Sum) | ||
| /// group.all_reduce(&slice, &mut slice_receive, &ReduceOp::Sum) | ||
| /// .unwrap(); | ||
| /// }); | ||
| /// group_start().unwrap(); | ||
| /// drop(group); | ||
| /// ``` | ||
| pub fn from_devices(streams: Vec<Arc<CudaStream>>) -> Result<Vec<Self>, result::NcclError> { | ||
| let n_streams = streams.len(); | ||
|
|
@@ -445,6 +447,325 @@ impl Comm { | |
| } | ||
| } | ||
|
|
||
| /// An NCCL Group. Calls [group_start()] via [Comm::group()], and [group_end()] on Drop. | ||
| /// | ||
| /// Works with the event tracking system in [CudaContext] by delaying the drop of [SyncOnDrop] of all | ||
| /// [CudaView]/[CudaViewMut] until **after** [group_end()] is called on drop. | ||
| /// | ||
| /// Note that the main difference between the calls on [Group] vs [Comm] is that group **requires** | ||
| /// [CudaView]/[CudaViewMut]. This is because we need to enforce that the view's lifetimes outlive the | ||
| /// group's lifetime. This is not necessarily possible with [DevicePtr]/[DevicePtrMut] because they capture | ||
| /// the &self lifetime of the borrow instead of the view's original data lifetime. | ||
| /// | ||
| /// When using [Group], you will likely need to create all the views you intend to use within the group | ||
| /// **before** starting the group. | ||
| /// | ||
| /// ```ignore | ||
| /// let send_view: CudaView<'_, u8> = ...; | ||
| /// let recv_view: CudaViewMut<'_, u8> = ...; | ||
| /// let mut group = comm.group(); | ||
| /// group.all_gather(send_view, recv_view); | ||
| /// ``` | ||
| /// | ||
| /// If you create the views after the group is created, rust will complain about the views not outliving the | ||
| /// group lifetime. | ||
| /// | ||
| /// See [nvidia docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/groups.html) | ||
| #[derive(Debug)] | ||
| pub struct Group<'a> { | ||
| comm: &'a Comm, | ||
| syncs: Vec<SyncOnDrop<'a>>, | ||
| } | ||
|
Comment on lines
+475
to
+478
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
pub struct Group<'a> {
comm: &'a Comm,
syncs: Vec<SyncOnDrop<'a>>,
_marker: std::marker::PhantomData<*const ()>,
} |
||
|
|
||
| impl<'a> Drop for Group<'a> { | ||
| fn drop(&mut self) { | ||
| group_end().unwrap(); | ||
| } | ||
| } | ||
|
|
||
| impl Comm { | ||
| /// Initializes a new group call: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/groups.html | ||
| pub fn group(&self) -> Group<'_> { | ||
| group_start().unwrap(); | ||
| Group { | ||
| comm: self, | ||
| syncs: Vec::new(), | ||
| } | ||
| } | ||
|
Comment on lines
+488
to
+494
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
pub fn group(&self) -> Result<Group<'_>, result::NcclError> {
group_start()?;
Ok(Group {
comm: self,
syncs: Vec::new(),
_marker: std::marker::PhantomData,
})
} |
||
| } | ||
|
|
||
| impl<'g> Group<'g> { | ||
| /// The underlying [Comm] object. | ||
| pub fn comm(&self) -> &'g Comm { | ||
| self.comm | ||
| } | ||
|
|
||
| /// Send data to one peer, see [cuda docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsend) | ||
| pub fn send<'s: 'g, T: NcclType>( | ||
| &mut self, | ||
| data: CudaView<'s, T>, | ||
| peer: i32, | ||
| ) -> Result<(), result::NcclError> { | ||
| let count = data.len(); | ||
| let (src, record_src) = data.view_ptr(&self.comm.stream); | ||
| unsafe { | ||
| result::send( | ||
| src as _, | ||
| count, | ||
| T::as_nccl_type(), | ||
| peer, | ||
| self.comm.comm, | ||
| self.comm.stream.cu_stream as _, | ||
| ) | ||
| }?; | ||
| self.syncs.push(record_src); | ||
| Ok(()) | ||
| } | ||
|
|
||
| /// Receive data from one peer, see [cuda docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclrecv) | ||
| pub fn recv<'r: 'g, T: NcclType>( | ||
| &mut self, | ||
| buff: CudaViewMut<'r, T>, | ||
| peer: i32, | ||
| ) -> Result<result::NcclStatus, result::NcclError> { | ||
| let count = buff.len(); | ||
| let (dst, record_dst) = buff.view_ptr_mut(&self.comm.stream); | ||
| let status = unsafe { | ||
| result::recv( | ||
| dst as _, | ||
| count, | ||
| T::as_nccl_type(), | ||
| peer, | ||
| self.comm.comm, | ||
| self.comm.stream.cu_stream as _, | ||
| ) | ||
| }?; | ||
| self.syncs.push(record_dst); | ||
| Ok(status) | ||
| } | ||
|
|
||
| /// Broadcasts a value from `root` rank to every other ranks `recvbuff`. | ||
| /// sendbuff is ignored on ranks other than `root`, so you can pass `None` | ||
| /// on non-root ranks. | ||
| /// | ||
| /// sendbuff must be Some on root rank! | ||
| /// | ||
| /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#broadcast) | ||
| pub fn broadcast<'s: 'g, 'r: 'g, T: NcclType>( | ||
| &mut self, | ||
| sendbuff: Option<CudaView<'s, T>>, | ||
| recvbuff: CudaViewMut<'r, T>, | ||
| root: i32, | ||
| ) -> Result<result::NcclStatus, result::NcclError> { | ||
| debug_assert!(sendbuff.is_some() || self.comm.rank != root as usize); | ||
| let count = recvbuff.len(); | ||
|
chelsea0x3b marked this conversation as resolved.
|
||
| let (src, record_src) = sendbuff.map(|b| b.view_ptr(&self.comm.stream)).unzip(); | ||
| let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream); | ||
| let status = unsafe { | ||
| result::broadcast( | ||
| src.map(|ptr| ptr as _).unwrap_or(std::ptr::null()), | ||
| dst as _, | ||
| count, | ||
| T::as_nccl_type(), | ||
| root, | ||
| self.comm.comm, | ||
| self.comm.stream.cu_stream as _, | ||
| ) | ||
| }?; | ||
| if let Some(record_src) = record_src { | ||
| self.syncs.push(record_src); | ||
| } | ||
| self.syncs.push(record_dst); | ||
| Ok(status) | ||
| } | ||
|
|
||
| /// In place version of [Comm::broadcast()]. | ||
| /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#broadcast) | ||
| pub fn broadcast_in_place<'r: 'g, T: NcclType>( | ||
| &mut self, | ||
| recvbuff: CudaViewMut<'r, T>, | ||
| root: i32, | ||
| ) -> Result<result::NcclStatus, result::NcclError> { | ||
| let count = recvbuff.len(); | ||
| let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream); | ||
| let status = unsafe { | ||
| result::broadcast( | ||
| dst as _, | ||
| dst as _, | ||
| count, | ||
| T::as_nccl_type(), | ||
| root, | ||
| self.comm.comm, | ||
| self.comm.stream.cu_stream as _, | ||
| ) | ||
| }?; | ||
| self.syncs.push(record_dst); | ||
| Ok(status) | ||
| } | ||
|
|
||
| /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather) | ||
| pub fn all_gather<'s: 'g, 'r: 'g, T: NcclType>( | ||
| &mut self, | ||
| sendbuff: CudaView<'s, T>, | ||
| recvbuff: CudaViewMut<'r, T>, | ||
| ) -> Result<result::NcclStatus, result::NcclError> { | ||
| let sendcount = sendbuff.len(); | ||
|
chelsea0x3b marked this conversation as resolved.
|
||
| let (src, record_src) = sendbuff.view_ptr(&self.comm.stream); | ||
| let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream); | ||
| let status = unsafe { | ||
| result::all_gather( | ||
| src as _, | ||
| dst as _, | ||
| sendcount, | ||
| T::as_nccl_type(), | ||
| self.comm.comm, | ||
| self.comm.stream.cu_stream as _, | ||
| ) | ||
| }?; | ||
| self.syncs.push(record_src); | ||
| self.syncs.push(record_dst); | ||
| Ok(status) | ||
| } | ||
|
|
||
| /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce) | ||
| pub fn all_reduce<'s: 'g, 'r: 'g, T: NcclType>( | ||
| &mut self, | ||
| sendbuff: CudaView<'s, T>, | ||
| recvbuff: CudaViewMut<'r, T>, | ||
| reduce_op: &ReduceOp, | ||
| ) -> Result<result::NcclStatus, result::NcclError> { | ||
| let count = sendbuff.len(); | ||
|
chelsea0x3b marked this conversation as resolved.
|
||
| let (src, record_src) = sendbuff.view_ptr(&self.comm.stream); | ||
| let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream); | ||
| let status = unsafe { | ||
| result::all_reduce( | ||
| src as _, | ||
| dst as _, | ||
| count, | ||
| T::as_nccl_type(), | ||
| convert_to_nccl_reduce_op(reduce_op), | ||
| self.comm.comm, | ||
| self.comm.stream.cu_stream as _, | ||
| ) | ||
| }?; | ||
| self.syncs.push(record_src); | ||
| self.syncs.push(record_dst); | ||
| Ok(status) | ||
| } | ||
|
|
||
| /// In place version of [Comm::all_reduce()]. | ||
| /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce) | ||
| pub fn all_reduce_in_place<'r: 'g, T: NcclType>( | ||
| &mut self, | ||
| buff: CudaViewMut<'r, T>, | ||
| reduce_op: &ReduceOp, | ||
| ) -> Result<result::NcclStatus, result::NcclError> { | ||
| let count = buff.len(); | ||
| let (dst, record_dst) = buff.view_ptr_mut(&self.comm.stream); | ||
| let status = unsafe { | ||
| result::all_reduce( | ||
| dst as _, | ||
| dst as _, | ||
| count, | ||
| T::as_nccl_type(), | ||
| convert_to_nccl_reduce_op(reduce_op), | ||
| self.comm.comm, | ||
| self.comm.stream.cu_stream as _, | ||
| ) | ||
| }?; | ||
| self.syncs.push(record_dst); | ||
| Ok(status) | ||
| } | ||
|
|
||
| /// Reduces the sendbuff from all ranks into the recvbuff on the | ||
| /// `root` rank. | ||
| /// | ||
| /// recvbuff must be Some on the root rank! | ||
| /// | ||
| /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reduce) | ||
| pub fn reduce<'s: 'g, 'r: 'g, T: NcclType>( | ||
| &mut self, | ||
| sendbuff: CudaView<'s, T>, | ||
| recvbuff: Option<CudaViewMut<'r, T>>, | ||
| reduce_op: &ReduceOp, | ||
| root: i32, | ||
| ) -> Result<result::NcclStatus, result::NcclError> { | ||
| debug_assert!(recvbuff.is_some() || self.comm.rank != root as usize); | ||
| let count = sendbuff.len(); | ||
|
chelsea0x3b marked this conversation as resolved.
|
||
| let (src, record_src) = sendbuff.view_ptr(&self.comm.stream); | ||
| let (dst, record_dst) = recvbuff.map(|b| b.view_ptr_mut(&self.comm.stream)).unzip(); | ||
| let status = unsafe { | ||
| result::reduce( | ||
| src as _, | ||
| dst.map(|ptr| ptr as _).unwrap_or(std::ptr::null_mut()), | ||
| count, | ||
| T::as_nccl_type(), | ||
| convert_to_nccl_reduce_op(reduce_op), | ||
| root, | ||
| self.comm.comm, | ||
| self.comm.stream.cu_stream as _, | ||
| ) | ||
| }?; | ||
| self.syncs.push(record_src); | ||
| if let Some(record_dst) = record_dst { | ||
| self.syncs.push(record_dst); | ||
| } | ||
| Ok(status) | ||
| } | ||
|
|
||
| /// In place version of [Comm::reduce()]. | ||
| /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reduce) | ||
| pub fn reduce_in_place<'s: 'g, T: NcclType>( | ||
| &mut self, | ||
| recvbuff: CudaViewMut<'s, T>, | ||
| reduce_op: &ReduceOp, | ||
| root: i32, | ||
| ) -> Result<result::NcclStatus, result::NcclError> { | ||
| let count = recvbuff.len(); | ||
| let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream); | ||
| let status = unsafe { | ||
| result::reduce( | ||
| dst as _, | ||
| dst as _, | ||
| count, | ||
| T::as_nccl_type(), | ||
| convert_to_nccl_reduce_op(reduce_op), | ||
| root, | ||
| self.comm.comm, | ||
| self.comm.stream.cu_stream as _, | ||
| ) | ||
| }?; | ||
| self.syncs.push(record_dst); | ||
| Ok(status) | ||
| } | ||
|
|
||
| /// See [nccl docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter) | ||
| pub fn reduce_scatter<'s: 'g, 'r: 'g, T: NcclType>( | ||
| &mut self, | ||
| sendbuff: CudaView<'s, T>, | ||
| recvbuff: CudaViewMut<'r, T>, | ||
| reduce_op: &ReduceOp, | ||
| ) -> Result<result::NcclStatus, result::NcclError> { | ||
| let count = recvbuff.len(); | ||
|
chelsea0x3b marked this conversation as resolved.
|
||
| let (src, record_src) = sendbuff.view_ptr(&self.comm.stream); | ||
| let (dst, record_dst) = recvbuff.view_ptr_mut(&self.comm.stream); | ||
| let status = unsafe { | ||
| result::reduce_scatter( | ||
| src as _, | ||
| dst as _, | ||
| count, | ||
| T::as_nccl_type(), | ||
| convert_to_nccl_reduce_op(reduce_op), | ||
| self.comm.comm, | ||
| self.comm.stream.cu_stream as _, | ||
| ) | ||
| }?; | ||
| self.syncs.push(record_src); | ||
| self.syncs.push(record_dst); | ||
| Ok(status) | ||
| } | ||
| } | ||
|
|
||
| #[macro_export] | ||
| macro_rules! group { | ||
| ($x:block) => { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.