diff --git a/src/uct/cuda/base/cuda_iface.c b/src/uct/cuda/base/cuda_iface.c index 6586efbe536..4f788a2a27e 100644 --- a/src/uct/cuda/base/cuda_iface.c +++ b/src/uct/cuda/base/cuda_iface.c @@ -315,11 +315,63 @@ ucs_status_t uct_cuda_base_iface_flush(uct_iface_h tl_iface, unsigned flags, return UCS_OK; } -void uct_cuda_base_stream_destroy(CUstream *stream) +static ucs_status_t +uct_cuda_base_ctx_rsc_push(const uct_cuda_ctx_rsc_t *ctx_rsc, int *pushed) { - if (*stream != NULL) { - (void)UCT_CUDADRV_FUNC_LOG_WARN(cuStreamDestroy(*stream)); + CUcontext primary_ctx; + ucs_status_t status; + + *pushed = 0; + if (ctx_rsc->primary_ctx == NULL) { + return UCS_OK; } + + status = uct_cuda_ctx_primary_retain(ctx_rsc->cuda_device, 0, + &primary_ctx); + if (status == UCS_ERR_NO_DEVICE) { + return status; + } else if (status != UCS_OK) { + return status; + } + + if (primary_ctx != ctx_rsc->primary_ctx) { + UCT_CUDADRV_FUNC_LOG_WARN( + cuDevicePrimaryCtxRelease(ctx_rsc->cuda_device)); + return UCS_ERR_NO_DEVICE; + } + + status = UCT_CUDADRV_FUNC_LOG_DEBUG(cuCtxPushCurrent(primary_ctx)); + UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(ctx_rsc->cuda_device)); + if (status != UCS_OK) { + return status; + } + + *pushed = 1; + return UCS_OK; +} + +static void uct_cuda_base_ctx_rsc_pop(int pushed) +{ + if (pushed) { + UCT_CUDADRV_FUNC_LOG_WARN(cuCtxPopCurrent(NULL)); + } +} + +void uct_cuda_base_stream_destroy(const uct_cuda_ctx_rsc_t *ctx_rsc, + CUstream *stream) +{ + int pushed; + + if (*stream == NULL) { + return; + } + + if (uct_cuda_base_ctx_rsc_push(ctx_rsc, &pushed) != UCS_OK) { + return; + } + + (void)UCT_CUDADRV_FUNC_LOG_WARN(cuStreamDestroy(*stream)); + uct_cuda_base_ctx_rsc_pop(pushed); } static void @@ -334,8 +386,16 @@ uct_cuda_base_event_desc_init(ucs_mpool_t *mp, void *obj, void *chunk) static void uct_cuda_base_event_desc_cleanup(ucs_mpool_t *mp, void *obj) { uct_cuda_event_desc_t *event_desc = obj; + uct_cuda_ctx_rsc_t *ctx_rsc = ucs_container_of(mp, uct_cuda_ctx_rsc_t, + event_mp); + int pushed; + + if (uct_cuda_base_ctx_rsc_push(ctx_rsc, &pushed) != UCS_OK) { + return; + } (void)UCT_CUDADRV_FUNC_LOG_WARN(cuEventDestroy(event_desc->event)); + uct_cuda_base_ctx_rsc_pop(pushed); } void uct_cuda_base_queue_desc_init(uct_cuda_queue_desc_t *qdesc) @@ -353,7 +413,7 @@ void uct_cuda_base_queue_desc_destroy(const uct_cuda_ctx_rsc_t *ctx_rsc, ucs_queue_length(&qdesc->event_queue)); } - uct_cuda_base_stream_destroy(&qdesc->stream); + uct_cuda_base_stream_destroy(ctx_rsc, &qdesc->stream); } static ucs_mpool_ops_t uct_cuda_event_desc_mpool_ops = { @@ -364,6 +424,52 @@ static ucs_mpool_ops_t uct_cuda_event_desc_mpool_ops = { .obj_str = NULL }; +static void uct_cuda_base_ctx_rsc_release_primary_ctx( + uct_cuda_ctx_rsc_t *ctx_rsc) +{ + if (ctx_rsc->primary_ctx == NULL) { + return; + } + + UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(ctx_rsc->cuda_device)); + ctx_rsc->primary_ctx = NULL; + ctx_rsc->cuda_device = CU_DEVICE_INVALID; +} + +static ucs_status_t +uct_cuda_base_ctx_rsc_retain_primary_ctx(uct_cuda_ctx_rsc_t *ctx_rsc) +{ + CUcontext primary_ctx; + CUdevice cuda_device; + ucs_status_t status; + + ctx_rsc->primary_ctx = NULL; + ctx_rsc->cuda_device = CU_DEVICE_INVALID; + + status = UCT_CUDADRV_FUNC_LOG_ERR(cuCtxGetDevice(&cuda_device)); + if (status != UCS_OK) { + return status; + } + + status = uct_cuda_ctx_primary_retain(cuda_device, 0, &primary_ctx); + if (status == UCS_ERR_NO_DEVICE) { + return UCS_OK; + } else if (status != UCS_OK) { + return status; + } + + if (primary_ctx != ctx_rsc->ctx) { + UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(cuda_device)); + ucs_debug("cuda context %p is not the primary context on device %d", + ctx_rsc->ctx, cuda_device); + return UCS_OK; + } + + ctx_rsc->primary_ctx = primary_ctx; + ctx_rsc->cuda_device = cuda_device; + return UCS_OK; +} + ucs_status_t uct_cuda_base_ctx_rsc_create(uct_cuda_iface_t *iface, unsigned long long ctx_id, uct_cuda_ctx_rsc_t **ctx_rsc_p) @@ -399,6 +505,14 @@ ucs_status_t uct_cuda_base_ctx_rsc_create(uct_cuda_iface_t *iface, goto err_del_iter; } + ctx_rsc->ctx = ctx; + ctx_rsc->ctx_id = ctx_id; + + status = uct_cuda_base_ctx_rsc_retain_primary_ctx(ctx_rsc); + if (status != UCS_OK) { + goto err_free_ctx_rsc; + } + ucs_mpool_params_reset(&mp_params); mp_params.elem_size = iface->config.event_desc_size; mp_params.elems_per_chunk = 128; @@ -408,27 +522,34 @@ ucs_status_t uct_cuda_base_ctx_rsc_create(uct_cuda_iface_t *iface, status = ucs_mpool_init(&mp_params, &ctx_rsc->event_mp); if (status != UCS_OK) { - goto err_free_ctx_rsc; + goto err_release_primary_ctx; } - ctx_rsc->ctx = ctx; - ctx_rsc->ctx_id = ctx_id; kh_value(&iface->ctx_rscs, iter) = ctx_rsc; *ctx_rsc_p = ctx_rsc; return UCS_OK; +err_release_primary_ctx: + uct_cuda_base_ctx_rsc_release_primary_ctx(ctx_rsc); err_free_ctx_rsc: iface->ops->destroy_rsc(&iface->super.super, ctx_rsc); err_del_iter: kh_del(cuda_ctx_rscs, &iface->ctx_rscs, iter); - return UCS_ERR_NO_MEMORY; + return status; } static void uct_cuda_base_ctx_rsc_destroy(uct_cuda_iface_t *iface, uct_cuda_ctx_rsc_t *ctx_rsc) { + CUcontext primary_ctx = ctx_rsc->primary_ctx; + CUdevice cuda_device = ctx_rsc->cuda_device; + ucs_mpool_cleanup(&ctx_rsc->event_mp, 1); iface->ops->destroy_rsc(&iface->super.super, ctx_rsc); + + if (primary_ctx != NULL) { + UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(cuda_device)); + } } static ucs_mpool_ops_t uct_cuda_flush_desc_mpool_ops = { diff --git a/src/uct/cuda/base/cuda_iface.h b/src/uct/cuda/base/cuda_iface.h index b17ddfc55f6..0c6b3fc8303 100644 --- a/src/uct/cuda/base/cuda_iface.h +++ b/src/uct/cuda/base/cuda_iface.h @@ -61,6 +61,10 @@ typedef struct { typedef struct { /* CUDA context handle */ CUcontext ctx; + /* Retained CUDA primary context, if @ctx is a primary context */ + CUcontext primary_ctx; + /* CUDA device of @primary_ctx */ + CUdevice cuda_device; /* CUDA context id */ unsigned long long ctx_id; /* pool of cuda events to check completion of memcpy operations */ @@ -129,7 +133,8 @@ void uct_cuda_base_queue_desc_init(uct_cuda_queue_desc_t *qdesc); void uct_cuda_base_queue_desc_destroy(const uct_cuda_ctx_rsc_t *ctx_rsc, uct_cuda_queue_desc_t *qdesc); -void uct_cuda_base_stream_destroy(CUstream *stream); +void uct_cuda_base_stream_destroy(const uct_cuda_ctx_rsc_t *ctx_rsc, + CUstream *stream); #if (__CUDACC_VER_MAJOR__ >= 100000) void CUDA_CB uct_cuda_base_iface_stream_cb_fxn(void *arg); diff --git a/src/uct/cuda/cuda_copy/cuda_copy_iface.c b/src/uct/cuda/cuda_copy/cuda_copy_iface.c index 7b25a6a57f2..d307447edbe 100644 --- a/src/uct/cuda/cuda_copy/cuda_copy_iface.c +++ b/src/uct/cuda/cuda_copy/cuda_copy_iface.c @@ -310,7 +310,7 @@ static void uct_cuda_copy_ctx_rsc_destroy(uct_iface_h tl_iface, } } - uct_cuda_base_stream_destroy(&ctx_rsc->short_stream); + uct_cuda_base_stream_destroy(cuda_ctx_rsc, &ctx_rsc->short_stream); ucs_free(ctx_rsc); } diff --git a/test/gtest/Makefile.am b/test/gtest/Makefile.am index f30fc8d9a0b..0c4b2af4baf 100644 --- a/test/gtest/Makefile.am +++ b/test/gtest/Makefile.am @@ -267,6 +267,7 @@ endif # HAVE_IB if HAVE_CUDA gtest_SOURCES += \ ucm/cuda_hooks.cc \ + uct/cuda/test_cuda_ctx_cleanup.cc \ uct/cuda/test_switch_cuda_device.cc \ uct/cuda/test_cuda_ipc_md.cc \ uct/cuda/test_cuda_nvml.cc diff --git a/test/gtest/uct/cuda/test_cuda_ctx_cleanup.cc b/test/gtest/uct/cuda/test_cuda_ctx_cleanup.cc new file mode 100644 index 00000000000..09fafa8aed6 --- /dev/null +++ b/test/gtest/uct/cuda/test_cuda_ctx_cleanup.cc @@ -0,0 +1,280 @@ +/** + * Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2026. ALL RIGHTS RESERVED. + * + * See file LICENSE for terms. + */ + +#include + +extern "C" { +#include +#include +} + +#include + + +class test_cuda_ctx_cleanup : public ucs::test { +protected: + void init() override + { + int num_devices; + + ucs::test::init(); + + ASSERT_EQ(CUDA_SUCCESS, cuInit(0)); + ASSERT_EQ(CUDA_SUCCESS, cuDeviceGetCount(&num_devices)); + if (num_devices == 0) { + UCS_TEST_SKIP_R("no cuda devices available"); + } + + ASSERT_EQ(CUDA_SUCCESS, cuDeviceGet(&m_device, 0)); + init_iface(); + push_primary_ctx(); + } + + void cleanup() override + { + destroy_ctx_rsc(); + + kh_destroy_inplace(cuda_ctx_rscs, &m_iface.ctx_rscs); + + if (m_ctx_pushed) { + EXPECT_EQ(CUDA_SUCCESS, cuCtxPopCurrent(NULL)); + m_ctx_pushed = false; + } + + if (m_ctx_retained) { + EXPECT_EQ(CUDA_SUCCESS, cuDevicePrimaryCtxRelease(m_device)); + m_ctx_retained = false; + } + + restore_current_cuda_context(m_restore_current_ctx); + m_restore_current_ctx = 0; + + ucs::test::cleanup(); + } + + void create_ctx_rsc() + { + ASSERT_UCS_OK(uct_cuda_base_ctx_rsc_create(&m_iface, 1, &m_ctx_rsc)); + } + + void release_ctx_rsc_primary_ctx() + { + if ((m_ctx_rsc == NULL) || (m_ctx_rsc->primary_ctx == NULL)) { + return; + } + + EXPECT_EQ(CUDA_SUCCESS, + cuDevicePrimaryCtxRelease(m_ctx_rsc->cuda_device)); + m_ctx_rsc->primary_ctx = NULL; + m_ctx_rsc->cuda_device = CU_DEVICE_INVALID; + } + + void free_ctx_rsc() + { + if (m_ctx_rsc == NULL) { + return; + } + + release_ctx_rsc_primary_ctx(); + destroy_rsc(NULL, m_ctx_rsc); + m_ctx_rsc = NULL; + } + + void destroy_ctx_rsc() + { + if (m_ctx_rsc == NULL) { + return; + } + + ucs_mpool_cleanup(&m_ctx_rsc->event_mp, 1); + free_ctx_rsc(); + } + + void release_primary_ctx() + { + CUcontext cuda_ctx; + + ASSERT_TRUE(m_ctx_pushed); + ASSERT_EQ(CUDA_SUCCESS, cuCtxPopCurrent(&cuda_ctx)); + ASSERT_EQ(m_cuda_ctx, cuda_ctx); + m_ctx_pushed = false; + + ASSERT_TRUE(m_ctx_retained); + ASSERT_EQ(CUDA_SUCCESS, cuDevicePrimaryCtxRelease(m_device)); + m_ctx_retained = false; + } + + void pop_current_cuda_contexts(int *popped) + { + CUcontext cuda_ctx; + + *popped = 0; + for (;;) { + ASSERT_EQ(CUDA_SUCCESS, cuCtxGetCurrent(&cuda_ctx)); + if (cuda_ctx == NULL) { + return; + } + + ASSERT_EQ(CUDA_SUCCESS, cuCtxPopCurrent(NULL)); + *popped = 1; + } + } + + void restore_current_cuda_context(int popped) + { + CUcontext cuda_ctx; + + if (!popped) { + return; + } + + /* Restore the gtest CUDA guard context that was current before reset. */ + ASSERT_EQ(CUDA_SUCCESS, cuDevicePrimaryCtxRetain(&cuda_ctx, m_device)); + ASSERT_EQ(CUDA_SUCCESS, cuCtxPushCurrent(cuda_ctx)); + } + + void reset_primary_ctx() + { + CUcontext cuda_ctx; + int popped; + + ASSERT_TRUE(m_ctx_pushed); + ASSERT_EQ(CUDA_SUCCESS, cuCtxPopCurrent(&cuda_ctx)); + ASSERT_EQ(m_cuda_ctx, cuda_ctx); + m_ctx_pushed = false; + + /* Avoid resetting a primary context which is still current. */ + pop_current_cuda_contexts(&popped); + ASSERT_EQ(CUDA_SUCCESS, cuDevicePrimaryCtxReset(m_device)); + m_restore_current_ctx = popped; + /* + * cuDevicePrimaryCtxReset() does not release the retain taken by + * push_primary_ctx(), so cleanup() should still release it. + */ + ASSERT_TRUE(m_ctx_retained); + } + + static uct_cuda_ctx_rsc_t *create_rsc(uct_iface_h iface) + { + return static_cast( + ucs_calloc(1, sizeof(uct_cuda_ctx_rsc_t), + "test_cuda_ctx_rsc")); + } + + static void destroy_rsc(uct_iface_h iface, uct_cuda_ctx_rsc_t *ctx_rsc) + { + ucs_free(ctx_rsc); + } + + static void complete_event(uct_iface_h iface, uct_cuda_event_desc_t *event) + { + } + + uct_cuda_iface_t m_iface = {}; + uct_cuda_ctx_rsc_t *m_ctx_rsc = NULL; + CUcontext m_cuda_ctx = NULL; + CUdevice m_device = 0; + bool m_ctx_pushed = false; + bool m_ctx_retained = false; + int m_restore_current_ctx = 0; + +private: + void init_iface() + { + static uct_cuda_iface_ops_t iface_ops = { + create_rsc, + destroy_rsc, + complete_event + }; + + m_iface.ops = &iface_ops; + m_iface.config.event_desc_size = sizeof(uct_cuda_event_desc_t); + m_iface.config.max_events = 128; + kh_init_inplace(cuda_ctx_rscs, &m_iface.ctx_rscs); + } + + void push_primary_ctx() + { + ASSERT_EQ(CUDA_SUCCESS, + cuDevicePrimaryCtxRetain(&m_cuda_ctx, m_device)); + m_ctx_retained = true; + ASSERT_EQ(CUDA_SUCCESS, cuCtxPushCurrent(m_cuda_ctx)); + m_ctx_pushed = true; + } +}; + +UCS_TEST_F(test_cuda_ctx_cleanup, retain_primary_ctx_until_rsc_cleanup) +{ + create_ctx_rsc(); + EXPECT_EQ(m_cuda_ctx, m_ctx_rsc->primary_ctx); + EXPECT_EQ(m_device, m_ctx_rsc->cuda_device); + + destroy_ctx_rsc(); +} + +UCS_TEST_F(test_cuda_ctx_cleanup, event_cleanup_after_primary_ctx_release) +{ + uct_cuda_event_desc_t *event_desc; + + create_ctx_rsc(); + + event_desc = static_cast( + ucs_mpool_get(&m_ctx_rsc->event_mp)); + ASSERT_NE(nullptr, event_desc); + ucs_mpool_put(event_desc); + + release_primary_ctx(); + + ucs_mpool_cleanup(&m_ctx_rsc->event_mp, 1); + free_ctx_rsc(); +} + +UCS_TEST_F(test_cuda_ctx_cleanup, stream_cleanup_after_primary_ctx_release) +{ + uct_cuda_queue_desc_t qdesc; + + create_ctx_rsc(); + + uct_cuda_base_queue_desc_init(&qdesc); + ASSERT_UCS_OK(uct_cuda_base_init_stream(&qdesc.stream)); + + release_primary_ctx(); + + uct_cuda_base_queue_desc_destroy(m_ctx_rsc, &qdesc); + destroy_ctx_rsc(); +} + +UCS_TEST_F(test_cuda_ctx_cleanup, event_cleanup_after_primary_ctx_reset) +{ + uct_cuda_event_desc_t *event_desc; + + create_ctx_rsc(); + + event_desc = static_cast( + ucs_mpool_get(&m_ctx_rsc->event_mp)); + ASSERT_NE(nullptr, event_desc); + ucs_mpool_put(event_desc); + + reset_primary_ctx(); + + ucs_mpool_cleanup(&m_ctx_rsc->event_mp, 1); + free_ctx_rsc(); +} + +UCS_TEST_F(test_cuda_ctx_cleanup, stream_cleanup_after_primary_ctx_reset) +{ + uct_cuda_queue_desc_t qdesc; + + create_ctx_rsc(); + + uct_cuda_base_queue_desc_init(&qdesc); + ASSERT_UCS_OK(uct_cuda_base_init_stream(&qdesc.stream)); + + reset_primary_ctx(); + + uct_cuda_base_queue_desc_destroy(m_ctx_rsc, &qdesc); + destroy_ctx_rsc(); +}