Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions paddle/phi/api/include/compat/ATen/ops/equal.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
namespace at {

inline bool equal(const at::Tensor& self, const at::Tensor& other) {
PD_CHECK(self.defined(),
"Expected a proper Tensor but got None (or an undefined Tensor in "
"C++)");
PD_CHECK(other.defined(),
"Expected a proper Tensor but got None (or an undefined Tensor in "
"C++)");
PD_CHECK(self.device() == other.device(),
"Cannot compare two tensors on "
"different devices. Got: ",
Expand Down
28 changes: 25 additions & 3 deletions paddle/phi/api/include/compat/ATen/ops/select.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,35 @@
namespace at {

inline at::Tensor select(const at::Tensor& self, int64_t dim, int64_t index) {
// Normalize dim to positive value for error messages
int64_t orig_dim = dim;
if (dim < 0) {
dim += self.dim();
}
// Handle negative indexing
// Check dim is valid
if (dim < 0 || dim >= self.dim()) {
PD_CHECK(false,
"select(): index ",
orig_dim,
" out of range for tensor of size ",
self.sizes(),
" at dimension ",
orig_dim);
}
// Handle negative index
int64_t orig_index = index;
if (index < 0) {
int64_t dim_size = self.size(dim);
index = dim_size + index;
index = self.size(dim) + index;
}
// Check index is valid
if (index < 0 || index >= self.size(dim)) {
PD_CHECK(false,
"select(): index ",
orig_index,
" out of range for tensor of size ",
self.sizes(),
" at dimension ",
orig_dim < 0 ? orig_dim + self.dim() : orig_dim);
}
Comment on lines +28 to 51
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New error handling was added for invalid dim / index values, but the existing select tests (e.g., test/cpp/compat/ATen_select_test.cc) don't cover these out-of-range branches. Add test cases that assert an exception is thrown for (1) dim out of range (including negative beyond -self.dim()), and (2) index out of range (including negative beyond -size(dim)).

Copilot uses AI. Check for mistakes.

return Tensor(
Expand Down
15 changes: 15 additions & 0 deletions paddle/phi/api/include/compat/ATen/ops/std.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ inline Tensor std_impl(const Tensor& self,
const std::vector<int64_t>& dims_vec,
double correction_value,
bool keepdim) {
// Validate dimensions before processing
int64_t ndim = self.dim();
for (int64_t d : dims_vec) {
int64_t dim_idx = d < 0 ? d + ndim : d;
if (dim_idx < 0 || dim_idx >= ndim) {
PD_CHECK(false,
"Dimension out of range (expected to be in range of [",
-ndim,
", ",
ndim - 1,
"], but got ",
d,
")");
}
}
phi::IntArray dims_int_array(dims_vec);
paddle::Tensor tensor = self._PD_GetInner();

Expand Down
115 changes: 8 additions & 107 deletions paddle/phi/api/include/compat/c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,58 +130,15 @@ struct dummy_int1_7_t {};
_(uint32_t, UINT32, UInt32)

enum class PADDLE_API ScalarType : int8_t {
Byte = 0,
Char = 1,
Short = 2,
Int = 3,
Long = 4,
Half = 5,
Float = 6,
Double = 7,
ComplexHalf = 8,
ComplexFloat = 9,
ComplexDouble = 10,
Bool = 11,
QInt8 = 12,
QUInt8 = 13,
QInt32 = 14,
BFloat16 = 15,
QUInt4x2 = 16,
QUInt2x4 = 17,
Bits1x8 = 18,
Bits2x4 = 19,
Bits4x2 = 20,
Bits8 = 21,
Bits16 = 22,
Float8_e5m2 = 23,
Float8_e4m3fn = 24,
Float8_e5m2fnuz = 25,
Float8_e4m3fnuz = 26,
UInt16 = 27,
UInt32 = 28,
UInt64 = 29,
UInt1 = 30,
UInt2 = 31,
UInt3 = 32,
UInt4 = 33,
UInt5 = 34,
UInt6 = 35,
UInt7 = 36,
Int1 = 37,
Int2 = 38,
Int3 = 39,
Int4 = 40,
Int5 = 41,
Int6 = 42,
Int7 = 43,
Float8_e8m0fnu = 44,
Float4_e2m1fn_x2 = 45,
Undefined = 46,
NumOptions = 47
#define DEFINE_ST_ENUM_VAL_(_1, _2, n) n,
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_)
#undef DEFINE_ENUM_ST_ENUM_VAL_
#define DEFINE_ST_ENUM_VAL_FOR_QINTS_(_1, n) n,
AT_FORALL_QINT_TYPES(DEFINE_ST_ENUM_VAL_FOR_QINTS_)
#undef DEFINE_ST_ENUM_VAL_FOR_QINTS_
Undefined,
NumOptions
};

constexpr uint16_t NumScalarTypes =
static_cast<uint16_t>(ScalarType::NumOptions);
namespace impl {

// These are used to map ScalarTypes to C++ types.
Expand Down Expand Up @@ -281,38 +238,6 @@ inline const char* toString(ScalarType t) {

switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
case ScalarType::QInt8:
return "QInt8";
case ScalarType::QUInt8:
return "QUInt8";
case ScalarType::QInt32:
return "QInt32";
case ScalarType::QUInt4x2:
return "QUInt4x2";
case ScalarType::QUInt2x4:
return "QUInt2x4";
case ScalarType::ComplexHalf:
return "ComplexHalf";
case ScalarType::Bits1x8:
return "Bits1x8";
case ScalarType::Bits2x4:
return "Bits2x4";
case ScalarType::Bits4x2:
return "Bits4x2";
case ScalarType::Bits8:
return "Bits8";
case ScalarType::Bits16:
return "Bits16";
case ScalarType::Float8_e5m2fnuz:
return "Float8_e5m2fnuz";
case ScalarType::Float8_e4m3fnuz:
return "Float8_e4m3fnuz";
case ScalarType::Float8_e8m0fnu:
return "Float8_e8m0fnu";
case ScalarType::Float4_e2m1fn_x2:
return "Float4_e2m1fn_x2";
case ScalarType::Undefined:
return "Undefined";
default:
return "UNKNOWN_SCALAR";
}
Expand All @@ -326,18 +251,6 @@ inline size_t elementSize(ScalarType t) {

switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE)
case ScalarType::QInt8:
case ScalarType::QUInt8:
case ScalarType::QUInt4x2:
case ScalarType::QUInt2x4:
case ScalarType::Bits1x8:
case ScalarType::Bits2x4:
case ScalarType::Bits4x2:
case ScalarType::Bits8:
return 1;
case ScalarType::QInt32:
case ScalarType::Bits16:
return 4;
default:
TORCH_CHECK(false, "Unknown ScalarType");
}
Comment on lines -294 to 367
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ScalarType still declares the quantized enums via AT_FORALL_QINT_TYPES(...), but toString() and elementSize() no longer handle QInt8/QUInt8/QInt32/QUInt4x2/QUInt2x4 (the switches only expand AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS). This is a behavioral regression: printing these scalar types will return "UNKNOWN_SCALAR", and elementSize() will throw. Add QInt cases back (e.g., by also expanding AT_FORALL_QINT_TYPES in these switches) and ensure their element sizes match the intended semantics.

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -410,7 +323,6 @@ inline bool isSignedType(ScalarType t) {
// Complex types (treated as signed)
case ScalarType::ComplexFloat:
case ScalarType::ComplexDouble:
case ScalarType::ComplexHalf:
return true;

// Signed quantized types (explicitly return true)
Expand Down Expand Up @@ -438,22 +350,11 @@ inline bool isSignedType(ScalarType t) {
case ScalarType::QUInt8:
case ScalarType::QUInt4x2:
case ScalarType::QUInt2x4:
case ScalarType::Bits1x8:
case ScalarType::Bits2x4:
case ScalarType::Bits4x2:
case ScalarType::Bits8:
case ScalarType::Bits16:
return false;

// Bool is unsigned (using numeric_limits)
CASE_ISSIGNED(Bool);

case ScalarType::Float8_e5m2fnuz:
case ScalarType::Float8_e4m3fnuz:
case ScalarType::Float8_e8m0fnu:
case ScalarType::Float4_e2m1fn_x2:
return true;

// Invalid/undefined types - should not happen in normal usage
// If this is hit, it indicates a programming error or unsupported type
case ScalarType::Undefined:
Expand Down
8 changes: 5 additions & 3 deletions paddle/phi/api/include/compat/c10/core/Stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ void* Stream::native_handle() const {
return reinterpret_cast<void*>(static_cast<intptr_t>(id_));
}
#endif
PADDLE_THROW(::common::errors::Unimplemented(
"c10::Stream::native_handle() is not supported for device type %d",
static_cast<int>(device_type())));
// Match PyTorch error message format for unsupported device types
PD_CHECK(false,
"native_handle() is not supported for this device type (",
Comment on lines +48 to +49
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PD_CHECK(false, ...) will throw a PD_Exception that appends an additional "Expected false, but it's not satisfied." context. If the goal is to match PyTorch's error message format exactly (as the comment states), this will not produce an exact match. Prefer throwing directly (e.g., TORCH_CHECK(false, ...) / PD_THROW(...)) or pass the real condition into PD_CHECK(...) instead of false so the extra context remains meaningful.

Suggested change
PD_CHECK(false,
"native_handle() is not supported for this device type (",
PD_THROW("native_handle() is not supported for this device type (",

Copilot uses AI. Check for mistakes.
static_cast<int>(device_type()),
")");
}

bool Stream::query() const {
Expand Down
5 changes: 2 additions & 3 deletions paddle/phi/api/include/compat/c10/core/Stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,8 @@ class Stream final {
};

inline std::ostream& operator<<(std::ostream& os, const Stream& s) {
os << "Stream(device_type=" << static_cast<int>(s.device_type())
<< ", device_index=" << static_cast<int>(s.device_index())
<< ", id=" << s.id() << ")";
// Format: "stream {id} on device {device_type}:{device_index}"
os << "stream " << s.id() << " on device " << s.device();
return os;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,5 @@ void synchronize(int64_t device_index = -1);

} // namespace torch::cuda
namespace at::cuda {
using torch::cuda::device_count;
using torch::cuda::is_available;
using torch::cuda::synchronize;
} // namespace at::cuda
76 changes: 0 additions & 76 deletions test/cpp/compat/c10_ScalarType_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAGuard.h>
#endif
#include <sstream>
#include "ATen/ATen.h"
#include "gtest/gtest.h"
#include "paddle/phi/common/float16.h"
Expand Down Expand Up @@ -91,78 +90,3 @@ TEST(TensorBaseTest, TypeCheckingAPIs) {
ASSERT_FALSE(uint8_tensor.is_signed());
ASSERT_FALSE(bool_tensor.is_signed());
}

TEST(ScalarTypeCompatTest, ScalarTypeUtilityBranches) {
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits1x8), "Bits1x8");
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits16), "Bits16");
EXPECT_STREQ(c10::toString(c10::ScalarType::Float8_e5m2fnuz),
"Float8_e5m2fnuz");
EXPECT_STREQ(c10::toString(c10::ScalarType::Float8_e4m3fnuz),
"Float8_e4m3fnuz");
EXPECT_STREQ(c10::toString(c10::ScalarType::Float8_e8m0fnu),
"Float8_e8m0fnu");
EXPECT_STREQ(c10::toString(c10::ScalarType::Float4_e2m1fn_x2),
"Float4_e2m1fn_x2");
EXPECT_STREQ(c10::toString(c10::ScalarType::Undefined), "Undefined");
EXPECT_STREQ(c10::toString(static_cast<c10::ScalarType>(-1)),
"UNKNOWN_SCALAR");

EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt8), static_cast<size_t>(1));
EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt4x2),
static_cast<size_t>(1));
EXPECT_EQ(c10::elementSize(c10::ScalarType::QInt32), static_cast<size_t>(4));
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits16), static_cast<size_t>(4));
EXPECT_THROW(c10::elementSize(c10::ScalarType::Undefined), ::std::exception);

EXPECT_TRUE(c10::isIntegralType(c10::ScalarType::Bool, true));
EXPECT_FALSE(c10::isIntegralType(c10::ScalarType::Bool, false));
EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e5m2));
EXPECT_FALSE(c10::isFloat8Type(c10::ScalarType::Float8_e4m3fnuz));
EXPECT_TRUE(c10::isReducedFloatingType(c10::ScalarType::BFloat16));
EXPECT_TRUE(c10::isFloatingType(c10::ScalarType::Float));
EXPECT_FALSE(c10::isComplexType(c10::ScalarType::ComplexHalf));

