-
Notifications
You must be signed in to change notification settings - Fork 128
Expand file tree
/
Copy pathalltoallv_onesided.c
More file actions
131 lines (118 loc) · 5.25 KB
/
alltoallv_onesided.c
File metadata and controls
131 lines (118 loc) · 5.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
/**
* Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
#include "config.h"
#include "tl_ucp.h"
#include "alltoallv.h"
#include "core/ucc_progress_queue.h"
#include "utils/ucc_math.h"
#include "tl_ucp_sendrecv.h"
void ucc_tl_ucp_alltoallv_onesided_progress(ucc_coll_task_t *ctask);
ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ptrdiff_t src = (ptrdiff_t)TASK_ARGS(task).src.info_v.buffer;
ptrdiff_t dest = (ptrdiff_t)TASK_ARGS(task).dst.info_v.buffer;
ucc_memory_type_t mtype = TASK_ARGS(task).src.info_v.mem_type;
ucc_rank_t grank = UCC_TL_TEAM_RANK(team);
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
long *pSync = TASK_ARGS(task).global_work_buffer;
ucc_aint_t *s_disp = TASK_ARGS(task).src.info_v.displacements;
ucc_aint_t *d_disp = TASK_ARGS(task).dst.info_v.displacements;
size_t sdt_size = ucc_dt_size(TASK_ARGS(task).src.info_v.datatype);
size_t rdt_size = ucc_dt_size(TASK_ARGS(task).dst.info_v.datatype);
ucc_mem_map_mem_h src_memh = TASK_ARGS(task).src_memh.local_memh;
ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).dst_memh.global_memh;
ucc_rank_t peer;
size_t sd_disp, dd_disp, data_size;
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
if (TASK_ARGS(task).mask & UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH) {
if (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL) {
src_memh = TASK_ARGS(task).src_memh.global_memh[grank];
}
}
/* perform a put to each member peer using the peer's index in the
* destination displacement. */
for (peer = (grank + 1) % gsize; task->onesided.put_posted < gsize;
peer = (peer + 1) % gsize) {
sd_disp =
ucc_coll_args_get_displacement(&TASK_ARGS(task), s_disp, peer) *
sdt_size;
dd_disp =
ucc_coll_args_get_displacement(&TASK_ARGS(task), d_disp, peer) *
rdt_size;
data_size =
ucc_coll_args_get_count(&TASK_ARGS(task),
TASK_ARGS(task).src.info_v.counts, peer) *
sdt_size;
UCPCHECK_GOTO(ucc_tl_ucp_put_nb(PTR_OFFSET(src, sd_disp),
PTR_OFFSET(dest, dd_disp),
data_size, mtype, peer, src_memh,
dst_memh, team, task),
task, out);
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer,
dst_memh, team),
task, out);
}
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
out:
return task->super.status;
}
void ucc_tl_ucp_alltoallv_onesided_progress(ucc_coll_task_t *ctask)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
long *pSync = TASK_ARGS(task).global_work_buffer;
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
if (ucc_tl_ucp_test_onesided(task, gsize) == UCC_INPROGRESS) {
return;
}
pSync[0] = 0;
task->super.status = UCC_OK;
}
ucc_status_t ucc_tl_ucp_alltoallv_onesided_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_status_t status = UCC_OK;
ucc_tl_ucp_task_t *task;
ALLTOALLV_TASK_CHECK(coll_args->args, tl_team);
if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER)) {
tl_error(UCC_TL_TEAM_LIB(tl_team),
"global work buffer not provided nor associated with team");
status = UCC_ERR_NOT_SUPPORTED;
goto out;
}
if (coll_args->args.mask & UCC_COLL_ARGS_FIELD_FLAGS) {
if (!(coll_args->args.flags & UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS)) {
tl_error(UCC_TL_TEAM_LIB(tl_team),
"non memory mapped buffers are not supported");
status = UCC_ERR_NOT_SUPPORTED;
goto out;
}
}
if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH)) {
coll_args->args.src_memh.global_memh = NULL;
}
if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_DST_MEMH)) {
coll_args->args.dst_memh.global_memh = NULL;
} else {
if (!(coll_args->args.flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL)) {
tl_error(UCC_TL_TEAM_LIB(tl_team),
"onesided alltoallv requires global memory handles for dst "
"buffers");
status = UCC_ERR_INVALID_PARAM;
goto out;
}
}
task = ucc_tl_ucp_init_task(coll_args, team);
*task_h = &task->super;
task->super.post = ucc_tl_ucp_alltoallv_onesided_start;
task->super.progress = ucc_tl_ucp_alltoallv_onesided_progress;
out:
return status;
}