@@ -114,7 +114,40 @@ std::vector<paddle::Tensor> filter_no_need_buffer_input_var_in_backward(
114114 return filter_x;
115115}
116116
117- std::vector<paddle::Tensor> Trans2ContiguousTensors (
117+ std::vector<size_t > GetNonContiguousTensorIndices (
118+ const std::vector<paddle::Tensor>& tensors) {
119+ std::vector<size_t > need_trans_idx;
120+ for (size_t idx = 0 ; idx < tensors.size (); idx++) {
121+ auto & t = tensors[idx];
122+ if (t.initialized () && t.is_dense_tensor () &&
123+ !std::static_pointer_cast<phi::DenseTensor>(t.impl ())
124+ ->meta ()
125+ .is_contiguous ()) {
126+ need_trans_idx.push_back (idx);
127+ }
128+ }
129+ return need_trans_idx;
130+ }
131+
132+ void Trans2ContiguousTensors (const std::vector<paddle::Tensor>& tensors,
133+ const std::vector<size_t >& need_trans_idx,
134+ std::vector<paddle::Tensor>* tensors_contig) {
135+ if (!need_trans_idx.empty ()) {
136+ tensors_contig->insert (
137+ tensors_contig->end (), tensors.begin (), tensors.end ());
138+ for (auto idx : need_trans_idx) {
139+ auto & t = tensors[idx];
140+ tensors_contig->at (idx) = paddle::Tensor (
141+ std::make_shared<phi::DenseTensor>(
142+ paddle::experimental::Trans2Contiguous (
143+ *(std::static_pointer_cast<phi::DenseTensor>(t.impl ())))),
144+ t.mutable_autograd_meta (),
145+ t.name ());
146+ }
147+ }
148+ }
149+
150+ std::vector<paddle::Tensor> LegacyTrans2ContiguousTensors (
118151 const std::vector<paddle::Tensor>& tensors) {
119152 std::vector<paddle::Tensor> res;
120153 for (const auto & t : tensors) {
@@ -241,8 +274,17 @@ std::vector<paddle::Tensor> run_program_ad_func(
241274 }
242275 VLOG (2 ) << " start run run_program with require_any_grad = "
243276 << require_any_grad << " , is_test = " << is_test;
244- auto x_tmp = Trans2ContiguousTensors (x);
245- auto params_tmp = Trans2ContiguousTensors (params);
277+ // Note: We should only perform contiguous transformations in the presence of
278+ // non-contiguous tensors. Otherwise, unnecessary overhead will be incurred
279+ // during Tensor construction.
280+ auto x_need_trans_idx = GetNonContiguousTensorIndices (x);
281+ auto params_need_trans_idx = GetNonContiguousTensorIndices (params);
282+ std::vector<paddle::Tensor> x_contig, params_contig;
283+ Trans2ContiguousTensors (x, x_need_trans_idx, &x_contig);
284+ Trans2ContiguousTensors (params, params_need_trans_idx, ¶ms_contig);
285+ const auto & x_tmp = x_need_trans_idx.empty () ? x : x_contig;
286+ const auto & params_tmp =
287+ params_need_trans_idx.empty () ? params : params_contig;
246288 // Call forward function
247289 // if require_any_grad is False, don't save any middle vars.
248290 int64_t place_hash_key = 0x9e3779b9 ;
@@ -334,8 +376,8 @@ void legacy_run_program_ad_func(
334376
335377 VLOG (2 ) << " start run run_program with require_any_grad = "
336378 << require_any_grad;
337- auto x_tmp = Trans2ContiguousTensors (x);
338- auto params_tmp = Trans2ContiguousTensors (params);
379+ auto x_tmp = LegacyTrans2ContiguousTensors (x);
380+ auto params_tmp = LegacyTrans2ContiguousTensors (params);
339381 // Call forward function
340382 // if require_any_grad is False, don't save any middle vars.
341383 int64_t place_hash_key = 0 ;
0 commit comments