diff --git a/ompi/mca/coll/ucc/coll_ucc_allgather.c b/ompi/mca/coll/ucc/coll_ucc_allgather.c index 776109c8fd2..c80aebb2a2c 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allgather.c +++ b/ompi/mca/coll/ucc/coll_ucc_allgather.c @@ -15,15 +15,20 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == sbuf); int comm_size = ompi_comm_size(ucc_module->comm); - if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount) || + if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) || !ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) { goto fallback; } - ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); + if (!is_inplace) { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + } + if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt || COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", @@ -50,7 +55,7 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t } }; - if (MPI_IN_PLACE == sbuf) { + if (is_inplace) { coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } diff --git a/ompi/mca/coll/ucc/coll_ucc_allgatherv.c b/ompi/mca/coll/ucc/coll_ucc_allgatherv.c index 8652a360eb8..1a3ba27f053 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allgatherv.c +++ b/ompi/mca/coll/ucc/coll_ucc_allgatherv.c @@ -17,10 +17,14 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == sbuf); - ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); + if (!is_inplace) { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + } + if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt || COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", @@ -48,7 +52,7 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t } }; - if (MPI_IN_PLACE == sbuf) { + if (is_inplace) { coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } diff --git a/ompi/mca/coll/ucc/coll_ucc_alltoall.c b/ompi/mca/coll/ucc/coll_ucc_alltoall.c index 6b921033a57..1fce7b1f733 100644 --- a/ompi/mca/coll/ucc/coll_ucc_alltoall.c +++ b/ompi/mca/coll/ucc/coll_ucc_alltoall.c @@ -15,15 +15,20 @@ static inline ucc_status_t mca_coll_ucc_alltoall_init(const void *sbuf, size_t s ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == sbuf); int comm_size = ompi_comm_size(ucc_module->comm); - if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size) || + if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size)) || !ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) { goto fallback; } - ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); + if (!is_inplace) { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + } + if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt || COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", @@ -50,7 +55,7 @@ static inline ucc_status_t mca_coll_ucc_alltoall_init(const void *sbuf, size_t s } }; - if (MPI_IN_PLACE == sbuf) { + if (is_inplace) { coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } diff --git a/ompi/mca/coll/ucc/coll_ucc_alltoallv.c b/ompi/mca/coll/ucc/coll_ucc_alltoallv.c index 06ea98adfe2..53fd0cfa4d7 100644 --- a/ompi/mca/coll/ucc/coll_ucc_alltoallv.c +++ b/ompi/mca/coll/ucc/coll_ucc_alltoallv.c @@ -17,10 +17,14 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, const i ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == sbuf); - ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); + if (!is_inplace) { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + } + if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt || COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", @@ -49,7 +53,7 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, const i } }; - if (MPI_IN_PLACE == sbuf) { + if (is_inplace) { coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } diff --git a/ompi/mca/coll/ucc/coll_ucc_gather.c b/ompi/mca/coll/ucc/coll_ucc_gather.c index f74dc25d8a0..8ede6a58e58 100644 --- a/ompi/mca/coll/ucc/coll_ucc_gather.c +++ b/ompi/mca/coll/ucc/coll_ucc_gather.c @@ -17,27 +17,35 @@ ucc_status_t mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct om ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == sbuf); int comm_rank = ompi_comm_rank(ucc_module->comm); int comm_size = ompi_comm_size(ucc_module->comm); - if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) { - goto fallback; - } - ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); if (comm_rank == root) { - if (!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) { + if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) || + !ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) { goto fallback; } + ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); - if ((COLL_UCC_DT_UNSUPPORTED == ucc_rdt) || - (MPI_IN_PLACE != sbuf && COLL_UCC_DT_UNSUPPORTED == ucc_sdt)) { + if (!is_inplace) { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + } + + if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) || + (COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", - (COLL_UCC_DT_UNSUPPORTED == ucc_rdt) ? - rdtype->super.name : sdtype->super.name); + (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ? + sdtype->super.name : rdtype->super.name); goto fallback; } } else { + if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) { + goto fallback; + } + + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", sdtype->super.name); @@ -64,7 +72,7 @@ ucc_status_t mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct om }, }; - if (MPI_IN_PLACE == sbuf) { + if (is_inplace) { coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } diff --git a/ompi/mca/coll/ucc/coll_ucc_gatherv.c b/ompi/mca/coll/ucc/coll_ucc_gatherv.c index 15023698ffa..13049a76e0f 100644 --- a/ompi/mca/coll/ucc/coll_ucc_gatherv.c +++ b/ompi/mca/coll/ucc/coll_ucc_gatherv.c @@ -17,20 +17,24 @@ static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t sc ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == sbuf); int comm_rank = ompi_comm_rank(ucc_module->comm); - ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); if (comm_rank == root) { ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); - if ((COLL_UCC_DT_UNSUPPORTED == ucc_rdt) || - (MPI_IN_PLACE != sbuf && COLL_UCC_DT_UNSUPPORTED == ucc_sdt)) { + if (!is_inplace) { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + } + if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) || + (COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", - (COLL_UCC_DT_UNSUPPORTED == ucc_rdt) ? - rdtype->super.name : sdtype->super.name); + (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ? + sdtype->super.name : rdtype->super.name); goto fallback; } } else { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", sdtype->super.name); @@ -58,7 +62,7 @@ static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t sc }, }; - if (MPI_IN_PLACE == sbuf) { + if (is_inplace) { coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } diff --git a/ompi/mca/coll/ucc/coll_ucc_scatter.c b/ompi/mca/coll/ucc/coll_ucc_scatter.c index 6154e820b40..548ce290bdf 100644 --- a/ompi/mca/coll/ucc/coll_ucc_scatter.c +++ b/ompi/mca/coll/ucc/coll_ucc_scatter.c @@ -18,21 +18,35 @@ ucc_status_t mca_coll_ucc_scatter_init(const void *sbuf, size_t scount, ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == rbuf); int comm_rank = ompi_comm_rank(ucc_module->comm); int comm_size = ompi_comm_size(ucc_module->comm); - ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); if (comm_rank == root) { + if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(rdtype, rcount)) || + !ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size)) { + goto fallback; + } + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + if (!is_inplace) { + ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); + } + if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) || - (MPI_IN_PLACE != rbuf && COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) { + (COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ? sdtype->super.name : rdtype->super.name); goto fallback; } } else { + if (!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount)) { + goto fallback; + } + + ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); if (COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", rdtype->super.name); @@ -59,7 +73,7 @@ ucc_status_t mca_coll_ucc_scatter_init(const void *sbuf, size_t scount, }, }; - if (MPI_IN_PLACE == rbuf) { + if (is_inplace) { coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } diff --git a/ompi/mca/coll/ucc/coll_ucc_scatterv.c b/ompi/mca/coll/ucc/coll_ucc_scatterv.c index 848508e8bc6..738aa14a953 100644 --- a/ompi/mca/coll/ucc/coll_ucc_scatterv.c +++ b/ompi/mca/coll/ucc/coll_ucc_scatterv.c @@ -18,22 +18,25 @@ ucc_status_t mca_coll_ucc_scatterv_init(const void *sbuf, const int *scounts, ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == rbuf); int comm_rank = ompi_comm_rank(ucc_module->comm); - int comm_size = ompi_comm_size(ucc_module->comm); - ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); if (comm_rank == root) { ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + if (!is_inplace) { + ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); + } + if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) || - (MPI_IN_PLACE != rbuf && COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) { + (COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ? sdtype->super.name : rdtype->super.name); goto fallback; } - } else { + ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); if (COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", rdtype->super.name); @@ -61,7 +64,7 @@ ucc_status_t mca_coll_ucc_scatterv_init(const void *sbuf, const int *scounts, }, }; - if (MPI_IN_PLACE == rbuf) { + if (is_inplace) { coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; }