Skip to content

Commit 93c69ad

Browse files
authored
[Rendezvous] RemoteRendezvous supports FlowControl. (#994)
Signed-off-by: chenbangduo.cbd <[email protected]>
1 parent b2aed96 commit 93c69ad

21 files changed

+903
-30
lines changed

tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc

+212-1
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,13 @@ limitations under the License.
3434
#include "tensorflow/core/platform/logging.h"
3535
#include "tensorflow/core/platform/mutex.h"
3636
#include "tensorflow/core/platform/types.h"
37+
#include "tensorflow/core/util/env_var.h"
3738

3839
namespace tensorflow {
3940

4041
namespace {
4142
uint64 kGlobalStepId = 0x100000000000000uLL;
43+
int64 kFlowControlMaxSize = 16;
4244
} // namespace anonymous
4345

4446
static void StartAbortRendevous(Rendezvous* rendez, const Status& s) {
@@ -127,6 +129,23 @@ void BaseRendezvousMgr::FuseRecvLocalAsync(
127129
rendez->FuseRecvLocalAsync(parsed_keys, std::move(done_cb));
128130
}
129131

132+
void BaseRendezvousMgr::FlowControlRecvLocalAsync(int64 step_id,
133+
const StringPiece& tag, const Rendezvous::ParsedKey& parsed,
134+
Rendezvous::DoneCallback done) {
135+
auto rendez = FindOrCreate(step_id);
136+
using namespace std::placeholders;
137+
Rendezvous::DoneCallback done_cb = std::bind(
138+
[rendez](Rendezvous::DoneCallback done,
139+
// Begin unbound arguments.
140+
const Status& s, const Rendezvous::Args& send_args,
141+
const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
142+
rendez->Unref();
143+
done(s, send_args, recv_args, v, dead);
144+
},
145+
std::move(done), _1, _2, _3, _4, _5);
146+
rendez->FlowControlRecvLocalAsync(tag, parsed, std::move(done_cb));
147+
}
148+
130149
void BaseRendezvousMgr::Cleanup(int64 step_id) {
131150
Rendezvous* rendez = nullptr;
132151
{
@@ -174,7 +193,17 @@ BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id)
174193
: env_(env),
175194
step_id_(step_id),
176195
local_(NewLocalRendezvous()),
177-
session_(nullptr) {}
196+
session_(nullptr),
197+
flow_control_num_(0) {
198+
Status s = ReadInt64FromEnvVar("REMOTE_RENDEZVOUS_FLOW_CONTROL_MAX_SIZE",
199+
kFlowControlMaxSize, &flow_control_max_size_);
200+
if (!s.ok()) {
201+
LOG(ERROR) << "Read REMOTE_RENDEZVOUS_FLOW_CONTROL_MAX_SIZE env error: "
202+
<< s.error_message();
203+
}
204+
VLOG(2) << "BaseRemoteRendezvous set flow control max size: "
205+
<< flow_control_max_size_;
206+
}
178207

179208
BaseRemoteRendezvous::~BaseRemoteRendezvous() {
180209
CHECK(active_.empty());
@@ -221,6 +250,16 @@ Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
221250
std::move(fuse_call.done));
222251
}
223252

253+
std::vector<DeferredFlowControlCall> deferred_flow_control_calls;
254+
{
255+
mutex_lock l(mu_);
256+
std::swap(deferred_flow_control_calls, deferred_flow_control_calls_);
257+
}
258+
for (auto& fc_call : deferred_flow_control_calls) {
259+
FlowControlRecvLocalAsyncInternal(fc_call.tag, fc_call.parsed,
260+
std::move(fc_call.done));
261+
}
262+
224263
return Status::OK();
225264
}
226265

@@ -271,6 +310,43 @@ Status BaseRemoteRendezvous::Send(const ParsedKey& parsed,
271310
return local_->Send(parsed, args, val, mu, is_dead);
272311
}
273312

313+
Status BaseRemoteRendezvous::FlowControlSend(const StringPiece& tag,
314+
const ParsedKey& parsed,
315+
const Args& args,
316+
const Tensor& val,
317+
const bool is_dead,
318+
const int64 timeout_millis) {
319+
VLOG(1) << "BaseRemoteRendezvous FlowControlSend " << this << " "
320+
<< parsed.FullKey();
321+
const std::string tag_string(tag.data(), tag.size());
322+
{
323+
mutex_lock l(mu_);
324+
while(status_.ok() && flow_control_num_ >= flow_control_max_size_) {
325+
if (flow_control_cv_.wait_for(
326+
l, std::chrono::milliseconds(timeout_millis)) == \
327+
std::cv_status::timeout) {
328+
return errors::DeadlineExceeded("FlowControlSend has timed out.");
329+
}
330+
}
331+
332+
if (!status_.ok()) return status_;
333+
DCHECK(is_initialized_locked());
334+
if (!IsLocalDevice(session_->worker_name, parsed.src_device)) {
335+
return errors::InvalidArgument(
336+
"Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
337+
session_->worker_name);
338+
}
339+
340+
flow_control_num_++;
341+
if (flow_control_counters_.count(tag_string) == 0) {
342+
flow_control_counters_[tag_string] = 0;
343+
}
344+
flow_control_counters_[tag_string]++;
345+
}
346+
// Buffers "val" and "device_context" in local_.
347+
return local_->Send(parsed, args, val, is_dead);
348+
}
349+
274350
Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
275351
bool is_src) {
276352
// Cache session pointer to avoid repeatedly taking & releasing the lock
@@ -413,6 +489,63 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
413489
}
414490
}
415491

