@@ -60,6 +60,14 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
6060 ucc_status_t status ;
6161 size_t extra_count ;
6262
63+ uint32_t USE_CUDA = UCC_TL_UCP_TEAM_LIB (team )-> cfg .allgather_use_cuda ;
64+ if (!USE_CUDA ){
65+ if (UCC_INPROGRESS == ucc_tl_ucp_test (task )){
66+ // should I use ucc_tl_ucp_test_with_etasks ?
67+ return ;
68+ }
69+ }
70+
6371 EXEC_TASK_TEST (UCC_KN_PHASE_INIT , "failed during ee task test" ,
6472 task -> allgather_kn .etask );
6573 task -> allgather_kn .etask = NULL ;
@@ -210,23 +218,29 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)
210218 & task -> allgather_kn .p );
211219 offset = ucc_buffer_block_offset (args -> dst .info .count , size , rank ) *
212220 ucc_dt_size (args -> dst .info .datatype );
213- if (!UCC_IS_INPLACE (* args )) {
214- status = ucc_coll_task_get_executor (& task -> super , & exec );
215- if (ucc_unlikely (status != UCC_OK )) {
216- task -> super .status = status ;
217- return status ;
218- }
219- eargs .task_type = UCC_EE_EXECUTOR_TASK_COPY ;
220- eargs .copy .dst = PTR_OFFSET (args -> dst .info .buffer , offset );
221- eargs .copy .src = args -> src .info .buffer ;
222- eargs .copy .len = args -> src .info .count *
223- ucc_dt_size (args -> src .info .datatype );
224- status = ucc_ee_executor_task_post (exec , & eargs ,
225- & task -> allgather_kn .etask );
226- if (ucc_unlikely (status != UCC_OK )) {
227- task -> super .status = status ;
228- return status ;
229- }
221+ if (USE_CUDA ){
222+ status = ucc_coll_task_get_executor (& task -> super , & exec );
223+ if (ucc_unlikely (status != UCC_OK )) {
224+ task -> super .status = status ;
225+ return status ;
226+ }
227+ eargs .task_type = UCC_EE_EXECUTOR_TASK_COPY ;
228+ eargs .copy .dst = PTR_OFFSET (args -> dst .info .buffer , offset );
229+ eargs .copy .src = args -> src .info .buffer ;
230+ eargs .copy .len = args -> src .info .count *
231+ ucc_dt_size (args -> src .info .datatype );
232+ status = ucc_ee_executor_task_post (exec , & eargs ,
233+ & task -> allgather_kn .etask );
234+ if (ucc_unlikely (status != UCC_OK )) {
235+ task -> super .status = status ;
236+ return status ;
237+ }
238+ } else {
239+ /*Loopback*/
240+ UCPCHECK_GOTO (ucc_tl_ucp_send_nb (args -> src .info .buffer , args -> src .info .count * ucc_dt_size (args -> src .info .datatype ),
241+ args -> src .info .mem_type , rank , team , task ),task , out );
242+ UCPCHECK_GOTO (ucc_tl_ucp_recv_nb (PTR_OFFSET (args -> dst .info .buffer , offset ), args -> src .info .count * ucc_dt_size (args -> src .info .datatype ),
243+ args -> dst .info .mem_type , rank , team , task ),task , out );
230244 }
231245 } else {
232246 ucc_kn_agx_pattern_init (size , rank , radix , args -> dst .info .count ,
@@ -430,6 +444,8 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_init_r(
430444 task -> super .post = ucc_tl_ucp_allgather_knomial_start ;
431445 task -> super .progress = ucc_tl_ucp_allgather_knomial_progress ;
432446 task -> super .finalize = ucc_tl_ucp_allgather_knomial_finalize ;
447+ # trigger_post
448+ # trigger_progress
433449 status = register_memory (& task -> super );
434450 if (status < 0 ){
435451 tl_error (UCC_TASK_LIB (task ),
0 commit comments