Skip to content

Commit f021456

Browse files
committed
fix(phi): fix CPU
1 parent a41057a commit f021456

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

paddle/phi/kernels/set_kernel.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
#include "paddle/phi/kernels/set_kernel.h"
15+
#include <cstring>
1516
#include "paddle/phi/common/memory_utils.h"
1617
#include "paddle/phi/core/kernel_registry.h"
1718
#include "paddle/phi/kernels/full_kernel.h"
@@ -59,12 +60,16 @@ void SetKernel(const Context& dev_ctx,
5960
DenseTensor tmp;
6061
std::vector<int64_t> alloc_shape = {required_size};
6162
Full<T, Context>(dev_ctx, alloc_shape, 0, &tmp);
62-
memory_utils::Copy(dev_ctx.GetPlace(),
63-
tmp.data<T>(),
64-
dev_ctx.GetPlace(),
65-
x.data<T>(),
66-
x.numel() * sizeof(T),
67-
nullptr);
63+
if (dev_ctx.GetPlace().GetType() == phi::AllocationType::CPU) {
64+
std::memcpy(tmp.data<T>(), x.data<T>(), x.numel() * sizeof(T));
65+
} else {
66+
memory_utils::Copy(dev_ctx.GetPlace(),
67+
tmp.data<T>(),
68+
dev_ctx.GetPlace(),
69+
x.data<T>(),
70+
x.numel() * sizeof(T),
71+
nullptr);
72+
}
6873
out->clear();
6974
*out = DenseTensor{tmp.Holder(), meta};
7075
} else {

0 commit comments

Comments
 (0)