From b5282de7e2f64e75a118193c2f8369ec6ebb60e9 Mon Sep 17 00:00:00 2001 From: Sergey Lebedev Date: Wed, 14 Jan 2026 08:46:10 +0100 Subject: [PATCH] TL/UCP: use memcpy instead of sendrecv in agv --- .../tl/ucp/allgatherv/allgatherv_ring.c | 89 ++++++++++--------- 1 file changed, 48 insertions(+), 41 deletions(-) diff --git a/src/components/tl/ucp/allgatherv/allgatherv_ring.c b/src/components/tl/ucp/allgatherv/allgatherv_ring.c index 77e26354966..b26ec292367 100644 --- a/src/components/tl/ucp/allgatherv/allgatherv_ring.c +++ b/src/components/tl/ucp/allgatherv/allgatherv_ring.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -32,33 +32,45 @@ void ucc_tl_ucp_allgatherv_ring_progress(ucc_coll_task_t *coll_task) sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize); recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize); - while (task->tagged.send_posted < tsize) { - send_idx = - ucc_ep_map_eval(task->subset.map, (trank - - task->tagged.send_posted + 1 + - tsize) % tsize); + while (task->tagged.send_posted < tsize - 1) { + send_idx = ucc_ep_map_eval( + task->subset.map, + (trank - task->tagged.send_posted + tsize) % tsize); data_displ = ucc_coll_args_get_displacement( args, args->dst.info_v.displacements, send_idx) * rdt_size; - data_size = - ucc_coll_args_get_count(args, args->dst.info_v.counts, send_idx) * - rdt_size; - UCPCHECK_GOTO(ucc_tl_ucp_send_nb((void *)(rbuf + data_displ), data_size, - rmem, sendto, team, task), - task, out); - recv_idx = - ucc_ep_map_eval(task->subset.map, (trank - - task->tagged.recv_posted + - tsize) % tsize); + data_size = ucc_coll_args_get_count( + args, args->dst.info_v.counts, send_idx) * + rdt_size; + UCPCHECK_GOTO( + ucc_tl_ucp_send_nb( + (void *)(rbuf + data_displ), + data_size, + rmem, + sendto, + team, + task), + task, + out); + recv_idx = ucc_ep_map_eval( + task->subset.map, + (trank - task->tagged.recv_posted - 1 + tsize) % tsize); data_displ = ucc_coll_args_get_displacement( args, args->dst.info_v.displacements, recv_idx) * rdt_size; - data_size = - ucc_coll_args_get_count(args, args->dst.info_v.counts, recv_idx) * - rdt_size; - UCPCHECK_GOTO(ucc_tl_ucp_recv_nb((void *)(rbuf + data_displ), data_size, - rmem, recvfrom, team, task), - task, out); + data_size = ucc_coll_args_get_count( + args, args->dst.info_v.counts, recv_idx) * + rdt_size; + UCPCHECK_GOTO( + ucc_tl_ucp_recv_nb( + (void *)(rbuf + data_displ), + data_size, + rmem, + recvfrom, + team, + task), + task, + out); if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { return; } @@ -66,7 +78,7 @@ void ucc_tl_ucp_allgatherv_ring_progress(ucc_coll_task_t *coll_task) ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); task->super.status = UCC_OK; out: - return; + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgatherv_ring_done", 0); } ucc_status_t ucc_tl_ucp_allgatherv_ring_start(ucc_coll_task_t *coll_task) @@ -80,31 +92,26 @@ ucc_status_t ucc_tl_ucp_allgatherv_ring_start(ucc_coll_task_t *coll_task) ucc_memory_type_t rmem = args->dst.info_v.mem_type; ucc_rank_t grank = UCC_TL_TEAM_RANK(team); size_t data_size, data_displ, rdt_size; + ucc_status_t status; + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgatherv_ring_start", 0); ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); if (!UCC_IS_INPLACE(*args)) { - /* TODO replace local sendrecv with memcpy? */ rdt_size = ucc_dt_size(args->dst.info_v.datatype); - data_displ = ucc_coll_args_get_displacement(args, - args->dst.info_v.displacements, grank) * rdt_size; - data_size = ucc_coll_args_get_count(args, - args->dst.info_v.counts, grank) * rdt_size; - UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(PTR_OFFSET(rbuf, data_displ), data_size, - rmem, grank, team, task), - task, error); - UCPCHECK_GOTO(ucc_tl_ucp_send_nb(sbuf, data_size, smem, grank, team, task), - task, error); - } else { - /* to simplify progress fucnction and make it identical for - in-place and non in-place */ - task->tagged.send_posted = task->tagged.send_completed = 1; - task->tagged.recv_posted = task->tagged.recv_completed = 1; + data_displ = ucc_coll_args_get_displacement( + args, args->dst.info_v.displacements, grank) * + rdt_size; + data_size = ucc_coll_args_get_count( + args, args->dst.info_v.counts, grank) * + rdt_size; + status = ucc_mc_memcpy( + PTR_OFFSET(rbuf, data_displ), sbuf, data_size, rmem, smem); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } } - return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); -error: - return task->super.status; } ucc_status_t ucc_tl_ucp_allgatherv_ring_init_common(ucc_tl_ucp_task_t *task)