Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Implement gloo abort for graceful shutdown #388

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gloo/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
set(GLOO_COMMON_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/logging.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/utils.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/error.cc"
)

set(GLOO_COMMON_HDRS
Expand Down
46 changes: 46 additions & 0 deletions gloo/common/error.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/**
* Copyright (c) 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <atomic>
#include <list>

#include "gloo/common/error.h"

namespace gloo {


std::list<std::condition_variable *> _cvs;
std::mutex _cvs_mutex;

std::atomic_bool _is_aborted_flag(false);

bool _is_aborted() {
return _is_aborted_flag.load();
}

void abort() {
_is_aborted_flag.store(true);
std::lock_guard<std::mutex> guard(_cvs_mutex);
for(auto& cv : _cvs) {
if(cv != NULL) {
cv->notify_all();
}
}
GLOO_THROW("GLOO ABORTED");
}

void _register_cv(std::condition_variable *cv) {
std::lock_guard<std::mutex> guard(_cvs_mutex);
_cvs.push_back(cv);
}

void _deregister_cv(std::condition_variable *cv) {
std::lock_guard<std::mutex> guard(_cvs_mutex);
_cvs.remove(cv);
}
} // namespace gloo
6 changes: 6 additions & 0 deletions gloo/common/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <chrono>
#include <exception>
#include <condition_variable>

#include "gloo/common/string.h"

Expand All @@ -20,6 +21,11 @@ namespace gloo {

const std::chrono::milliseconds kNoTimeout = std::chrono::milliseconds::zero();

bool _is_aborted();
void abort();
void _register_cv(std::condition_variable *cv);
void _deregister_cv(std::condition_variable *cv);

// A base class for all gloo runtime errors
struct Exception : public std::runtime_error {
Exception() = delete;
Expand Down
22 changes: 17 additions & 5 deletions gloo/transport/tcp/unbound_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@ UnboundBuffer::UnboundBuffer(
recvRank_(-1),
sendCompletions_(0),
sendRank_(-1),
shareableNonOwningPtr_(this) {}
shareableNonOwningPtr_(this) {
gloo::_register_cv(&recvCv_);
gloo::_register_cv(&sendCv_);
}

UnboundBuffer::~UnboundBuffer() {}
UnboundBuffer::~UnboundBuffer() {
gloo::_deregister_cv(&recvCv_);
gloo::_deregister_cv(&sendCv_);
}

void UnboundBuffer::handleRecvCompletion(int rank) {
std::lock_guard<std::mutex> lock(m_);
Expand Down Expand Up @@ -60,6 +66,9 @@ bool UnboundBuffer::waitRecv(int* rank, std::chrono::milliseconds timeout) {
if (recvCompletions_ == 0) {
auto done = recvCv_.wait_for(lock, timeout, [&] {
throwIfException();
if(gloo::_is_aborted()) {
abortWaitRecv_ = true;
}
return abortWaitRecv_ || recvCompletions_ > 0;
});
if (!done) {
Expand Down Expand Up @@ -111,9 +120,12 @@ bool UnboundBuffer::waitSend(int* rank, std::chrono::milliseconds timeout) {

if (sendCompletions_ == 0) {
auto done = sendCv_.wait_for(lock, timeout, [&] {
throwIfException();
return abortWaitSend_ || sendCompletions_ > 0;
});
throwIfException();
if(gloo::_is_aborted()) {
abortWaitSend_ = true;
}
return abortWaitSend_ || sendCompletions_ > 0;
});
if (!done) {
// Below, we let all pairs in the transport context know about this
// application side timeout. This in turn will call into all pending
Expand Down
26 changes: 20 additions & 6 deletions gloo/transport/uv/unbound_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@ UnboundBuffer::UnboundBuffer(
recvRank_(-1),
sendCompletions_(0),
sendRank_(-1),
shareableNonOwningPtr_(this) {}
shareableNonOwningPtr_(this) {
gloo::_register_cv(&recvCv_);
gloo::_register_cv(&sendCv_);
}

UnboundBuffer::~UnboundBuffer() {}
UnboundBuffer::~UnboundBuffer() {
gloo::_deregister_cv(&recvCv_);
gloo::_deregister_cv(&sendCv_);
}

void UnboundBuffer::handleRecvCompletion(int rank) {
std::lock_guard<std::mutex> lock(mutex_);
Expand Down Expand Up @@ -58,8 +64,12 @@ bool UnboundBuffer::waitRecv(int* rank, std::chrono::milliseconds timeout) {
}

if (recvCompletions_ == 0) {
auto done = recvCv_.wait_for(
lock, timeout, [&] { return abortWaitRecv_ || recvCompletions_ > 0; });
auto done = recvCv_.wait_for(lock, timeout, [&] {
if(gloo::_is_aborted()) {
abortWaitRecv_ = true;
}
return abortWaitRecv_ || recvCompletions_ > 0;
});
if (!done) {
throw ::gloo::IoException(GLOO_ERROR_MSG(
"Timed out waiting ",
Expand Down Expand Up @@ -94,8 +104,12 @@ bool UnboundBuffer::waitSend(int* rank, std::chrono::milliseconds timeout) {
}

if (sendCompletions_ == 0) {
auto done = sendCv_.wait_for(
lock, timeout, [&] { return abortWaitSend_ || sendCompletions_ > 0; });
auto done = sendCv_.wait_for(lock, timeout, [&] {
if(gloo::_is_aborted()) {
abortWaitSend_ = true;
}
return abortWaitSend_ || sendCompletions_ > 0;
});
if (!done) {
throw ::gloo::IoException(GLOO_ERROR_MSG(
"Timed out waiting ",
Expand Down