Skip to content

Commit 89eb485

Browse files
TL/UCP: use memcpy instead of sendrecv in agv
1 parent 09f3e59 commit 89eb485

File tree

1 file changed

+48
-41
lines changed

1 file changed

+48
-41
lines changed

src/components/tl/ucp/allgatherv/allgatherv_ring.c

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
* Copyright (c) 2021-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
*
44
* See file LICENSE for terms.
55
*/
@@ -32,41 +32,53 @@ void ucc_tl_ucp_allgatherv_ring_progress(ucc_coll_task_t *coll_task)
3232
sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize);
3333
recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize);
3434

35-
while (task->tagged.send_posted < tsize) {
36-
send_idx =
37-
ucc_ep_map_eval(task->subset.map, (trank -
38-
task->tagged.send_posted + 1 +
39-
tsize) % tsize);
35+
while (task->tagged.send_posted < tsize - 1) {
36+
send_idx = ucc_ep_map_eval(
37+
task->subset.map,
38+
(trank - task->tagged.send_posted + tsize) % tsize);
4039
data_displ = ucc_coll_args_get_displacement(
4140
args, args->dst.info_v.displacements, send_idx) *
4241
rdt_size;
43-
data_size =
44-
ucc_coll_args_get_count(args, args->dst.info_v.counts, send_idx) *
45-
rdt_size;
46-
UCPCHECK_GOTO(ucc_tl_ucp_send_nb((void *)(rbuf + data_displ), data_size,
47-
rmem, sendto, team, task),
48-
task, out);
49-
recv_idx =
50-
ucc_ep_map_eval(task->subset.map, (trank -
51-
task->tagged.recv_posted +
52-
tsize) % tsize);
42+
data_size = ucc_coll_args_get_count(
43+
args, args->dst.info_v.counts, send_idx) *
44+
rdt_size;
45+
UCPCHECK_GOTO(
46+
ucc_tl_ucp_send_nb(
47+
(void *)(rbuf + data_displ),
48+
data_size,
49+
rmem,
50+
sendto,
51+
team,
52+
task),
53+
task,
54+
out);
55+
recv_idx = ucc_ep_map_eval(
56+
task->subset.map,
57+
(trank - task->tagged.recv_posted - 1 + tsize) % tsize);
5358
data_displ = ucc_coll_args_get_displacement(
5459
args, args->dst.info_v.displacements, recv_idx) *
5560
rdt_size;
56-
data_size =
57-
ucc_coll_args_get_count(args, args->dst.info_v.counts, recv_idx) *
58-
rdt_size;
59-
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb((void *)(rbuf + data_displ), data_size,
60-
rmem, recvfrom, team, task),
61-
task, out);
61+
data_size = ucc_coll_args_get_count(
62+
args, args->dst.info_v.counts, recv_idx) *
63+
rdt_size;
64+
UCPCHECK_GOTO(
65+
ucc_tl_ucp_recv_nb(
66+
(void *)(rbuf + data_displ),
67+
data_size,
68+
rmem,
69+
recvfrom,
70+
team,
71+
task),
72+
task,
73+
out);
6274
if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
6375
return;
6476
}
6577
}
6678
ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task));
6779
task->super.status = UCC_OK;
6880
out:
69-
return;
81+
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgatherv_ring_done", 0);
7082
}
7183

7284
ucc_status_t ucc_tl_ucp_allgatherv_ring_start(ucc_coll_task_t *coll_task)
@@ -80,31 +92,26 @@ ucc_status_t ucc_tl_ucp_allgatherv_ring_start(ucc_coll_task_t *coll_task)
8092
ucc_memory_type_t rmem = args->dst.info_v.mem_type;
8193
ucc_rank_t grank = UCC_TL_TEAM_RANK(team);
8294
size_t data_size, data_displ, rdt_size;
95+
ucc_status_t status;
8396

97+
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgatherv_ring_start", 0);
8498
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
8599

86100
if (!UCC_IS_INPLACE(*args)) {
87-
/* TODO replace local sendrecv with memcpy? */
88101
rdt_size = ucc_dt_size(args->dst.info_v.datatype);
89-
data_displ = ucc_coll_args_get_displacement(args,
90-
args->dst.info_v.displacements, grank) * rdt_size;
91-
data_size = ucc_coll_args_get_count(args,
92-
args->dst.info_v.counts, grank) * rdt_size;
93-
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(PTR_OFFSET(rbuf, data_displ), data_size,
94-
rmem, grank, team, task),
95-
task, error);
96-
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(sbuf, data_size, smem, grank, team, task),
97-
task, error);
98-
} else {
99-
/* to simplify progress fucnction and make it identical for
100-
in-place and non in-place */
101-
task->tagged.send_posted = task->tagged.send_completed = 1;
102-
task->tagged.recv_posted = task->tagged.recv_completed = 1;
102+
data_displ = ucc_coll_args_get_displacement(
103+
args, args->dst.info_v.displacements, grank) *
104+
rdt_size;
105+
data_size = ucc_coll_args_get_count(
106+
args, args->dst.info_v.counts, grank) *
107+
rdt_size;
108+
status = ucc_mc_memcpy(
109+
PTR_OFFSET(rbuf, data_displ), sbuf, data_size, rmem, smem);
110+
if (ucc_unlikely(UCC_OK != status)) {
111+
return status;
112+
}
103113
}
104-
105114
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
106-
error:
107-
return task->super.status;
108115
}
109116

110117
ucc_status_t ucc_tl_ucp_allgatherv_ring_init_common(ucc_tl_ucp_task_t *task)

0 commit comments

Comments
 (0)