-
Notifications
You must be signed in to change notification settings - Fork 128
TL/UCP: add support for onesided dynamic segments #1149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
bfe39bd
2718ac6
361bc56
d95b505
15a9f44
348ac56
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| #include "alltoall.h" | ||
| #include "core/ucc_progress_queue.h" | ||
| #include "utils/ucc_math.h" | ||
| #include "tl_ucp_coll.h" | ||
| #include "tl_ucp_sendrecv.h" | ||
|
|
||
| #define CONGESTION_THRESHOLD 8 | ||
|
|
@@ -100,13 +101,29 @@ void ucc_tl_ucp_alltoall_onesided_get_progress(ucc_coll_task_t *ctask) | |
| ucc_rank_t peer = (grank + *posted + 1) % gsize; | ||
| ucc_mem_map_mem_h src_memh; | ||
| size_t nelems; | ||
| ucc_status_t status; | ||
|
|
||
| nelems = TASK_ARGS(task).src.info.count; | ||
| nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype); | ||
| src_memh = (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL) | ||
| if (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG) { | ||
| status = ucc_tl_ucp_test_dynamic_segment(task); | ||
| if (status == UCC_INPROGRESS) { | ||
| return; | ||
| } | ||
| if (UCC_OK != status) { | ||
| task->super.status = status; | ||
| tl_error(UCC_TL_TEAM_LIB(team), | ||
| "failed to exchange dynamic segments"); | ||
| return; | ||
| } | ||
| src_memh = task->dynamic_segments.dst_local; | ||
| dst_memh = (ucc_mem_map_mem_h *)task->dynamic_segments.src_global; | ||
| } else { | ||
| src_memh = (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL) | ||
| ? TASK_ARGS(task).dst_memh.global_memh[grank] | ||
| : TASK_ARGS(task).dst_memh.local_memh; | ||
| } | ||
|
|
||
| nelems = TASK_ARGS(task).src.info.count; | ||
| nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype); | ||
| for (; *posted < gsize; peer = (peer + 1) % gsize) { | ||
| UCPCHECK_GOTO(ucc_tl_ucp_get_nb(PTR_OFFSET(dest, peer * nelems), | ||
| PTR_OFFSET(src, grank * nelems), | ||
|
|
@@ -122,7 +139,10 @@ void ucc_tl_ucp_alltoall_onesided_get_progress(ucc_coll_task_t *ctask) | |
|
|
||
| alltoall_onesided_wait_completion(task, npolls); | ||
| out: | ||
| return; | ||
| if (task->super.status != UCC_INPROGRESS && | ||
| (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG)) { | ||
| task->super.status = ucc_tl_ucp_coll_dynamic_segment_finalize(task); | ||
| } | ||
|
Comment on lines
+148
to
+151
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: dynamic segment finalization happens after completion/error but the finalize call's error status overwrites the original task status. If the task completed successfully (UCC_OK) but finalization fails, the error is propagated; however if the task already failed, the finalization error replaces it, losing the original failure reason. Should finalization errors be logged separately while preserving the original task failure status? |
||
| } | ||
|
|
||
| void ucc_tl_ucp_alltoall_onesided_put_progress(ucc_coll_task_t *ctask) | ||
|
|
@@ -142,12 +162,28 @@ void ucc_tl_ucp_alltoall_onesided_put_progress(ucc_coll_task_t *ctask) | |
| ucc_rank_t peer = (grank + *posted + 1) % gsize; | ||
| ucc_mem_map_mem_h src_memh; | ||
| size_t nelems; | ||
| ucc_status_t status; | ||
|
|
||
| nelems = TASK_ARGS(task).src.info.count; | ||
| nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype); | ||
| src_memh = (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL) | ||
| if (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG) { | ||
| status = ucc_tl_ucp_test_dynamic_segment(task); | ||
| if (status == UCC_INPROGRESS) { | ||
| return; | ||
| } | ||
| if (UCC_OK != status) { | ||
| task->super.status = status; | ||
| tl_error(UCC_TL_TEAM_LIB(team), | ||
| "failed to exchange dynamic segments"); | ||
| return; | ||
| } | ||
| src_memh = task->dynamic_segments.src_local; | ||
| dst_memh = (ucc_mem_map_mem_h *)task->dynamic_segments.dst_global; | ||
| } else { | ||
| src_memh = (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL) | ||
| ? TASK_ARGS(task).src_memh.global_memh[grank] | ||
| : TASK_ARGS(task).src_memh.local_memh; | ||
| } | ||
|
|
||
| for (; *posted < gsize; peer = (peer + 1) % gsize) { | ||
| UCPCHECK_GOTO( | ||
|
|
@@ -165,15 +201,30 @@ void ucc_tl_ucp_alltoall_onesided_put_progress(ucc_coll_task_t *ctask) | |
|
|
||
| alltoall_onesided_wait_completion(task, npolls); | ||
| out: | ||
| return; | ||
| if (task->super.status != UCC_INPROGRESS && | ||
| (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG)) { | ||
| task->super.status = ucc_tl_ucp_coll_dynamic_segment_finalize(task); | ||
| } | ||
|
Comment on lines
+210
to
+213
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: same issue as GET progress: finalization status overwrites task status |
||
| } | ||
|
|
||
| ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask) | ||
| { | ||
| ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t); | ||
| ucc_tl_ucp_team_t *team = TASK_TEAM(task); | ||
| ucc_status_t status; | ||
|
|
||
| ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); | ||
| if (task->flags & UCC_TL_UCP_TASK_FLAG_USE_DYN_SEG) { | ||
| status = ucc_tl_ucp_coll_dynamic_segment_exchange_nb(task); | ||
| if (UCC_OK != status && UCC_INPROGRESS != status) { | ||
| task->super.status = status; | ||
| tl_error(UCC_TL_TEAM_LIB(team), | ||
| "failed to exchange dynamic segments"); | ||
| return task->super.status; | ||
| } | ||
| } | ||
|
Comment on lines
+223
to
+231
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: when dynamic segment exchange completes synchronously (returns UCC_OK), the function continues to line 223 and enqueues the task. However, the progress functions (GET/PUT) will call |
||
|
|
||
| /* Start the onesided operations */ | ||
| return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); | ||
| } | ||
|
|
||
|
|
@@ -190,7 +241,7 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args, | |
| }; | ||
| size_t perc_bw = | ||
| UCC_TL_UCP_TEAM_LIB(tl_team)->cfg.alltoall_onesided_percent_bw; | ||
| ucc_tl_ucp_alltoall_onesided_alg_t alg = | ||
| ucc_tl_ucp_onesided_alg_type alg = | ||
| UCC_TL_UCP_TEAM_LIB(tl_team)->cfg.alltoall_onesided_alg; | ||
| ucc_tl_ucp_schedule_t *tl_schedule = NULL; | ||
| ucc_rank_t group_size = 1; | ||
|
|
@@ -208,15 +259,6 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args, | |
| ucc_sbgp_t *sbgp; | ||
|
|
||
| ALLTOALL_TASK_CHECK(coll_args->args, tl_team); | ||
| if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_FLAGS) || | ||
| (coll_args->args.mask & UCC_COLL_ARGS_FIELD_FLAGS && | ||
| (!(coll_args->args.flags & | ||
| UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS)))) { | ||
| tl_error(UCC_TL_TEAM_LIB(tl_team), | ||
| "non memory mapped buffers are not supported"); | ||
| status = UCC_ERR_NOT_SUPPORTED; | ||
| return status; | ||
| } | ||
|
|
||
| if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH)) { | ||
| coll_args->args.src_memh.global_memh = NULL; | ||
|
|
@@ -228,7 +270,6 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args, | |
| return status; | ||
| } | ||
| } | ||
|
|
||
| if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_DST_MEMH)) { | ||
| coll_args->args.dst_memh.global_memh = NULL; | ||
| } else { | ||
|
|
@@ -239,6 +280,8 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args, | |
| return status; | ||
| } | ||
| } | ||
|
|
||
|
|
||
| status = ucc_tl_ucp_get_schedule(tl_team, coll_args, | ||
| (ucc_tl_ucp_schedule_t **)&tl_schedule); | ||
| if (ucc_unlikely(UCC_OK != status)) { | ||
|
|
@@ -264,6 +307,20 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args, | |
| task->super.finalize = ucc_tl_ucp_alltoall_onesided_finalize; | ||
| a2a_task = &task->super; | ||
|
|
||
| /* initialize dynamic segments */ | ||
| if (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_GET || | ||
| (alg == UCC_TL_UCP_ALLTOALL_ONESIDED_AUTO && | ||
| sbgp->group_size >= CONGESTION_THRESHOLD)) { | ||
| alg = UCC_TL_UCP_ALLTOALL_ONESIDED_GET; | ||
| } | ||
| status = ucc_tl_ucp_coll_dynamic_segment_init(&coll_args->args, alg, task); | ||
| if (UCC_OK != status) { | ||
| tl_error(UCC_TL_TEAM_LIB(tl_team), | ||
| "failed to initialize dynamic segments"); | ||
| ucc_tl_ucp_coll_finalize(&task->super); | ||
| goto out; | ||
| } | ||
|
Comment on lines
+328
to
+336
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: if dynamic segment initialization fails, the code jumps to |
||
|
|
||
| status = ucc_tl_ucp_coll_init(&barrier_coll_args, team, &barrier_task); | ||
| if (status != UCC_OK) { | ||
| goto out; | ||
|
|
@@ -313,6 +370,7 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args, | |
| ucc_task_subscribe_dep(a2a_task, barrier_task, | ||
| UCC_EVENT_COMPLETED); | ||
| *task_h = &schedule->super; | ||
|
|
||
| return status; | ||
| out: | ||
| if (tl_schedule) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -11,6 +11,8 @@ | |||||||||||||||||
| #include "utils/ucc_math.h" | ||||||||||||||||||
| #include "tl_ucp_sendrecv.h" | ||||||||||||||||||
|
|
||||||||||||||||||
| void ucc_tl_ucp_alltoallv_onesided_progress(ucc_coll_task_t *ctask); | ||||||||||||||||||
|
|
||||||||||||||||||
| ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask) | ||||||||||||||||||
| { | ||||||||||||||||||
| ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t); | ||||||||||||||||||
|
|
@@ -32,6 +34,12 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask) | |||||||||||||||||
|
|
||||||||||||||||||
| ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); | ||||||||||||||||||
|
|
||||||||||||||||||
| if (TASK_ARGS(task).mask & UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH) { | ||||||||||||||||||
| if (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL) { | ||||||||||||||||||
| src_memh = TASK_ARGS(task).src_memh.global_memh[grank]; | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| /* perform a put to each member peer using the peer's index in the | ||||||||||||||||||
| * destination displacement. */ | ||||||||||||||||||
| for (peer = (grank + 1) % gsize; task->onesided.put_posted < gsize; | ||||||||||||||||||
|
|
@@ -43,8 +51,8 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask) | |||||||||||||||||
| ucc_coll_args_get_displacement(&TASK_ARGS(task), d_disp, peer) * | ||||||||||||||||||
| rdt_size; | ||||||||||||||||||
| data_size = | ||||||||||||||||||
| ucc_coll_args_get_count( | ||||||||||||||||||
| &TASK_ARGS(task), TASK_ARGS(task).src.info_v.counts, peer) * | ||||||||||||||||||
| ucc_coll_args_get_count(&TASK_ARGS(task), | ||||||||||||||||||
| TASK_ARGS(task).src.info_v.counts, peer) * | ||||||||||||||||||
| sdt_size; | ||||||||||||||||||
|
|
||||||||||||||||||
| UCPCHECK_GOTO(ucc_tl_ucp_put_nb(PTR_OFFSET(src, sd_disp), | ||||||||||||||||||
|
|
@@ -56,22 +64,22 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask) | |||||||||||||||||
| dst_memh, team), | ||||||||||||||||||
| task, out); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); | ||||||||||||||||||
| out: | ||||||||||||||||||
| return task->super.status; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| void ucc_tl_ucp_alltoallv_onesided_progress(ucc_coll_task_t *ctask) | ||||||||||||||||||
| { | ||||||||||||||||||
| ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t); | ||||||||||||||||||
| ucc_tl_ucp_team_t *team = TASK_TEAM(task); | ||||||||||||||||||
| ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team); | ||||||||||||||||||
| long *pSync = TASK_ARGS(task).global_work_buffer; | ||||||||||||||||||
| ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t); | ||||||||||||||||||
| ucc_tl_ucp_team_t *team = TASK_TEAM(task); | ||||||||||||||||||
| long *pSync = TASK_ARGS(task).global_work_buffer; | ||||||||||||||||||
| ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team); | ||||||||||||||||||
|
|
||||||||||||||||||
| if (ucc_tl_ucp_test_onesided(task, gsize) == UCC_INPROGRESS) { | ||||||||||||||||||
| return; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| pSync[0] = 0; | ||||||||||||||||||
| task->super.status = UCC_OK; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
@@ -81,8 +89,8 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_init(ucc_base_coll_args_t *coll_args, | |||||||||||||||||
| ucc_coll_task_t **task_h) | ||||||||||||||||||
| { | ||||||||||||||||||
| ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); | ||||||||||||||||||
| ucc_status_t status = UCC_OK; | ||||||||||||||||||
| ucc_tl_ucp_task_t *task; | ||||||||||||||||||
| ucc_status_t status; | ||||||||||||||||||
|
|
||||||||||||||||||
| ALLTOALLV_TASK_CHECK(coll_args->args, tl_team); | ||||||||||||||||||
| if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER)) { | ||||||||||||||||||
|
|
@@ -104,13 +112,20 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_init(ucc_base_coll_args_t *coll_args, | |||||||||||||||||
| } | ||||||||||||||||||
| if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_DST_MEMH)) { | ||||||||||||||||||
| coll_args->args.dst_memh.global_memh = NULL; | ||||||||||||||||||
| } else { | ||||||||||||||||||
| if (!(coll_args->args.flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL)) { | ||||||||||||||||||
| tl_error(UCC_TL_TEAM_LIB(tl_team), | ||||||||||||||||||
| "onesided alltoallv requires global memory handles for dst " | ||||||||||||||||||
| "buffers"); | ||||||||||||||||||
| status = UCC_ERR_INVALID_PARAM; | ||||||||||||||||||
| goto out; | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
Comment on lines
110
to
123
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: destination memory‑handle check (lines 106–116) now moved after the global work buffer check; source handle (lines 103–105) is still set to NULL if missing. Consider validating that the source memory handle flag is also set and is global when provided (similar to the destination logic) to avoid silent failures later. Is there a reason source memory handles are always accepted as local or missing, while destination handles require the global flag? Should source handles also be validated when present? |
||||||||||||||||||
|
|
||||||||||||||||||
| task = ucc_tl_ucp_init_task(coll_args, team); | ||||||||||||||||||
| *task_h = &task->super; | ||||||||||||||||||
|
||||||||||||||||||
| task = ucc_tl_ucp_init_task(coll_args, team); | |
| *task_h = &task->super; | |
| task = ucc_tl_ucp_init_task(coll_args, team); | |
| if (ucc_unlikely(!task)) { | |
| status = UCC_ERR_NO_MEMORY; | |
| goto out; | |
| } | |
| *task_h = &task->super; |
Uh oh!
There was an error while loading. Please reload this page.