Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 129 additions & 8 deletions src/uct/cuda/base/cuda_iface.c
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,63 @@ ucs_status_t uct_cuda_base_iface_flush(uct_iface_h tl_iface, unsigned flags,
return UCS_OK;
}

void uct_cuda_base_stream_destroy(CUstream *stream)
static ucs_status_t
uct_cuda_base_ctx_rsc_push(const uct_cuda_ctx_rsc_t *ctx_rsc, int *pushed)
{
if (*stream != NULL) {
(void)UCT_CUDADRV_FUNC_LOG_WARN(cuStreamDestroy(*stream));
CUcontext primary_ctx;
ucs_status_t status;

*pushed = 0;
if (ctx_rsc->primary_ctx == NULL) {
return UCS_OK;
}

status = uct_cuda_ctx_primary_retain(ctx_rsc->cuda_device, 0,
&primary_ctx);
if (status == UCS_ERR_NO_DEVICE) {
return status;
} else if (status != UCS_OK) {
return status;
}

if (primary_ctx != ctx_rsc->primary_ctx) {
UCT_CUDADRV_FUNC_LOG_WARN(
cuDevicePrimaryCtxRelease(ctx_rsc->cuda_device));
return UCS_ERR_NO_DEVICE;
}

status = UCT_CUDADRV_FUNC_LOG_DEBUG(cuCtxPushCurrent(primary_ctx));
UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(ctx_rsc->cuda_device));
if (status != UCS_OK) {
return status;
}

*pushed = 1;
return UCS_OK;
}

static void uct_cuda_base_ctx_rsc_pop(int pushed)
{
if (pushed) {
UCT_CUDADRV_FUNC_LOG_WARN(cuCtxPopCurrent(NULL));
}
}

void uct_cuda_base_stream_destroy(const uct_cuda_ctx_rsc_t *ctx_rsc,
CUstream *stream)
{
int pushed;

if (*stream == NULL) {
return;
}

if (uct_cuda_base_ctx_rsc_push(ctx_rsc, &pushed) != UCS_OK) {
return;
}

(void)UCT_CUDADRV_FUNC_LOG_WARN(cuStreamDestroy(*stream));
uct_cuda_base_ctx_rsc_pop(pushed);
}

static void
Expand All @@ -334,8 +386,16 @@ uct_cuda_base_event_desc_init(ucs_mpool_t *mp, void *obj, void *chunk)
static void uct_cuda_base_event_desc_cleanup(ucs_mpool_t *mp, void *obj)
{
uct_cuda_event_desc_t *event_desc = obj;
uct_cuda_ctx_rsc_t *ctx_rsc = ucs_container_of(mp, uct_cuda_ctx_rsc_t,
event_mp);
int pushed;

if (uct_cuda_base_ctx_rsc_push(ctx_rsc, &pushed) != UCS_OK) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to push the context here? Can we just check that retained primary context has the same ID as one that was stored during uct_cuda_ctx_rsc_t creation?

return;
}

(void)UCT_CUDADRV_FUNC_LOG_WARN(cuEventDestroy(event_desc->event));
uct_cuda_base_ctx_rsc_pop(pushed);
}

void uct_cuda_base_queue_desc_init(uct_cuda_queue_desc_t *qdesc)
Expand All @@ -353,7 +413,7 @@ void uct_cuda_base_queue_desc_destroy(const uct_cuda_ctx_rsc_t *ctx_rsc,
ucs_queue_length(&qdesc->event_queue));
}

uct_cuda_base_stream_destroy(&qdesc->stream);
uct_cuda_base_stream_destroy(ctx_rsc, &qdesc->stream);
}

static ucs_mpool_ops_t uct_cuda_event_desc_mpool_ops = {
Expand All @@ -364,6 +424,52 @@ static ucs_mpool_ops_t uct_cuda_event_desc_mpool_ops = {
.obj_str = NULL
};

static void uct_cuda_base_ctx_rsc_release_primary_ctx(
uct_cuda_ctx_rsc_t *ctx_rsc)
{
if (ctx_rsc->primary_ctx == NULL) {
return;
}

UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(ctx_rsc->cuda_device));
ctx_rsc->primary_ctx = NULL;
ctx_rsc->cuda_device = CU_DEVICE_INVALID;
}

static ucs_status_t
uct_cuda_base_ctx_rsc_retain_primary_ctx(uct_cuda_ctx_rsc_t *ctx_rsc)
{
CUcontext primary_ctx;
CUdevice cuda_device;
ucs_status_t status;

ctx_rsc->primary_ctx = NULL;
ctx_rsc->cuda_device = CU_DEVICE_INVALID;

status = UCT_CUDADRV_FUNC_LOG_ERR(cuCtxGetDevice(&cuda_device));
if (status != UCS_OK) {
return status;
}

status = uct_cuda_ctx_primary_retain(cuda_device, 0, &primary_ctx);
if (status == UCS_ERR_NO_DEVICE) {
return UCS_OK;
} else if (status != UCS_OK) {
return status;
}

if (primary_ctx != ctx_rsc->ctx) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's safer to use context ID to check that contexts are the same: cuCtxGetId. The method was added in CUDA 12 to solve the problem when newly created context could have the same pointer address as the previously deleted one.

UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(cuda_device));
ucs_debug("cuda context %p is not the primary context on device %d",
ctx_rsc->ctx, cuda_device);
return UCS_OK;
}

