@@ -34,11 +34,13 @@ limitations under the License.
34
34
#include " tensorflow/core/platform/logging.h"
35
35
#include " tensorflow/core/platform/mutex.h"
36
36
#include " tensorflow/core/platform/types.h"
37
+ #include " tensorflow/core/util/env_var.h"
37
38
38
39
namespace tensorflow {
39
40
40
41
namespace {
41
42
uint64 kGlobalStepId = 0x100000000000000uLL;
43
+ int64 kFlowControlMaxSize = 16 ;
42
44
} // namespace anonymous
43
45
44
46
static void StartAbortRendevous (Rendezvous* rendez, const Status& s) {
@@ -127,6 +129,23 @@ void BaseRendezvousMgr::FuseRecvLocalAsync(
127
129
rendez->FuseRecvLocalAsync (parsed_keys, std::move (done_cb));
128
130
}
129
131
132
+ void BaseRendezvousMgr::FlowControlRecvLocalAsync (int64 step_id,
133
+ const StringPiece& tag, const Rendezvous::ParsedKey& parsed,
134
+ Rendezvous::DoneCallback done) {
135
+ auto rendez = FindOrCreate (step_id);
136
+ using namespace std ::placeholders;
137
+ Rendezvous::DoneCallback done_cb = std::bind (
138
+ [rendez](Rendezvous::DoneCallback done,
139
+ // Begin unbound arguments.
140
+ const Status& s, const Rendezvous::Args& send_args,
141
+ const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
142
+ rendez->Unref ();
143
+ done (s, send_args, recv_args, v, dead);
144
+ },
145
+ std::move (done), _1, _2, _3, _4, _5);
146
+ rendez->FlowControlRecvLocalAsync (tag, parsed, std::move (done_cb));
147
+ }
148
+
130
149
void BaseRendezvousMgr::Cleanup (int64 step_id) {
131
150
Rendezvous* rendez = nullptr ;
132
151
{
@@ -174,7 +193,17 @@ BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id)
174
193
: env_(env),
175
194
step_id_ (step_id),
176
195
local_(NewLocalRendezvous()),
177
- session_(nullptr ) {}
196
+ session_(nullptr ),
197
+ flow_control_num_(0 ) {
198
+ Status s = ReadInt64FromEnvVar (" REMOTE_RENDEZVOUS_FLOW_CONTROL_MAX_SIZE" ,
199
+ kFlowControlMaxSize , &flow_control_max_size_);
200
+ if (!s.ok ()) {
201
+ LOG (ERROR) << " Read REMOTE_RENDEZVOUS_FLOW_CONTROL_MAX_SIZE env error: "
202
+ << s.error_message ();
203
+ }
204
+ VLOG (2 ) << " BaseRemoteRendezvous set flow control max size: "
205
+ << flow_control_max_size_;
206
+ }
178
207
179
208
BaseRemoteRendezvous::~BaseRemoteRendezvous () {
180
209
CHECK (active_.empty ());
@@ -221,6 +250,16 @@ Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
221
250
std::move (fuse_call.done ));
222
251
}
223
252
253
+ std::vector<DeferredFlowControlCall> deferred_flow_control_calls;
254
+ {
255
+ mutex_lock l (mu_);
256
+ std::swap (deferred_flow_control_calls, deferred_flow_control_calls_);
257
+ }
258
+ for (auto & fc_call : deferred_flow_control_calls) {
259
+ FlowControlRecvLocalAsyncInternal (fc_call.tag , fc_call.parsed ,
260
+ std::move (fc_call.done ));
261
+ }
262
+
224
263
return Status::OK ();
225
264
}
226
265
@@ -271,6 +310,43 @@ Status BaseRemoteRendezvous::Send(const ParsedKey& parsed,
271
310
return local_->Send (parsed, args, val, mu, is_dead);
272
311
}
273
312
313
+ Status BaseRemoteRendezvous::FlowControlSend (const StringPiece& tag,
314
+ const ParsedKey& parsed,
315
+ const Args& args,
316
+ const Tensor& val,
317
+ const bool is_dead,
318
+ const int64 timeout_millis) {
319
+ VLOG (1 ) << " BaseRemoteRendezvous FlowControlSend " << this << " "
320
+ << parsed.FullKey ();
321
+ const std::string tag_string (tag.data (), tag.size ());
322
+ {
323
+ mutex_lock l (mu_);
324
+ while (status_.ok () && flow_control_num_ >= flow_control_max_size_) {
325
+ if (flow_control_cv_.wait_for (
326
+ l, std::chrono::milliseconds (timeout_millis)) == \
327
+ std::cv_status::timeout) {
328
+ return errors::DeadlineExceeded (" FlowControlSend has timed out." );
329
+ }
330
+ }
331
+
332
+ if (!status_.ok ()) return status_;
333
+ DCHECK (is_initialized_locked ());
334
+ if (!IsLocalDevice (session_->worker_name , parsed.src_device )) {
335
+ return errors::InvalidArgument (
336
+ " Invalid rendezvous key (src): " , parsed.FullKey (), " @ " ,
337
+ session_->worker_name );
338
+ }
339
+
340
+ flow_control_num_++;
341
+ if (flow_control_counters_.count (tag_string) == 0 ) {
342
+ flow_control_counters_[tag_string] = 0 ;
343
+ }
344
+ flow_control_counters_[tag_string]++;
345
+ }
346
+ // Buffers "val" and "device_context" in local_.
347
+ return local_->Send (parsed, args, val, is_dead);
348
+ }
349
+
274
350
Status BaseRemoteRendezvous::ValidateDevices (const ParsedKey& parsed,
275
351
bool is_src) {
276
352
// Cache session pointer to avoid repeatedly taking & releasing the lock
@@ -413,6 +489,63 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
413
489
}
414
490
}
415
491
492
+ void BaseRemoteRendezvous::FlowControlRecvAsync (const StringPiece& tag,
493
+ const ParsedKey& parsed,
494
+ const Args& recv_args,
495
+ DoneCallback done) {
496
+ VLOG (1 ) << " RemoteRendezvous FlowControlRecvAsync " << this
497
+ << " " << tag << " " << parsed.FullKey ();
498
+
499
+ Status s = ValidateDevices (parsed, false /* !is_src*/ );
500
+ if (s.ok () && !is_initialized ()) {
501
+ s.Update (errors::Internal (
502
+ " FlowControlRecvAsync called when uninitialized (key:" ,
503
+ parsed.FullKey (), " )." ));
504
+ }
505
+ if (!s.ok ()) {
506
+ done (s, Args (), recv_args, Tensor (), false );
507
+ return ;
508
+ }
509
+
510
+ // Are src and dst in the same worker?
511
+ if (IsSameWorker (parsed.src , parsed.dst )) {
512
+ // Recv the tensor from local_.
513
+ local_->RecvAsync (
514
+ parsed, recv_args,
515
+ [this , tag, parsed, done](
516
+ const Status& status, const Rendezvous::Args& send_args,
517
+ const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
518
+ VLOG (2 ) << " RemoteRendezvous Finished Recv " << this << " "
519
+ << parsed.FullKey ();
520
+ Tensor* out = new Tensor;
521
+ StatusCallback final_callback = [done, send_args, recv_args, out,
522
+ is_dead](const Status& s) {
523
+ done (s, send_args, recv_args, *out, is_dead);
524
+ delete out;
525
+ };
526
+
527
+ if (status.ok ()) {
528
+ SameWorkerRecvDone (parsed, send_args, recv_args, in, out,
529
+ std::move (final_callback));
530
+ const std::string tag_string (tag.data (), tag.size ());
531
+ {
532
+ mutex_lock l (mu_);
533
+ flow_control_num_--;
534
+ DCHECK (flow_control_counters_.count (tag_string) != 0 );
535
+ flow_control_counters_[tag_string]--;
536
+ }
537
+ flow_control_cv_.notify_one ();
538
+ } else {
539
+ final_callback (status);
540
+ }
541
+ });
542
+ return ;
543
+ } else {
544
+ FlowControlRecvFromRemoteAsync (tag, parsed, recv_args, std::move (done));
545
+ }
546
+
547
+ }
548
+
416
549
void BaseRemoteRendezvous::RecvLocalAsync (const ParsedKey& parsed,
417
550
DoneCallback done) {
418
551
{
@@ -600,13 +733,71 @@ void BaseRemoteRendezvous::FuseRecvLocalAsyncInternal(
600
733
}
601
734
}
602
735
736
+ void BaseRemoteRendezvous::FlowControlRecvLocalAsync (const StringPiece& tag,
737
+ const ParsedKey& parsed,
738
+ DoneCallback done) {
739
+ {
740
+ mutex_lock l (mu_);
741
+ if (!is_initialized_locked ()) {
742
+ // FlowControlRecvLocalAsync can be called (due to an incoming RecvTensor
743
+ // RPC from a remote worker) before the RunStep (or PartialRunStep) RPC
744
+ // from the master arrives. RecvLocalAsync thus buffers the arguments
745
+ // until after the RemoteRendezvous is Initialize()'d, when it completes
746
+ // the rendezvous logic. At some point after Initialize() is called, a
747
+ // Tensor is produced locally that will then be sent in response to the
748
+ // incoming RPC.
749
+ DeferredFlowControlCall call (tag, parsed, std::move (done));
750
+ deferred_flow_control_calls_.push_back (call);
751
+ return ;
752
+ }
753
+ }
754
+ FlowControlRecvLocalAsyncInternal (tag, parsed, std::move (done));
755
+ }
756
+
757
+ void BaseRemoteRendezvous::FlowControlRecvLocalAsyncInternal (
758
+ const StringPiece& tag, const ParsedKey& parsed, DoneCallback done) {
759
+ Status s = ValidateDevices (parsed, true /* is_src */ );
760
+ if (!s.ok ()) {
761
+ done (s, Args (), Args (), Tensor (), false );
762
+ return ;
763
+ }
764
+
765
+ using namespace std ::placeholders;
766
+ Rendezvous::DoneCallback done_cb = std::bind (
767
+ [this , tag](Rendezvous::DoneCallback done,
768
+ // Begin unbound arguments.
769
+ const Status& s, const Rendezvous::Args& send_args,
770
+ const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
771
+ done (s, send_args, recv_args, v, dead);
772
+ if (s.ok ()) {
773
+ const std::string tag_string (tag.data (), tag.size ());
774
+ {
775
+ mutex_lock l (mu_);
776
+ flow_control_num_--;
777
+ DCHECK (flow_control_counters_.count (tag_string) != 0 );
778
+ flow_control_counters_[tag_string]--;
779
+ }
780
+ flow_control_cv_.notify_one ();
781
+ }
782
+ },
783
+ std::move (done), _1, _2, _3, _4, _5);
784
+
785
+ local_->RecvAsync (parsed, Args (), std::move (done_cb));
786
+ }
787
+
603
788
void BaseRemoteRendezvous::FuseRecvFromRemoteAsync (
604
789
const std::vector<Rendezvous::ParsedKey>& parsed_keys,
605
790
const Rendezvous::Args& args,
606
791
FuseDoneCallback done) {
607
792
CHECK (false ) << " FuseRecvFromRemoteAsync Unimplemented" ;
608
793
}
609
794
795
+ void BaseRemoteRendezvous::FlowControlRecvFromRemoteAsync (
796
+ const StringPiece& tag, const Rendezvous::ParsedKey& parsed,
797
+ const Rendezvous::Args& args, DoneCallback done) {
798
+ CHECK (false ) << " FlowControlRecvFromRemoteAsync Unimplemented." ;
799
+ }
800
+
610
801
void BaseRemoteRendezvous::RecvAsync (const ParsedKey& parsed,
611
802
const Rendezvous::Args& recv_args,
612
803
RefDoneCallback done) {
@@ -636,6 +827,19 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
636
827
}
637
828
}
638
829
830
+ int64 BaseRemoteRendezvous::GetAllFlowControlItemNum () {
831
+ mutex_lock l (mu_);
832
+ return flow_control_num_;
833
+ }
834
+
835
+ int64 BaseRemoteRendezvous::GetFlowControlItemNum (StringPiece tag) {
836
+ const std::string tag_string (tag.data (), tag.size ());
837
+ mutex_lock l (mu_);
838
+ if (flow_control_counters_.count (tag_string) == 0 )
839
+ return 0 ;
840
+ return flow_control_counters_[tag_string];
841
+ }
842
+
639
843
void BaseRemoteRendezvous::StartAbort (const Status& s) {
640
844
CHECK (!s.ok ());
641
845
// Use a "derived" status as the status for the rendezvous. Derived
@@ -656,7 +860,10 @@ void BaseRemoteRendezvous::StartAbort(const Status& s) {
656
860
}
657
861
active_.clear ();
658
862
}
863
+ flow_control_num_ = 0 ;
864
+ flow_control_counters_.clear ();
659
865
}
866
+ flow_control_cv_.notify_all ();
660
867
}
661
868
662
869
void BaseRemoteRendezvous::RegisterCall (BaseRecvTensorCall* call,
@@ -707,4 +914,8 @@ BaseRemoteRendezvous::DeferredFuseCall::DeferredFuseCall(
707
914
const std::vector<ParsedKey>& parsed_keys, FuseDoneCallback done)
708
915
: parsed_keys(parsed_keys), done(std::move(done)) {}
709
916
917
+ BaseRemoteRendezvous::DeferredFlowControlCall::DeferredFlowControlCall (
918
+ const StringPiece& tag, const ParsedKey& parsed, DoneCallback done)
919
+ : tag(tag), parsed(parsed), done(std::move(done)) {}
920
+
710
921
} // end namespace tensorflow
0 commit comments