@@ -17,25 +17,20 @@ namespace nccl {
1717
1818namespace KC = KokkosComm;
1919
20- template <KokkosExecutionSpace ExecSpace, KokkosView SendView, KokkosView RecvView>
21- auto broadcast (const ExecSpace &space, const SendView &sv, const RecvView &rv, int root, ncclComm_t comm) -> Req<Nccl> {
22- using ST = typename SendView::non_const_value_type;
23- using RT = typename RecvView::non_const_value_type;
24- static_assert (std::is_same_v<ST, RT>,
25- " KokkosComm::Experimental::nccl::broadcast: View value types must be identical" );
26- static_assert (KC::rank<SendView>() <= 1 and KC::rank<RecvView>() <= 1 ,
20+ template <KokkosView View>
21+ auto broadcast (const Kokkos::Cuda& space, View& v, int root, ncclComm_t comm) -> Req<Nccl> {
22+ using T = typename View::non_const_value_type;
23+ static_assert (KC::rank<View>() <= 1 ,
2724 " KokkosComm::Experimental::nccl::broadcast: Views with rank higher than 1 are not supported" );
2825 Kokkos::Tools::pushRegion (" KokkosComm::Experimental::nccl::broadcast" );
2926
3027 Req<Nccl> req{space.cuda_stream ()};
31- if (KC::is_contiguous (sv) and KC::is_contiguous (rv)) {
32- ncclBroadcast (KC::data_handle (sv), KC::data_handle (rv), KC::span (sv), Impl::datatype_v<ST>, root, comm,
33- space.cuda_stream ());
28+ if (KC::is_contiguous (v)) {
29+ ncclBcast (KC::data_handle (v), KC::span (v), Impl::datatype_v<T>, root, comm, space.cuda_stream ());
3430 } else {
3531 Kokkos::abort (" KokkosComm::Experimental::nccl::broadcast: unimplemented for non-contiguous views" );
3632 }
37- req.extend_view_lifetime (sv);
38- req.extend_view_lifetime (rv);
33+ req.extend_view_lifetime (v);
3934
4035 Kokkos::Tools::popRegion ();
4136 return req;
@@ -44,10 +39,10 @@ auto broadcast(const ExecSpace &space, const SendView &sv, const RecvView &rv, i
4439} // namespace nccl
4540namespace Impl {
4641
47- template <KokkosView SendView, KokkosView RecvView >
48- struct Broadcast <SendView, RecvView , Kokkos::Cuda, Nccl> {
49- static auto execute (Handle<Kokkos::Cuda, Nccl> & h, const SendView sv, RecvView rv , int root) -> Req<Nccl> {
50- return nccl::broadcast (h.space (), sv, rv , root, h.comm ());
42+ template <KokkosView View >
43+ struct Broadcast <View , Kokkos::Cuda, Nccl> {
44+ static auto execute (Handle<Kokkos::Cuda, Nccl>& h, View v , int root) -> Req<Nccl> {
45+ return nccl::broadcast (h.space (), v , root, h.comm ());
5146 }
5247};
5348
0 commit comments