ctx_rsc->primary_ctx = primary_ctx;
ctx_rsc->cuda_device = cuda_device;
return UCS_OK;
}

ucs_status_t uct_cuda_base_ctx_rsc_create(uct_cuda_iface_t *iface,
unsigned long long ctx_id,
uct_cuda_ctx_rsc_t **ctx_rsc_p)
Expand Down Expand Up @@ -399,6 +505,14 @@ ucs_status_t uct_cuda_base_ctx_rsc_create(uct_cuda_iface_t *iface,
goto err_del_iter;
}

ctx_rsc->ctx = ctx;
ctx_rsc->ctx_id = ctx_id;

status = uct_cuda_base_ctx_rsc_retain_primary_ctx(ctx_rsc);
if (status != UCS_OK) {
goto err_free_ctx_rsc;
}

ucs_mpool_params_reset(&mp_params);
mp_params.elem_size = iface->config.event_desc_size;
mp_params.elems_per_chunk = 128;
Expand All @@ -408,27 +522,34 @@ ucs_status_t uct_cuda_base_ctx_rsc_create(uct_cuda_iface_t *iface,

status = ucs_mpool_init(&mp_params, &ctx_rsc->event_mp);
if (status != UCS_OK) {
goto err_free_ctx_rsc;
goto err_release_primary_ctx;
}

ctx_rsc->ctx = ctx;
ctx_rsc->ctx_id = ctx_id;
kh_value(&iface->ctx_rscs, iter) = ctx_rsc;
*ctx_rsc_p = ctx_rsc;
return UCS_OK;

err_release_primary_ctx:
uct_cuda_base_ctx_rsc_release_primary_ctx(ctx_rsc);
err_free_ctx_rsc:
iface->ops->destroy_rsc(&iface->super.super, ctx_rsc);
err_del_iter:
kh_del(cuda_ctx_rscs, &iface->ctx_rscs, iter);
return UCS_ERR_NO_MEMORY;
return status;
}

static void uct_cuda_base_ctx_rsc_destroy(uct_cuda_iface_t *iface,
uct_cuda_ctx_rsc_t *ctx_rsc)
{
CUcontext primary_ctx = ctx_rsc->primary_ctx;
CUdevice cuda_device = ctx_rsc->cuda_device;

ucs_mpool_cleanup(&ctx_rsc->event_mp, 1);
iface->ops->destroy_rsc(&iface->super.super, ctx_rsc);

if (primary_ctx != NULL) {
UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(cuda_device));
}
}

static ucs_mpool_ops_t uct_cuda_flush_desc_mpool_ops = {
Expand Down
7 changes: 6 additions & 1 deletion src/uct/cuda/base/cuda_iface.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ typedef struct {
typedef struct {
/* CUDA context handle */
CUcontext ctx;
/* Retained CUDA primary context, if @ctx is a primary context */
CUcontext primary_ctx;
/* CUDA device of @primary_ctx */
CUdevice cuda_device;
Comment on lines +64 to +67

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use cuda_device field as a flag that ctx field represents a primary device context.

/* CUDA context id */
unsigned long long ctx_id;
/* pool of cuda events to check completion of memcpy operations */
Expand Down Expand Up @@ -129,7 +133,8 @@ void uct_cuda_base_queue_desc_init(uct_cuda_queue_desc_t *qdesc);
void uct_cuda_base_queue_desc_destroy(const uct_cuda_ctx_rsc_t *ctx_rsc,
uct_cuda_queue_desc_t *qdesc);

void uct_cuda_base_stream_destroy(CUstream *stream);
void uct_cuda_base_stream_destroy(const uct_cuda_ctx_rsc_t *ctx_rsc,
CUstream *stream);

#if (__CUDACC_VER_MAJOR__ >= 100000)
void CUDA_CB uct_cuda_base_iface_stream_cb_fxn(void *arg);
Expand Down
2 changes: 1 addition & 1 deletion src/uct/cuda/cuda_copy/cuda_copy_iface.c
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ static void uct_cuda_copy_ctx_rsc_destroy(uct_iface_h tl_iface,
}
}

uct_cuda_base_stream_destroy(&ctx_rsc->short_stream);
uct_cuda_base_stream_destroy(cuda_ctx_rsc, &ctx_rsc->short_stream);
ucs_free(ctx_rsc);
}

Expand Down
1 change: 1 addition & 0 deletions test/gtest/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ endif # HAVE_IB
if HAVE_CUDA
gtest_SOURCES += \
ucm/cuda_hooks.cc \
uct/cuda/test_cuda_ctx_cleanup.cc \
uct/cuda/test_switch_cuda_device.cc \
uct/cuda/test_cuda_ipc_md.cc \
uct/cuda/test_cuda_nvml.cc
Expand Down
Loading