diff --git a/ompi/mca/coll/acoll/Makefile.am b/ompi/mca/coll/acoll/Makefile.am index fdbd7edbbd2..705ed5bf479 100644 --- a/ompi/mca/coll/acoll/Makefile.am +++ b/ompi/mca/coll/acoll/Makefile.am @@ -15,6 +15,7 @@ sources = \ coll_acoll_allgather.c \ coll_acoll_bcast.c \ coll_acoll_gather.c \ + coll_acoll_alltoall.c \ coll_acoll_reduce.c \ coll_acoll_allreduce.c \ coll_acoll_barrier.c \ diff --git a/ompi/mca/coll/acoll/coll_acoll.h b/ompi/mca/coll/acoll/coll_acoll.h index 91f9a2475fa..e5ec8a381a8 100644 --- a/ompi/mca/coll/acoll/coll_acoll.h +++ b/ompi/mca/coll/acoll/coll_acoll.h @@ -1,6 +1,6 @@ /* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil -*- */ /* - * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024 - 2025 Advanced Micro Devices, Inc. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -66,6 +66,13 @@ int mca_coll_acoll_gather_intra(const void *sbuf, size_t scount, struct ompi_dat void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); +int mca_coll_acoll_alltoall(const void *sbuf, size_t scount, + struct ompi_datatype_t *sdtype, + void* rbuf, size_t rcount, + struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module); + int mca_coll_acoll_reduce_intra(const void *sbuf, void *rbuf, size_t count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); @@ -80,6 +87,8 @@ int mca_coll_acoll_barrier_intra(struct ompi_communicator_t *comm, mca_coll_base END_C_DECLS #define MCA_COLL_ACOLL_ROOT_CHANGE_THRESH 10 +#define MCA_COLL_ACOLL_SPLIT_FACTOR_LIST_LEN 6 +#define MCA_COLL_ACOLL_SPLIT_FACTOR_LIST {2, 4, 8, 16, 32, 64} typedef enum MCA_COLL_ACOLL_SG_SIZES { MCA_COLL_ACOLL_SG_SIZE_1 = 8, @@ -142,6 +151,18 @@ typedef struct coll_acoll_data { int sync[2]; } coll_acoll_data_t; +/* The enum literals are used as indices into arrays and values are + * assigned to the enum literals so as to ensure it is valid irrespective + * of what the compiler assigns. */ +typedef enum MCA_COLL_ACOLL_R2R_DIST { + DIST_CORE = 0, + DIST_L3CACHE, + DIST_NUMA, + DIST_SOCKET, + DIST_NODE, + DIST_END +} MCA_COLL_ACOLL_R2R_DIST_T; + typedef struct coll_acoll_subcomms { ompi_communicator_t *local_comm; ompi_communicator_t *local_r_comm; @@ -152,6 +173,7 @@ typedef struct coll_acoll_subcomms { ompi_communicator_t *orig_comm; ompi_communicator_t *socket_comm; ompi_communicator_t *socket_ldr_comm; + ompi_communicator_t *split_comm[MCA_COLL_ACOLL_SPLIT_FACTOR_LIST_LEN]; // AllToAll odd even split comm int num_nodes; int derived_node_size; int is_root_node; @@ -170,6 +192,7 @@ typedef struct coll_acoll_subcomms { int initialized; int prev_init_root; int num_root_change; + MCA_COLL_ACOLL_R2R_DIST_T r2r_dist; ompi_communicator_t *numa_comm_ldrs; ompi_communicator_t *node_comm; @@ -193,6 +216,12 @@ typedef struct coll_acoll_reserve_mem { bool reserve_mem_in_use; } coll_acoll_reserve_mem_t; +typedef struct { + int split_factor; + size_t psplit_msg_thresh; + size_t xpmem_msg_thresh; +} coll_acoll_alltoall_attr_t; + struct mca_coll_acoll_module_t { mca_coll_base_module_t super; MCA_COLL_ACOLL_SG_SIZES sg_size; @@ -218,6 +247,7 @@ struct mca_coll_acoll_module_t { coll_acoll_subcomms_t **subc; coll_acoll_reserve_mem_t reserve_mem_s; int num_subc; + coll_acoll_alltoall_attr_t alltoall_attr; }; #ifdef HAVE_XPMEM_H diff --git a/ompi/mca/coll/acoll/coll_acoll_alltoall.c b/ompi/mca/coll/acoll/coll_acoll_alltoall.c new file mode 100644 index 00000000000..ab58a38d503 --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_alltoall.c @@ -0,0 +1,568 @@ +/* + * Copyright (c) 2024 - 2025 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#include "mpi.h" +#include "ompi/constants.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "ompi/mca/coll/base/coll_base_util.h" +#include "coll_acoll.h" +#include "ompi/mca/pml/pml.h" +#include "opal/util/bit_ops.h" +#include "coll_acoll_utils.h" + +static void mca_coll_acoll_get_split_factor_and_base_algo + (size_t scount, struct ompi_datatype_t *sdtype, + size_t rcount, struct ompi_datatype_t *rdtype, + bool is_inplace, + struct ompi_communicator_t *comm, + bool* sync_enable, + int* split_factor) +{ + (*sync_enable) = false; + (*split_factor) = 2; + + size_t dsize = 0; + size_t total_dsize = 0; + + int comm_size = ompi_comm_size(comm); + + if (false == is_inplace) { + ompi_datatype_type_size(sdtype, &dsize); + total_dsize = dsize * (ptrdiff_t)scount; + } else { + ompi_datatype_type_size(rdtype, &dsize); + total_dsize = dsize * (ptrdiff_t)rcount; + } + + if (comm_size <= 8) { + if (total_dsize <= 128) { + (*sync_enable) = true; + } else { + (*sync_enable) = false; + } + (*split_factor) = 2; + } else if (comm_size <= 16) { + if (total_dsize <= 192) { + (*sync_enable) = true; + (*split_factor) = 4; + } else if (total_dsize <= 512) { + (*sync_enable) = false; + (*split_factor) = 4; + } else if (total_dsize <= 4096) { + (*sync_enable) = false; + (*split_factor) = 2; + } else { + (*sync_enable) = true; + (*split_factor) = 2; + } + } else if (comm_size <= 24) { + if (total_dsize <= 64) { + (*sync_enable) = true; + (*split_factor) = 4; + } else if (total_dsize <= 1024) { + (*sync_enable) = false; + (*split_factor) = 4; + } else { + (*sync_enable) = false; + (*split_factor) = 2; + } + } else if (comm_size <= 32) { + if (total_dsize <= 64) { + (*sync_enable) = true; + (*split_factor) = 4; + } else if (total_dsize <= 1024) { + (*sync_enable) = false; + (*split_factor) = 4; + } else if (total_dsize <= 4096) { + (*sync_enable) = false; + (*split_factor) = 2; + } else { + (*sync_enable) = true; + (*split_factor) = 2; + } + } else if (comm_size <= 48) { + if (total_dsize <= 64) { + (*sync_enable) = true; + (*split_factor) = 4; + } else if (total_dsize <= 1024) { + (*sync_enable) = false; + (*split_factor) = 4; + } else { + (*sync_enable) = false; + (*split_factor) = 2; + } + } else if (comm_size <= 64) { + if (total_dsize <= 64) { + (*sync_enable) = true; + (*split_factor) = 4; + } else if (total_dsize <= 1024) { + (*sync_enable) = false; + (*split_factor) = 4; + } else { + (*sync_enable) = false; + (*split_factor) = 2; + } + } else if (comm_size <= 72) { + if (total_dsize <= 64) { + (*sync_enable) = true; + (*split_factor) = 4; + } else if (total_dsize <= 1024) { + (*sync_enable) = false; + (*split_factor) = 4; + } else { + (*sync_enable) = false; + (*split_factor) = 2; + } + } else if (comm_size <= 96) { + if (total_dsize <= 64) { + (*sync_enable) = true; + (*split_factor) = 4; + } else if (total_dsize <= 1024) { + (*sync_enable) = false; + (*split_factor) = 4; + } else { + (*sync_enable) = false; + (*split_factor) = 2; + } + } else if (comm_size <= 128) { + if (total_dsize <= 64) { + (*sync_enable) = true; + (*split_factor) = 8; + } else if (total_dsize <= 512) { + (*sync_enable) = false; + (*split_factor) = 8; + } else { + (*sync_enable) = false; + (*split_factor) = 2; + } + } else { + if (total_dsize <= 32) { + (*sync_enable) = true; + (*split_factor) = 8; + } else if (total_dsize <= 2048) { + (*sync_enable) = false; + (*split_factor) = 8; + } else if (total_dsize <= 8192) { + (*sync_enable) = false; + (*split_factor) = 2; + } else { + (*sync_enable) = true; + (*split_factor) = 2; + } + } + + /* Non-multiple size of comm only supported for split factor where + * comm_size % split factor = 1. Split factor should always be a + * power of 2, else undefined behavior. */ + while ((2 < (*split_factor)) && + (1 < (comm_size % (*split_factor)))) { + (*split_factor) = (*split_factor) / 2; + } +} + +static inline size_t mca_coll_acoll_get_msg_thresh(coll_acoll_subcomms_t *subc, + mca_coll_acoll_module_t *acoll_module) +{ + size_t msg_thres[DIST_END] = {4096, 2048, 1024, 1024, 512}; + size_t dsize_thresh = msg_thres[subc->r2r_dist]; + + /* Override if associated mca param is set. */ + if (0 < (acoll_module->alltoall_attr).psplit_msg_thresh) { + dsize_thresh = (acoll_module->alltoall_attr).psplit_msg_thresh; + } + + return dsize_thresh; +} + +static int mca_coll_acoll_last_rank_scatter_gather + (const void *sbuf, size_t scount, + struct ompi_datatype_t *sdtype, + void* rbuf, size_t rcount, + struct ompi_datatype_t *rdtype, + char* work_buf, + struct ompi_communicator_t *comm, + mca_coll_acoll_module_t *acoll_module) +{ + int error; + int rank = ompi_comm_rank(comm); + int size = ompi_comm_size(comm); + MPI_Aint sbext, sblb; + MPI_Aint rext, rlb; + error = ompi_datatype_get_extent (sdtype, &sblb, &sbext); + if (MPI_SUCCESS != error) { + return error; + } + error = ompi_datatype_get_extent (rdtype, &rlb, &rext); + if (MPI_SUCCESS != error) { + return error; + } + + MPI_Status status; + int subgrp_size = acoll_module->sg_cnt; + + /* Scatter/Gather fused code. */ + /* Last rank does a scatter and gather to the sub group leaders. */ + if ((size - 1) == rank) { + /* The last block of data belongs to this last rank, and copying it + * to rbuf from sbuf suffices.*/ + error = ompi_datatype_sndrcv( + (char*)sbuf + ((size - 1) * scount * sbext), + scount, sdtype, + (char*)rbuf + ((size - 1) * rcount * rext), + rcount, rdtype); + if (MPI_SUCCESS != error) { goto error_handler; } + + /* Scatterring data to the sub group leaders, with sub group size worth + * of data.*/ + for (int cur_rank = 0; cur_rank < (size - 1); cur_rank += subgrp_size) { + int sg_scount = ((cur_rank + subgrp_size) >= size) ? + ((size - (cur_rank + 1)) * scount) : + (scount * subgrp_size); + error = MCA_PML_CALL(send((char*)sbuf + (cur_rank * scount * sbext), + sg_scount, sdtype, cur_rank, + MCA_COLL_BASE_TAG_ALLTOALL, + MCA_PML_BASE_SEND_STANDARD, comm)); + if (MPI_SUCCESS != error) { goto error_handler; } + } + + /* Gathering data from the sub group leaders, with sub group size worth + * of data. */ + for (int cur_rank = 0; cur_rank < (size - 1); cur_rank += subgrp_size) { + int sg_rcount = ((cur_rank + subgrp_size) >= size) ? + ((size - (cur_rank + 1)) * rcount) : + (rcount * subgrp_size); + error = MCA_PML_CALL(recv((char*)rbuf + (cur_rank * rcount * rext), + sg_rcount, rdtype, cur_rank, + MCA_COLL_BASE_TAG_ALLTOALL, comm, &status)); + if (MPI_SUCCESS != error) { goto error_handler; } + } + } else { + /* The 0th rank within a sub group is considered as sub group leader. */ + if (0 == (rank % subgrp_size)) { + /* Receive sub group specific data from last rank. */ + int sg_rcount = ((rank + subgrp_size) >= size) ? + ((size - (rank + 1)) * rcount) : + (rcount * subgrp_size); + error = MCA_PML_CALL(recv(work_buf, + sg_rcount, rdtype, size - 1, + MCA_COLL_BASE_TAG_ALLTOALL, comm, &status)); + if (MPI_SUCCESS != error) { goto error_handler; } + + int end_rank = ((rank + subgrp_size) >= size) ? + (size - 1) : ( rank + subgrp_size); + + /* The data received from last rank is distributed in the sub group. */ + error = ompi_datatype_copy_content_same_ddt(rdtype, rcount, + (char*)rbuf + ((size - 1) * rcount * rext), + (char*)work_buf); + if (MPI_SUCCESS != error) { goto error_handler; } + + for (int cur_rank = rank + 1; cur_rank < end_rank; ++cur_rank) { + error = MCA_PML_CALL(send(((char*)work_buf + + ((cur_rank % subgrp_size) * rcount * rext)), + rcount, rdtype, cur_rank, + MCA_COLL_BASE_TAG_ALLTOALL, + MCA_PML_BASE_SEND_STANDARD, comm)); + if (MPI_SUCCESS != error) { goto error_handler; } + } + + /* The sub group leader gathers the data for the last rank from the + * sub group and then sends it to the last rank. */ + error = ompi_datatype_sndrcv( + (char*)sbuf + ((size - 1) * scount * sbext), + scount, sdtype, + (char*)work_buf, + rcount, rdtype); + if (MPI_SUCCESS != error) { goto error_handler; } + + for (int cur_rank = rank + 1; cur_rank < end_rank; ++cur_rank) { + error = MCA_PML_CALL(recv(((char*)work_buf + + ((cur_rank % subgrp_size) * rcount * rext)), + rcount, rdtype, cur_rank, + MCA_COLL_BASE_TAG_ALLTOALL, comm, &status)); + if (MPI_SUCCESS != error) { goto error_handler; } + } + + int sg_rscount = ((rank + subgrp_size) >= size) ? + ((size - (rank + 1)) * rcount) : + (rcount * subgrp_size); + error = MCA_PML_CALL(send(work_buf, + sg_rscount, rdtype, size - 1, + MCA_COLL_BASE_TAG_ALLTOALL, + MCA_PML_BASE_SEND_STANDARD, comm)); + if (MPI_SUCCESS != error) { goto error_handler; } + } else { + /* The leaf ranks send/receive the data for/from the last rank + * to/from the sub group leader. */ + int sg_ldr_rank = ((rank / subgrp_size) * subgrp_size); + + error = MCA_PML_CALL(recv((char*)rbuf + ((size - 1) * rcount * rext), + rcount, rdtype, sg_ldr_rank, + MCA_COLL_BASE_TAG_ALLTOALL, comm, &status)); + if (MPI_SUCCESS != error) { goto error_handler; } + + error = MCA_PML_CALL(send((char*)sbuf + ((size - 1) * scount * sbext), + scount, sdtype, sg_ldr_rank, + MCA_COLL_BASE_TAG_ALLTOALL, + MCA_PML_BASE_SEND_STANDARD, comm)); + if (MPI_SUCCESS != error) { goto error_handler; } + } + } + +error_handler: + + return error; +} + +static inline int mca_coll_acoll_base_alltoall_dispatcher + (const void *sbuf, size_t scount, + struct ompi_datatype_t *sdtype, + void* rbuf, size_t rcount, + struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, + mca_coll_acoll_module_t *acoll_module, + bool sync_enable) +{ + int error; + + if (sync_enable) { + error = ompi_coll_base_alltoall_intra_linear_sync + ((char*)sbuf, scount, sdtype, + (char*)rbuf, rcount, rdtype, + comm, &acoll_module->super, 0); + } else { + error = ompi_coll_base_alltoall_intra_basic_linear + ((char*)sbuf, scount, sdtype, + (char*)rbuf, rcount, rdtype, + comm, &acoll_module->super); + } + return error; +} + +static inline int mca_coll_acoll_exchange_data + (const void *sbuf, size_t scount, + struct ompi_datatype_t *sdtype, + void* rbuf, size_t rcount, + struct ompi_datatype_t *rdtype, + char* work_buf, + struct ompi_communicator_t *comm, + mca_coll_acoll_module_t *acoll_module, + int grp_split_f) +{ + /* sbuf is not used, but added to maintain uniform arguments. */ + (void) sbuf; + (void) scount; + (void) sdtype; + + int error; + int rank = ompi_comm_rank(comm); + int size = ompi_comm_size(comm); + MPI_Aint rext, rlb; + error = ompi_datatype_get_extent (rdtype, &rlb, &rext); + if (MPI_SUCCESS != error) { + return error; + } + + /* Exchange data among groups with split factor (2 or 4 or 8) number of + * ranks. */ + int ps_grp_size = grp_split_f; + int ps_grp_start_rank = (rank / ps_grp_size) * ps_grp_size; + int ps_grp_num_ranks = size / ps_grp_size; + size_t ps_grp_rcount = ps_grp_num_ranks * rcount; + size_t ps_grp_rcount_ext = ps_grp_rcount * rext; + size_t ps_grp_buf_copy_stride = ps_grp_size * rcount * rext; + + /* Create a new datatype that iterates over the send buffer in strides + * of ps_grp_size * rcount. */ + struct ompi_datatype_t *new_ddt; + ompi_datatype_create_vector(ps_grp_num_ranks, rcount, + (rcount * ps_grp_size), + rdtype, &new_ddt); + error = ompi_datatype_commit(&new_ddt); + if (MPI_SUCCESS != error) { goto error_handler; } + + for (int iter = 1; iter < ps_grp_size; ++iter) { + int next_rank = ps_grp_start_rank + ((rank + iter) % ps_grp_size); + int prev_rank = ps_grp_start_rank + + ((rank + ps_grp_size - iter) % ps_grp_size); + int read_pos = ((rank + iter) % ps_grp_size); + + error = ompi_coll_base_sendrecv + ((char*)rbuf + ((ptrdiff_t)read_pos * rcount * rext), + 1, new_ddt, next_rank, + MCA_COLL_BASE_TAG_ALLTOALL, + (char*)work_buf + ((iter - 1) * ps_grp_rcount_ext), + ps_grp_rcount, rdtype, prev_rank, + MCA_COLL_BASE_TAG_ALLTOALL, + comm, MPI_STATUS_IGNORE, rank); + if (MPI_SUCCESS != error) { goto error_handler; } + } + + error = ompi_datatype_destroy(&new_ddt); + if (MPI_SUCCESS != error) { goto error_handler; } + + /* Copy received data to the correct blocks. */ + for (int iter = 1; iter < ps_grp_size; ++iter) { + int write_pos = ((rank + ps_grp_size - iter) % ps_grp_size); + char* dst = (char*)rbuf + (write_pos * rcount * rext); + char* src = (char*)work_buf + ((iter - 1) * ps_grp_rcount_ext); + + for (int i = 0; i < ps_grp_num_ranks; ++i) { + error = ompi_datatype_copy_content_same_ddt(rdtype, rcount, + dst, src); + if (MPI_SUCCESS != error) { goto error_handler; } + + dst = dst + ps_grp_buf_copy_stride; + src = src + (1 * rcount * rext); + } + } + +error_handler: + + return error; +} + +/* Parallel Split AllToAll algorithm in a nutshell: + * 1. Divide the ranks into split factor number of parallel groups. + * -Rank r is part of parallel group i if r % split_factor == i. + * 2. Perform all_to_all among the split groups in parallel. + * 3. Divide the ranks into exchange groups, where each group contains + split factor number of consecutive ranks. + -Rank r is part of exchange group i if r / split_factor == i. + * 4. Exchange data among the ranks in each exchange group to complete + * all_to_all. */ +int mca_coll_acoll_alltoall + (const void *sbuf, size_t scount, + struct ompi_datatype_t *sdtype, + void* rbuf, size_t rcount, + struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + int rank = ompi_comm_rank(comm); + int size = ompi_comm_size(comm); + int error = MPI_SUCCESS; + + MPI_Aint rext, rlb; + error = ompi_datatype_get_extent (rdtype, &rlb, &rext); + if (MPI_SUCCESS != error) { return error; } + + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *)module; + coll_acoll_subcomms_t *subc = NULL; + + /* Obtain the subcomms structure */ + error = check_and_create_subc(comm, acoll_module, &subc); + /* Fallback to knomial if subcomms is not obtained */ + if ((NULL == subc) || (size < 4)) { + return mca_coll_acoll_base_alltoall_dispatcher + (sbuf, scount, sdtype, + rbuf, rcount, rdtype, + comm, acoll_module, false); + } + + coll_acoll_reserve_mem_t* reserve_mem_gather = &(acoll_module->reserve_mem_s); + + if (!subc->initialized && (size > 2)) { + error = mca_coll_acoll_comm_split_init(comm, acoll_module, subc, 0); + if (MPI_SUCCESS != error) { return error; } + } + + size_t dsize = 0; + ompi_datatype_type_size(rdtype, &dsize); + + /* Derive upper bound on message size where this algorithm is applicable. */ + size_t dsize_thresh = mca_coll_acoll_get_msg_thresh(subc, acoll_module); + + if (dsize_thresh < (rcount * rext)) { + return mca_coll_acoll_base_alltoall_dispatcher + (sbuf, scount, sdtype, + rbuf, rcount, rdtype, + comm, acoll_module, false); + } + + bool sync_enable = false; + int grp_split_f = 2; + if ((acoll_module->alltoall_attr).split_factor > 0 ) { + grp_split_f = (acoll_module->alltoall_attr).split_factor; + + /* Non-multiple size of comm only supported for split factor where + * comm_size % split factor = 1. Split factor should always be a + * power of 2, else undefined behavior. */ + while ((2 < grp_split_f) && + (1 < (size % grp_split_f))) { + grp_split_f = grp_split_f / 2; + } + } else { + mca_coll_acoll_get_split_factor_and_base_algo + (scount, sdtype, rcount, rdtype, + (MPI_IN_PLACE == sbuf), comm, + &sync_enable, &grp_split_f); + } + + char* work_buf_free = NULL; + char* work_buf = NULL; + MPI_Aint rgap = 0, ssize; + + ssize = opal_datatype_span(&rdtype->super, size * rcount, &rgap); + work_buf_free = (char*)coll_acoll_buf_alloc(reserve_mem_gather, ssize); + if (NULL == work_buf_free) { + error = OMPI_ERR_OUT_OF_RESOURCE; + goto error_handler; + } + work_buf = work_buf_free - rgap; + + /* In case size is odd, the data to and from the last rank is handled as + * a separate case. */ + if ((0 == (size % 2)) || (rank != (size - 1))) { + /* Perform all_to_all among the parallel-split groups. */ + struct ompi_communicator_t *split_comm; + + /* Select the right split_comm. */ + int pow2_idx = -2; + int tmp_grp_split_f = grp_split_f; + while (tmp_grp_split_f > 0) + { + pow2_idx += 1; + tmp_grp_split_f = tmp_grp_split_f / 2; + } + split_comm = subc->split_comm[pow2_idx]; + + error = mca_coll_acoll_base_alltoall_dispatcher + (sbuf, (grp_split_f * scount), sdtype, + rbuf, (grp_split_f * rcount), rdtype, + split_comm, acoll_module, sync_enable); + if (MPI_SUCCESS != error) { goto error_handler; } + + /* Exchange data among consecutive blocks of split factor ranks. */ + error = mca_coll_acoll_exchange_data + (sbuf, scount, sdtype, + rbuf, rcount, rdtype, + work_buf, comm, acoll_module, grp_split_f); + if (MPI_SUCCESS != error) { goto error_handler; } + } + + /* Data transfer for the last rank. */ + if (0 != (size % 2)) { + error = mca_coll_acoll_last_rank_scatter_gather + (sbuf, scount, sdtype, + rbuf, rcount, rdtype, + work_buf, comm, acoll_module); + if (MPI_SUCCESS != error) { goto error_handler; } + } + +error_handler: + coll_acoll_buf_free(reserve_mem_gather, work_buf_free); + + return error; +} diff --git a/ompi/mca/coll/acoll/coll_acoll_component.c b/ompi/mca/coll/acoll/coll_acoll_component.c index 6a8651fcf81..d3c8d2469b1 100644 --- a/ompi/mca/coll/acoll/coll_acoll_component.c +++ b/ompi/mca/coll/acoll/coll_acoll_component.c @@ -1,6 +1,6 @@ /* -*- Mode: C; c-acoll-offset:4 ; indent-tabs-mode:nil -*- */ /* - * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024 - 2025 Advanced Micro Devices, Inc. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -42,6 +42,9 @@ int mca_coll_acoll_allgather_ring_1 = 0; int mca_coll_acoll_reserve_memory_for_algo = 0; uint64_t mca_coll_acoll_reserve_memory_size_for_algo = 128 * 32768; // 4 MB uint64_t mca_coll_acoll_xpmem_buffer_size = 128 * 32768; +int mca_coll_acoll_alltoall_split_factor = 0; +size_t mca_coll_acoll_alltoall_psplit_msg_thres = 0; +size_t mca_coll_acoll_alltoall_xpmem_msg_thres = 0; /* By default utilize xpmem based algorithms applicable when built with xpmem. */ int mca_coll_acoll_without_xpmem = 0; @@ -193,6 +196,24 @@ static int acoll_register(void) "assumed to persist for the duration of the application.", MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_xpmem_use_sr_buf); + (void) mca_base_component_var_register( + &mca_coll_acoll_component.collm_version, "alltoall_split_factor", + "Split factor value to be used in alltoall parallel split algorithm," + "valid values are 2, 4, 8, 16, 32, 64.", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_alltoall_split_factor); + (void) mca_base_component_var_register( + &mca_coll_acoll_component.collm_version, "alltoall_psplit_msg_thresh", + "Message threshold above which parallel split alltoall algorithm " + "should not be used.", + MCA_BASE_VAR_TYPE_SIZE_T, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_alltoall_psplit_msg_thres); + (void) mca_base_component_var_register( + &mca_coll_acoll_component.collm_version, "alltoall_xpmem_msg_thresh", + "Message threshold above which xpmem based linear alltoall algorithm " + "should be used for intra node cases.", + MCA_BASE_VAR_TYPE_SIZE_T, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_alltoall_xpmem_msg_thres); return OMPI_SUCCESS; } @@ -218,6 +239,33 @@ static void mca_coll_acoll_module_construct(mca_coll_acoll_module_t *module) (module->reserve_mem_s).reserve_mem_allocate = true; (module->reserve_mem_s).reserve_mem_size = mca_coll_acoll_reserve_memory_size_for_algo; } + + /* Ensure valid split factor is given. */ + int8_t valid_sf = 0; + const int split_factor_list[MCA_COLL_ACOLL_SPLIT_FACTOR_LIST_LEN] = + MCA_COLL_ACOLL_SPLIT_FACTOR_LIST; + for (int ii = 0; ii < MCA_COLL_ACOLL_SPLIT_FACTOR_LIST_LEN; ++ii) { + if (split_factor_list[ii] == mca_coll_acoll_alltoall_split_factor) { + valid_sf = 1; + break; + } + } + (module->alltoall_attr).split_factor = 0; + if (1 == valid_sf) { + (module->alltoall_attr).split_factor = mca_coll_acoll_alltoall_split_factor; + } + + (module->alltoall_attr).psplit_msg_thresh = 0; + if (0 < mca_coll_acoll_alltoall_psplit_msg_thres) { + (module->alltoall_attr).psplit_msg_thresh = + mca_coll_acoll_alltoall_psplit_msg_thres; + } + + (module->alltoall_attr).xpmem_msg_thresh = 0; + if (0 < mca_coll_acoll_alltoall_xpmem_msg_thres) { + (module->alltoall_attr).xpmem_msg_thresh = + mca_coll_acoll_alltoall_xpmem_msg_thres; + } } /* @@ -332,6 +380,13 @@ static void mca_coll_acoll_module_destruct(mca_coll_acoll_module_t *module) } } } + + for (int k = 0; k < MCA_COLL_ACOLL_SPLIT_FACTOR_LIST_LEN; ++k) { + if (subc->split_comm[k] != NULL) { + ompi_comm_free(&(subc->split_comm[k])); + subc->split_comm[k] = NULL; + } + } subc->initialized = 0; free(subc); module->subc[i] = NULL; @@ -345,6 +400,10 @@ static void mca_coll_acoll_module_destruct(mca_coll_acoll_module_t *module) && (NULL != (module->reserve_mem_s).reserve_mem)) { free((module->reserve_mem_s).reserve_mem); } + + (module->alltoall_attr).split_factor = 0; + (module->alltoall_attr).psplit_msg_thresh = 0; + (module->alltoall_attr).xpmem_msg_thresh = 0; } OBJ_CLASS_INSTANCE(mca_coll_acoll_module_t, mca_coll_base_module_t, mca_coll_acoll_module_construct, diff --git a/ompi/mca/coll/acoll/coll_acoll_module.c b/ompi/mca/coll/acoll/coll_acoll_module.c index 3924e755dc0..697f47bb8d0 100644 --- a/ompi/mca/coll/acoll/coll_acoll_module.c +++ b/ompi/mca/coll/acoll/coll_acoll_module.c @@ -158,6 +158,7 @@ mca_coll_base_module_t *mca_coll_acoll_comm_query(struct ompi_communicator_t *co acoll_module->super.coll_allgather = mca_coll_acoll_allgather; acoll_module->super.coll_allreduce = mca_coll_acoll_allreduce_intra; + acoll_module->super.coll_alltoall = mca_coll_acoll_alltoall; acoll_module->super.coll_barrier = mca_coll_acoll_barrier_intra; acoll_module->super.coll_bcast = mca_coll_acoll_bcast; acoll_module->super.coll_gather = mca_coll_acoll_gather_intra; @@ -181,6 +182,7 @@ static int acoll_module_enable(mca_coll_base_module_t *module, struct ompi_commu ACOLL_INSTALL_COLL_API(comm, acoll_module, allgather); ACOLL_INSTALL_COLL_API(comm, acoll_module, allreduce); + ACOLL_INSTALL_COLL_API(comm, acoll_module, alltoall); ACOLL_INSTALL_COLL_API(comm, acoll_module, barrier); ACOLL_INSTALL_COLL_API(comm, acoll_module, bcast); ACOLL_INSTALL_COLL_API(comm, acoll_module, gather); @@ -201,6 +203,7 @@ static int acoll_module_disable(mca_coll_base_module_t *module, struct ompi_comm ACOLL_UNINSTALL_COLL_API(comm, acoll_module, allgather); ACOLL_UNINSTALL_COLL_API(comm, acoll_module, allreduce); + ACOLL_UNINSTALL_COLL_API(comm, acoll_module, alltoall); ACOLL_UNINSTALL_COLL_API(comm, acoll_module, barrier); ACOLL_UNINSTALL_COLL_API(comm, acoll_module, bcast); ACOLL_UNINSTALL_COLL_API(comm, acoll_module, gather); diff --git a/ompi/mca/coll/acoll/coll_acoll_utils.h b/ompi/mca/coll/acoll/coll_acoll_utils.h index c665ad2babc..20b1f26df42 100644 --- a/ompi/mca/coll/acoll/coll_acoll_utils.h +++ b/ompi/mca/coll/acoll/coll_acoll_utils.h @@ -1,6 +1,6 @@ /* -*- Mode: C; indent-tabs-mode:nil -*- */ /* - * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024 - 2025 Advanced Micro Devices, Inc. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -154,6 +154,10 @@ static inline int check_and_create_subc(ompi_communicator_t *comm, subc->local_root[j] = 0; } + for (int k = 0; k < MCA_COLL_ACOLL_SPLIT_FACTOR_LIST_LEN; ++k) { + subc->split_comm[k] = NULL; + } + subc->numa_comm = NULL; subc->numa_comm_ldrs = NULL; subc->node_comm = NULL; @@ -251,6 +255,133 @@ static inline int mca_coll_acoll_create_base_comm(ompi_communicator_t **parent_c return err; } +static inline int mca_coll_acoll_is_adj_rank_same_sub_comm + (ompi_communicator_t* sub_comm, + int sub_comm_size, int* sub_comm_ranks, + ompi_group_t* parent_grp, + int parent_comm_size, int* parent_comm_ranks, + int par_comm_rank, bool* is_adj) +{ + (*is_adj) = false; + + ompi_group_t *sub_comm_grp; + int error = ompi_comm_group(sub_comm, &sub_comm_grp); + if (MPI_SUCCESS != error) { + return error; + } + + for (int i = 0; i < sub_comm_size; ++i) { + sub_comm_ranks[i] = i; + } + + /* Current rank is guaranteed to be in all the accessed subcomms. */ + error = ompi_group_translate_ranks(sub_comm_grp, sub_comm_size, sub_comm_ranks, + parent_grp, parent_comm_ranks); + if (MPI_SUCCESS != error) { + return error; + } + + for (int ii = 0; ii < sub_comm_size; ++ii) + { + if (((par_comm_rank + 1) % parent_comm_size) == parent_comm_ranks[ii]) + { + (*is_adj) = true; + break; + } + } + + return MPI_SUCCESS; +} + +static inline int mca_coll_acoll_derive_r2r_latency + (ompi_communicator_t *comm, + coll_acoll_subcomms_t *subc, + mca_coll_acoll_module_t *acoll_module) +{ + int size = ompi_comm_size(comm); + int rank = ompi_comm_rank(comm); + subc->r2r_dist = DIST_NODE; + + coll_acoll_reserve_mem_t *rsv_mem = &(acoll_module->reserve_mem_s); + int* workbuf = (int *) coll_acoll_buf_alloc(rsv_mem, + 2 * size * sizeof(int)); + if (NULL == workbuf) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + + int* comm_ranks = workbuf + size; + ompi_group_t *comm_grp; + int error = ompi_comm_group(comm, &comm_grp); + if (MPI_SUCCESS != error) { goto error_handler; } + + bool is_same_l3 = false; + int distance = DIST_CORE; /* map-by core distance. */ + int l3_comm_size = subc->subgrp_size; + error = mca_coll_acoll_is_adj_rank_same_sub_comm(subc->subgrp_comm, + l3_comm_size, workbuf, comm_grp, + size, comm_ranks, rank, &is_same_l3); + if (MPI_SUCCESS != error) { goto error_handler; } + + bool is_same_numa = false; + if (!is_same_l3) { + distance = DIST_L3CACHE; /* map-by l3 distance. */ + int numa_comm_size = ompi_comm_size(subc->numa_comm); + error = mca_coll_acoll_is_adj_rank_same_sub_comm(subc->numa_comm, + numa_comm_size, workbuf, comm_grp, + size, comm_ranks, rank, &is_same_numa); + if (MPI_SUCCESS != error) { goto error_handler; } + } + + bool is_same_socket = false; + if ((!is_same_l3) && (!is_same_numa)) { + distance = DIST_NUMA; /* map-by numa distance. */ + int socket_comm_size = ompi_comm_size(subc->socket_comm); + error = mca_coll_acoll_is_adj_rank_same_sub_comm(subc->socket_comm, + socket_comm_size, workbuf, comm_grp, + size, comm_ranks, rank, &is_same_socket); + if (MPI_SUCCESS != error) { goto error_handler; } + } + + bool is_same_node = false; + if ((!is_same_l3) && (!is_same_numa) && (!is_same_socket)) { + distance = DIST_SOCKET; /* map-by socket distance. */ + int local_comm_size = ompi_comm_size(subc->local_comm); + error = mca_coll_acoll_is_adj_rank_same_sub_comm(subc->local_comm, + local_comm_size, workbuf, comm_grp, + size, comm_ranks, rank, &is_same_node); + if (MPI_SUCCESS != error) { goto error_handler; } + } + + if ((!is_same_l3) && (!is_same_numa) && (!is_same_socket) && (!is_same_node)) { + distance = DIST_NODE; /* map-by node distance. */ + } + + error = (comm)->c_coll->coll_allgather(&distance, 1, MPI_INT, + workbuf, 1, MPI_INT, + comm, &acoll_module->super); + if (MPI_SUCCESS != error) { goto error_handler; } + + int dist_count_array[DIST_END] = {0}; + for (int ii = 0; ii < size; ++ii) { + dist_count_array[workbuf[ii]] += 1; + } + + int max_idx = DIST_CORE; + for (int ii = (max_idx + 1); ii < DIST_END; ++ii) { + if (dist_count_array[ii] > dist_count_array[max_idx]) { + max_idx = ii; + } + } + subc->r2r_dist = max_idx; + + coll_acoll_buf_free(rsv_mem, workbuf); + return MPI_SUCCESS; + +error_handler: + coll_acoll_buf_free(rsv_mem, workbuf); + return error; +} + static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm, mca_coll_acoll_module_t *acoll_module, coll_acoll_subcomms_t *subc, @@ -576,6 +707,32 @@ static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm, return err; } } + + err = mca_coll_acoll_derive_r2r_latency(comm, subc, acoll_module); + if (MPI_SUCCESS != err) { + return err; + } + + const int split_factor_list_len = MCA_COLL_ACOLL_SPLIT_FACTOR_LIST_LEN; + const int split_factor_list[MCA_COLL_ACOLL_SPLIT_FACTOR_LIST_LEN] = + MCA_COLL_ACOLL_SPLIT_FACTOR_LIST; + for (int ii = 0; ii < split_factor_list_len; ++ii) { + int split_comm_color = rank % split_factor_list[ii]; + + /* If comm size is not a perfect multiple of split factor, then + * unless comm size % split factor <= 1, the split_comm + * for split factor 2 is used.*/ + if ((0 != (size % split_factor_list[ii])) && + (rank >= (size - (size % split_factor_list[ii])))) { + split_comm_color = split_factor_list[ii]; + } + err = ompi_comm_split(comm, split_comm_color, rank, + &subc->split_comm[ii], false); + if (MPI_SUCCESS != err) { + return err; + } + } + subc->derived_node_size = (size + subc->num_nodes - 1) / subc->num_nodes; }