Skip to content

Commit 05f76dc

Browse files
authored
[Runtime] Optimize CUDA multi-stream by merging copy stream into compute stream. (#557)
Use MERGE_COMPUTE_COPY_STREAM=true to enable merged stream. It could reduce compute stream and copy stream synchronize cost, and improve performance when multi-stream(multi compute stream) enabled.
1 parent 63f6b9a commit 05f76dc

16 files changed

+200
-89
lines changed

Diff for: tensorflow/core/common_runtime/direct_session.cc

+17-1
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,10 @@ DirectSession::DirectSession(const SessionOptions& options,
766766
LOG(INFO) << "Current DirectSession " << this << " will be pinned to core: " << msg;
767767
thread_pools_[0].first->SetThreadPoolAffinity(cpuset);
768768
}
769+
770+
tensorflow::ReadBoolFromEnvVar("MERGE_COMPUTE_COPY_STREAM",
771+
/*default_val=*/false,
772+
&merge_compute_and_copy_stream_);
769773
}
770774

771775
DirectSession::~DirectSession() {
@@ -961,8 +965,16 @@ Status DirectSession::RunInternal(
961965

962966
// Start parallel Executors.
963967
const size_t num_executors = executors_and_keys->items.size();
968+
// ref_send_inputs will be filled during execute graph.
969+
std::vector<TensorReference*> ref_send_inputs;
964970
ExecutorBarrier* barrier = new ExecutorBarrier(
965-
num_executors, run_state.rendez, [&run_state](const Status& ret) {
971+
num_executors, run_state.rendez,
972+
[&run_state, &ref_send_inputs](const Status& ret) {
973+
VLOG(2) << "To unref buffer size: " << ref_send_inputs.size();
974+
for (auto& ref : ref_send_inputs) {
975+
ref->Unref();
976+
}
977+
ref_send_inputs.clear();
966978
{
967979
mutex_lock l(run_state.mu_);
968980
run_state.status.Update(ret);
@@ -994,6 +1006,10 @@ Status DirectSession::RunInternal(
9941006
args.executor_policy = ExecutorPolicy::USE_NORMAL_EXECUTOR;
9951007
}
9961008

1009+
args.ref_send_inputs_mu_ptr = std::make_unique<mutex>();
1010+
args.ref_send_inputs_ptr = &ref_send_inputs;
1011+
args.merge_compute_and_copy_stream = merge_compute_and_copy_stream_;
1012+
9971013
const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
9981014

9991015
bool update_cost_model = false;

Diff for: tensorflow/core/common_runtime/direct_session.h

+4
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,10 @@ class DirectSession : public Session {
447447
int multi_stream_num_ = 0;
448448
ResourceMgr* multi_stream_shared_rmgr_ = nullptr;
449449

450+
// User decide whether use compute stream as copy stream
451+
// by set environment 'MERGE_COMPUTE_COPY_STREAM'
452+
bool merge_compute_and_copy_stream_ = false;
453+
450454
TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
451455

452456
// EXPERIMENTAL: debugger (tfdbg) related

Diff for: tensorflow/core/common_runtime/executor.cc

+33-10
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,11 @@ class ExecutorState {
395395
bool finish_when_deferred_ops_done_ TF_GUARDED_BY(num_deferred_ops_mu_) =
396396
false;
397397

398+
// Ref Tensors of input of send op
399+
mutex* ref_send_inputs_mu_ptr_;
400+
std::vector<TensorReference*>* ref_send_inputs_ptr_;
401+
bool merge_compute_and_copy_stream_;
402+
398403
mutex mu_;
399404
Status status_ TF_GUARDED_BY(mu_);
400405
};
@@ -489,8 +494,8 @@ struct SortTaggedNode {
489494
SortTaggedNode(const std::vector<int64>* immutable_accumulative_cost) :
490495
immutable_accumulative_cost_(immutable_accumulative_cost) {}
491496
bool operator()(const TaggedNode& n1, const TaggedNode& n2) {
492-
return (*immutable_accumulative_cost_)[n1.get_node_item().node_id] >
493-
(*immutable_accumulative_cost_)[n2.get_node_item().node_id];
497+
return (*immutable_accumulative_cost_)[n1.get_node_item().node->id()] >
498+
(*immutable_accumulative_cost_)[n2.get_node_item().node->id()];
494499
}
495500
const std::vector<int64>* immutable_accumulative_cost_;
496501
};
@@ -537,7 +542,10 @@ ExecutorState<PropagatorStateType>::ExecutorState(
537542
sync_on_finish_(args.sync_on_finish),
538543
executor_policy_(args.executor_policy),
539544
propagator_(immutable_state, step_id_, vlog_),
540-
num_outstanding_ops_(0) {
545+
num_outstanding_ops_(0),
546+
ref_send_inputs_mu_ptr_(args.ref_send_inputs_mu_ptr.get()),
547+
ref_send_inputs_ptr_(args.ref_send_inputs_ptr),
548+
merge_compute_and_copy_stream_(args.merge_compute_and_copy_stream) {
541549
// TODO: FIXME Consider function lib executor later
542550
//if (args.cost_runner == nullptr) {
543551
// LOG(FATAL) << "cost_runner is nullptr, please check the args.";
@@ -668,16 +676,31 @@ Status ExecutorState<PropagatorStateType>::ProcessSync(
668676
NodeExecStatsInterface* stats) {
669677
Status s;
670678
OpKernelContext ctx(params, item.num_outputs);
671-
nodestats::SetOpStart(stats);
672-
673-
ExecutorInternal::KernelStatsInfo kernel_stat_buffer;
674-
kernel_stats_->StartCollectOp(&item, &kernel_stat_buffer);
675-
676679
OpKernel* op_kernel = item.kernel;
677680
Device* device = immutable_state_.params().device;
678681
if (item.virtual_device.get() != nullptr) {
679682
device = item.virtual_device.get();
680683
}
684+
685+
if (merge_compute_and_copy_stream_ &&
686+
(op_kernel->type_string() == "_HostSend" ||
687+
(op_kernel->type_string() == "_Send" &&
688+
device->parsed_name().type == "CPU")) &&
689+
item.node->attrs().Find("recv_device")->s().find("GPU") != string::npos &&
690+
(*params->inputs)[0].tensor->NumElements() > 0) {
691+
CHECK(item.num_inputs == 1); // send op allow one tensor
692+
TensorReference* ref = new TensorReference(*((*params->inputs)[0].tensor));
693+
{
694+
mutex_lock l(*ref_send_inputs_mu_ptr_);
695+
ref_send_inputs_ptr_->push_back(std::move(ref));
696+
}
697+
}
698+
699+
nodestats::SetOpStart(stats);
700+
701+
ExecutorInternal::KernelStatsInfo kernel_stat_buffer;
702+
kernel_stats_->StartCollectOp(&item, &kernel_stat_buffer);
703+
681704
const bool is_expensive = kernel_stats_->IsExpensive(item);
682705

683706
if (TF_PREDICT_FALSE(MightTrace(item, event_collector_))) {
@@ -748,7 +771,7 @@ void ExecutorState<PropagatorStateType>::ProcessAsync(
748771
Status s = ProcessOutputs(*state->item, &state->ctx, outputs.data(), stats);
749772
nodestats::SetMemory(stats, &state->ctx);
750773
if (vlog_) {
751-
VLOG(2) << "Async kernel done: " << state->item->node_id << " step "
774+
VLOG(2) << "Async kernel done: " << state->item->node->id() << " step "
752775
<< step_id_ << " " << SummarizeNodeDef(state->item->kernel->def())
753776
<< (state->tagged_node.get_is_dead() ? " is dead" : "")
754777
<< " device: " << device->name();
@@ -898,7 +921,7 @@ void ExecutorState<PropagatorStateType>::BatchProcess(std::vector<TaggedNode> no
898921
tagged_node = inline_ready.front();
899922
inline_ready.pop_front();
900923
const NodeItem& item = tagged_node.get_node_item();
901-
const int id = item.node_id;
924+
const int id = item.node->id();
902925

903926
propagator_.MaybeMarkStarted(tagged_node);
904927

Diff for: tensorflow/core/common_runtime/executor.h

+6
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ class Executor {
121121
CostRunner cost_runner = nullptr;
122122

123123
ExecutorPolicy executor_policy = ExecutorPolicy::USE_NORMAL_EXECUTOR;
124+
125+
// store refs to cpu tensors that will be sent to gpu,
126+
// and release them when the session run finishes.
127+
std::unique_ptr<mutex> ref_send_inputs_mu_ptr;
128+
std::vector<TensorReference*>* ref_send_inputs_ptr = nullptr;
129+
bool merge_compute_and_copy_stream = false;
124130
};
125131
typedef std::function<void(const Status&)> DoneCallback;
126132
virtual void RunAsync(const Args& args, DoneCallback done) = 0;

Diff for: tensorflow/core/common_runtime/gpu/gpu_device.cc

+36-11
Original file line numberDiff line numberDiff line change
@@ -698,17 +698,42 @@ Status BaseGPUDevice::MaybeCopyTensorToGPU(
698698
return err;
699699
}
700700

701-
StatusCallback wrapped_done = std::bind(
702-
[to, copy](StatusCallback done_,
703-
// Begin unbound arguments.
704-
const Status& s) {
705-
if (s.ok()) {
706-
*to = std::move(*copy);
707-
}
708-
delete copy;
709-
done_(s);
710-
},
711-
std::move(done), std::placeholders::_1);
701+
StatusCallback wrapped_done;
702+
if (GPUUtil::MergeComputeAndCopyStream()) {
703+
TensorReference input_ref(from);
704+
auto recv_host_to_device_stream = device_contexts_[0]->stream();
705+
auto event_mgr = em_;
706+
wrapped_done = std::bind(
707+
[to, copy, recv_host_to_device_stream, event_mgr, input_ref](
708+
StatusCallback done_,
709+
// Begin unbound arguments.
710+
const Status& s) {
711+
event_mgr->ThenExecute(
712+
recv_host_to_device_stream,
713+
[to, copy, recv_host_to_device_stream, done_, &s, input_ref]() {
714+
input_ref.Unref();
715+
if (!recv_host_to_device_stream->ok()) {
716+
LOG(FATAL) << "CPU->GPU Memcpy failed";
717+
}
718+
*to = std::move(*copy);
719+
delete copy;
720+
done_(s);
721+
});
722+
},
723+
std::move(done), std::placeholders::_1);
724+
} else {
725+
wrapped_done = std::bind(
726+
[to, copy](StatusCallback done_,
727+
// Begin unbound arguments.
728+
const Status& s) {
729+
if (s.ok()) {
730+
*to = std::move(*copy);
731+
}
732+
delete copy;
733+
done_(s);
734+
},
735+
std::move(done), std::placeholders::_1);
736+
}
712737

713738
tracing::ScopedAnnotation annotation("MakeTensorFromProto");
714739
device_contexts_[0]->CopyCPUTensorToDevice(

Diff for: tensorflow/core/common_runtime/gpu/gpu_util.cc

+66-32
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ limitations under the License.
3636
#include "tensorflow/core/platform/stream_executor.h"
3737
#include "tensorflow/core/platform/tensor_coding.h"
3838
#include "tensorflow/core/platform/tracing.h"
39+
#include "tensorflow/core/util/env_var.h"
3940
#include "tensorflow/core/util/util.h"
4041

4142
// IMPLEMENTATION NOTE:
@@ -111,6 +112,23 @@ void* GetBase(const Tensor* src) {
111112

112113
void* GetBase(Tensor* dst) { return DMAHelper::base(dst); }
113114

115+
/*static*/
116+
bool GPUUtil::MergeComputeAndCopyStream() {
117+
static bool merge = false;
118+
static bool check_setting = true;
119+
if (check_setting) {
120+
static mutex mu;
121+
mutex_lock l(mu);
122+
if (check_setting) {
123+
tensorflow::ReadBoolFromEnvVar("MERGE_COMPUTE_COPY_STREAM",
124+
/*default_val=*/false, &merge);
125+
check_setting = false;
126+
}
127+
}
128+
129+
return merge;
130+
}
131+
114132
/*static*/
115133
void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev,
116134
const DeviceContext* device_context,
@@ -273,16 +291,22 @@ void GPUUtil::CopyGPUTensorToCPU(Device* gpu_device,
273291
return;
274292
}
275293

276-
auto send_device_to_host_stream =
277-
static_cast<const GPUDeviceContext*>(device_context)
278-
->device_to_host_stream();
279-
if (send_device_to_host_stream == nullptr) {
280-
done(errors::Internal("No send gpu copy-out-stream is available."));
281-
return;
282-
}
283-
// Wait for the sender's main stream to make sure the data are available.
284-
if (send_device_to_host_stream != send_stream) {
285-
send_device_to_host_stream->ThenWaitFor(send_stream);
294+
se::Stream* send_device_to_host_stream = nullptr;
295+
if (MergeComputeAndCopyStream()) {
296+
send_device_to_host_stream = send_stream;
297+
} else {
298+
send_device_to_host_stream =
299+
static_cast<const GPUDeviceContext*>(device_context)
300+
->device_to_host_stream();
301+
if (send_device_to_host_stream == nullptr) {
302+
done(errors::Internal("No send gpu copy-out-stream is available."));
303+
return;
304+
}
305+
306+
// Wait for the sender's main stream to make sure the data are available.
307+
if (send_device_to_host_stream != send_stream) {
308+
send_device_to_host_stream->ThenWaitFor(send_stream);
309+
}
286310
}
287311

288312
const int64 total_bytes = gpu_tensor->TotalBytes();
@@ -320,17 +344,22 @@ void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor,
320344
return;
321345
}
322346

323-
auto recv_host_to_device_stream =
324-
static_cast<const GPUDeviceContext*>(device_context)
325-
->host_to_device_stream();
326-
if (recv_host_to_device_stream == nullptr) {
327-
done(errors::Internal("No send gpu copy-out-stream is available."));
328-
return;
329-
}
330-
// Wait for the recv-stream to make sure the buffer is truly available.
331-
if (sync_dst_compute) {
332-
if (recv_host_to_device_stream != recv_stream) {
333-
recv_host_to_device_stream->ThenWaitFor(recv_stream);
347+
se::Stream* recv_host_to_device_stream = nullptr;
348+
if (MergeComputeAndCopyStream()) {
349+
recv_host_to_device_stream = recv_stream;
350+
} else {
351+
recv_host_to_device_stream =
352+
static_cast<const GPUDeviceContext*>(device_context)
353+
->host_to_device_stream();
354+
if (recv_host_to_device_stream == nullptr) {
355+
done(errors::Internal("No send gpu copy-out-stream is available."));
356+
return;
357+
}
358+
// Wait for the recv-stream to make sure the buffer is truly available.
359+
if (sync_dst_compute) {
360+
if (recv_host_to_device_stream != recv_stream) {
361+
recv_host_to_device_stream->ThenWaitFor(recv_stream);
362+
}
334363
}
335364
}
336365

@@ -342,17 +371,22 @@ void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor,
342371
DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes);
343372
recv_host_to_device_stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes);
344373
}
345-
// Use of cpu_tensor may outlive stack scope, so keep a ref.
346-
TensorReference input_ref(*cpu_tensor);
347-
dev_info->event_mgr->ThenExecute(
348-
recv_host_to_device_stream,
349-
[recv_host_to_device_stream, done, input_ref]() {
350-
input_ref.Unref();
351-
if (!recv_host_to_device_stream->ok()) {
352-
LOG(FATAL) << "CPU->GPU Memcpy failed";
353-
}
354-
done(Status::OK());
355-
});
374+
375+
if (MergeComputeAndCopyStream()) {
376+
done(Status::OK());
377+
} else {
378+
// Use of cpu_tensor may outlive stack scope, so keep a ref.
379+
TensorReference input_ref(*cpu_tensor);
380+
dev_info->event_mgr->ThenExecute(
381+
recv_host_to_device_stream,
382+
[recv_host_to_device_stream, done, input_ref]() {
383+
input_ref.Unref();
384+
if (!recv_host_to_device_stream->ok()) {
385+
LOG(FATAL) << "CPU->GPU Memcpy failed";
386+
}
387+
done(Status::OK());
388+
});
389+
}
356390
}
357391

358392
Status GPUUtil::Sync(Device* gpu_device) {

Diff for: tensorflow/core/common_runtime/gpu/gpu_util.h

+4
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ class GPUUtil {
105105
const Tensor* src_gpu_tensor,
106106
Tensor* dst_gpu_tensor,
107107
StatusCallback done);
108+
109+
// User decide whether use compute stream as copy stream
110+
// by set environment 'MERGE_COMPUTE_COPY_STREAM'
111+
static bool MergeComputeAndCopyStream();
108112
};
109113

110114
} // namespace tensorflow

Diff for: tensorflow/core/common_runtime/graph_view.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ limitations under the License.
3636
namespace tensorflow {
3737

3838
string NodeItem::DebugString() const {
39-
string ret = strings::StrCat("{name:'", kernel->name(), "' id:", node_id);
39+
string ret = strings::StrCat("{name:'", kernel->name(), "' id:", node->id());
4040
if (is_source) {
4141
strings::StrAppend(&ret, " source}");
4242
} else {

Diff for: tensorflow/core/common_runtime/graph_view.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ struct ControlEdgeInfo {
5858
//
5959
// Each NodeItem is an element of exactly one GraphView.
6060
struct NodeItem {
61-
// The index of this node's item in its GraphView.
62-
int node_id = -1;
61+
const Node* node = nullptr;
6362

6463
// Cached attributes of this node for fast lookup.
6564
bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr

Diff for: tensorflow/core/common_runtime/immutable_executor_state.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ Status ImmutableExecutorState::Initialize() {
126126
FrameInfo* frame_info = EnsureFrameInfo(frame_name);
127127

128128
NodeItem* item = gview_.node(id);
129-
item->node_id = id;
129+
item->node = n;
130130

131131
item->input_start = frame_info->total_inputs;
132132
frame_info->total_inputs += n->num_inputs();

Diff for: tensorflow/core/common_runtime/immutable_executor_state.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class ImmutableExecutorState {
103103

104104
const FrameInfo& get_enter_frame_info(const NodeItem& node_item) const {
105105
DCHECK(node_item.is_enter);
106-
return *enter_frame_info_[node_item.node_id];
106+
return *enter_frame_info_[node_item.node->id()];
107107
}
108108

109109
bool requires_control_flow_support() const { return requires_control_flow_; }

0 commit comments

Comments
 (0)