@@ -151,7 +151,7 @@ ucc_status_t ucc_tl_nccl_collective_sync(ucc_tl_nccl_task_t *task,
151151 ucc_status_t status = UCC_OK ;
152152 CUresult cu_status ;
153153
154- task -> host_status = task -> super .super . status ;
154+ task -> host_status = task -> super .status ;
155155 if (ctx -> cfg .sync_type == UCC_TL_NCCL_COMPLETION_SYNC_TYPE_EVENT ) {
156156 status = ucc_mc_ee_event_post (stream , task -> completed ,
157157 UCC_EE_CUDA_STREAM );
@@ -166,14 +166,8 @@ ucc_status_t ucc_tl_nccl_collective_sync(ucc_tl_nccl_task_t *task,
166166 }
167167 }
168168
169- status = task -> super .progress (& task -> super );
170- if (status == UCC_INPROGRESS ) {
171- ucc_progress_enqueue (UCC_TL_CORE_CTX (TASK_TEAM (task ))-> pq ,
172- & task -> super );
173- return UCC_OK ;
174- }
175-
176- return ucc_task_complete (& task -> super );
169+ return ucc_progress_queue_enqueue (UCC_TL_CORE_CTX (TASK_TEAM (task ))-> pq ,
170+ & task -> super );
177171}
178172
179173ucc_status_t ucc_tl_nccl_alltoall_start (ucc_coll_task_t * coll_task )
@@ -190,13 +184,13 @@ ucc_status_t ucc_tl_nccl_alltoall_start(ucc_coll_task_t *coll_task)
190184 size_t data_size ;
191185 ucc_rank_t peer ;
192186
193- task -> super .super . status = UCC_INPROGRESS ;
194- data_size = (size_t )(args -> src .info .count / gsize ) *
187+ task -> super .status = UCC_INPROGRESS ;
188+ data_size = (size_t )(args -> src .info .count / gsize ) *
195189 ucc_dt_size (args -> src .info .datatype );
196190 ucc_assert (args -> src .info .count % gsize == 0 );
197191 if (data_size == 0 ) {
198- task -> super .super . status = UCC_OK ;
199- return UCC_OK ;
192+ task -> super .status = UCC_OK ;
193+ return ucc_task_complete ( & task -> super ) ;
200194 }
201195 UCC_TL_NCCL_PROFILE_REQUEST_EVENT (coll_task , "nccl_alltoall_start" , 0 );
202196 NCCLCHECK_GOTO (ncclGroupStart (), exit_coll , status , UCC_TL_TEAM_LIB (team ));
@@ -242,9 +236,9 @@ ucc_status_t ucc_tl_nccl_alltoallv_start(ucc_coll_task_t *coll_task)
242236 size_t sdt_size , rdt_size , count , displ ;
243237 ucc_rank_t peer ;
244238
245- task -> super .super . status = UCC_INPROGRESS ;
246- sdt_size = ucc_dt_size (args -> src .info_v .datatype );
247- rdt_size = ucc_dt_size (args -> dst .info_v .datatype );
239+ task -> super .status = UCC_INPROGRESS ;
240+ sdt_size = ucc_dt_size (args -> src .info_v .datatype );
241+ rdt_size = ucc_dt_size (args -> dst .info_v .datatype );
248242 UCC_TL_NCCL_PROFILE_REQUEST_EVENT (coll_task , "nccl_alltoallv_start" , 0 );
249243 NCCLCHECK_GOTO (ncclGroupStart (), exit_coll , status , UCC_TL_TEAM_LIB (team ));
250244 for (peer = 0 ; peer < UCC_TL_TEAM_SIZE (team ); peer ++ ) {
@@ -304,7 +298,7 @@ ucc_status_t ucc_tl_nccl_allreduce_start(ucc_coll_task_t *coll_task)
304298 ncclDataType_t dt ;
305299
306300 dt = ucc_to_nccl_dtype [UCC_DT_PREDEFINED_ID (args -> dst .info .datatype )];
307- task -> super .super . status = UCC_INPROGRESS ;
301+ task -> super .status = UCC_INPROGRESS ;
308302 UCC_TL_NCCL_PROFILE_REQUEST_EVENT (coll_task ,
309303 args -> coll_type == UCC_COLL_TYPE_BARRIER
310304 ? "nccl_barrier_start"
@@ -356,7 +350,7 @@ ucc_status_t ucc_tl_nccl_allgather_start(ucc_coll_task_t *coll_task)
356350 src = (void * )((ptrdiff_t )args -> dst .info .buffer + (count / size ) *
357351 ucc_dt_size (args -> dst .info .datatype ) * rank );
358352 }
359- task -> super .super . status = UCC_INPROGRESS ;
353+ task -> super .status = UCC_INPROGRESS ;
360354 UCC_TL_NCCL_PROFILE_REQUEST_EVENT (coll_task , "nccl_allgather_start" , 0 );
361355 NCCLCHECK_GOTO (ncclAllGather (src , dst , count / size , dt ,
362356 team -> nccl_comm , stream ),
@@ -411,7 +405,7 @@ ucc_status_t ucc_tl_nccl_bcast_start(ucc_coll_task_t *coll_task)
411405 ncclDataType_t dt ;
412406
413407 dt = ucc_to_nccl_dtype [UCC_DT_PREDEFINED_ID (args -> src .info .datatype )];
414- task -> super .super . status = UCC_INPROGRESS ;
408+ task -> super .status = UCC_INPROGRESS ;
415409 UCC_TL_NCCL_PROFILE_REQUEST_EVENT (coll_task , "nccl_bcast_start" , 0 );
416410 NCCLCHECK_GOTO (ncclBroadcast (src , src , count , dt , root , team -> nccl_comm ,
417411 stream ),
@@ -449,7 +443,7 @@ ucc_status_t ucc_tl_nccl_reduce_scatter_start(ucc_coll_task_t *coll_task)
449443 ncclDataType_t dt ;
450444
451445 dt = ucc_to_nccl_dtype [UCC_DT_PREDEFINED_ID (args -> dst .info .datatype )];
452- task -> super .super . status = UCC_INPROGRESS ;
446+ task -> super .status = UCC_INPROGRESS ;
453447 UCC_TL_NCCL_PROFILE_REQUEST_EVENT (coll_task , "nccl_reduce_scatter_start" , 0 );
454448 if (UCC_IS_INPLACE (* args )) {
455449 count /= UCC_TL_TEAM_SIZE (team );
@@ -507,7 +501,7 @@ ucc_status_t ucc_tl_nccl_reduce_start(ucc_coll_task_t *coll_task)
507501 }
508502 }
509503 nccl_dt = ucc_to_nccl_dtype [UCC_DT_PREDEFINED_ID (ucc_dt )];
510- task -> super .super . status = UCC_INPROGRESS ;
504+ task -> super .status = UCC_INPROGRESS ;
511505 NCCLCHECK_GOTO (ncclReduce (src , dst , count , nccl_dt , op , args -> root ,
512506 team -> nccl_comm , stream ),
513507 exit_coll , status , UCC_TL_TEAM_LIB (team ));
0 commit comments