diff --git a/src/components/tl/ucp/alltoall/alltoall_onesided.c b/src/components/tl/ucp/alltoall/alltoall_onesided.c index 6ba61130c13..8cb9f5d573e 100644 --- a/src/components/tl/ucp/alltoall/alltoall_onesided.c +++ b/src/components/tl/ucp/alltoall/alltoall_onesided.c @@ -9,6 +9,7 @@ #include "alltoall.h" #include "core/ucc_progress_queue.h" #include "utils/ucc_math.h" +#include "tl_ucp_coll.h" #include "tl_ucp_sendrecv.h" #define CONGESTION_THRESHOLD 8 @@ -71,8 +72,14 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_sched_finalize(ucc_coll_task_t *ctask) ucc_status_t ucc_tl_ucp_alltoall_onesided_finalize(ucc_coll_task_t *coll_task) { - ucc_status_t status; + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_status_t status; + status = ucc_tl_ucp_coll_dynamic_segment_destroy(task); + if (ucc_unlikely(UCC_OK != status)) { + tl_error(UCC_TASK_LIB(coll_task), + "failed to destroy dynamic segment local handles"); + } status = ucc_tl_ucp_coll_finalize(coll_task); if (ucc_unlikely(UCC_OK != status)) { tl_error(UCC_TASK_LIB(coll_task), "failed to finalize collective"); @@ -91,27 +98,43 @@ void ucc_tl_ucp_alltoall_onesided_get_progress(ucc_coll_task_t *ctask) ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team); uint32_t ntokens = task->alltoall_onesided.tokens; int64_t npolls = task->alltoall_onesided.npolls; - /* To resolve remote virtual addresses, the dst_memh is the one that must - * have the rkey information. For this algorithm, we need to swap the - * src and dst handles to operate correctly */ - ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).src_memh.global_memh; - uint32_t *posted = &task->onesided.get_posted; - uint32_t *completed = &task->onesided.get_completed; - ucc_rank_t peer = (grank + *posted + 1) % gsize; - ucc_mem_map_mem_h src_memh; + /* For GET, we read from each peer's src buffer into our local dst buffer. + * remote_rkeys is the per-rank array of src rkeys (what we GET from); + * local_h is our own dst buffer registration (where data lands). */ + ucc_mem_map_mem_h *remote_rkeys = TASK_ARGS(task).src_memh.global_memh; + uint32_t *posted = &task->onesided.get_posted; + uint32_t *completed = &task->onesided.get_completed; + ucc_rank_t peer = (grank + *posted + 1) % gsize; + ucc_mem_map_mem_h local_h; size_t nelems; + ucc_status_t status; + + if (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG) { + status = ucc_tl_ucp_test_dynamic_segment(task); + if (status == UCC_INPROGRESS) { + return; + } + if (UCC_OK != status) { + task->super.status = status; + tl_error(UCC_TL_TEAM_LIB(team), + "failed to exchange dynamic segments"); + return; + } + local_h = task->dynamic_segments.dst_local; + remote_rkeys = (ucc_mem_map_mem_h *)task->dynamic_segments.src_global; + } else { + local_h = (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL) + ? TASK_ARGS(task).dst_memh.global_memh[grank] + : TASK_ARGS(task).dst_memh.local_memh; + } nelems = TASK_ARGS(task).src.info.count; nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype); - src_memh = (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL) - ? TASK_ARGS(task).dst_memh.global_memh[grank] - : TASK_ARGS(task).dst_memh.local_memh; - for (; *posted < gsize; peer = (peer + 1) % gsize) { UCPCHECK_GOTO(ucc_tl_ucp_get_nb(PTR_OFFSET(dest, peer * nelems), PTR_OFFSET(src, grank * nelems), - nelems, mtype, peer, src_memh, dst_memh, - team, task), + nelems, mtype, peer, local_h, + remote_rkeys, team, task), task, out); if (!alltoall_onesided_handle_completion(task, posted, completed, @@ -122,7 +145,10 @@ void ucc_tl_ucp_alltoall_onesided_get_progress(ucc_coll_task_t *ctask) alltoall_onesided_wait_completion(task, npolls); out: - return; + if (task->super.status != UCC_INPROGRESS && + (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG)) { + task->super.status = ucc_tl_ucp_coll_dynamic_segment_finalize(task); + } } void ucc_tl_ucp_alltoall_onesided_put_progress(ucc_coll_task_t *ctask) @@ -142,12 +168,28 @@ void ucc_tl_ucp_alltoall_onesided_put_progress(ucc_coll_task_t *ctask) ucc_rank_t peer = (grank + *posted + 1) % gsize; ucc_mem_map_mem_h src_memh; size_t nelems; + ucc_status_t status; nelems = TASK_ARGS(task).src.info.count; nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype); - src_memh = (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL) + if (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG) { + status = ucc_tl_ucp_test_dynamic_segment(task); + if (status == UCC_INPROGRESS) { + return; + } + if (UCC_OK != status) { + task->super.status = status; + tl_error(UCC_TL_TEAM_LIB(team), + "failed to exchange dynamic segments"); + return; + } + src_memh = task->dynamic_segments.src_local; + dst_memh = (ucc_mem_map_mem_h *)task->dynamic_segments.dst_global; + } else { + src_memh = (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL) ? TASK_ARGS(task).src_memh.global_memh[grank] : TASK_ARGS(task).src_memh.local_memh; + } for (; *posted < gsize; peer = (peer + 1) % gsize) { UCPCHECK_GOTO( @@ -165,15 +207,30 @@ void ucc_tl_ucp_alltoall_onesided_put_progress(ucc_coll_task_t *ctask) alltoall_onesided_wait_completion(task, npolls); out: - return; + if (task->super.status != UCC_INPROGRESS && + (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG)) { + task->super.status = ucc_tl_ucp_coll_dynamic_segment_finalize(task); + } } ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask) { ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t); ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_status_t status; ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); + if (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG) { + status = ucc_tl_ucp_coll_dynamic_segment_exchange_nb(task); + if (UCC_OK != status && UCC_INPROGRESS != status) { + task->super.status = status; + tl_error(UCC_TL_TEAM_LIB(team), + "failed to exchange dynamic segments"); + return task->super.status; + } + } + + /* Start the onesided operations */ return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); } @@ -190,7 +247,7 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args, }; size_t perc_bw = UCC_TL_UCP_TEAM_LIB(tl_team)->cfg.alltoall_onesided_percent_bw; - ucc_tl_ucp_alltoall_onesided_alg_t alg = + ucc_tl_ucp_onesided_alg_type alg = UCC_TL_UCP_TEAM_LIB(tl_team)->cfg.alltoall_onesided_alg; ucc_tl_ucp_schedule_t *tl_schedule = NULL; ucc_rank_t group_size = 1; @@ -208,15 +265,6 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args, ucc_sbgp_t *sbgp; ALLTOALL_TASK_CHECK(coll_args->args, tl_team); - if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_FLAGS) || - (coll_args->args.mask & UCC_COLL_ARGS_FIELD_FLAGS && - (!(coll_args->args.flags & - UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS)))) { - tl_error(UCC_TL_TEAM_LIB(tl_team), - "non memory mapped buffers are not supported"); - status = UCC_ERR_NOT_SUPPORTED; - return status; - } if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH)) { coll_args->args.src_memh.global_memh = NULL; @@ -228,7 +276,6 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args, return status; } } - if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_DST_MEMH)) { coll_args->args.dst_memh.global_memh = NULL; } else { @@ -239,6 +286,8 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args, return status; } } + + status = ucc_tl_ucp_get_schedule(tl_team, coll_args, (ucc_tl_ucp_schedule_t **)&tl_schedule); if (ucc_unlikely(UCC_OK != status)) { @@ -260,12 +309,35 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args, group_size = sbgp->group_size; } - task = ucc_tl_ucp_init_task(coll_args, team); + task = ucc_tl_ucp_init_task(coll_args, team); + if (ucc_unlikely(!task)) { + status = UCC_ERR_NO_MEMORY; + goto out; + } task->super.finalize = ucc_tl_ucp_alltoall_onesided_finalize; a2a_task = &task->super; + /* initialize dynamic segments */ + if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET || + (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_AUTO && + sbgp->group_size >= CONGESTION_THRESHOLD)) { + alg = UCC_TL_UCP_ALLTOALL_ONESIDED_GET; + } else if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_AUTO) { + alg = UCC_TL_UCP_ALLTOALL_ONESIDED_PUT; + } + status = ucc_tl_ucp_coll_dynamic_segment_init(&coll_args->args, alg, task); + if (UCC_OK != status) { + if (status != UCC_ERR_NOT_SUPPORTED) { + tl_error(UCC_TL_TEAM_LIB(tl_team), + "failed to initialize dynamic segments"); + } + ucc_tl_ucp_coll_finalize(&task->super); + goto out; + } + status = ucc_tl_ucp_coll_init(&barrier_coll_args, team, &barrier_task); if (status != UCC_OK) { + task->super.finalize(&task->super); goto out; } if (perc_bw > 100) { @@ -278,23 +350,25 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args, nelems = nelems / UCC_TL_TEAM_SIZE(tl_team); param.field_mask = UCP_EP_PERF_PARAM_FIELD_MESSAGE_SIZE; attr.field_mask = UCP_EP_PERF_ATTR_FIELD_ESTIMATED_TIME; - param.message_size = nelems * ucc_dt_size(TASK_ARGS(task).src.info.datatype);; + param.message_size = nelems * ucc_dt_size(TASK_ARGS(task).src.info.datatype); ucc_tl_ucp_get_ep( tl_team, (UCC_TL_TEAM_RANK(tl_team) + 1) % UCC_TL_TEAM_SIZE(tl_team), &ep); ucp_ep_evaluate_perf(ep, ¶m, &attr); - rate = (1 / attr.estimated_time) * (double)(perc_bw / 100.0); - ratio = (nelems > 0) ? nelems * group_size : 1; - task->alltoall_onesided.tokens = rate / ratio; + if (attr.estimated_time > 0) { + rate = (1 / attr.estimated_time) * (double)(perc_bw / 100.0); + ratio = (nelems > 0) ? nelems * group_size : 1; + task->alltoall_onesided.tokens = rate / ratio; + } else { + task->alltoall_onesided.tokens = 0; + } if (task->alltoall_onesided.tokens < 1) { task->alltoall_onesided.tokens = 1; } task->super.post = ucc_tl_ucp_alltoall_onesided_start; npolls = task->n_polls; - if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET || - (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_AUTO && - group_size >= CONGESTION_THRESHOLD)) { + if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET) { npolls = nelems * ucc_dt_size(TASK_ARGS(task).src.info.datatype); if (npolls < task->n_polls) { npolls = task->n_polls; @@ -313,6 +387,7 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args, ucc_task_subscribe_dep(a2a_task, barrier_task, UCC_EVENT_COMPLETED); *task_h = &schedule->super; + return status; out: if (tl_schedule) { diff --git a/src/components/tl/ucp/alltoallv/alltoallv_onesided.c b/src/components/tl/ucp/alltoallv/alltoallv_onesided.c index d572ce1aa35..6eaba91e509 100644 --- a/src/components/tl/ucp/alltoallv/alltoallv_onesided.c +++ b/src/components/tl/ucp/alltoallv/alltoallv_onesided.c @@ -11,6 +11,8 @@ #include "utils/ucc_math.h" #include "tl_ucp_sendrecv.h" +void ucc_tl_ucp_alltoallv_onesided_progress(ucc_coll_task_t *ctask); + ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask) { ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t); @@ -32,6 +34,12 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask) ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); + if (TASK_ARGS(task).mask & UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH) { + if (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL) { + src_memh = TASK_ARGS(task).src_memh.global_memh[grank]; + } + } + /* perform a put to each member peer using the peer's index in the * destination displacement. */ for (peer = (grank + 1) % gsize; task->onesided.put_posted < gsize; @@ -43,8 +51,8 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask) ucc_coll_args_get_displacement(&TASK_ARGS(task), d_disp, peer) * rdt_size; data_size = - ucc_coll_args_get_count( - &TASK_ARGS(task), TASK_ARGS(task).src.info_v.counts, peer) * + ucc_coll_args_get_count(&TASK_ARGS(task), + TASK_ARGS(task).src.info_v.counts, peer) * sdt_size; UCPCHECK_GOTO(ucc_tl_ucp_put_nb(PTR_OFFSET(src, sd_disp), @@ -56,6 +64,7 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask) dst_memh, team), task, out); } + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); out: return task->super.status; @@ -63,15 +72,14 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask) void ucc_tl_ucp_alltoallv_onesided_progress(ucc_coll_task_t *ctask) { - ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team); - long *pSync = TASK_ARGS(task).global_work_buffer; + ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + long *pSync = TASK_ARGS(task).global_work_buffer; + ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team); if (ucc_tl_ucp_test_onesided(task, gsize) == UCC_INPROGRESS) { return; } - pSync[0] = 0; task->super.status = UCC_OK; } @@ -81,8 +89,8 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_init(ucc_base_coll_args_t *coll_args, ucc_coll_task_t **task_h) { ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); + ucc_status_t status = UCC_OK; ucc_tl_ucp_task_t *task; - ucc_status_t status; ALLTOALLV_TASK_CHECK(coll_args->args, tl_team); if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER)) { @@ -104,13 +112,24 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_init(ucc_base_coll_args_t *coll_args, } if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_DST_MEMH)) { coll_args->args.dst_memh.global_memh = NULL; + } else { + if (!(coll_args->args.flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL)) { + tl_error(UCC_TL_TEAM_LIB(tl_team), + "onesided alltoallv requires global memory handles for dst " + "buffers"); + status = UCC_ERR_INVALID_PARAM; + goto out; + } } - task = ucc_tl_ucp_init_task(coll_args, team); + task = ucc_tl_ucp_init_task(coll_args, team); + if (ucc_unlikely(!task)) { + status = UCC_ERR_NO_MEMORY; + goto out; + } *task_h = &task->super; task->super.post = ucc_tl_ucp_alltoallv_onesided_start; task->super.progress = ucc_tl_ucp_alltoallv_onesided_progress; - status = UCC_OK; out: return status; } diff --git a/src/components/tl/ucp/tl_ucp.c b/src/components/tl/ucp/tl_ucp.c index e2cfe1434d5..64601b49532 100644 --- a/src/components/tl/ucp/tl_ucp.c +++ b/src/components/tl/ucp/tl_ucp.c @@ -308,10 +308,10 @@ static ucs_config_field_t ucc_tl_ucp_context_config_table[] = { ucc_offsetof(ucc_tl_ucp_context_config_t, memtype_copy_enable), UCC_CONFIG_TYPE_BOOL}, - {"EXPORTED_MEMORY_HANDLE", "n", - "If set to yes, initialize UCP context with the exported memory handle " - "feature, which is useful for offload devices such as a DPU. Otherwise " - "disable the use of this feature.", + {"EXPORTED_MEMORY_HANDLE", "0", + "If set to 1, initialize UCP context with the exported memory handle " + "feature, which is useful for offload devices such as a DPU. Set to 0 " + "to disable this feature (default is 0).", ucc_offsetof(ucc_tl_ucp_context_config_t, exported_memory_handle), UCC_CONFIG_TYPE_BOOL}, diff --git a/src/components/tl/ucp/tl_ucp.h b/src/components/tl/ucp/tl_ucp.h index d75d7cec6f4..eff83cae607 100644 --- a/src/components/tl/ucp/tl_ucp.h +++ b/src/components/tl/ucp/tl_ucp.h @@ -42,12 +42,12 @@ typedef struct ucc_tl_ucp_iface { /* Extern iface should follow the pattern: ucc_tl_ */ extern ucc_tl_ucp_iface_t ucc_tl_ucp; -typedef enum ucc_tl_ucp_alltoall_onesided_alg_type { +typedef enum ucc_tl_ucp_onesided_alg_type { UCC_TL_UCP_ALLTOALL_ONESIDED_PUT, UCC_TL_UCP_ALLTOALL_ONESIDED_GET, UCC_TL_UCP_ALLTOALL_ONESIDED_AUTO, UCC_TL_UCP_ALLTOALL_ONESIDED_LAST -} ucc_tl_ucp_alltoall_onesided_alg_t; +} ucc_tl_ucp_onesided_alg_type; typedef struct ucc_tl_ucp_lib_config { ucc_tl_lib_config_t super; @@ -88,7 +88,7 @@ typedef struct ucc_tl_ucp_lib_config { ucc_ternary_auto_value_t use_topo; int use_reordering; uint32_t alltoall_onesided_percent_bw; - ucc_tl_ucp_alltoall_onesided_alg_t alltoall_onesided_alg; + ucc_tl_ucp_onesided_alg_type alltoall_onesided_alg; } ucc_tl_ucp_lib_config_t; typedef enum ucc_tl_ucp_local_copy_type { @@ -168,8 +168,6 @@ typedef struct ucc_tl_ucp_team { ucc_status_t status; uint32_t seq_num; ucc_tl_ucp_task_t *preconnect_task; - void * va_base[MAX_NR_SEGMENTS]; - size_t base_length[MAX_NR_SEGMENTS]; ucc_tl_ucp_worker_t * worker; ucc_tl_ucp_team_config_t cfg; const char * tuning_str; @@ -313,4 +311,19 @@ void ucc_tl_ucp_pre_register_mem(ucc_tl_ucp_team_t *team, void *addr, ucc_status_t ucc_tl_ucp_ctx_remote_populate(ucc_tl_ucp_context_t *ctx, ucc_mem_map_params_t map, ucc_team_oob_coll_t oob); + +ucc_status_t ucc_tl_ucp_mem_map(const ucc_base_context_t *context, + ucc_mem_map_mode_t mode, + ucc_mem_map_memh_t *memh, + ucc_mem_map_tl_t *tl_h); + +ucc_status_t ucc_tl_ucp_memh_pack(const ucc_base_context_t *context, + ucc_mem_map_mode_t mode, + ucc_mem_map_tl_t *tl_h, + void **pack_buffer); + +ucc_status_t ucc_tl_ucp_mem_unmap(const ucc_base_context_t *context, + ucc_mem_map_mode_t mode, + ucc_mem_map_tl_t *memh); + #endif diff --git a/src/components/tl/ucp/tl_ucp_coll.c b/src/components/tl/ucp/tl_ucp_coll.c index d0eb03cedc5..38da2336b5d 100644 --- a/src/components/tl/ucp/tl_ucp_coll.c +++ b/src/components/tl/ucp/tl_ucp_coll.c @@ -8,6 +8,8 @@ #include "tl_ucp_coll.h" #include "components/mc/ucc_mc.h" #include "core/ucc_team.h" +#include "utils/ucc_math.h" +#include "utils/ucc_coll_utils.h" #include "barrier/barrier.h" #include "alltoall/alltoall.h" #include "alltoallv/alltoallv.h" @@ -24,6 +26,9 @@ #include "fanout/fanout.h" #include "scatterv/scatterv.h" +/* Selects which buffer is mapped in dynamic_segment_map_memh */ +enum { DYN_SEG_DST = 0, DYN_SEG_SRC = 1 }; + const ucc_tl_ucp_default_alg_desc_t ucc_tl_ucp_default_alg_descs[UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR] = { { @@ -106,6 +111,790 @@ ucc_status_t ucc_tl_ucp_coll_finalize(ucc_coll_task_t *coll_task) return UCC_OK; } +static inline ucc_status_t dynamic_segment_map_memh(ucc_mem_map_memh_t **memh, + ucc_coll_args_t *coll_args, + int is_src, + ucc_tl_ucp_task_t *task) +{ + ucc_tl_ucp_team_t *tl_team = UCC_TL_UCP_TASK_TEAM(task); + ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(tl_team); + ucc_status_t status = UCC_OK; + ucc_mem_map_memh_t *lmemh = NULL; + void *buffer; + ucc_count_t total_count; + ucc_datatype_t datatype; + + lmemh = ucc_calloc(1, sizeof(ucc_mem_map_memh_t), "dyn_memh"); + if (lmemh == NULL) { + tl_error(UCC_TASK_LIB(task), "failed to allocate memh"); + status = UCC_ERR_NO_MEMORY; + goto out; + } + lmemh->tl_h = ucc_calloc(1, sizeof(ucc_mem_map_tl_t), "dyn_tlh"); + if (!lmemh->tl_h) { + tl_error(UCC_TASK_LIB(task), "failed to allocate memh"); + ucc_free(lmemh); + status = UCC_ERR_NO_MEMORY; + goto out; + } + + if (is_src) { + total_count = coll_args->src.info.count; + buffer = coll_args->src.info.buffer; + datatype = coll_args->src.info.datatype; + } else { + total_count = coll_args->dst.info.count; + buffer = coll_args->dst.info.buffer; + datatype = coll_args->dst.info.datatype; + } + + lmemh->address = buffer; + lmemh->len = total_count * ucc_dt_size(datatype); + lmemh->num_tls = 1; /* Only one transport layer (UCP) */ + strncpy(lmemh->tl_h->tl_name, "ucp", UCC_MEM_MAP_TL_NAME_LEN - 1); + status = ucc_tl_ucp_mem_map(&ctx->super.super, UCC_MEM_MAP_MODE_EXPORT, + lmemh, lmemh->tl_h); + if (UCC_OK != status) { + tl_error(UCC_TASK_LIB(task), "failed to map memory for memh"); + ucc_free(lmemh->tl_h); + ucc_free(lmemh); + goto out; + } + *memh = lmemh; +out: + return status; +} + +/* + * This function initializes dynamic memory segments for onesided collectives. + * It checks if user-provided memory handles are available, and if not, + * creates and maps local memory handles for source and destination buffers. + * These handles will be exchanged across ranks for remote memory access. */ +UCC_TL_UCP_PROFILE_FUNC(ucc_status_t, ucc_tl_ucp_coll_dynamic_segment_init, + (coll_args, alg, task), ucc_coll_args_t *coll_args, + ucc_tl_ucp_onesided_alg_type alg, + ucc_tl_ucp_task_t *task) +{ + ucc_tl_ucp_team_t *tl_team = UCC_TL_UCP_TASK_TEAM(task); + ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(tl_team); + ucc_status_t status = UCC_OK; + ucc_mem_map_memh_t *src_memh; + ucc_mem_map_memh_t *dst_memh; + + if ((coll_args->coll_type == UCC_COLL_TYPE_ALLTOALLV) || + (coll_args->coll_type == UCC_COLL_TYPE_ALLGATHERV) || + (coll_args->coll_type == UCC_COLL_TYPE_GATHERV) || + (coll_args->coll_type == UCC_COLL_TYPE_REDUCE_SCATTERV) || + (coll_args->coll_type == UCC_COLL_TYPE_SCATTERV)) { + tl_debug(UCC_TASK_LIB(task), "dynamic segments are not supported for %s", + ucc_coll_type_str(coll_args->coll_type)); + return UCC_ERR_NOT_SUPPORTED; + } + if ((coll_args->mask & UCC_COLL_ARGS_FIELD_FLAGS) && + (coll_args->flags & UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS)) { + return UCC_OK; + } + /* Skip dynamic mapping when the user has already provided the handle that + * this algorithm needs: src for GET (peers pull from it), dst for PUT + * (peers push into it). The non-DYN_SEG progress path will use it. */ + if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET) { + if ((coll_args->mask & UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH) && + (coll_args->mask & UCC_COLL_ARGS_FIELD_FLAGS) && + (coll_args->flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL)) { + return UCC_OK; + } + } else { + if ((coll_args->mask & UCC_COLL_ARGS_FIELD_MEM_MAP_DST_MEMH) && + (coll_args->mask & UCC_COLL_ARGS_FIELD_FLAGS) && + (coll_args->flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL)) { + return UCC_OK; + } + } + /* Register both local buffers for RDMA. Only the remotely-accessed side + * (src for GET, dst for PUT) is exported to peers during exchange; the + * other side is kept as the local UCP memh for the RDMA operation. */ + status = dynamic_segment_map_memh(&src_memh, coll_args, DYN_SEG_SRC, task); + if (UCC_OK != status) { + return status; + } + status = dynamic_segment_map_memh(&dst_memh, coll_args, DYN_SEG_DST, task); + if (UCC_OK != status) { + ucc_tl_ucp_mem_unmap(&ctx->super.super, UCC_MEM_MAP_MODE_EXPORT, + src_memh->tl_h); + ucc_free(src_memh->tl_h); + ucc_free(src_memh); + return status; + } + memset(&task->dynamic_segments, 0, sizeof(task->dynamic_segments)); + task->dynamic_segments.src_local = src_memh; + task->dynamic_segments.dst_local = dst_memh; + task->dynamic_segments.src_global = NULL; + task->dynamic_segments.dst_global = NULL; + task->dynamic_segments.alg = alg; + task->flags |= UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG; + task->dynamic_segments.exchange_step = UCC_TL_UCP_DYN_SEG_EXCHANGE_STEP_INIT; + return status; +} + +static inline ucc_status_t dynamic_segment_alloc_seg(ucc_mem_map_memh_t **memh, + size_t packed_size) +{ + ucc_mem_map_memh_t *lmemh; + + lmemh = ucc_calloc(1, sizeof(ucc_mem_map_memh_t) + packed_size, "packed_memh"); + if (!lmemh) { + return UCC_ERR_NO_MEMORY; + } + lmemh->tl_h = ucc_calloc(1, sizeof(ucc_mem_map_tl_t), "packed_tl_h"); + if (!lmemh->tl_h) { + ucc_free(lmemh); + return UCC_ERR_NO_MEMORY; + } + *memh = lmemh; + return UCC_OK; +} + +static inline void dynamic_segment_memh_pack(ucc_context_h ctx, + ucc_tl_ucp_dyn_seg_args_t *args, + int is_src) +{ + ucc_mem_map_memh_t *memh; + void *pack_buffer; + size_t pack_size; + void *address; + size_t len; + + if (is_src) { + pack_buffer = args->src_pack_buffer; + pack_size = args->src_pack_size; + address = args->src_memh_local->address; + len = args->src_memh_local->len; + memh = args->src_memh_pack; + } else { + pack_buffer = args->dst_pack_buffer; + pack_size = args->dst_pack_size; + address = args->dst_memh_local->address; + len = args->dst_memh_local->len; + memh = args->dst_memh_pack; + } + /* Pack into exchange buffer. pack_buffer layout: + * [0 .. TL_NAME_LEN) : TL name (null-padded) + * [TL_NAME_LEN .. +sizeof(size_t)): packed rkey size + * [TL_NAME_LEN + sizeof(size_t).. ): packed rkey data */ + strncpy(memh->pack_buffer, "ucp", UCC_MEM_MAP_TL_NAME_LEN - 1); + memcpy(PTR_OFFSET(memh->pack_buffer, UCC_MEM_MAP_TL_NAME_LEN), &pack_size, + sizeof(size_t)); + memcpy(PTR_OFFSET(memh->pack_buffer, UCC_MEM_MAP_TL_NAME_LEN + sizeof(size_t)), + pack_buffer, pack_size); + if (!is_src) { + memcpy(memh->tl_h, args->dst_memh_local->tl_h, sizeof(ucc_mem_map_tl_t)); + } else { + memcpy(memh->tl_h, args->src_memh_local->tl_h, sizeof(ucc_mem_map_tl_t)); + } + memh->mode = UCC_MEM_MAP_MODE_EXPORT; + memh->context = ctx; + memh->address = address; + memh->len = len; + memh->num_tls = 1; +} + +static ucc_status_t +dynamic_segment_pack_memory_handles(ucc_tl_ucp_dyn_seg_args_t *args) +{ + ucc_tl_ucp_team_t *tl_team = UCC_TL_UCP_TASK_TEAM(args->task); + ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(tl_team); + ucc_status_t status; + + if (args->src_memh_local) { + status = ucc_tl_ucp_memh_pack(&ctx->super.super, UCC_MEM_MAP_MODE_EXPORT, + args->src_memh_local->tl_h, + &args->src_pack_buffer); + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(args->task), "failed to pack src memory handle"); + return status; + } + args->src_pack_size = args->src_memh_local->tl_h->packed_size; + } + if (args->dst_memh_local) { + status = ucc_tl_ucp_memh_pack(&ctx->super.super, UCC_MEM_MAP_MODE_EXPORT, + args->dst_memh_local->tl_h, + &args->dst_pack_buffer); + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(args->task), "failed to pack dst memory handle"); + ucc_free(args->src_pack_buffer); + args->src_pack_buffer = NULL; + return status; + } + args->dst_pack_size = args->dst_memh_local->tl_h->packed_size; + } + return UCC_OK; +} + +static ucc_status_t +dynamic_segment_calculate_sizes_start(ucc_tl_ucp_dyn_seg_args_t *args, + ucc_service_coll_req_t **scoll_req) +{ + ucc_tl_ucp_team_t *tl_team = UCC_TL_UCP_TASK_TEAM(args->task); + ucc_team_t *core_team = UCC_TL_CORE_TEAM(tl_team); + ucc_subset_t subset; + size_t *global_sizes; + ucc_status_t status; + size_t local_pack_size; + + subset.map = UCC_TL_TEAM_MAP(tl_team); + subset.myrank = UCC_TL_TEAM_RANK(tl_team); + + /* Calculate total pack size for this rank - only destination handles + * Use inner packed TL size; exchange_size will add outer TL header. */ + if (args->task->dynamic_segments.alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET) { + local_pack_size = args->src_pack_size; + } else { + local_pack_size = args->dst_pack_size; + } + + global_sizes = ucc_calloc(UCC_TL_TEAM_SIZE(tl_team), sizeof(size_t), "global sizes"); + if (!global_sizes) { + tl_error(UCC_TASK_LIB(args->task), + "failed to allocate global sizes buffer"); + return UCC_ERR_NO_MEMORY; + } + args->global_sizes = global_sizes; + + status = ucc_service_allgather(core_team, &local_pack_size, global_sizes, + sizeof(size_t), subset, scoll_req); + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(args->task), + "failed to start service allgather for sizes"); + ucc_free(global_sizes); + args->global_sizes = NULL; + return status; + } + return UCC_OK; +} + +static ucc_status_t +dynamic_segment_calculate_sizes_test(ucc_tl_ucp_dyn_seg_args_t *args, + ucc_service_coll_req_t **scoll_req) +{ + ucc_tl_ucp_team_t *tl_team = UCC_TL_UCP_TASK_TEAM(args->task); + ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(tl_team); + ucc_status_t status; + int i; + + status = ucc_collective_test(&(*scoll_req)->task->super); + if (status == UCC_INPROGRESS) { + if (ctx->cfg.service_worker) { + ucp_worker_progress(ctx->service_worker.ucp_worker); + } + return UCC_INPROGRESS; + } else if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(args->task), + "failed during service allgather for sizes %s", + ucc_status_string(status)); + ucc_service_coll_finalize(*scoll_req); + ucc_free(args->global_sizes); + args->global_sizes = NULL; + return status; + } + args->max_individual_pack_size = 0; + for (i = 0; i < UCC_TL_TEAM_SIZE(tl_team); i++) { + if (args->global_sizes[i] > args->max_individual_pack_size) { + args->max_individual_pack_size = args->global_sizes[i]; + } + } + /* Total per-rank slot: memh struct + TL name + packed-size field + rkey data. */ + args->exchange_size = sizeof(ucc_mem_map_memh_t) + UCC_MEM_MAP_TL_NAME_LEN + + sizeof(size_t) + args->max_individual_pack_size; + ucc_service_coll_finalize(*scoll_req); + ucc_free(args->global_sizes); + args->global_sizes = NULL; + *scoll_req = NULL; + return UCC_OK; +} + +static ucc_status_t +dynamic_segment_allocate_buffers(ucc_tl_ucp_dyn_seg_args_t *args) +{ + ucc_tl_ucp_team_t *tl_team = UCC_TL_UCP_TASK_TEAM(args->task); + size_t pack_slots = args->max_individual_pack_size + + UCC_MEM_MAP_TL_NAME_LEN + sizeof(size_t); + ucc_status_t status; + + status = dynamic_segment_alloc_seg(&args->src_memh_pack, pack_slots); + if (UCC_OK != status) { + tl_error(UCC_TASK_LIB(args->task), "failed to allocate src_memh_pack"); + return status; + } + status = dynamic_segment_alloc_seg(&args->dst_memh_pack, pack_slots); + if (UCC_OK != status) { + tl_error(UCC_TASK_LIB(args->task), "failed to allocate dst_memh_pack"); + goto err_free_src; + } + args->exchange_buffer = ucc_calloc(1, args->exchange_size, "exchange buffer"); + if (!args->exchange_buffer) { + tl_error(UCC_TASK_LIB(args->task), "failed to allocate exchange buffer"); + status = UCC_ERR_NO_MEMORY; + goto err_free_dst; + } + args->task->dynamic_segments.global_buffer = + ucc_calloc(UCC_TL_TEAM_SIZE(tl_team), args->exchange_size, "global buffer"); + if (!args->task->dynamic_segments.global_buffer) { + tl_error(UCC_TASK_LIB(args->task), "failed to allocate global buffer"); + ucc_free(args->exchange_buffer); + args->exchange_buffer = NULL; + status = UCC_ERR_NO_MEMORY; + goto err_free_dst; + } + return UCC_OK; + +err_free_dst: + ucc_free(args->dst_memh_pack->tl_h); + ucc_free(args->dst_memh_pack); + args->dst_memh_pack = NULL; +err_free_src: + ucc_free(args->src_memh_pack->tl_h); + ucc_free(args->src_memh_pack); + args->src_memh_pack = NULL; + return status; +} + +static ucc_status_t +dynamic_segment_pack_and_exchange_data_start(ucc_tl_ucp_dyn_seg_args_t *args, + ucc_service_coll_req_t **scoll_req) +{ + ucc_tl_ucp_team_t *tl_team = UCC_TL_UCP_TASK_TEAM(args->task); + ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(tl_team); + ucc_team_t *core_team = UCC_TL_CORE_TEAM(tl_team); + ucc_subset_t subset; + ucc_status_t status; + size_t copy_size; + + subset.map = UCC_TL_TEAM_MAP(tl_team); + subset.myrank = UCC_TL_TEAM_RANK(tl_team); + + /* Include outer TL header (name + size) in the copy size in addition to memh struct */ + copy_size = sizeof(ucc_mem_map_memh_t) + UCC_MEM_MAP_TL_NAME_LEN + + sizeof(size_t); + if (args->task->dynamic_segments.alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET) { + /* GET: each rank shares its src handle so peers can GET from it */ + dynamic_segment_memh_pack((ucc_context_h)&ctx->super.super, args, DYN_SEG_SRC); + copy_size += args->src_pack_size; + memcpy(args->exchange_buffer, args->src_memh_pack, copy_size); + } else { + /* PUT: each rank shares its dst handle so peers can PUT into it */ + dynamic_segment_memh_pack((ucc_context_h)&ctx->super.super, args, DYN_SEG_DST); + copy_size += args->dst_pack_size; + memcpy(args->exchange_buffer, args->dst_memh_pack, copy_size); + } + /* Allgather the packed memory handles */ + status = ucc_service_allgather(core_team, args->exchange_buffer, + args->task->dynamic_segments.global_buffer, + args->exchange_size, subset, scoll_req); + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(args->task), + "failed to start service allgather for memory handles"); + return status; + } + return UCC_OK; +} + +static ucc_status_t +dynamic_segment_pack_and_exchange_data_test(ucc_tl_ucp_dyn_seg_args_t *args, + ucc_service_coll_req_t **scoll_req) +{ + ucc_tl_ucp_team_t *tl_team = UCC_TL_UCP_TASK_TEAM(args->task); + ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(tl_team); + ucc_status_t status; + + status = ucc_collective_test(&(*scoll_req)->task->super); + if (status == UCC_INPROGRESS) { + if (ctx->cfg.service_worker) { + ucp_worker_progress(ctx->service_worker.ucp_worker); + } + return UCC_INPROGRESS; + } else if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(args->task), + "failed during service allgather for memory handles %s", + ucc_status_string(status)); + ucc_service_coll_finalize(*scoll_req); + return status; + } + ucc_service_coll_finalize(*scoll_req); + *scoll_req = NULL; + return UCC_OK; +} + +static ucc_status_t +dynamic_segment_import_memory_handles(ucc_tl_ucp_dyn_seg_args_t *args) +{ + ucc_tl_ucp_team_t *tl_team = UCC_TL_UCP_TASK_TEAM(args->task); + ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(tl_team); + ucc_mem_map_memh_t **global; + size_t offset; + ucc_status_t status; + int i; + int j; + + /* Only allocate destination handles for one-sided operations */ + global = + ucc_calloc(UCC_TL_TEAM_SIZE(tl_team), sizeof(ucc_mem_map_memh_t *), "global"); + if (!global) { + tl_error(UCC_TASK_LIB(args->task), + "failed to allocate global memory handles"); + return UCC_ERR_NO_MEMORY; + } + + /* Import memory handles for each rank using ucc_tl_ucp_mem_map */ + for (i = 0; i < UCC_TL_TEAM_SIZE(tl_team); i++) { + /* Each rank's data in global buffer contains only: + - dst_memh_pack: sizeof(ucc_mem_map_memh_t) + max_individual_pack_size + sizeof(size_t) * 2 + */ + offset = i * args->exchange_size; + global[i] = + (ucc_mem_map_memh_t *)PTR_OFFSET( + args->task->dynamic_segments.global_buffer, offset); + global[i]->tl_h = + ucc_calloc(1, sizeof(ucc_mem_map_tl_t), "global tl_h"); + if (!global[i]->tl_h) { + tl_error(UCC_TASK_LIB(args->task), + "failed to allocate global tl handles"); + status = UCC_ERR_NO_MEMORY; + goto out; + } + + status = ucc_tl_ucp_mem_map( + &ctx->super.super, UCC_MEM_MAP_MODE_IMPORT, + global[i], + global[i]->tl_h); + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(args->task), + "failed to import dst memory handle for rank %d ", i); + ucc_free(global[i]->tl_h); + global[i]->tl_h = NULL; + goto out; + } + } + + if (args->task->dynamic_segments.alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET) { + args->task->dynamic_segments.src_global = global; + args->task->dynamic_segments.dst_global = NULL; + } else { + args->task->dynamic_segments.src_global = NULL; + args->task->dynamic_segments.dst_global = global; + } + + return UCC_OK; +out: + for (j = 0; j < i; j++) { + if (global[j]->tl_h) { + /* we need to unmap */ + ucc_tl_ucp_mem_unmap(&ctx->super.super, UCC_MEM_MAP_MODE_IMPORT, + global[j]->tl_h); + } + ucc_free(global[j]->tl_h); + } + ucc_free(global); + return status; +} + +static void dynamic_segment_cleanup_buffers(ucc_tl_ucp_dyn_seg_args_t *args) +{ + if (!args) { + return; + } + + if (args->src_pack_buffer) { + ucc_free(args->src_pack_buffer); + args->src_pack_buffer = NULL; + } + if (args->dst_pack_buffer) { + ucc_free(args->dst_pack_buffer); + args->dst_pack_buffer = NULL; + } + if (args->src_memh_pack) { + if (args->src_memh_pack->tl_h) { + ucc_free(args->src_memh_pack->tl_h); + args->src_memh_pack->tl_h = NULL; + } + ucc_free(args->src_memh_pack); + args->src_memh_pack = NULL; + } + if (args->dst_memh_pack) { + if (args->dst_memh_pack->tl_h) { + ucc_free(args->dst_memh_pack->tl_h); + args->dst_memh_pack->tl_h = NULL; + } + ucc_free(args->dst_memh_pack); + args->dst_memh_pack = NULL; + } + if (args->exchange_buffer) { + ucc_free(args->exchange_buffer); + args->exchange_buffer = NULL; + } + if (args->global_sizes) { + ucc_free(args->global_sizes); + args->global_sizes = NULL; + } +} + +UCC_TL_UCP_PROFILE_FUNC(ucc_status_t, + ucc_tl_ucp_coll_dynamic_segment_exchange_nb, (task), + ucc_tl_ucp_task_t *task) +{ + ucc_tl_ucp_team_t *tl_team = UCC_TL_UCP_TASK_TEAM(task); + ucc_team_t *core_team = UCC_TL_CORE_TEAM(tl_team); + ucc_status_t status = UCC_OK; + + if (core_team->size == 0) { + tl_error(UCC_TASK_LIB(task), + "unable to exchange segments with team size of 0"); + return UCC_ERR_INVALID_PARAM; + } + + /* Initialize on first call */ + if (task->dynamic_segments.exchange_args == NULL) { + task->dynamic_segments.exchange_args = + ucc_calloc(1, sizeof(ucc_tl_ucp_dyn_seg_args_t), "exchange_args"); + if (!task->dynamic_segments.exchange_args) { + tl_error(UCC_TASK_LIB(task), "failed to allocate exchange_args"); + return UCC_ERR_NO_MEMORY; + } + task->dynamic_segments.exchange_args->task = task; + /* Only expose the exchanged side for packing; the other side is used + * locally by the RDMA operation and does not need to be shared. */ + if (task->dynamic_segments.alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET) { + task->dynamic_segments.exchange_args->src_memh_local = + task->dynamic_segments.src_local; + task->dynamic_segments.exchange_args->dst_memh_local = NULL; + } else { + task->dynamic_segments.exchange_args->src_memh_local = NULL; + task->dynamic_segments.exchange_args->dst_memh_local = + task->dynamic_segments.dst_local; + } + task->dynamic_segments.exchange_step = 0; + task->dynamic_segments.scoll_req_sizes = NULL; + task->dynamic_segments.scoll_req_data = NULL; + } + switch (task->dynamic_segments.exchange_step) { + case UCC_TL_UCP_DYN_SEG_EXCHANGE_STEP_INIT: + status = dynamic_segment_pack_memory_handles( + task->dynamic_segments.exchange_args); + if (status != UCC_OK) { + goto err_cleanup; + } + task->dynamic_segments.exchange_step = UCC_TL_UCP_DYN_SEG_EXCHANGE_STEP_SIZE_TEST; + return UCC_INPROGRESS; + + case UCC_TL_UCP_DYN_SEG_EXCHANGE_STEP_SIZE_TEST: + if (task->dynamic_segments.exchange_args->global_sizes == NULL) { + + /* First call - start the allgather */ + status = dynamic_segment_calculate_sizes_start( + task->dynamic_segments.exchange_args, + &task->dynamic_segments.scoll_req_sizes); + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(task), "failed to start allgather %s", + ucc_status_string(status)); + goto err_cleanup; + } + return UCC_INPROGRESS; + } else { + /* Subsequent calls - test for completion */ + status = dynamic_segment_calculate_sizes_test( + task->dynamic_segments.exchange_args, + &task->dynamic_segments.scoll_req_sizes); + if (status == UCC_INPROGRESS) { + return UCC_INPROGRESS; + } + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(task), "failed to test allgather"); + goto err_cleanup; + } + task->dynamic_segments.exchange_step = + UCC_TL_UCP_DYN_SEG_EXCHANGE_STEP_DATA_ALLOC; + return UCC_INPROGRESS; + } + + case UCC_TL_UCP_DYN_SEG_EXCHANGE_STEP_DATA_ALLOC: + status = dynamic_segment_allocate_buffers( + task->dynamic_segments.exchange_args); + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(task), "failed to allocate buffers"); + goto err_cleanup; + } + task->dynamic_segments.exchange_step = UCC_TL_UCP_DYN_SEG_EXCHANGE_STEP_DATA_START; + return UCC_INPROGRESS; + + case UCC_TL_UCP_DYN_SEG_EXCHANGE_STEP_DATA_START: + if (task->dynamic_segments.scoll_req_data == NULL) { + /* First call - start the allgather */ + status = dynamic_segment_pack_and_exchange_data_start( + task->dynamic_segments.exchange_args, + &task->dynamic_segments.scoll_req_data); + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(task), "failed to start data exchange %s", + ucc_status_string(status)); + goto err_cleanup_global; + } + return UCC_INPROGRESS; + } else { + /* Subsequent calls - test for completion */ + status = dynamic_segment_pack_and_exchange_data_test( + task->dynamic_segments.exchange_args, + &task->dynamic_segments.scoll_req_data); + if (status == UCC_INPROGRESS) { + return UCC_INPROGRESS; + } + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(task), "failed to test data exchange"); + goto err_cleanup_global; + } + } + /* falls through: data exchange complete, import handles below */ + } + status = dynamic_segment_import_memory_handles( + task->dynamic_segments.exchange_args); + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(task), "failed to import memory handles"); + goto err_cleanup_global; + } + + /* Cleanup and complete */ + dynamic_segment_cleanup_buffers(task->dynamic_segments.exchange_args); + ucc_free(task->dynamic_segments.exchange_args); + task->dynamic_segments.exchange_args = NULL; + task->dynamic_segments.exchange_step = UCC_TL_UCP_DYN_SEG_EXCHANGE_STEP_COMPLETE; + return UCC_OK; + +err_cleanup_global: + if (task->dynamic_segments.global_buffer) { + ucc_free(task->dynamic_segments.global_buffer); + task->dynamic_segments.global_buffer = NULL; + } +err_cleanup: + if (task->dynamic_segments.exchange_args) { + dynamic_segment_cleanup_buffers(task->dynamic_segments.exchange_args); + ucc_free(task->dynamic_segments.exchange_args); + task->dynamic_segments.exchange_args = NULL; + } + return status; +} + +UCC_TL_UCP_PROFILE_FUNC(ucc_status_t, ucc_tl_ucp_coll_dynamic_segment_finalize, + (task), ucc_tl_ucp_task_t *task) +{ + ucc_tl_ucp_team_t *team = UCC_TL_UCP_TASK_TEAM(task); + ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team); + ucc_status_t status = UCC_OK; + ucc_status_t tmp; + size_t team_size = UCC_TL_TEAM_SIZE(team); + int i; + + if (!(task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG)) { + return UCC_OK; + } + if (task->dynamic_segments.src_global) { + for (i = 0; i < team_size; i++) { + if (task->dynamic_segments.src_global[i] && + task->dynamic_segments.src_global[i]->tl_h) { + tmp = ucc_tl_ucp_mem_unmap( + &ctx->super.super, UCC_MEM_MAP_MODE_IMPORT, + task->dynamic_segments.src_global[i]->tl_h); + if (tmp != UCC_OK) { + tl_error(UCC_TASK_LIB(task), + "failed to unmap src global memory handle for " + "rank %d", + i); + if (status == UCC_OK) { + status = tmp; + } + } + ucc_free(task->dynamic_segments.src_global[i]->tl_h); + task->dynamic_segments.src_global[i]->tl_h = NULL; + task->dynamic_segments.src_global[i] = NULL; + } + } + ucc_free(task->dynamic_segments.src_global); + task->dynamic_segments.src_global = NULL; + } + if (task->dynamic_segments.dst_global) { + for (i = 0; i < team_size; i++) { + if (task->dynamic_segments.dst_global[i] && + task->dynamic_segments.dst_global[i]->tl_h) { + tmp = ucc_tl_ucp_mem_unmap( + &ctx->super.super, UCC_MEM_MAP_MODE_IMPORT, + task->dynamic_segments.dst_global[i]->tl_h); + if (tmp != UCC_OK) { + tl_error(UCC_TASK_LIB(task), + "failed to unmap dst global memory handle for " + "rank %d", + i); + if (status == UCC_OK) { + status = tmp; + } + } + ucc_free(task->dynamic_segments.dst_global[i]->tl_h); + task->dynamic_segments.dst_global[i]->tl_h = NULL; + task->dynamic_segments.dst_global[i] = NULL; + } + } + ucc_free(task->dynamic_segments.dst_global); + task->dynamic_segments.dst_global = NULL; + } + /* Free global buffer */ + if (task->dynamic_segments.global_buffer) { + ucc_free(task->dynamic_segments.global_buffer); + task->dynamic_segments.global_buffer = NULL; + } + /* src_local and dst_local are kept alive for persistent collective reuse. + * They are released by ucc_tl_ucp_coll_dynamic_segment_destroy() when the + * task is truly destroyed. On the next start(), exchange_nb will detect + * exchange_args == NULL and restart the exchange from scratch. */ + return status; +} + +ucc_status_t ucc_tl_ucp_coll_dynamic_segment_destroy(ucc_tl_ucp_task_t *task) +{ + ucc_tl_ucp_team_t *team = UCC_TL_UCP_TASK_TEAM(task); + ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team); + ucc_status_t status = UCC_OK; + + if (!(task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG)) { + return UCC_OK; + } + if (task->dynamic_segments.src_local) { + if (task->dynamic_segments.src_local->tl_h) { + status = + ucc_tl_ucp_mem_unmap(&ctx->super.super, UCC_MEM_MAP_MODE_EXPORT, + task->dynamic_segments.src_local->tl_h); + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(task), + "failed to unmap src local memory handle"); + } + ucc_free(task->dynamic_segments.src_local->tl_h); + task->dynamic_segments.src_local->tl_h = NULL; + } + ucc_free(task->dynamic_segments.src_local); + task->dynamic_segments.src_local = NULL; + } + if (task->dynamic_segments.dst_local) { + if (task->dynamic_segments.dst_local->tl_h) { + status = + ucc_tl_ucp_mem_unmap(&ctx->super.super, UCC_MEM_MAP_MODE_EXPORT, + task->dynamic_segments.dst_local->tl_h); + if (status != UCC_OK) { + tl_error(UCC_TASK_LIB(task), + "failed to unmap dst local memory handle"); + } + ucc_free(task->dynamic_segments.dst_local->tl_h); + task->dynamic_segments.dst_local->tl_h = NULL; + } + ucc_free(task->dynamic_segments.dst_local); + task->dynamic_segments.dst_local = NULL; + } + return status; +} + ucc_status_t ucc_tl_ucp_coll_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task_h) diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index b36ce7a3e84..420e33b3bde 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -15,6 +15,14 @@ #define UCC_UUNITS_AUTO_RADIX 4 #define UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR 9 +enum ucc_tl_ucp_dyn_seg_exchange_step { + UCC_TL_UCP_DYN_SEG_EXCHANGE_STEP_INIT, + UCC_TL_UCP_DYN_SEG_EXCHANGE_STEP_SIZE_TEST, + UCC_TL_UCP_DYN_SEG_EXCHANGE_STEP_DATA_ALLOC, + UCC_TL_UCP_DYN_SEG_EXCHANGE_STEP_DATA_START, + UCC_TL_UCP_DYN_SEG_EXCHANGE_STEP_COMPLETE +}; + ucc_status_t ucc_tl_ucp_team_default_score_str_alloc(ucc_tl_ucp_team_t *team, char *default_select_str[UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR]); @@ -276,4 +284,30 @@ static inline unsigned ucc_tl_ucp_get_knomial_radix(ucc_tl_ucp_team_t *team, return radix; } +ucc_status_t ucc_tl_ucp_coll_dynamic_segment_init(ucc_coll_args_t *coll_args, + ucc_tl_ucp_onesided_alg_type alg, + ucc_tl_ucp_task_t *task); + +ucc_status_t ucc_tl_ucp_coll_dynamic_segment_exchange_nb(ucc_tl_ucp_task_t *task); + +/* Called at end of each iteration: releases imported/global handles and + * resets exchange state for the next start(). */ +ucc_status_t ucc_tl_ucp_coll_dynamic_segment_finalize(ucc_tl_ucp_task_t *task); + +/* Called once at task destroy: releases the locally exported handles. */ +ucc_status_t ucc_tl_ucp_coll_dynamic_segment_destroy(ucc_tl_ucp_task_t *task); + +static inline ucc_status_t ucc_tl_ucp_test_dynamic_segment(ucc_tl_ucp_task_t *task) +{ + if (!(task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG)) { + return UCC_OK; + } + + if (task->dynamic_segments.exchange_step < UCC_TL_UCP_DYN_SEG_EXCHANGE_STEP_COMPLETE) { + return ucc_tl_ucp_coll_dynamic_segment_exchange_nb(task); + } + + return UCC_OK; +} + #endif diff --git a/src/components/tl/ucp/tl_ucp_context.c b/src/components/tl/ucp/tl_ucp_context.c index d60b1359ca9..06032469d63 100644 --- a/src/components/tl/ucp/tl_ucp_context.c +++ b/src/components/tl/ucp/tl_ucp_context.c @@ -206,7 +206,7 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, ucp_params.features |= UCP_FEATURE_EXPORTED_MEMH; } ucp_params.tag_sender_mask = UCC_TL_UCP_TAG_SENDER_MASK; - ucp_params.name = "UCC_UCP_CONTEXT"; + ucp_params.name = "UCC_UCP_CONTEXT"; if (params->estimated_num_ppn > 0) { ucp_params.field_mask |= UCP_PARAM_FIELD_ESTIMATED_NUM_PPN; @@ -307,9 +307,9 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, "failed to register progress function", err_thread_mode, UCC_ERR_NO_MESSAGE, self); - self->remote_info = NULL; - self->n_rinfo_segs = 0; - self->rkeys = NULL; + self->remote_info = NULL; + self->n_rinfo_segs = 0; + self->rkeys = NULL; if (params->params.mask & UCC_CONTEXT_PARAM_FIELD_MEM_PARAMS && params->params.mask & UCC_CONTEXT_PARAM_FIELD_OOB) { ucc_status = ucc_tl_ucp_ctx_remote_populate( @@ -552,8 +552,8 @@ ucc_status_t ucc_tl_ucp_ctx_remote_populate(ucc_tl_ucp_context_t * ctx, MAX_NR_SEGMENTS); return UCC_ERR_INVALID_PARAM; } - ctx->rkeys = - (ucp_rkey_h *)ucc_calloc(sizeof(ucp_rkey_h), nsegs * size, "ucp_ctx_rkeys"); + ctx->rkeys = (ucp_rkey_h *)ucc_calloc(sizeof(ucp_rkey_h), nsegs * size, + "ucp_ctx_rkeys"); if (NULL == ctx->rkeys) { tl_error(ctx->super.super.lib, "failed to allocated %zu bytes", sizeof(ucp_rkey_h) * nsegs * size); @@ -569,8 +569,8 @@ ucc_status_t ucc_tl_ucp_ctx_remote_populate(ucc_tl_ucp_context_t * ctx, } for (i = 0; i < nsegs; i++) { - mmap_params.field_mask = - UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH; + mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | + UCP_MEM_MAP_PARAM_FIELD_LENGTH; mmap_params.address = map.segments[i].address; mmap_params.length = map.segments[i].len; @@ -581,10 +581,11 @@ ucc_status_t ucc_tl_ucp_ctx_remote_populate(ucc_tl_ucp_context_t * ctx, ucc_status = ucs_status_to_ucc_status(status); goto fail_mem_map; } + ctx->remote_info[i].mem_h = (void *)mh; status = ucp_rkey_pack(ctx->worker.ucp_context, mh, - &ctx->remote_info[i].packed_key, - &ctx->remote_info[i].packed_key_len); + &ctx->remote_info[i].packed_key, + &ctx->remote_info[i].packed_key_len); if (UCS_OK != status) { tl_error(ctx->super.super.lib, "failed to pack UCP key with error code: %d", status); @@ -753,7 +754,6 @@ ucc_status_t ucc_tl_ucp_mem_map(const ucc_base_context_t *context, ucc_mem_map_m // copy name information for look ups later strncpy(tl_h->tl_name, "ucp", UCC_MEM_MAP_TL_NAME_LEN - 1); - /* nothing to do for UCC_MEM_MAP_MODE_EXPORT_OFFLOAD and UCC_MEM_MAP_MODE_IMPORT */ if (mode == UCC_MEM_MAP_MODE_EXPORT) { ucc_status = ucc_tl_ucp_mem_map_export(ctx, memh->address, memh->len, mode, m_data); if (UCC_OK != ucc_status) { @@ -801,28 +801,28 @@ ucc_status_t ucc_tl_ucp_mem_unmap(const ucc_base_context_t *context, ucc_mem_map ucp_rkey_buffer_release(data->rinfo.packed_key); data->rinfo.packed_key = NULL; } + // Free the data structure itself for export mode too + ucc_free(data); + memh->tl_data = NULL; } else if (mode == UCC_MEM_MAP_MODE_IMPORT || mode == UCC_MEM_MAP_MODE_IMPORT_OFFLOAD) { // need to free rkeys (data->rkey) , packed memh (data->packed_memh) if (data->packed_memh) { ucp_memh_buffer_release(data->packed_memh, NULL); + data->packed_memh = NULL; } if (data->rinfo.packed_key) { ucp_rkey_buffer_release(data->rinfo.packed_key); + data->rinfo.packed_key = NULL; } if (data->rkey) { ucp_rkey_destroy(data->rkey); } + ucc_free(data); + memh->tl_data = NULL; } else { ucc_error("Unknown mem map mode entered: %d", mode); return UCC_ERR_INVALID_PARAM; } - - /* Free the TL data structure */ - if (data) { - ucc_free(data); - memh->tl_data = NULL; - } - return UCC_OK; } diff --git a/src/components/tl/ucp/tl_ucp_sendrecv.h b/src/components/tl/ucp/tl_ucp_sendrecv.h index 5f0715bd486..de030889219 100644 --- a/src/components/tl/ucp/tl_ucp_sendrecv.h +++ b/src/components/tl/ucp/tl_ucp_sendrecv.h @@ -408,6 +408,7 @@ static inline ucc_status_t ucc_tl_ucp_check_memh(ucp_ep_h *ep, ucc_rank_t me, ucs_status_t ucs_status; int i; size_t offset; + size_t addr_offset; base = (uint64_t)dst_memh[me]->address; end = base + dst_memh[me]->len; @@ -415,8 +416,9 @@ static inline ucc_status_t ucc_tl_ucp_check_memh(ucp_ep_h *ep, ucc_rank_t me, if (!((uint64_t)va >= base && (uint64_t)va < end)) { return UCC_ERR_NOT_FOUND; } - *rva = (uint64_t)PTR_OFFSET(dst_memh[peer]->address, - ((uint64_t)va - (uint64_t)dst_memh[me]->address)); + addr_offset = (uint64_t)va - (uint64_t)dst_memh[me]->address; + *rva = (uint64_t)PTR_OFFSET(dst_memh[peer]->address, addr_offset); + dst_tl_data = (ucc_tl_ucp_memh_data_t *)dst_memh[peer]->tl_h[tl_index].tl_data; if (NULL == dst_tl_data->rkey) { offset = 0; @@ -444,6 +446,25 @@ static inline ucc_status_t ucc_tl_ucp_check_memh(ucp_ep_h *ep, ucc_rank_t me, return UCC_OK; } +static inline int resolve_segment(const void *va, size_t *key_sizes, + ptrdiff_t *key_offset, size_t nr_segments, + ucc_tl_ucp_remote_info_t *rinfo) +{ + int i; + uint64_t base; + uint64_t end; + + for (i = 0; i < nr_segments; i++) { + base = (uint64_t)rinfo[i].va_base; + end = base + rinfo[i].len; + if ((uint64_t)va >= base && (uint64_t)va < end) { + return i; + } + *key_offset += key_sizes[i]; + } + return -1; +} + static inline ucc_status_t ucc_tl_ucp_resolve_p2p_by_va(ucc_tl_ucp_team_t *team, void *va, ucp_ep_h *ep, ucc_rank_t peer, uint64_t *rva, ucp_rkey_h *rkey, @@ -461,12 +482,13 @@ ucc_tl_ucp_resolve_p2p_by_va(ucc_tl_ucp_team_t *team, void *va, ucp_ep_h *ep, void *offset; ptrdiff_t base_offset; ucc_status_t status; + ucc_rank_t ctx_peer; /* for onesided reg. through context */ *segment = -1; core_rank = ucc_ep_map_eval(UCC_TL_TEAM_MAP(team), peer); ucc_assert(UCC_TL_CORE_TEAM(team) != NULL); - peer = ucc_get_ctx_rank(UCC_TL_CORE_TEAM(team), core_rank); - + /* ctx_peer (context rank) is used for segment-based addressing */ + ctx_peer = ucc_get_ctx_rank(UCC_TL_CORE_TEAM(team), core_rank); offset = ucc_get_team_ep_addr(UCC_TL_CORE_CTX(team), UCC_TL_CORE_TEAM(team), core_rank, ucc_tl_ucp.super.super.id); @@ -474,48 +496,39 @@ ucc_tl_ucp_resolve_p2p_by_va(ucc_tl_ucp_team_t *team, void *va, ucp_ep_h *ep, rvas = (uint64_t *)base_offset; key_sizes = PTR_OFFSET(base_offset, (section_offset * 2)); keys = PTR_OFFSET(base_offset, (section_offset * 3)); - - for (int i = 0; i < ctx->n_rinfo_segs; i++) { - uint64_t base = (uint64_t)team->va_base[i]; - uint64_t end = base + team->base_length[i]; - if ((uint64_t)va >= base && - (uint64_t)va < end) { - *segment = i; - break; - } - key_offset += key_sizes[i]; - } - if (ucc_unlikely(0 > *segment)) { - if (dst_memh) { - /* check if segment is in src/dst memh */ - status = find_tl_index(dst_memh[grank], &tl_index); - if (status == UCC_ERR_NOT_FOUND) { - tl_error(UCC_TL_TEAM_LIB(team), - "attempt to perform one-sided operation with malformed mem map handle"); - return status; - } - - status = ucc_tl_ucp_check_memh(ep, grank, peer, va, rva, rkey, tl_index, dst_memh); - if (status == UCC_OK) { - return UCC_OK; + *segment = resolve_segment(va, key_sizes, &key_offset, ctx->n_rinfo_segs, + ctx->remote_info); + if (*segment >= 0) { + *rva = rvas[*segment] + + ((uint64_t)va - (uint64_t)ctx->remote_info[*segment].va_base); + if (ucc_unlikely(NULL == UCC_TL_UCP_REMOTE_RKEY(ctx, ctx_peer, *segment))) { + ucs_status_t ucs_status = ucp_ep_rkey_unpack( + *ep, PTR_OFFSET(keys, key_offset), + &UCC_TL_UCP_REMOTE_RKEY(ctx, ctx_peer, *segment)); + if (UCS_OK != ucs_status) { + return ucs_status_to_ucc_status(ucs_status); } } - /* general error if nothing was found */ - tl_error(UCC_TL_TEAM_LIB(team), - "attempt to perform one-sided operation on non-registered memory %p", va); - return UCC_ERR_NOT_FOUND; + *rkey = UCC_TL_UCP_REMOTE_RKEY(ctx, ctx_peer, *segment); + return UCC_OK; } - if (ucc_unlikely(NULL == UCC_TL_UCP_REMOTE_RKEY(ctx, peer, *segment))) { - ucs_status_t ucs_status = - ucp_ep_rkey_unpack(*ep, PTR_OFFSET(keys, key_offset), - &UCC_TL_UCP_REMOTE_RKEY(ctx, peer, *segment)); - if (UCS_OK != ucs_status) { - return ucs_status_to_ucc_status(ucs_status); + if (dst_memh) { + /* check if segment is in dst memh */ + status = find_tl_index(dst_memh[grank], &tl_index); + if (status == UCC_ERR_NOT_FOUND) { + tl_error(UCC_TL_TEAM_LIB(team), + "attempt to perform one-sided operation with malformed mem map handle"); + return status; + } + status = ucc_tl_ucp_check_memh(ep, grank, peer, va, rva, rkey, tl_index, dst_memh); + if (status == UCC_OK) { + return UCC_OK; } } - *rkey = UCC_TL_UCP_REMOTE_RKEY(ctx, peer, *segment); - *rva = rvas[*segment] + ((uint64_t)va - (uint64_t)team->va_base[*segment]); - return UCC_OK; + /* general error if nothing was found */ + tl_error(UCC_TL_TEAM_LIB(team), + "attempt to perform one-sided operation on non-registered memory %p", va); + return UCC_ERR_NOT_FOUND; } static inline ucc_status_t ucc_tl_ucp_flush(ucc_tl_ucp_team_t *team) diff --git a/src/components/tl/ucp/tl_ucp_task.h b/src/components/tl/ucp/tl_ucp_task.h index 19b137bbe79..91e828bd9ac 100644 --- a/src/components/tl/ucp/tl_ucp_task.h +++ b/src/components/tl/ucp/tl_ucp_task.h @@ -33,9 +33,28 @@ typedef struct ucc_tl_ucp_dpu_offload_buf_info enum ucc_tl_ucp_task_flags { /*indicates whether subset field of tl_ucp_task is set*/ - UCC_TL_UCP_TASK_FLAG_SUBSET = UCC_BIT(0), + UCC_TL_UCP_TASK_FLAG_SUBSET = UCC_BIT(0), + /* indicates usage of dynamic segments */ + UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG = UCC_BIT(1), }; +/* Structure to hold dynamic segment exchange parameters and buffers */ +typedef struct { + ucc_tl_ucp_task_t *task; + void *src_pack_buffer; + void *dst_pack_buffer; + size_t src_pack_size; + size_t dst_pack_size; + size_t max_individual_pack_size; + size_t exchange_size; + ucc_mem_map_memh_t *src_memh_pack; + ucc_mem_map_memh_t *dst_memh_pack; + void *exchange_buffer; + ucc_mem_map_memh_t *src_memh_local; + ucc_mem_map_memh_t *dst_memh_local; + size_t *global_sizes; +} ucc_tl_ucp_dyn_seg_args_t; + typedef struct ucc_tl_ucp_task { ucc_coll_task_t super; uint32_t flags; @@ -223,6 +242,18 @@ typedef struct ucc_tl_ucp_task { }; uint32_t flush_posted; uint32_t flush_completed; + struct { + ucc_mem_map_memh_t *src_local; + ucc_mem_map_memh_t *dst_local; + ucc_mem_map_memh_t **src_global; + ucc_mem_map_memh_t **dst_global; + ucc_tl_ucp_dyn_seg_args_t *exchange_args; + void *global_buffer; + ucc_service_coll_req_t *scoll_req_sizes; /* For sizes allgather */ + ucc_service_coll_req_t *scoll_req_data; /* For data ex allgather */ + int exchange_step; + ucc_tl_ucp_onesided_alg_type alg; + } dynamic_segments; } ucc_tl_ucp_task_t; static inline void ucc_tl_ucp_task_reset(ucc_tl_ucp_task_t *task, diff --git a/src/components/tl/ucp/tl_ucp_team.c b/src/components/tl/ucp/tl_ucp_team.c index 9ff4664650c..d36619a1101 100644 --- a/src/components/tl/ucp/tl_ucp_team.c +++ b/src/components/tl/ucp/tl_ucp_team.c @@ -182,7 +182,6 @@ ucc_status_t ucc_tl_ucp_team_create_test(ucc_base_team_t *tl_team) { ucc_tl_ucp_team_t * team = ucc_derived_of(tl_team, ucc_tl_ucp_team_t); ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team); - int i; ucc_status_t status; if (USE_SERVICE_WORKER(team)) { @@ -204,13 +203,6 @@ ucc_status_t ucc_tl_ucp_team_create_test(ucc_base_team_t *tl_team) } } - if (ctx->remote_info) { - for (i = 0; i < ctx->n_rinfo_segs; i++) { - team->va_base[i] = ctx->remote_info[i].va_base; - team->base_length[i] = ctx->remote_info[i].len; - } - } - tl_debug(tl_team->context->lib, "initialized tl team: %p", team); team->status = UCC_OK; return UCC_OK; diff --git a/test/gtest/coll/test_alltoall.cc b/test/gtest/coll/test_alltoall.cc index fadce3ba5ca..9a7e2282371 100644 --- a/test/gtest/coll/test_alltoall.cc +++ b/test/gtest/coll/test_alltoall.cc @@ -12,6 +12,11 @@ using Param_1 = std::tupleargs = coll; - coll->mask = 0; + coll->mask = coll_mask; coll->coll_type = UCC_COLL_TYPE_ALLTOALL; coll->src.info.mem_type = mem_type; coll->src.info.count = (ucc_count_t)single_rank_count * nprocs; @@ -52,9 +57,9 @@ class test_alltoall : public UccCollArgs, public ucc::test sbuf = team->procs[i].p->onesided_buf[0]; rbuf = team->procs[i].p->onesided_buf[1]; work_buf = (long *)team->procs[i].p->onesided_buf[2]; - coll->mask = UCC_COLL_ARGS_FIELD_FLAGS | + coll->mask |= UCC_COLL_ARGS_FIELD_FLAGS | UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER; - coll->flags = UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS; + coll->flags |= UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS; coll->src.info.buffer = sbuf; coll->src.info.mem_type = UCC_MEMORY_TYPE_HOST; coll->dst.info.buffer = rbuf; @@ -135,9 +140,8 @@ class test_alltoall : public UccCollArgs, public ucc::test void data_fini_onesided(UccCollCtxVec ctxs) { for (gtest_ucc_coll_ctx_t *ctx : ctxs) { - ucc_coll_args_t *coll = ctx->args; ucc_free(ctx->init_buf); - free(coll); + free(ctx->args); free(ctx); } ctxs.clear(); @@ -246,6 +250,48 @@ UCC_TEST_P(test_alltoall_0, single_onesided) data_fini_onesided(ctxs); } +UCC_TEST_P(test_alltoall_0, single_onesided_dynamic_segment) +{ + const int team_id = std::get<0>(GetParam()); + const ucc_datatype_t dtype = std::get<1>(GetParam()); + ucc_memory_type_t mem_type = std::get<2>(GetParam()); + gtest_ucc_inplace_t inplace = std::get<3>(GetParam()); + const int count = std::get<4>(GetParam()); + UccTeam_h reference_team = UccJob::getStaticTeams()[team_id]; + int size = reference_team->procs.size(); + ucc_job_env_t env = {{"UCC_TL_UCP_TUNE", "alltoall:0-inf:@onesided"}}; + bool is_contig = true; + UccJob job(size, UccJob::UCC_JOB_CTX_GLOBAL_ONESIDED, env); + UccTeam_h team; + std::vector reference_ranks; + UccCollCtxVec ctxs; + + for (auto i = 0; i < reference_team->n_procs; i++) { + int rank = reference_team->procs[i].p->job_rank; + reference_ranks.push_back(rank); + if (is_contig && i > 0 && + (rank - reference_ranks[i - 1] > 1 || + reference_ranks[i - 1] - rank > 1)) { + is_contig = false; + } + } + team = job.create_team(reference_ranks, true, is_contig, true); + this->set_inplace(inplace); + SET_MEM_TYPE(mem_type); + /* for dynamic segments, setup as onesided and override the mask/flags */ + data_init(size, dtype, count, ctxs, team, false); + for (auto i = 0; i < ctxs.size(); i++) { + ctxs[i]->args->mask = UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER | + (ctxs[i]->args->mask & UCC_COLL_ARGS_FIELD_FLAGS); + ctxs[i]->args->flags &= UCC_COLL_ARGS_FLAG_IN_PLACE; + } + UccReq req(team, ctxs); + req.start(); + req.wait(); + EXPECT_EQ(true, data_validate(ctxs)); + data_fini_onesided(ctxs); +} + UCC_TEST_P(test_alltoall_0, single_persistent) { const int team_id = std::get<0>(GetParam()); diff --git a/test/gtest/common/test_ucc.cc b/test/gtest/common/test_ucc.cc index 4728f2119ac..2673375f157 100644 --- a/test/gtest/common/test_ucc.cc +++ b/test/gtest/common/test_ucc.cc @@ -430,10 +430,10 @@ void proc_context_create(UccProcess_h proc, int id, ThreadAllgather *ta, bool is void proc_context_create_mem_params(UccProcess_h proc, int id, ThreadAllgather *ta) { + ucc_mem_map_t map[UCC_TEST_N_MEM_SEGMENTS] = {}; ucc_status_t status; ucc_context_config_h ctx_config; std::stringstream err_msg; - ucc_mem_map_t map[UCC_TEST_N_MEM_SEGMENTS]; status = ucc_context_config_read(proc->lib_h, NULL, &ctx_config); if (status != UCC_OK) { diff --git a/test/mpi/test_mpi.cc b/test/mpi/test_mpi.cc index 3c1ce77ecc3..feb046c83dd 100644 --- a/test/mpi/test_mpi.cc +++ b/test/mpi/test_mpi.cc @@ -41,11 +41,11 @@ static ucc_status_t oob_allgather_free(void *req) UccTestMpi::UccTestMpi(int argc, char *argv[], ucc_thread_mode_t _tm, int is_local, bool with_onesided) { + ucc_mem_map_t segments[UCC_TEST_N_MEM_SEGMENTS] = {0}; ucc_lib_config_h lib_config; ucc_context_config_h ctx_config; int size, rank; char *prev_env; - ucc_mem_map_t segments[UCC_TEST_N_MEM_SEGMENTS]; MPI_Comm_size(MPI_COMM_WORLD, &size); MPI_Comm_rank(MPI_COMM_WORLD, &rank);