492+
void BaseRemoteRendezvous::FlowControlRecvAsync(const StringPiece& tag,
493+
const ParsedKey& parsed,
494+
const Args& recv_args,
495+
DoneCallback done) {
496+
VLOG(1) << "RemoteRendezvous FlowControlRecvAsync " << this
497+
<< " " << tag << " " << parsed.FullKey();
498+
499+
Status s = ValidateDevices(parsed, false /*!is_src*/);
500+
if (s.ok() && !is_initialized()) {
501+
s.Update(errors::Internal(
502+
"FlowControlRecvAsync called when uninitialized (key:",
503+
parsed.FullKey(), ")."));
504+
}
505+
if (!s.ok()) {
506+
done(s, Args(), recv_args, Tensor(), false);
507+
return;
508+
}
509+
510+
// Are src and dst in the same worker?
511+
if (IsSameWorker(parsed.src, parsed.dst)) {
512+
// Recv the tensor from local_.
513+
local_->RecvAsync(
514+
parsed, recv_args,
515+
[this, tag, parsed, done](
516+
const Status& status, const Rendezvous::Args& send_args,
517+
const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
518+
VLOG(2) << "RemoteRendezvous Finished Recv " << this << " "
519+
<< parsed.FullKey();
520+
Tensor* out = new Tensor;
521+
StatusCallback final_callback = [done, send_args, recv_args, out,
522+
is_dead](const Status& s) {
523+
done(s, send_args, recv_args, *out, is_dead);
524+
delete out;
525+
};
526+
527+
if (status.ok()) {
528+
SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
529+
std::move(final_callback));
530+
const std::string tag_string(tag.data(), tag.size());
531+
{
532+
mutex_lock l(mu_);
533+
flow_control_num_--;
534+
DCHECK(flow_control_counters_.count(tag_string) != 0);
535+
flow_control_counters_[tag_string]--;
536+
}
537+
flow_control_cv_.notify_one();
538+
} else {
539+
final_callback(status);
540+
}
541+
});
542+
return;
543+
} else {
544+
FlowControlRecvFromRemoteAsync(tag, parsed, recv_args, std::move(done));
545+
}
546+
547+
}
548+
416549
void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
417550
DoneCallback done) {
418551
{
@@ -600,13 +733,71 @@ void BaseRemoteRendezvous::FuseRecvLocalAsyncInternal(
600733
}
601734
}
602735

736+
void BaseRemoteRendezvous::FlowControlRecvLocalAsync(const StringPiece& tag,
737+
const ParsedKey& parsed,
738+
DoneCallback done) {
739+
{
740+
mutex_lock l(mu_);
741+
if (!is_initialized_locked()) {
742+
// FlowControlRecvLocalAsync can be called (due to an incoming RecvTensor
743+
// RPC from a remote worker) before the RunStep (or PartialRunStep) RPC
744+
// from the master arrives. RecvLocalAsync thus buffers the arguments
745+
// until after the RemoteRendezvous is Initialize()'d, when it completes
746+
// the rendezvous logic. At some point after Initialize() is called, a
747+
// Tensor is produced locally that will then be sent in response to the
748+
// incoming RPC.
749+
DeferredFlowControlCall call(tag, parsed, std::move(done));
750+
deferred_flow_control_calls_.push_back(call);
751+
return;
752+
}
753+
}
754+
FlowControlRecvLocalAsyncInternal(tag, parsed, std::move(done));
755+
}
756+
757+
void BaseRemoteRendezvous::FlowControlRecvLocalAsyncInternal(
758+
const StringPiece& tag, const ParsedKey& parsed, DoneCallback done) {
759+
Status s = ValidateDevices(parsed, true /* is_src */);
760+
if (!s.ok()) {
761+
done(s, Args(), Args(), Tensor(), false);
762+
return;
763+
}
764+
765+
using namespace std::placeholders;
766+
Rendezvous::DoneCallback done_cb = std::bind(
767+
[this, tag](Rendezvous::DoneCallback done,
768+
// Begin unbound arguments.
769+
const Status& s, const Rendezvous::Args& send_args,
770+
const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
771+
done(s, send_args, recv_args, v, dead);
772+
if (s.ok()) {
773+
const std::string tag_string(tag.data(), tag.size());
774+
{
775+
mutex_lock l(mu_);
776+
flow_control_num_--;
777+
DCHECK(flow_control_counters_.count(tag_string) != 0);
778+
flow_control_counters_[tag_string]--;
779+
}
780+
flow_control_cv_.notify_one();
781+
}
782+
},
783+
std::move(done), _1, _2, _3, _4, _5);
784+
785+
local_->RecvAsync(parsed, Args(), std::move(done_cb));
786+
}
787+
603788
void BaseRemoteRendezvous::FuseRecvFromRemoteAsync(
604789
const std::vector<Rendezvous::ParsedKey>& parsed_keys,
605790
const Rendezvous::Args& args,
606791
FuseDoneCallback done) {
607792
CHECK(false) << "FuseRecvFromRemoteAsync Unimplemented";
608793
}
609794

795+
void BaseRemoteRendezvous::FlowControlRecvFromRemoteAsync(
796+
const StringPiece& tag, const Rendezvous::ParsedKey& parsed,
797+
const Rendezvous::Args& args, DoneCallback done) {
798+
CHECK(false) << "FlowControlRecvFromRemoteAsync Unimplemented.";
799+
}
800+
610801
void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
611802
const Rendezvous::Args& recv_args,
612803
RefDoneCallback done) {
@@ -636,6 +827,19 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
636827
}
637828
}
638829

