Skip to content

Commit 20670f8

Browse files
CORE: add ucc task internal status (#441) (#452)
Co-authored-by: Sergey Lebedev <[email protected]>
1 parent 1e9f7af commit 20670f8

38 files changed

+277
-335
lines changed

src/components/tl/nccl/allgatherv/allgatherv.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task)
6262
size_t sdt_size, rdt_size, count, displ;
6363
ucc_rank_t peer;
6464

65-
task->super.super.status = UCC_INPROGRESS;
66-
sdt_size = ucc_dt_size(args->src.info.datatype);
67-
rdt_size = ucc_dt_size(args->dst.info_v.datatype);
65+
task->super.status = UCC_INPROGRESS;
66+
sdt_size = ucc_dt_size(args->src.info.datatype);
67+
rdt_size = ucc_dt_size(args->dst.info_v.datatype);
6868
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0);
6969
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team));
7070
count = args->src.info.count;
@@ -129,7 +129,7 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcopy_start(ucc_coll_task_t *coll_task)
129129
size_t max_count, rdt_size, sdt_size, displ, scount, rcount;
130130
ucc_rank_t peer;
131131

132-
task->super.super.status = UCC_INPROGRESS;
132+
task->super.status = UCC_INPROGRESS;
133133
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0);
134134
max_count = task->allgatherv_bcopy.max_count;
135135
scount = args->src.info.count;
@@ -236,8 +236,8 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_start(ucc_coll_task_t *coll_task)
236236
size_t rdt_size, count, displ;
237237
ucc_rank_t peer;
238238

239-
task->super.super.status = UCC_INPROGRESS;
240-
rdt_size = ucc_dt_size(args->dst.info_v.datatype);
239+
task->super.status = UCC_INPROGRESS;
240+
rdt_size = ucc_dt_size(args->dst.info_v.datatype);
241241
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0);
242242
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team));
243243
for (peer = 0; peer < size; peer++) {

src/components/tl/nccl/tl_nccl_coll.c

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ ucc_status_t ucc_tl_nccl_collective_sync(ucc_tl_nccl_task_t *task,
151151
ucc_status_t status = UCC_OK;
152152
CUresult cu_status;
153153

154-
task->host_status = task->super.super.status;
154+
task->host_status = task->super.status;
155155
if (ctx->cfg.sync_type == UCC_TL_NCCL_COMPLETION_SYNC_TYPE_EVENT) {
156156
status = ucc_mc_ee_event_post(stream, task->completed,
157157
UCC_EE_CUDA_STREAM);
@@ -166,14 +166,8 @@ ucc_status_t ucc_tl_nccl_collective_sync(ucc_tl_nccl_task_t *task,
166166
}
167167
}
168168

169-
status = task->super.progress(&task->super);
170-
if (status == UCC_INPROGRESS) {
171-
ucc_progress_enqueue(UCC_TL_CORE_CTX(TASK_TEAM(task))->pq,
172-
&task->super);
173-
return UCC_OK;
174-
}
175-
176-
return ucc_task_complete(&task->super);
169+
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(TASK_TEAM(task))->pq,
170+
&task->super);
177171
}
178172

179173
ucc_status_t ucc_tl_nccl_alltoall_start(ucc_coll_task_t *coll_task)
@@ -190,13 +184,13 @@ ucc_status_t ucc_tl_nccl_alltoall_start(ucc_coll_task_t *coll_task)
190184
size_t data_size;
191185
ucc_rank_t peer;
192186

193-
task->super.super.status = UCC_INPROGRESS;
194-
data_size = (size_t)(args->src.info.count / gsize) *
187+
task->super.status = UCC_INPROGRESS;
188+
data_size = (size_t)(args->src.info.count / gsize) *
195189
ucc_dt_size(args->src.info.datatype);
196190
ucc_assert(args->src.info.count % gsize == 0);
197191
if (data_size == 0) {
198-
task->super.super.status = UCC_OK;
199-
return UCC_OK;
192+
task->super.status = UCC_OK;
193+
return ucc_task_complete(&task->super);
200194
}
201195
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_alltoall_start", 0);
202196
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team));
@@ -242,9 +236,9 @@ ucc_status_t ucc_tl_nccl_alltoallv_start(ucc_coll_task_t *coll_task)
242236
size_t sdt_size, rdt_size, count, displ;
243237
ucc_rank_t peer;
244238

