Skip to content

Commit 3c576f3

Browse files
committed
Optimize Trans2ContiguousTensors
1 parent 64a826b commit 3c576f3

File tree

1 file changed

+47
-5
lines changed

1 file changed

+47
-5
lines changed

paddle/fluid/eager/to_static/run_program_func.cc

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,40 @@ std::vector<paddle::Tensor> filter_no_need_buffer_input_var_in_backward(
114114
return filter_x;
115115
}
116116

117-
std::vector<paddle::Tensor> Trans2ContiguousTensors(
117+
std::vector<size_t> GetNonContiguousTensorIndices(
118+
const std::vector<paddle::Tensor>& tensors) {
119+
std::vector<size_t> need_trans_idx;
120+
for (size_t idx = 0; idx < tensors.size(); idx++) {
121+
auto& t = tensors[idx];
122+
if (t.initialized() && t.is_dense_tensor() &&
123+
!std::static_pointer_cast<phi::DenseTensor>(t.impl())
124+
->meta()
125+
.is_contiguous()) {
126+
need_trans_idx.push_back(idx);
127+
}
128+
}
129+
return need_trans_idx;
130+
}
131+
132+
void Trans2ContiguousTensors(const std::vector<paddle::Tensor>& tensors,
133+
const std::vector<size_t>& need_trans_idx,
134+
std::vector<paddle::Tensor>* tensors_contig) {
135+
if (!need_trans_idx.empty()) {
136+
tensors_contig->insert(
137+
tensors_contig->end(), tensors.begin(), tensors.end());
138+
for (auto idx : need_trans_idx) {
139+
auto& t = tensors[idx];
140+
tensors_contig->at(idx) = paddle::Tensor(
141+
std::make_shared<phi::DenseTensor>(
142+
paddle::experimental::Trans2Contiguous(
143+
*(std::static_pointer_cast<phi::DenseTensor>(t.impl())))),
144+
t.mutable_autograd_meta(),
145+
t.name());
146+
}
147+
}
148+
}
149+
150+
std::vector<paddle::Tensor> LegacyTrans2ContiguousTensors(
118151
const std::vector<paddle::Tensor>& tensors) {
119152
std::vector<paddle::Tensor> res;
120153
for (const auto& t : tensors) {
@@ -241,8 +274,17 @@ std::vector<paddle::Tensor> run_program_ad_func(
241274
}
242275
VLOG(2) << "start run run_program with require_any_grad = "
243276
<< require_any_grad << ", is_test = " << is_test;
244-
auto x_tmp = Trans2ContiguousTensors(x);
245-
auto params_tmp = Trans2ContiguousTensors(params);
277+
// Note: We should only perform contiguous transformations in the presence of
278+
// non-contiguous tensors. Otherwise, unnecessary overhead will be incurred
279+
// during Tensor construction.
280+
auto x_need_trans_idx = GetNonContiguousTensorIndices(x);
281+
auto params_need_trans_idx = GetNonContiguousTensorIndices(params);
282+
std::vector<paddle::Tensor> x_contig, params_contig;
283+
Trans2ContiguousTensors(x, x_need_trans_idx, &x_contig);
284+
Trans2ContiguousTensors(params, params_need_trans_idx, &params_contig);
285+
const auto& x_tmp = x_need_trans_idx.empty() ? x : x_contig;
286+
const auto& params_tmp =
287+
params_need_trans_idx.empty() ? params : params_contig;
246288
// Call forward function
247289
// if require_any_grad is False, don't save any middle vars.
248290
int64_t place_hash_key = 0x9e3779b9;
@@ -334,8 +376,8 @@ void legacy_run_program_ad_func(
334376

335377
VLOG(2) << "start run run_program with require_any_grad = "
336378
<< require_any_grad;
337-
auto x_tmp = Trans2ContiguousTensors(x);
338-
auto params_tmp = Trans2ContiguousTensors(params);
379+
auto x_tmp = LegacyTrans2ContiguousTensors(x);
380+
auto params_tmp = LegacyTrans2ContiguousTensors(params);
339381
// Call forward function
340382
// if require_any_grad is False, don't save any middle vars.
341383
int64_t place_hash_key = 0;

0 commit comments

Comments
 (0)