Skip to content

Commit ab0eda3

Browse files
authored
Merge pull request #13224 from Sergei-Lebedev/topic/fix_ucc_inplace
coll/ucc: refactor UCC collective operations to handle MPI_IN_PLACE correctly
2 parents f5b0876 + 887e7af commit ab0eda3

9 files changed

+111
-75
lines changed

ompi/mca/coll/ucc/coll_ucc_allgather.c

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,20 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t
1515
ucc_coll_req_h *req,
1616
mca_coll_ucc_req_t *coll_req)
1717
{
18-
ucc_datatype_t ucc_sdt, ucc_rdt;
18+
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
19+
bool is_inplace = (MPI_IN_PLACE == sbuf);
1920
int comm_size = ompi_comm_size(ucc_module->comm);
2021

21-
if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount) ||
22+
if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) ||
2223
!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) {
2324
goto fallback;
2425
}
25-
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
26+
2627
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
28+
if (!is_inplace) {
29+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
30+
}
31+
2732
if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt ||
2833
COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
2934
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
@@ -50,7 +55,7 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t
5055
}
5156
};
5257

53-
if (MPI_IN_PLACE == sbuf) {
58+
if (is_inplace) {
5459
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
5560
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
5661
}

ompi/mca/coll/ucc/coll_ucc_allgatherv.c

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,23 @@
99

1010
#include "coll_ucc_common.h"
1111

12-
static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, int scount,
12+
static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount,
1313
struct ompi_datatype_t *sdtype,
1414
void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
1515
struct ompi_datatype_t *rdtype,
1616
mca_coll_ucc_module_t *ucc_module,
1717
ucc_coll_req_h *req,
1818
mca_coll_ucc_req_t *coll_req)
1919
{
20-
ucc_datatype_t ucc_sdt, ucc_rdt;
20+
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
21+
bool is_inplace = (MPI_IN_PLACE == sbuf);
22+
uint64_t flags = 0;
2123

22-
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
2324
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
25+
if (!is_inplace) {
26+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
27+
}
28+
2429
if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt ||
2530
COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
2631
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
@@ -29,13 +34,13 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, int sc
2934
goto fallback;
3035
}
3136

32-
uint64_t flags = ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0;
33-
flags |= ompi_disp_array_is_64bit(rdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0;
37+
flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) |
38+
(ompi_disp_array_is_64bit(rdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) |
39+
(is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0);
3440

3541
ucc_coll_args_t coll = {
42+
.mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0,
3643
.flags = flags,
37-
.mask = 0,
38-
.flags = 0,
3944
.coll_type = UCC_COLL_TYPE_ALLGATHERV,
4045
.src.info = {
4146
.buffer = (void*)sbuf,
@@ -52,10 +57,6 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, int sc
5257
}
5358
};
5459

55-
if (MPI_IN_PLACE == sbuf) {
56-
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
57-
coll.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE;
58-
}
5960
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
6061
return UCC_OK;
6162
fallback:

ompi/mca/coll/ucc/coll_ucc_alltoall.c

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,20 @@ static inline ucc_status_t mca_coll_ucc_alltoall_init(const void *sbuf, size_t s
1515
ucc_coll_req_h *req,
1616
mca_coll_ucc_req_t *coll_req)
1717
{
18-
ucc_datatype_t ucc_sdt, ucc_rdt;
18+
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
19+
bool is_inplace = (MPI_IN_PLACE == sbuf);
1920
int comm_size = ompi_comm_size(ucc_module->comm);
2021

21-
if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size) ||
22+
if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size)) ||
2223
!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) {
2324
goto fallback;
2425
}
25-
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
26+
2627
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
28+
if (!is_inplace) {
29+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
30+
}
31+
2732
if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt ||
2833
COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
2934
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
@@ -50,7 +55,7 @@ static inline ucc_status_t mca_coll_ucc_alltoall_init(const void *sbuf, size_t s
5055
}
5156
};
5257