245-
task->super.super.status = UCC_INPROGRESS;
246-
sdt_size = ucc_dt_size(args->src.info_v.datatype);
247-
rdt_size = ucc_dt_size(args->dst.info_v.datatype);
239+
task->super.status = UCC_INPROGRESS;
240+
sdt_size = ucc_dt_size(args->src.info_v.datatype);
241+
rdt_size = ucc_dt_size(args->dst.info_v.datatype);
248242
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_alltoallv_start", 0);
249243
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team));
250244
for (peer = 0; peer < UCC_TL_TEAM_SIZE(team); peer++) {
@@ -304,7 +298,7 @@ ucc_status_t ucc_tl_nccl_allreduce_start(ucc_coll_task_t *coll_task)
304298
ncclDataType_t dt;
305299

306300
dt = ucc_to_nccl_dtype[UCC_DT_PREDEFINED_ID(args->dst.info.datatype)];
307-
task->super.super.status = UCC_INPROGRESS;
301+
task->super.status = UCC_INPROGRESS;
308302
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task,
309303
args->coll_type == UCC_COLL_TYPE_BARRIER
310304
? "nccl_barrier_start"
@@ -356,7 +350,7 @@ ucc_status_t ucc_tl_nccl_allgather_start(ucc_coll_task_t *coll_task)
356350
src = (void *)((ptrdiff_t)args->dst.info.buffer + (count / size) *
357351
ucc_dt_size(args->dst.info.datatype) * rank);
358352
}
359-
task->super.super.status = UCC_INPROGRESS;
353+
task->super.status = UCC_INPROGRESS;
360354
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgather_start", 0);
361355
NCCLCHECK_GOTO(ncclAllGather(src, dst, count / size, dt,
362356
team->nccl_comm, stream),
@@ -411,7 +405,7 @@ ucc_status_t ucc_tl_nccl_bcast_start(ucc_coll_task_t *coll_task)
411405
ncclDataType_t dt;
412406

413407
dt = ucc_to_nccl_dtype[UCC_DT_PREDEFINED_ID(args->src.info.datatype)];
414-
task->super.super.status = UCC_INPROGRESS;
408+
task->super.status = UCC_INPROGRESS;
415409
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_bcast_start", 0);
416410
NCCLCHECK_GOTO(ncclBroadcast(src, src, count, dt, root, team->nccl_comm,
417411
stream),
@@ -449,7 +443,7 @@ ucc_status_t ucc_tl_nccl_reduce_scatter_start(ucc_coll_task_t *coll_task)
449443
ncclDataType_t dt;
450444

451445
dt = ucc_to_nccl_dtype[UCC_DT_PREDEFINED_ID(args->dst.info.datatype)];
452-
task->super.super.status = UCC_INPROGRESS;
446+
task->super.status = UCC_INPROGRESS;
453447
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_reduce_scatter_start", 0);
454448
if (UCC_IS_INPLACE(*args)) {
455449
count /= UCC_TL_TEAM_SIZE(team);
@@ -507,7 +501,7 @@ ucc_status_t ucc_tl_nccl_reduce_start(ucc_coll_task_t *coll_task)
507501
}
508502
}
509503
nccl_dt = ucc_to_nccl_dtype[UCC_DT_PREDEFINED_ID(ucc_dt)];
510-
task->super.super.status = UCC_INPROGRESS;
504+
task->super.status = UCC_INPROGRESS;
511505
NCCLCHECK_GOTO(ncclReduce(src, dst, count, nccl_dt, op, args->root,
512506
team->nccl_comm, stream),
513507
exit_coll, status, UCC_TL_TEAM_LIB(team));

src/components/tl/nccl/tl_nccl_context.c

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,33 +10,31 @@
1010
#include "core/ucc_ee.h"
1111

1212

13-
ucc_status_t ucc_tl_nccl_event_collective_progress(ucc_coll_task_t *coll_task)
13+
void ucc_tl_nccl_event_collective_progress(ucc_coll_task_t *coll_task)
1414
{
1515
ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t);
1616
ucc_status_t status;
1717

1818
ucc_assert(task->completed != NULL);
1919
status = ucc_mc_ee_event_test(task->completed, UCC_EE_CUDA_STREAM);
20-
coll_task->super.status = status;
20+
coll_task->status = status;
2121
#ifdef HAVE_PROFILING_TL_NCCL
22-
if (coll_task->super.status == UCC_OK) {
22+
if (coll_task->status == UCC_OK) {
2323
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_coll_done", 0);
2424
}
2525
#endif
26-
return coll_task->super.status;
2726
}
2827

29-
ucc_status_t ucc_tl_nccl_driver_collective_progress(ucc_coll_task_t *coll_task)
28+
void ucc_tl_nccl_driver_collective_progress(ucc_coll_task_t *coll_task)
3029
{
3130
ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t);
3231

33-
coll_task->super.status = task->host_status;
32+
coll_task->status = task->host_status;
3433
#ifdef HAVE_PROFILING_TL_NCCL
35-
if (coll_task->super.status == UCC_OK) {
34+
if (coll_task->status == UCC_OK) {
3635
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_coll_done", 0);
3736
}
3837
#endif
39-
return coll_task->super.status;
4038
}
4139

4240
static void ucc_tl_nccl_req_mpool_obj_init(ucc_mpool_t *mp, void *obj,
@@ -82,7 +80,7 @@ static void ucc_tl_nccl_req_mapped_mpool_obj_init(ucc_mpool_t *mp, void *obj,
8280
st = cudaHostGetDevicePointer((void **)(&req->dev_status),
8381
(void *)&req->host_status, 0);
8482
if (st != cudaSuccess) {
85-
req->super.super.status = UCC_ERR_NO_MESSAGE;
83+
req->super.status = UCC_ERR_NO_MESSAGE;
8684
}
8785
req->super.progress = ucc_tl_nccl_driver_collective_progress;
8886
}

src/components/tl/sharp/tl_sharp_coll.c

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ ucc_tl_sharp_mem_deregister(ucc_tl_sharp_context_t *ctx,
115115
return UCC_OK;
116116
}
117117

118-
ucc_status_t ucc_tl_sharp_collective_progress(ucc_coll_task_t *coll_task)
118+
void ucc_tl_sharp_collective_progress(ucc_coll_task_t *coll_task)
119119
{
120120
ucc_tl_sharp_task_t *task = ucc_derived_of(coll_task, ucc_tl_sharp_task_t);
121121
int completed;
@@ -125,18 +125,18 @@ ucc_status_t ucc_tl_sharp_collective_progress(ucc_coll_task_t *coll_task)
125125
if (completed) {
126126
if (TASK_ARGS(task).coll_type == UCC_COLL_TYPE_ALLREDUCE) {
127127
if (!UCC_IS_INPLACE(TASK_ARGS(task))) {
128-
ucc_tl_sharp_mem_deregister(TASK_CTX(task), task->allreduce.s_mem_h);
128+
ucc_tl_sharp_mem_deregister(TASK_CTX(task),
129+
task->allreduce.s_mem_h);
129130
}
130-
ucc_tl_sharp_mem_deregister(TASK_CTX(task), task->allreduce.r_mem_h);
131+
ucc_tl_sharp_mem_deregister(TASK_CTX(task),
132+
task->allreduce.r_mem_h);
131133
}
132134
sharp_coll_req_free(task->req_handle);
133-
coll_task->super.status = UCC_OK;
135+
coll_task->status = UCC_OK;
134136
UCC_TL_SHARP_PROFILE_REQUEST_EVENT(coll_task,
135137
"sharp_collective_done", 0);
136138
}
137139
}
138-
139-
return coll_task->super.status;
140140
}
141141

142142
ucc_status_t ucc_tl_sharp_barrier_start(ucc_coll_task_t *coll_task)
@@ -145,23 +145,17 @@ ucc_status_t ucc_tl_sharp_barrier_start(ucc_coll_task_t *coll_task)
145145
ucc_tl_sharp_team_t *team = TASK_TEAM(task);
146146
int ret;
147147

148-
task->super.super.status = UCC_INPROGRESS;
149148
UCC_TL_SHARP_PROFILE_REQUEST_EVENT(coll_task, "sharp_barrier_start", 0);
150149

151150
ret = sharp_coll_do_barrier_nb(team->sharp_comm, &task->req_handle);
152151
if (ret != SHARP_COLL_SUCCESS) {
153152
tl_error(UCC_TASK_LIB(task), "sharp_coll_do_barrier_nb failed:%s",
154153
sharp_coll_strerror(ret));
155-
coll_task->super.status = UCC_ERR_NO_RESOURCE;
154+
coll_task->status = UCC_ERR_NO_RESOURCE;
156155
return ucc_task_complete(coll_task);
157156
}
158157

159-
if (UCC_INPROGRESS == ucc_tl_sharp_collective_progress(coll_task)) {
160-
ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
161-
return UCC_OK;
162-
}
163-
164-
return ucc_task_complete(coll_task);
158+
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
165159
}
166160

