11/**
2- * Copyright (c) 2021-2024 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+ * Copyright (c) 2021-2026 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33 *
44 * See file LICENSE for terms.
55 */
@@ -32,41 +32,53 @@ void ucc_tl_ucp_allgatherv_ring_progress(ucc_coll_task_t *coll_task)
3232 sendto = ucc_ep_map_eval (task -> subset .map , (trank + 1 ) % tsize );
3333 recvfrom = ucc_ep_map_eval (task -> subset .map , (trank - 1 + tsize ) % tsize );
3434
35- while (task -> tagged .send_posted < tsize ) {
36- send_idx =
37- ucc_ep_map_eval (task -> subset .map , (trank -
38- task -> tagged .send_posted + 1 +
39- tsize ) % tsize );
35+ while (task -> tagged .send_posted < tsize - 1 ) {
36+ send_idx = ucc_ep_map_eval (
37+ task -> subset .map ,
38+ (trank - task -> tagged .send_posted + tsize ) % tsize );
4039 data_displ = ucc_coll_args_get_displacement (
4140 args , args -> dst .info_v .displacements , send_idx ) *
4241 rdt_size ;
43- data_size =
44- ucc_coll_args_get_count (args , args -> dst .info_v .counts , send_idx ) *
45- rdt_size ;
46- UCPCHECK_GOTO (ucc_tl_ucp_send_nb ((void * )(rbuf + data_displ ), data_size ,
47- rmem , sendto , team , task ),
48- task , out );
49- recv_idx =
50- ucc_ep_map_eval (task -> subset .map , (trank -
51- task -> tagged .recv_posted +
52- tsize ) % tsize );
42+ data_size = ucc_coll_args_get_count (
43+ args , args -> dst .info_v .counts , send_idx ) *
44+ rdt_size ;
45+ UCPCHECK_GOTO (
46+ ucc_tl_ucp_send_nb (
47+ (void * )(rbuf + data_displ ),
48+ data_size ,
49+ rmem ,
50+ sendto ,
51+ team ,
52+ task ),
53+ task ,
54+ out );
55+ recv_idx = ucc_ep_map_eval (
56+ task -> subset .map ,
57+ (trank - task -> tagged .recv_posted - 1 + tsize ) % tsize );
5358 data_displ = ucc_coll_args_get_displacement (
5459 args , args -> dst .info_v .displacements , recv_idx ) *
5560 rdt_size ;
56- data_size =
57- ucc_coll_args_get_count (args , args -> dst .info_v .counts , recv_idx ) *
58- rdt_size ;
59- UCPCHECK_GOTO (ucc_tl_ucp_recv_nb ((void * )(rbuf + data_displ ), data_size ,
60- rmem , recvfrom , team , task ),
61- task , out );
61+ data_size = ucc_coll_args_get_count (
62+ args , args -> dst .info_v .counts , recv_idx ) *
63+ rdt_size ;
64+ UCPCHECK_GOTO (
65+ ucc_tl_ucp_recv_nb (
66+ (void * )(rbuf + data_displ ),
67+ data_size ,
68+ rmem ,
69+ recvfrom ,
70+ team ,
71+ task ),
72+ task ,
73+ out );
6274 if (UCC_INPROGRESS == ucc_tl_ucp_test (task )) {
6375 return ;
6476 }
6577 }
6678 ucc_assert (UCC_TL_UCP_TASK_P2P_COMPLETE (task ));
6779 task -> super .status = UCC_OK ;
6880out :
69- return ;
81+ UCC_TL_UCP_PROFILE_REQUEST_EVENT ( coll_task , "ucp_allgatherv_ring_done" , 0 ) ;
7082}
7183
7284ucc_status_t ucc_tl_ucp_allgatherv_ring_start (ucc_coll_task_t * coll_task )
@@ -80,31 +92,26 @@ ucc_status_t ucc_tl_ucp_allgatherv_ring_start(ucc_coll_task_t *coll_task)
8092 ucc_memory_type_t rmem = args -> dst .info_v .mem_type ;
8193 ucc_rank_t grank = UCC_TL_TEAM_RANK (team );
8294 size_t data_size , data_displ , rdt_size ;
95+ ucc_status_t status ;
8396
97+ UCC_TL_UCP_PROFILE_REQUEST_EVENT (coll_task , "ucp_allgatherv_ring_start" , 0 );
8498 ucc_tl_ucp_task_reset (task , UCC_INPROGRESS );
8599
86100 if (!UCC_IS_INPLACE (* args )) {
87- /* TODO replace local sendrecv with memcpy? */
88101 rdt_size = ucc_dt_size (args -> dst .info_v .datatype );
89- data_displ = ucc_coll_args_get_displacement (args ,
90- args -> dst .info_v .displacements , grank ) * rdt_size ;
91- data_size = ucc_coll_args_get_count (args ,
92- args -> dst .info_v .counts , grank ) * rdt_size ;
93- UCPCHECK_GOTO (ucc_tl_ucp_recv_nb (PTR_OFFSET (rbuf , data_displ ), data_size ,
94- rmem , grank , team , task ),
95- task , error );
96- UCPCHECK_GOTO (ucc_tl_ucp_send_nb (sbuf , data_size , smem , grank , team , task ),
97- task , error );
98- } else {
99- /* to simplify progress fucnction and make it identical for
100- in-place and non in-place */
101- task -> tagged .send_posted = task -> tagged .send_completed = 1 ;
102- task -> tagged .recv_posted = task -> tagged .recv_completed = 1 ;
102+ data_displ = ucc_coll_args_get_displacement (
103+ args , args -> dst .info_v .displacements , grank ) *
104+ rdt_size ;
105+ data_size = ucc_coll_args_get_count (
106+ args , args -> dst .info_v .counts , grank ) *
107+ rdt_size ;
108+ status = ucc_mc_memcpy (
109+ PTR_OFFSET (rbuf , data_displ ), sbuf , data_size , rmem , smem );
110+ if (ucc_unlikely (UCC_OK != status )) {
111+ return status ;
112+ }
103113 }
104-
105114 return ucc_progress_queue_enqueue (UCC_TL_CORE_CTX (team )-> pq , & task -> super );
106- error :
107- return task -> super .status ;
108115}
109116
110117ucc_status_t ucc_tl_ucp_allgatherv_ring_init_common (ucc_tl_ucp_task_t * task )
0 commit comments