53-
if (MPI_IN_PLACE == sbuf) {
58+
if (is_inplace) {
5459
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
5560
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
5661
}

ompi/mca/coll/ucc/coll_ucc_alltoallv.c

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_co
1717
ucc_coll_req_h *req,
1818
mca_coll_ucc_req_t *coll_req)
1919
{
20-
ucc_datatype_t ucc_sdt, ucc_rdt;
20+
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
21+
bool is_inplace = (MPI_IN_PLACE == sbuf);
22+
uint64_t flags = 0;
2123

22-
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
2324
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
25+
if (!is_inplace) {
26+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
27+
}
28+
2429
if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt ||
2530
COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
2631
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
@@ -30,13 +35,13 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_co
3035
}
3136

3237
/* Assumes that send counts/displs and recv counts/displs are both 32-bit or both 64-bit */
33-
uint64_t flags = ompi_count_array_is_64bit(scounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0;
34-
flags |= ompi_disp_array_is_64bit(sdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0;
38+
flags = (ompi_count_array_is_64bit(scounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) |
39+
(ompi_disp_array_is_64bit(sdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) |
40+
(is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0);
3541

3642
ucc_coll_args_t coll = {
43+
.mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0,
3744
.flags = flags,
38-
.mask = 0,
39-
.flags = 0,
4045
.coll_type = UCC_COLL_TYPE_ALLTOALLV,
4146
.src.info_v = {
4247
.buffer = (void*)sbuf,
@@ -54,10 +59,6 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_co
5459
}
5560
};
5661

57-
if (MPI_IN_PLACE == sbuf) {
58-
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
59-
coll.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE;
60-
}
6162
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
6263
return UCC_OK;
6364
fallback:

ompi/mca/coll/ucc/coll_ucc_gather.c

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,35 @@ ucc_status_t mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct om
1717
ucc_coll_req_h *req,
1818
mca_coll_ucc_req_t *coll_req)
1919
{
20-
ucc_datatype_t ucc_sdt, ucc_rdt;
20+
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
21+
bool is_inplace = (MPI_IN_PLACE == sbuf);
2122
int comm_rank = ompi_comm_rank(ucc_module->comm);
2223
int comm_size = ompi_comm_size(ucc_module->comm);
2324

24-
if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) {
25-
goto fallback;
26-
}
27-
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
2825
if (comm_rank == root) {
29-
if (!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) {
26+
if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) ||
27+
!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) {
3028
goto fallback;
3129
}
30+
3231
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
33-
if ((COLL_UCC_DT_UNSUPPORTED == ucc_rdt) ||
34-
(MPI_IN_PLACE != sbuf && COLL_UCC_DT_UNSUPPORTED == ucc_sdt)) {
32+
if (!is_inplace) {
33+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
34+
}
35+
36+
if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ||
37+
(COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) {
3538
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
36-
(COLL_UCC_DT_UNSUPPORTED == ucc_rdt) ?
37-
rdtype->super.name : sdtype->super.name);
39+
(COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ?
40+
sdtype->super.name : rdtype->super.name);
3841
goto fallback;
3942
}
4043
} else {
44+
if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) {
45+
goto fallback;
46+
}
47+
48+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
4149
if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) {
4250
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
4351
sdtype->super.name);
@@ -64,7 +72,7 @@ ucc_status_t mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct om
6472
},
6573
};
6674

67-
if (MPI_IN_PLACE == sbuf) {
75+
if (is_inplace) {
6876
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
6977
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
7078
}

ompi/mca/coll/ucc/coll_ucc_gatherv.c

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,34 +17,39 @@ static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t sc
1717
ucc_coll_req_h *req,
1818
mca_coll_ucc_req_t *coll_req)
1919
{
20-
ucc_datatype_t ucc_sdt, ucc_rdt;
20+
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
21+
bool is_inplace = (MPI_IN_PLACE == sbuf);
2122
int comm_rank = ompi_comm_rank(ucc_module->comm);
23+
uint64_t flags = 0;
2224

23-
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
2425
if (comm_rank == root) {
2526
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
26-
if ((COLL_UCC_DT_UNSUPPORTED == ucc_rdt) ||
27-
(MPI_IN_PLACE != sbuf && COLL_UCC_DT_UNSUPPORTED == ucc_sdt)) {
27+
if (!is_inplace) {
28+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
29+
}
30+
if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ||
31+
(COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) {
2832
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
29-
(COLL_UCC_DT_UNSUPPORTED == ucc_rdt) ?
30-
rdtype->super.name : sdtype->super.name);
33+
(COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ?
34+
sdtype->super.name : rdtype->super.name);
3135
goto fallback;
3236
}
3337
} else {
38+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
3439
if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) {
3540
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
3641
sdtype->super.name);
3742
goto fallback;
3843
}
3944
}
4045