EXPECT_TRUE(c10::isSignedType(c10::ScalarType::Int1));
EXPECT_FALSE(c10::isSignedType(c10::ScalarType::UInt3));
EXPECT_FALSE(c10::isSignedType(c10::ScalarType::QUInt8));
EXPECT_TRUE(c10::isSignedType(c10::ScalarType::Float8_e5m2fnuz));
EXPECT_THROW(c10::isSignedType(c10::ScalarType::Undefined), ::std::exception);

std::ostringstream oss;
oss << c10::ScalarType::UInt7;
EXPECT_EQ(oss.str(), "UInt7");
}

TEST(ScalarTypeCompatTest, AdditionalEnumAndPredicateBranches) {
EXPECT_STREQ(c10::toString(c10::ScalarType::QInt8), "QInt8");
EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt8), "QUInt8");
EXPECT_STREQ(c10::toString(c10::ScalarType::QInt32), "QInt32");
EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt4x2), "QUInt4x2");
EXPECT_STREQ(c10::toString(c10::ScalarType::QUInt2x4), "QUInt2x4");
EXPECT_STREQ(c10::toString(c10::ScalarType::ComplexHalf), "ComplexHalf");
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits2x4), "Bits2x4");
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits4x2), "Bits4x2");
EXPECT_STREQ(c10::toString(c10::ScalarType::Bits8), "Bits8");

EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt8), static_cast<size_t>(1));
EXPECT_EQ(c10::elementSize(c10::ScalarType::QUInt2x4),
static_cast<size_t>(1));
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits2x4), static_cast<size_t>(1));
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits4x2), static_cast<size_t>(1));
EXPECT_EQ(c10::elementSize(c10::ScalarType::Bits8), static_cast<size_t>(1));

EXPECT_TRUE(c10::isIntegralType(c10::ScalarType::UInt64, false));
EXPECT_FALSE(c10::isIntegralType(c10::ScalarType::Float, true));
EXPECT_TRUE(c10::isFloat8Type(c10::ScalarType::Float8_e4m3fn));
EXPECT_TRUE(c10::isReducedFloatingType(c10::ScalarType::Half));
EXPECT_FALSE(c10::isReducedFloatingType(c10::ScalarType::Float));
EXPECT_TRUE(c10::isFloatingType(c10::ScalarType::Half));
EXPECT_TRUE(c10::isComplexType(c10::ScalarType::ComplexFloat));

EXPECT_TRUE(c10::isSignedType(c10::ScalarType::QInt8));
EXPECT_TRUE(c10::isSignedType(c10::ScalarType::ComplexHalf));
EXPECT_FALSE(c10::isSignedType(c10::ScalarType::Byte));
EXPECT_FALSE(c10::isSignedType(c10::ScalarType::Bool));
EXPECT_THROW(c10::isSignedType(c10::ScalarType::NumOptions),
::std::exception);
}
Loading