167161
ucc_status_t ucc_tl_sharp_allreduce_start(ucc_coll_task_t *coll_task)
@@ -177,7 +171,6 @@ ucc_status_t ucc_tl_sharp_allreduce_start(ucc_coll_task_t *coll_task)
177171
size_t data_size;
178172
int ret;
179173

180-
task->super.super.status = UCC_INPROGRESS;
181174
UCC_TL_SHARP_PROFILE_REQUEST_EVENT(coll_task, "sharp_allreduce_start", 0);
182175

183176
sharp_type = ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(dt)];
@@ -217,16 +210,11 @@ ucc_status_t ucc_tl_sharp_allreduce_start(ucc_coll_task_t *coll_task)
217210
if (ret != SHARP_COLL_SUCCESS) {
218211
tl_error(UCC_TASK_LIB(task), "sharp_coll_do_allreduce_nb failed:%s",
219212
sharp_coll_strerror(ret));
220-
coll_task->super.status = UCC_ERR_NO_RESOURCE;
213+
coll_task->status = UCC_ERR_NO_RESOURCE;
221214
return ucc_task_complete(coll_task);
222215
}
223216

224-
if (UCC_INPROGRESS == ucc_tl_sharp_collective_progress(coll_task)) {
225-
ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
226-
return UCC_OK;
227-
}
228-
229-
return ucc_task_complete(coll_task);
217+
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
230218
}
231219