41-
uint64_t flags = ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0;
42-
flags |= ompi_disp_array_is_64bit(disps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0;
46+
flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) |
47+
(ompi_disp_array_is_64bit(disps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) |
48+
(is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0);
4349

4450
ucc_coll_args_t coll = {
51+
.mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0,
4552
.flags = flags,
46-
.mask = 0,
47-
.flags = 0,
4853
.coll_type = UCC_COLL_TYPE_GATHERV,
4954
.root = root,
5055
.src.info = {
@@ -62,10 +67,6 @@ static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t sc
6267
},
6368
};
6469

65-
if (MPI_IN_PLACE == sbuf) {
66-
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
67-
coll.flags |= UCC_COLL_ARGS_FLAG_IN_PLACE;
68-
}
6970
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
7071
return UCC_OK;
7172
fallback:

ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ ucc_status_t mca_coll_ucc_reduce_scatter_init(const void *sbuf, void *rbuf, ompi
2121
size_t total_count;
2222
int i;
2323
int comm_size = ompi_comm_size(ucc_module->comm);
24+
uint64_t flags = 0;
2425

2526
if (MPI_IN_PLACE == sbuf) {
2627
/* TODO: UCC defines inplace differently:
@@ -46,10 +47,11 @@ ucc_status_t mca_coll_ucc_reduce_scatter_init(const void *sbuf, void *rbuf, ompi
4647
total_count += ompi_count_array_get(rcounts, i);
4748
}
4849

50+
flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0);
51+
4952
ucc_coll_args_t coll = {
50-
.flags = ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0,
51-
.mask = 0,
52-
.flags = 0,
53+
.mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0,
54+
.flags = flags,
5355
.coll_type = UCC_COLL_TYPE_REDUCE_SCATTERV,
5456
.src.info = {
5557
.buffer = (void*)sbuf,

ompi/mca/coll/ucc/coll_ucc_scatter.c

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,35 @@ ucc_status_t mca_coll_ucc_scatter_init(const void *sbuf, size_t scount,
1818
ucc_coll_req_h *req,
1919
mca_coll_ucc_req_t *coll_req)
2020
{
21-
ucc_datatype_t ucc_sdt, ucc_rdt;
21+
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
22+
bool is_inplace = (MPI_IN_PLACE == rbuf);
2223
int comm_rank = ompi_comm_rank(ucc_module->comm);
2324
int comm_size = ompi_comm_size(ucc_module->comm);
2425

25-
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
2626
if (comm_rank == root) {
27+
if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(rdtype, rcount)) ||
28+
!ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size)) {
29+
goto fallback;
30+
}
31+
2732
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
33+
if (!is_inplace) {
34+
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
35+
}
36+
2837
if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ||
29-
(MPI_IN_PLACE != rbuf && COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) {
38+
(COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) {
3039
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
3140
(COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ?
3241
sdtype->super.name : rdtype->super.name);
3342
goto fallback;
3443
}
3544
} else {
45+
if (!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount)) {
46+
goto fallback;
47+
}
48+
49+
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
3650
if (COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
3751
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
3852
rdtype->super.name);
@@ -59,7 +73,7 @@ ucc_status_t mca_coll_ucc_scatter_init(const void *sbuf, size_t scount,
5973
},
6074
};
6175

62-
if (MPI_IN_PLACE == rbuf) {
76+
if (is_inplace) {
6377
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
6478
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
6579
}

0 commit comments

Comments
 (0)