Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions paddle/phi/api/include/compat/ATen/core/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/TensorOptions.h>
#include <utils/int_array_ref_conversion.h>
#include <utils/pinned_tensor_ops.h>
#include <utils/scalar_type_conversion.h>
#include <algorithm>
#include <iostream>
Expand Down Expand Up @@ -112,9 +113,6 @@ class PADDLE_API TensorBase {
return backend_str + scalar_type_str + "Type";
}

// Returns the tensor's current data pointer. Storage mutations flow through
// the compat holder view, so tensor.data_ptr() stays aligned with storage()
// while preserving tensor-specific offsets for views.
void* data_ptr() const {
if (!tensor_.defined()) {
return nullptr;
Expand Down Expand Up @@ -270,12 +268,12 @@ class PADDLE_API TensorBase {
}

const TensorBase& fill_(const at::Scalar& scalar) const {
paddle::experimental::fill_(const_cast<PaddleTensor&>(tensor_), scalar);
compat::_PD_FillTensorInplace(&const_cast<PaddleTensor&>(tensor_), scalar);
return *this;
}

const TensorBase& zero_() const {
paddle::experimental::fill_(const_cast<PaddleTensor&>(tensor_), 0.0);
compat::_PD_FillTensorInplace(&const_cast<PaddleTensor&>(tensor_), 0.0);
return *this;
}

Expand Down
22 changes: 13 additions & 9 deletions paddle/phi/api/include/compat/ATen/core/TensorBody.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <c10/core/Stream.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/util/OptionalArrayRef.h>
#include <utils/pinned_tensor_ops.h>
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/int_array.h"
Expand Down Expand Up @@ -126,10 +127,15 @@ class Tensor : public TensorBase {
return *this;
}

void* data_ptr() const { return const_cast<void*>(tensor_.data()); }
void* data_ptr() const {
if (!tensor_.defined()) {
return nullptr;
}
return const_cast<void*>(tensor_.data());
}
template <typename T>
T* data_ptr() const {
return const_cast<T*>(tensor_.data<T>());
return static_cast<T*>(data_ptr());
}

template <typename T>
Expand Down Expand Up @@ -176,9 +182,7 @@ class Tensor : public TensorBase {
#endif
}

const void* const_data_ptr() const {
return const_cast<void*>(tensor_.data());
}
const void* const_data_ptr() const { return data_ptr(); }

template <typename T, std::enable_if_t<!std::is_const_v<T>, int> = 0>
const T* const_data_ptr() const {
Expand All @@ -190,7 +194,7 @@ class Tensor : public TensorBase {
return TensorBase::const_data_ptr<T>();
}

void* mutable_data_ptr() const { return const_cast<void*>(tensor_.data()); }
void* mutable_data_ptr() const { return data_ptr(); }

template <typename T>
T* mutable_data_ptr() const {
Expand Down Expand Up @@ -407,12 +411,12 @@ class Tensor : public TensorBase {
at::Tensor unflatten_symint(int64_t dim, c10::SymIntArrayRef sizes) const;

Tensor& fill_(const at::Scalar& value) const {
paddle::experimental::fill_(const_cast<PaddleTensor&>(tensor_), value);
compat::_PD_FillTensorInplace(&const_cast<PaddleTensor&>(tensor_), value);
return const_cast<at::Tensor&>(*this);
}

Tensor& zero_() const {
paddle::experimental::fill_(const_cast<PaddleTensor&>(tensor_), 0.0);
compat::_PD_FillTensorInplace(&const_cast<PaddleTensor&>(tensor_), 0.0);
return const_cast<at::Tensor&>(*this);
}

Expand Down Expand Up @@ -489,7 +493,7 @@ class Tensor : public TensorBase {
#endif
}

return tensor_.copy_to(pinned_place, true);
return tensor_.copy_to(pinned_place, /*blocking=*/true);
}

at::Tensor narrow_copy(int64_t dim, int64_t start, int64_t length) const;
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/api/include/compat/ATen/core/TensorMethods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,24 @@ void check_type(const TensorBase& tensor,
template <> \
PADDLE_API const T* TensorBase::const_data_ptr() const { \
check_type(*this, ScalarType::name, #name); \
return const_cast<T*>(tensor_.data<T>()); \
return static_cast<const T*>(data_ptr()); \
} \
\
template <> \
PADDLE_API const T* TensorBase::const_data_ptr<const T>() const { \
check_type(*this, ScalarType::name, #name); \
return const_cast<T*>(tensor_.data<std::remove_const_t<T>>()); \
return static_cast<const std::remove_const_t<T>*>(data_ptr()); \
} \
\
template <> \
PADDLE_API T* TensorBase::mutable_data_ptr() const { \
check_type(*this, ScalarType::name, #name); \
return const_cast<PaddleTensor&>(tensor_).data<T>(); \
return static_cast<T*>(data_ptr()); \
} \
\
template <> \
PADDLE_API T* TensorBase::data_ptr() const { \
return const_cast<T*>(tensor_.data<T>()); \
return static_cast<T*>(data_ptr()); \
}

AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CAST) // missing half and float16
Expand Down
5 changes: 4 additions & 1 deletion paddle/phi/api/include/compat/ATen/ops/empty.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <ATen/core/Tensor.h>
#include <c10/core/TensorOptions.h>
#include <utils/dense_sparse_conversion.h>
#include <utils/pinned_place.h>
#include <optional>
#include <string_view>

Expand All @@ -38,11 +39,13 @@ inline at::Tensor empty(
PD_THROW(
"pin_memory=true requires device to be CPU, but got non-CPU device");
}
phi::Place pinned_place =
compat::_PD_GetCreatePinnedPlace(options._PD_GetPlace());
auto dense = paddle::experimental::empty(
size._PD_ToPaddleIntArray(),
compat::_PD_AtenScalarTypeToPhiDataType(options.dtype()),
phi::CPUPlace());
dense = dense.copy_to(phi::GPUPinnedPlace(), /*blocking=*/true);
dense = dense.copy_to(pinned_place, /*blocking=*/true);
return compat::_PD_ConvertToSparseIfNeeded(dense, options.layout());
}
auto dense = paddle::experimental::empty(
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/api/include/compat/ATen/ops/new_empty.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ inline Tensor Tensor::new_empty(at::IntArrayRef size,
"pin_memory=true requires device to be CPU, but got non-CPU device");
}
phi::Place pinned_place = compat::_PD_GetCreatePinnedPlace(pd_place);
auto dense_cpu = paddle::experimental::empty(
size._PD_ToPaddleIntArray(), pd_dtype, phi::CPUPlace());
result = dense_cpu.copy_to(pinned_place, /*blocking=*/true);
result = paddle::experimental::empty(
size._PD_ToPaddleIntArray(), pd_dtype, phi::CPUPlace())
.copy_to(pinned_place, /*blocking=*/true);
} else {
result = paddle::experimental::empty(
size._PD_ToPaddleIntArray(), pd_dtype, pd_place);
Expand Down
47 changes: 47 additions & 0 deletions paddle/phi/api/include/compat/utils/pinned_tensor_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <c10/core/Scalar.h>

#include "paddle/common/ddim.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h"

namespace compat {

inline bool _PD_IsHostPinnedTensor(const paddle::Tensor& tensor) {
const auto& place = tensor.place();
return phi::is_cuda_pinned_place(place) || phi::is_xpu_pinned_place(place);
}

inline void _PD_FillTensorInplace(paddle::Tensor* tensor,
const c10::Scalar& value) {
if (!_PD_IsHostPinnedTensor(*tensor)) {
paddle::experimental::fill_(*tensor, value);
return;
}

auto cpu_src = paddle::experimental::full(
phi::IntArray(common::vectorize<int64_t>(tensor->dims())),
value,
tensor->dtype(),
phi::CPUPlace());
tensor->copy_(cpu_src, tensor->place(), /*blocking=*/true);
}

} // namespace compat
24 changes: 24 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,26 @@ def _attr_offsets_check(offset_val):
return out


@dygraph_only
def _is_host_pinned_tensor(x: Tensor) -> bool:
place = x.place
return (
hasattr(place, "is_cuda_pinned_place") and place.is_cuda_pinned_place()
) or (hasattr(place, "is_xpu_pinned_place") and place.is_xpu_pinned_place())


@dygraph_only
def _fill_host_pinned_tensor_inplace(x: Tensor, value: float) -> Tensor:
cpu_src = paddle.full(
shape=x.shape,
fill_value=value,
dtype=x.dtype,
device=paddle.CPUPlace(),
)
x.copy_(cpu_src, True)
return x


@dygraph_only
def fill_(x: Tensor, value: float) -> Tensor:
"""
Expand Down Expand Up @@ -1165,6 +1185,8 @@ def fill_(x: Tensor, value: float) -> Tensor:
raise TypeError(
f"The type of 'value' must be int or float, but received {type(value)}."
)
if _is_host_pinned_tensor(x):
return _fill_host_pinned_tensor_inplace(x, value)
return _C_ops.fill_(x, value)


Expand Down Expand Up @@ -1194,6 +1216,8 @@ def zero_(x: Tensor) -> Tensor:
[0, 0, 0, 0, 0]

"""
if _is_host_pinned_tensor(x):
return _fill_host_pinned_tensor_inplace(x, 0.0)
return _C_ops.fill_(x, 0.0)


Expand Down
17 changes: 17 additions & 0 deletions test/cpp/compat/ATen_pin_memory_creation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,23 @@ TEST(ATenPinMemoryCreationTest, EmptyLikePinMemoryWithCUDADeviceErrors) {
std::exception);
}

TEST(ATenPinMemoryCreationTest, FillAndZeroPreservePinnedPlace) {
auto tensor =
at::empty({2}, at::TensorOptions().dtype(at::kInt).pinned_memory(true));

AssertPinned(tensor);

tensor.fill_(123);
AssertPinned(tensor);
EXPECT_EQ(tensor._PD_GetInner().data<int>()[0], 123);
EXPECT_EQ(tensor._PD_GetInner().data<int>()[1], 123);

tensor.zero_();
AssertPinned(tensor);
EXPECT_EQ(tensor._PD_GetInner().data<int>()[0], 0);
EXPECT_EQ(tensor._PD_GetInner().data<int>()[1], 0);
}

TEST(ATenPinMemoryCreationTest, ZerosLikePinMemory) {
auto base = at::ones({2, 4}, at::kFloat);

Expand Down
48 changes: 48 additions & 0 deletions test/cpp/compat/ATen_pin_memory_kernel_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#if defined(PADDLE_WITH_CUDA)

#include <ATen/ops/empty.h>
#include <c10/cuda/CUDAStream.h>

#include "gtest/gtest.h"

namespace {

__global__ void WriteScalarKernel(int* dst, int value) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
dst[0] = value;
}
}

} // namespace

TEST(ATenPinMemoryKernelTest, KernelCanWritePinnedTensorDirectly) {
auto stream = at::cuda::getCurrentCUDAStream();
auto tensor =
at::empty({1}, at::TensorOptions().dtype(at::kInt).pinned_memory(true));

ASSERT_TRUE(tensor.is_pinned());
ASSERT_FALSE(tensor.is_cuda());

tensor._PD_GetInner().data<int>()[0] = 0;
WriteScalarKernel<<<1, 1, 0, stream>>>(tensor.data_ptr<int>(), 123);
ASSERT_EQ(cudaGetLastError(), cudaSuccess);
ASSERT_EQ(cudaStreamSynchronize(stream), cudaSuccess);

EXPECT_EQ(tensor._PD_GetInner().data<int>()[0], 123);
}

#endif // PADDLE_WITH_CUDA
1 change: 1 addition & 0 deletions test/cpp/compat/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ if(NOT WIN32)
if(WITH_GPU)
nv_test(ATen_CUDABlas_test SRCS ATen_CUDABlas_test.cc)
nv_test(ATen_cuda_test SRCS ATen_cuda_test.cc)
nv_test(ATen_pin_memory_kernel_test SRCS ATen_pin_memory_kernel_test.cu)
nv_test(c10_cuda_generator_test SRCS c10_cuda_generator_test.cc)
nv_test(c10_generator_impl_test SRCS c10_generator_impl_test.cc)
endif()
Expand Down
36 changes: 36 additions & 0 deletions test/legacy_test/test_tensor_fill_.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from op_test import get_device, get_places

import paddle
from paddle.base import core


class TensorFill_Test(unittest.TestCase):
Expand Down Expand Up @@ -74,6 +75,41 @@ def test_list():

self.assertRaises(TypeError, test_list)

def test_tensor_fill_pinned_memory(self):
if not (
core.is_compiled_with_cuda()
or (
hasattr(core, 'is_compiled_with_xpu')
and core.is_compiled_with_xpu()
)
):
self.skipTest("Pinned memory requires CUDA or XPU backend")

paddle.set_device('cpu')
tensor = paddle.zeros([2], dtype='int32').pin_memory()
self.assertTrue(
tensor.place.is_cuda_pinned_place()
or tensor.place.is_xpu_pinned_place()
)

tensor.fill_(123)
self.assertTrue(
tensor.place.is_cuda_pinned_place()
or tensor.place.is_xpu_pinned_place()
)
np.testing.assert_array_equal(
tensor.cpu().numpy(), np.array([123, 123], dtype='int32')
)

tensor.zero_()
self.assertTrue(
tensor.place.is_cuda_pinned_place()
or tensor.place.is_xpu_pinned_place()
)
np.testing.assert_array_equal(
tensor.cpu().numpy(), np.array([0, 0], dtype='int32')
)


if __name__ == '__main__':
unittest.main()
Loading