Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions paddle/phi/api/include/compat/ATen/core/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/TensorOptions.h>
#include <utils/int_array_ref_conversion.h>
#include <utils/mapped_pinned_tensor.h>
#include <utils/scalar_type_conversion.h>
#include <algorithm>
#include <iostream>
Expand Down Expand Up @@ -112,14 +113,10 @@ class PADDLE_API TensorBase {
return backend_str + scalar_type_str + "Type";
}

// Returns the tensor's current data pointer. Storage mutations flow through
// the compat holder view, so tensor.data_ptr() stays aligned with storage()
// while preserving tensor-specific offsets for views.
// Returns the pointer kernels should use. For CUDA-pinned tensors this is
// the mapped device-visible alias rather than the raw host address.
void* data_ptr() const {
if (!tensor_.defined()) {
return nullptr;
}
return const_cast<void*>(tensor_.data());
return compat::_PD_GetKernelVisibleDataPtr(tensor_);
}
template <typename T>
T* data_ptr() const {
Expand Down
15 changes: 8 additions & 7 deletions paddle/phi/api/include/compat/ATen/core/TensorBody.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <c10/core/Stream.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/util/OptionalArrayRef.h>
#include <utils/mapped_pinned_tensor.h>
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/int_array.h"
Expand Down Expand Up @@ -126,10 +127,12 @@ class Tensor : public TensorBase {
return *this;
}

void* data_ptr() const { return const_cast<void*>(tensor_.data()); }
void* data_ptr() const {
return compat::_PD_GetKernelVisibleDataPtr(tensor_);
}
template <typename T>
T* data_ptr() const {
return const_cast<T*>(tensor_.data<T>());
return static_cast<T*>(data_ptr());
}

template <typename T>
Expand Down Expand Up @@ -176,9 +179,7 @@ class Tensor : public TensorBase {
#endif
}

const void* const_data_ptr() const {
return const_cast<void*>(tensor_.data());
}
const void* const_data_ptr() const { return data_ptr(); }

template <typename T, std::enable_if_t<!std::is_const_v<T>, int> = 0>
const T* const_data_ptr() const {
Expand All @@ -190,7 +191,7 @@ class Tensor : public TensorBase {
return TensorBase::const_data_ptr<T>();
}

void* mutable_data_ptr() const { return const_cast<void*>(tensor_.data()); }
void* mutable_data_ptr() const { return data_ptr(); }

template <typename T>
T* mutable_data_ptr() const {
Expand Down Expand Up @@ -489,7 +490,7 @@ class Tensor : public TensorBase {
#endif
}

return tensor_.copy_to(pinned_place, true);
return compat::_PD_CopyTensorToPinnedPlace(tensor_, pinned_place);
}

at::Tensor narrow_copy(int64_t dim, int64_t start, int64_t length) const;
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/api/include/compat/ATen/core/TensorMethods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,24 @@ void check_type(const TensorBase& tensor,
template <> \
PADDLE_API const T* TensorBase::const_data_ptr() const { \
check_type(*this, ScalarType::name, #name); \
return const_cast<T*>(tensor_.data<T>()); \
return static_cast<const T*>(data_ptr()); \
} \
\
template <> \
PADDLE_API const T* TensorBase::const_data_ptr<const T>() const { \
check_type(*this, ScalarType::name, #name); \
return const_cast<T*>(tensor_.data<std::remove_const_t<T>>()); \
return static_cast<const std::remove_const_t<T>*>(data_ptr()); \
} \
\
template <> \
PADDLE_API T* TensorBase::mutable_data_ptr() const { \
check_type(*this, ScalarType::name, #name); \
return const_cast<PaddleTensor&>(tensor_).data<T>(); \
return static_cast<T*>(data_ptr()); \
} \
\
template <> \
PADDLE_API T* TensorBase::data_ptr() const { \
return const_cast<T*>(tensor_.data<T>()); \
return static_cast<T*>(data_ptr()); \
}

AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CAST) // missing half and float16
Expand Down
9 changes: 6 additions & 3 deletions paddle/phi/api/include/compat/ATen/ops/empty.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <ATen/core/Tensor.h>
#include <c10/core/TensorOptions.h>
#include <utils/dense_sparse_conversion.h>
#include <utils/mapped_pinned_tensor.h>
#include <utils/pinned_place.h>
#include <optional>
#include <string_view>

Expand All @@ -38,11 +40,12 @@ inline at::Tensor empty(
PD_THROW(
"pin_memory=true requires device to be CPU, but got non-CPU device");
}
auto dense = paddle::experimental::empty(
phi::Place pinned_place =
compat::_PD_GetCreatePinnedPlace(options._PD_GetPlace());
auto dense = compat::_PD_EmptyPinnedTensor(
size._PD_ToPaddleIntArray(),
compat::_PD_AtenScalarTypeToPhiDataType(options.dtype()),
phi::CPUPlace());
dense = dense.copy_to(phi::GPUPinnedPlace(), /*blocking=*/true);
pinned_place);
return compat::_PD_ConvertToSparseIfNeeded(dense, options.layout());
}
auto dense = paddle::experimental::empty(
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/api/include/compat/ATen/ops/empty_like.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <ATen/core/Tensor.h>
#include <c10/core/TensorOptions.h>
#include <utils/dense_sparse_conversion.h>
#include <utils/mapped_pinned_tensor.h>
#include <utils/pinned_place.h>

#include <optional>
Expand Down Expand Up @@ -49,7 +50,7 @@ inline at::Tensor empty_like(
phi::CPUPlace());
phi::Place base_place = options._PD_GetPlace();
phi::Place pinned_place = compat::_PD_GetCreatePinnedPlace(base_place);
dense = dense_cpu.copy_to(pinned_place, /*blocking=*/true);
dense = compat::_PD_CopyTensorToPinnedPlace(dense_cpu, pinned_place);
} else {
auto place = options.device_opt().value_or(self.device());
dense = paddle::experimental::empty_like(
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/api/include/compat/ATen/ops/new_empty.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <ATen/core/Tensor.h>
#include <c10/core/TensorOptions.h>
#include <utils/mapped_pinned_tensor.h>
#include <utils/pinned_place.h>
#include <optional>
#include <string_view>
Expand Down Expand Up @@ -43,9 +44,8 @@ inline Tensor Tensor::new_empty(at::IntArrayRef size,
"pin_memory=true requires device to be CPU, but got non-CPU device");
}
phi::Place pinned_place = compat::_PD_GetCreatePinnedPlace(pd_place);
auto dense_cpu = paddle::experimental::empty(
size._PD_ToPaddleIntArray(), pd_dtype, phi::CPUPlace());
result = dense_cpu.copy_to(pinned_place, /*blocking=*/true);
result = compat::_PD_EmptyPinnedTensor(
size._PD_ToPaddleIntArray(), pd_dtype, pinned_place);
} else {
result = paddle::experimental::empty(
size._PD_ToPaddleIntArray(), pd_dtype, pd_place);
Expand Down
149 changes: 149 additions & 0 deletions paddle/phi/api/include/compat/utils/mapped_pinned_tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <cstddef>
#include <cstring>
#include <memory>

#include "paddle/common/ddim.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/allocator.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/tensor_meta.h"

#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#elif defined(PADDLE_WITH_CUDA)
#include <cuda_runtime_api.h>
#endif

namespace compat {

inline void _PD_FreeMappedPinnedAllocation(phi::Allocation* allocation) {
if (allocation == nullptr || allocation->ptr() == nullptr) {
return;
}
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipHostFree(allocation->ptr()));
#elif defined(PADDLE_WITH_CUDA)
PADDLE_ENFORCE_GPU_SUCCESS(cudaFreeHost(allocation->ptr()));
#endif
}

inline std::shared_ptr<phi::Allocation> _PD_CreateMappedPinnedAllocation(
size_t bytes, const phi::Place& pinned_place) {
if (bytes == 0) {
return std::make_shared<phi::Allocation>(nullptr, 0, pinned_place);
}

void* ptr = nullptr;
#ifdef PADDLE_WITH_HIP
constexpr unsigned int kMappedPinnedFlags =
hipHostMallocPortable | hipHostMallocMapped;
PADDLE_ENFORCE_GPU_SUCCESS(hipHostMalloc(&ptr, bytes, kMappedPinnedFlags));
#elif defined(PADDLE_WITH_CUDA)
constexpr unsigned int kMappedPinnedFlags =
cudaHostAllocPortable | cudaHostAllocMapped;
PADDLE_ENFORCE_GPU_SUCCESS(cudaHostAlloc(&ptr, bytes, kMappedPinnedFlags));
#else
PD_THROW("Mapped GPU pinned memory requires CUDA or HIP support.");
#endif

return std::make_shared<phi::Allocation>(
ptr, bytes, &_PD_FreeMappedPinnedAllocation, pinned_place);
}

inline paddle::Tensor _PD_EmptyPinnedTensor(const paddle::IntArray& shape,
phi::DataType dtype,
const phi::Place& pinned_place) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::is_cuda_pinned_place(pinned_place)) {
auto dims = common::make_ddim(shape.GetData());
auto meta = phi::DenseTensorMeta(dtype, dims);
auto bytes =
static_cast<size_t>(common::product(dims)) * phi::SizeOf(dtype);
auto holder = _PD_CreateMappedPinnedAllocation(bytes, pinned_place);
return paddle::Tensor(std::make_shared<phi::DenseTensor>(holder, meta));
}
#endif

auto dense = paddle::experimental::empty(shape, dtype, phi::CPUPlace());
return dense.copy_to(pinned_place, /*blocking=*/true);
}

inline paddle::Tensor _PD_CopyTensorToPinnedPlace(
const paddle::Tensor& src, const phi::Place& pinned_place) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::is_cuda_pinned_place(pinned_place)) {
auto src_dense = std::dynamic_pointer_cast<phi::DenseTensor>(src.impl());
if (src_dense && src_dense->meta().is_contiguous() &&
src_dense->meta().offset == 0) {
auto bytes = src_dense->memory_size();
auto holder = _PD_CreateMappedPinnedAllocation(bytes, pinned_place);
if (bytes > 0) {
std::memcpy(holder->ptr(), src_dense->data(), bytes);
}
return paddle::Tensor(
std::make_shared<phi::DenseTensor>(holder, src_dense->meta()));
}
}
#endif

return src.copy_to(pinned_place, /*blocking=*/true);
}

inline void* _PD_GetKernelVisibleDataPtr(const paddle::Tensor& tensor) {
if (!tensor.defined()) {
return nullptr;
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::is_cuda_pinned_place(tensor.place())) {
auto dense = std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
if (!dense) {
return const_cast<void*>(tensor.data());
}

auto holder = dense->Holder();
if (!holder || holder->ptr() == nullptr) {
return const_cast<void*>(tensor.data());
}

void* mapped_base = nullptr;
#ifdef PADDLE_WITH_HIP
auto err = hipHostGetDevicePointer(&mapped_base, holder->ptr(), 0);
if (err == hipSuccess && mapped_base != nullptr) {
return static_cast<char*>(mapped_base) + dense->meta().offset;
}
(void)hipGetLastError();
#elif defined(PADDLE_WITH_CUDA)
auto err = cudaHostGetDevicePointer(&mapped_base, holder->ptr(), 0);
if (err == cudaSuccess && mapped_base != nullptr) {
return static_cast<char*>(mapped_base) + dense->meta().offset;
}
(void)cudaGetLastError();
#endif
}
#endif

return const_cast<void*>(tensor.data());
}
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_PD_GetKernelVisibleDataPtr() falls back to returning tensor.data() when cudaHostGetDevicePointer/hipHostGetDevicePointer fails. Since the default GPUPinned allocator uses cudaHostAllocPortable (no Mapped flag) (see paddle/phi/core/memory/allocation/pinned_allocator.cc:40-46), this failure path will be hit for pinned tensors created outside the new mapped allocation helpers, and data_ptr() will again return a non-device-visible host address (contradicting the new “pointer kernels should use” contract). Consider either (a) ensuring all compat pinned-tensor creation/copy paths use _PD_CreateMappedPinnedAllocation (or another mapped/registered strategy), or (b) making this function throw/explicitly error when it cannot obtain a device-visible alias for a CUDA-pinned tensor, to avoid silently returning an unsafe pointer for kernels.

Copilot uses AI. Check for mistakes.

} // namespace compat
48 changes: 48 additions & 0 deletions test/cpp/compat/ATen_pin_memory_kernel_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#if defined(PADDLE_WITH_CUDA)