830+
int64 BaseRemoteRendezvous::GetAllFlowControlItemNum() {
831+
mutex_lock l(mu_);
832+
return flow_control_num_;
833+
}
834+
835+
int64 BaseRemoteRendezvous::GetFlowControlItemNum(StringPiece tag) {
836+
const std::string tag_string(tag.data(), tag.size());
837+
mutex_lock l(mu_);
838+
if (flow_control_counters_.count(tag_string) == 0)
839+
return 0;
840+
return flow_control_counters_[tag_string];
841+
}
842+
639843
void BaseRemoteRendezvous::StartAbort(const Status& s) {
640844
CHECK(!s.ok());
641845
// Use a "derived" status as the status for the rendezvous. Derived
@@ -656,7 +860,10 @@ void BaseRemoteRendezvous::StartAbort(const Status& s) {
656860
}
657861
active_.clear();
658862
}
863+
flow_control_num_ = 0;
864+
flow_control_counters_.clear();
659865
}
866+
flow_control_cv_.notify_all();
660867
}
661868

662869
void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call,
@@ -707,4 +914,8 @@ BaseRemoteRendezvous::DeferredFuseCall::DeferredFuseCall(
707914
const std::vector<ParsedKey>& parsed_keys, FuseDoneCallback done)
708915
: parsed_keys(parsed_keys), done(std::move(done)) {}
709916

917+
BaseRemoteRendezvous::DeferredFlowControlCall::DeferredFlowControlCall(
918+
const StringPiece& tag, const ParsedKey& parsed, DoneCallback done)
919+
: tag(tag), parsed(parsed), done(std::move(done)) {}
920+
710921
} // end namespace tensorflow

0 commit comments

Comments
 (0)