232220
ucc_status_t ucc_tl_sharp_allreduce_init(ucc_tl_sharp_task_t *task)

src/components/tl/ucp/allgather/allgather.c

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
#include "tl_ucp.h"
88
#include "allgather.h"
99

10-
ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *task);
11-
ucc_status_t ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *task);
12-
1310
ucc_status_t ucc_tl_ucp_allgather_init(ucc_tl_ucp_task_t *task)
1411
{
1512
if ((!UCC_DT_IS_PREDEFINED((TASK_ARGS(task)).dst.info.datatype)) ||

src/components/tl/ucp/allgather/allgather.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
#include "../tl_ucp_coll.h"
1010

1111
ucc_status_t ucc_tl_ucp_allgather_init(ucc_tl_ucp_task_t *task);
12-
ucc_status_t ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *task);
12+
13+
void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *task);
14+
1315
ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *task);
1416

1517
/* Uses allgather_kn_radix from config */

src/components/tl/ucp/allgather/allgather_knomial.c

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
task->allgather_kn.phase = _phase; \
2020
} while (0)
2121

22-
ucc_status_t ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
22+
void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
2323
{
2424
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
2525
ucc_coll_args_t *args = &TASK_ARGS(task);
@@ -64,7 +64,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
6464
if (KN_NODE_EXTRA == node_type) {
6565
if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
6666
SAVE_STATE(UCC_KN_PHASE_EXTRA);
67-
return task->super.super.status;
67+
return;
6868
}
6969
goto out;
7070
}
@@ -111,7 +111,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
111111
UCC_KN_PHASE_LOOP:
112112
if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
113113
SAVE_STATE(UCC_KN_PHASE_LOOP);
114-
return task->super.super.status;
114+
return;
115115
}
116116
ucc_knomial_pattern_next_iteration_backward(p);
117117
}
@@ -128,29 +128,28 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
128128
UCC_KN_PHASE_PROXY:
129129
if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
130130
SAVE_STATE(UCC_KN_PHASE_PROXY);
131-
return task->super.super.status;
131+
return;
132132
}
133133

134134
out:
135-
task->super.super.status = UCC_OK;
135+
task->super.status = UCC_OK;
136136
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_kn_done", 0);
137-
return task->super.super.status;
138137
}
139138

140139
ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)
141140
{
142141
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
143142
ucc_coll_args_t *args = &TASK_ARGS(task);
144143
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
145-
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
146-
ucc_rank_t size = UCC_TL_TEAM_SIZE(team);
144+
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
145+
ucc_rank_t size = UCC_TL_TEAM_SIZE(team);
147146
ucc_kn_radix_t radix = task->allgather_kn.p.radix;
148147
ucc_rank_t broot = 0;
149148
ucc_status_t status;
150149
ptrdiff_t offset;
151150

152151
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_kn_start", 0);
153-
ucc_tl_ucp_task_reset(task);
152+
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
154153
if (coll_task->bargs.args.coll_type == UCC_COLL_TYPE_BCAST) {
155154
broot = coll_task->bargs.args.root;
156155
rank = VRANK(rank, broot, size);
@@ -173,13 +172,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)
173172
}
174173
}
175174
task->allgather_kn.sbuf = PTR_OFFSET(args->dst.info.buffer, offset);
176-
177-
status = ucc_tl_ucp_allgather_knomial_progress(&task->super);
178-
if (UCC_INPROGRESS == status) {
179-
ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
180-
return UCC_OK;
181-
}
182-
return ucc_task_complete(coll_task);
175+
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
183176
}
184177

185178
ucc_status_t ucc_tl_ucp_allgather_knomial_init_r(

0 commit comments

Comments
 (0)