#include <ATen/ops/empty.h>
#include <c10/cuda/CUDAStream.h>

#include "gtest/gtest.h"

namespace {

__global__ void WriteScalarKernel(int* dst, int value) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
dst[0] = value;
}
}

} // namespace

TEST(ATenPinMemoryKernelTest, KernelCanWritePinnedTensorDirectly) {
auto stream = at::cuda::getCurrentCUDAStream();
auto tensor =
at::empty({1}, at::TensorOptions().dtype(at::kInt).pinned_memory(true));

ASSERT_TRUE(tensor.is_pinned());
ASSERT_FALSE(tensor.is_cuda());

tensor._PD_GetInner().data<int>()[0] = 0;
WriteScalarKernel<<<1, 1, 0, stream>>>(tensor.data_ptr<int>(), 123);
ASSERT_EQ(cudaGetLastError(), cudaSuccess);
ASSERT_EQ(cudaStreamSynchronize(stream), cudaSuccess);

EXPECT_EQ(tensor._PD_GetInner().data<int>()[0], 123);
}

#endif // PADDLE_WITH_CUDA
1 change: 1 addition & 0 deletions test/cpp/compat/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ if(NOT WIN32)
if(WITH_GPU)
nv_test(ATen_CUDABlas_test SRCS ATen_CUDABlas_test.cc)
nv_test(ATen_cuda_test SRCS ATen_cuda_test.cc)
nv_test(ATen_pin_memory_kernel_test SRCS ATen_pin_memory_kernel_test.cu)
nv_test(c10_cuda_generator_test SRCS c10_cuda_generator_test.cc)
nv_test(c10_generator_impl_test SRCS c10_generator_impl_test.cc)
endif()
Expand Down
Loading