Skip to content

Commit 74d9b58

Browse files
ptrendxJanuszL
authored andcommitted
Use pinned memory only when interfacing with MakeContiguous
Signed-off-by: ptredak <[email protected]>
1 parent cc214da commit 74d9b58

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

dali/pipeline/executor/executor.cc

+7-2
Original file line numberDiff line numberDiff line change
@@ -458,8 +458,13 @@ void Executor::SetupDataForGraph(WorkspaceBlob *wsb) {
458458

459459
HostWorkspace &src_ws = wsb->cpu_op_data[parent_idx];
460460
auto input = src_ws.SharedCPUOutput(input_src_idx);
461-
for (auto t : input) {
462-
t->set_pinned(true);
461+
// Use pinned memory only when it is useful
462+
if (node.spec.name() == "MakeContiguous" &&
463+
node.spec.NumOutput() == 1 &&
464+
node.spec.OutputDevice(0) == "gpu") {
465+
for (auto t : input) {
466+
t->set_pinned(true);
467+
}
463468
}
464469
ws.AddInput(input);
465470
}

dali/pipeline/executor/pipelined_executor.cc

+9
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,15 @@ void PipelinedExecutor::SetStageOutputsForIter(
223223
int input_idx = info.con_and_idx[j].second;
224224
wsb->mixed_op_data[mixed_op_id].SetInput(
225225
input_idx, tvp.Get(queue_idx));
226+
const OpNode &node = graph_->mixed_node(mixed_op_id);
227+
// Use pinned memory only when it is useful
228+
if (node.spec.name() == "MakeContiguous" &&
229+
node.spec.NumOutput() == 1 &&
230+
node.spec.OutputDevice(0) == "gpu") {
231+
for (auto& v : tvp.Get(queue_idx)) {
232+
v->set_pinned(true);
233+
}
234+
}
226235
} else if (graph_->NodeType(node_id) == DALI_CPU) {
227236
int cpu_op_id = graph_->NodeIdx(node_id);
228237
int input_idx = info.con_and_idx[j].second;

dali/pipeline/executor/pipelined_executor.h

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class DLL_PUBLIC PipelinedExecutor : public Executor {
6565
for (int j = 0; j < batch_size; ++j) {
6666
tvs_[i].push_back(std::make_shared<Tensor<Backend>>());
6767
tvs_[i].back()->Resize({(Index)bytes_hint});
68+
tvs_[i].back()->set_pinned(false);
6869
}
6970
}
7071
}

0 commit comments

Comments
 (0)