Skip to content

Dtl again #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/dyad/dtl/flux_dtl.c
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ dyad_rc_t dyad_dtl_flux_recv (const dyad_ctx_t* ctx, void** buf, size_t* buflen)
dyad_rc = DYAD_RC_OK;
finish_recv:
if (dtl_handle->f != NULL)
flux_future_destroy (dtl_handle->f);
flux_future_reset (dtl_handle->f);
DYAD_C_FUNCTION_UPDATE_INT ("tmp_buflen", tmp_buflen);
DYAD_C_FUNCTION_END();
return dyad_rc;
Expand All @@ -238,7 +238,7 @@ dyad_rc_t dyad_dtl_flux_finalize (const dyad_ctx_t* ctx)
{
DYAD_C_FUNCTION_START();
dyad_rc_t rc = DYAD_RC_OK;
if (ctx->dtl_handle == NULL) {
if (ctx->dtl_handle == NULL || !(ctx->dtl_handle->private_dtl.flux_dtl_handle)) {
goto dtl_flux_finalize_done;
}
ctx->dtl_handle->private_dtl.flux_dtl_handle->h = NULL;
Expand Down
11 changes: 2 additions & 9 deletions src/dyad/dtl/ucx_dtl.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,16 @@ static void dyad_recv_callback (void* request, ucs_status_t status, ucp_tag_recv

#if UCP_API_VERSION >= UCP_VERSION(1, 10)
static void dyad_send_callback (void* req, ucs_status_t status, void* ctx)
{
DYAD_C_FUNCTION_START();
DYAD_LOG_INFO ((dyad_ctx_t*) ctx, "Calling send callback");
dyad_ucx_request_t* real_req = (dyad_ucx_request_t*)req;
real_req->completed = 1;
DYAD_C_FUNCTION_END();
}
#else // UCP_API_VERSION
static void dyad_send_callback (void* req, ucs_status_t status)
#endif // UCP_API_VERSION
{
DYAD_C_FUNCTION_START();
DYAD_LOG_STDOUT ("Calling send callback");
DYAD_LOG_STDERR ("Calling send callback");
dyad_ucx_request_t* real_req = (dyad_ucx_request_t*)req;
real_req->completed = 1;
DYAD_C_FUNCTION_END();
}
#endif // UCP_API_VERSION

// Simple function used to wait on the async receive
static ucs_status_t dyad_ucx_request_wait (const dyad_ctx_t* ctx,
Expand Down
71 changes: 62 additions & 9 deletions src/dyad/dtl/ucx_ep_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ dyad_rc_t dyad_ucx_ep_cache_init (const dyad_ctx_t *ctx, ucx_ep_cache_h* cache)
rc = DYAD_RC_BADBUF;
goto ucx_ep_cache_init_done;
}
*cache = reinterpret_cast<ucx_ep_cache_h> (new (std::nothrow) cache_type ());
*cache = static_cast<ucx_ep_cache_h> (new (std::nothrow) cache_type ());
if (*cache == nullptr) {
rc = DYAD_RC_SYSFAIL;
goto ucx_ep_cache_init_done;
Expand All @@ -135,7 +135,7 @@ dyad_rc_t dyad_ucx_ep_cache_find (const dyad_ctx_t *ctx,
goto ucx_ep_cache_find_done;
}
try {
const auto* cpp_cache = reinterpret_cast<const cache_type*> (cache);
const auto* cpp_cache = static_cast<const cache_type*> (cache);
auto key = ctx->dtl_handle->private_dtl.ucx_dtl_handle->consumer_conn_key;
auto cache_it = cpp_cache->find (key);
if (cache_it == cpp_cache->cend ()) {
Expand Down Expand Up @@ -163,7 +163,7 @@ dyad_rc_t dyad_ucx_ep_cache_insert (const dyad_ctx_t *ctx,
DYAD_C_FUNCTION_START();
dyad_rc_t rc = DYAD_RC_OK;
try {
cache_type* cpp_cache = reinterpret_cast<cache_type*> (cache);
cache_type* cpp_cache = static_cast<cache_type*> (cache);
uint64_t key = ctx->dtl_handle->private_dtl.ucx_dtl_handle->consumer_conn_key;
DYAD_C_FUNCTION_UPDATE_INT("cons_key", ctx->dtl_handle->private_dtl.ucx_dtl_handle->consumer_conn_key)
auto cache_it = cpp_cache->find (key);
Expand Down Expand Up @@ -212,7 +212,7 @@ dyad_rc_t dyad_ucx_ep_cache_remove (const dyad_ctx_t *ctx,
DYAD_C_FUNCTION_START();
dyad_rc_t rc = DYAD_RC_OK;
try {
cache_type* cpp_cache = reinterpret_cast<cache_type*> (cache);
cache_type* cpp_cache = static_cast<cache_type*> (cache);
auto key = ctx->dtl_handle->private_dtl.ucx_dtl_handle->consumer_conn_key;
cache_type::iterator cache_it = cpp_cache->find (key);
cache_remove_impl (ctx, cpp_cache, cache_it, worker);
Expand All @@ -230,12 +230,65 @@ dyad_rc_t dyad_ucx_ep_cache_finalize (const dyad_ctx_t *ctx, ucx_ep_cache_h* cac
if (cache == nullptr || *cache == nullptr) {
return DYAD_RC_OK;
}
cache_type* cpp_cache = reinterpret_cast<cache_type*> (*cache);
for (cache_type::iterator it = cpp_cache->begin (); it != cpp_cache->end ();) {
it = cache_remove_impl (ctx, cpp_cache, it, worker);

dyad_rc_t rc = DYAD_RC_OK;
cache_type& cpp_cache = *(static_cast<cache_type*> (*cache));
std::vector<ucs_status_ptr_t> stat_ptrs (cpp_cache.size ());
auto it_stat = stat_ptrs.begin ();

#if UCP_API_VERSION >= UCP_VERSION(1, 10)
ucp_request_param_t close_params;
close_params.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS;
close_params.flags = UCP_EP_CLOSE_FLAG_FORCE;
#endif

for (cache_type::iterator it = cpp_cache.begin (); it != cpp_cache.end ();) {
ucp_ep_h& ep = it->second;
if (ep != NULL) {
// ucp_tag_send_sync_nbx is the prefered version of this send
// since UCX 1.9 However, some systems (e.g., Lassen) may have
// an older verison This conditional compilation will use
// ucp_tag_send_sync_nbx if using UCX 1.9+, and it will use the
// deprecated ucp_tag_send_sync_nb if using UCX < 1.9.
#if UCP_API_VERSION >= UCP_VERSION(1, 10)
*(it_stat++) = ucp_ep_close_nbx (ep, &close_params);
#else
// TODO change to FORCE if we decide to enable err handleing
// mode
*(it_stat++) = ucp_ep_close_nb (ep, UCP_EP_CLOSE_MODE_FORCE);
#endif
}
}

for (auto& stat_ptr : stat_ptrs)
{
ucs_status_t status = UCS_OK;
// Don't use dyad_ucx_request_wait here because ep_close behaves
// differently than other UCX calls
if (stat_ptr != NULL) {
if (UCS_PTR_IS_PTR (stat_ptr)) {
// Endpoint close is in-progress.
// Wait until finished
do {
ucp_worker_progress (worker);
status = ucp_request_check_status (stat_ptr);
} while (status == UCS_INPROGRESS);
ucp_request_free (stat_ptr);
} else {
// An error occurred during endpoint closure
// However, the endpoint can no longer be used
// Get the status code for reporting
status = UCS_PTR_STATUS (stat_ptr);
}
if (UCX_STATUS_FAIL (status)) {
rc = DYAD_RC_UCXEP_FAIL;
}
}
}
delete cpp_cache;

cpp_cache.clear ();
delete &cpp_cache;
*cache = nullptr;
DYAD_C_FUNCTION_END();
return DYAD_RC_OK;
return rc;
}