Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 61 additions & 99 deletions paddle/phi/api/include/compat/c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,29 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// #The file has been adapted from pytorch project
// #Licensed under BSD-style license -
// The file has been adapted from pytorch project
// Licensed under BSD-style license -
// https://github.com/pytorch/pytorch/blob/main/LICENSE
Comment on lines +15 to 17
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ShigureNyako 单独提一个 PR 把 repo 里所有这样的声明都参考这里的改动改一下吧

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ShigureNyako 单独提一个 PR 把 repo 里所有这样的声明都参考这里的改动改一下吧

收到。我会把这类 adapted from PyTorch 参考声明的补齐整理成单独 PR 跟进,避免把 repo-wide 清理和这次 ScalarType 兼容性改动混在一起。当前这个 PR 我这边保持通过,后续会单独开任务处理。

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ShigureNyako 单独提一个 PR 把 repo 里所有这样的声明都参考这里的改动改一下吧

已处理,单独提了 repo-wide 清理 PR #78590,按这次 ScalarType.h / 新增 compat wrapper 的声明格式统一了现有 malformed adapted-from-PyTorch reference blocks。

这个跟进 PR 只改参考声明注释,不涉及 ABI、dispatch 行为或调用方式变化。


#pragma once

#include <c10/core/ScalarType.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Exception.h>
#include <c10/util/Float4_e2m1fn_x2.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Float8_e8m0fnu.h>
#include <c10/util/Half.h>
#include <c10/util/bits.h>
#include <c10/util/complex.h>
#include <c10/util/qint32.h>
#include <c10/util/qint8.h>
#include <c10/util/quint2x4.h>
#include <c10/util/quint4x2.h>
#include <c10/util/quint8.h>
#include <sstream>

#include "paddle/common/macros.h"
Expand All @@ -41,38 +51,53 @@ struct dummy_uint1_7_t {};
template <unsigned int N>
struct dummy_int1_7_t {};

#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
_(uint8_t, UINT8, Byte) /* 0 */ \
_(int8_t, INT8, Char) /* 1 */ \
_(int16_t, INT16, Short) /* 2 */ \
_(int, INT32, Int) /* 3 */ \
_(int64_t, INT64, Long) /* 4 */ \
_(at::Half, FLOAT16, Half) \
_(float, FLOAT32, Float) /* 6 */ \
_(double, FLOAT64, Double) /* 7 */ \
_(c10::complex<float>, COMPLEX64, ComplexFloat) /* 9 */ \
_(c10::complex<double>, COMPLEX128, ComplexDouble) /* 10 */ \
_(bool, BOOL, Bool) /* 11 */ \
_(at::BFloat16, BFLOAT16, BFloat16) /* 15 */ \
_(c10::Float8_e5m2, FLOAT8_E5M2, Float8_e5m2) /* 23 */ \
_(c10::Float8_e4m3fn, FLOAT8_E4M3FN, Float8_e4m3fn) /* 24 */ \
_(uint16_t, UINT16, UInt16) /* 27 */ \
_(uint32_t, UINT32, UInt32) /* 28 */ \
_(uint64_t, UINT64, UInt64) /* 29 */ \
_(c10::dummy_uint1_7_t<1>, UInt1, UInt1) /* 30 */ \
_(c10::dummy_uint1_7_t<2>, UInt2, UInt2) /* 31 */ \
_(c10::dummy_uint1_7_t<3>, UInt3, UInt3) /* 32 */ \
_(c10::dummy_uint1_7_t<4>, UInt4, UInt4) /* 33 */ \
_(c10::dummy_uint1_7_t<5>, UInt5, UInt5) /* 34 */ \
_(c10::dummy_uint1_7_t<6>, UInt6, UInt6) /* 35 */ \
_(c10::dummy_uint1_7_t<7>, UInt7, UInt7) /* 36 */ \
_(c10::dummy_int1_7_t<1>, Int1, Int1) /* 37 */ \
_(c10::dummy_int1_7_t<2>, Int2, Int2) /* 38 */ \
_(c10::dummy_int1_7_t<3>, Int3, Int3) /* 39 */ \
_(c10::dummy_int1_7_t<4>, Int4, Int4) /* 40 */ \
_(c10::dummy_int1_7_t<5>, Int5, Int5) /* 41 */ \
_(c10::dummy_int1_7_t<6>, Int6, Int6) /* 42 */ \
_(c10::dummy_int1_7_t<7>, Int7, Int7) /* 43 */
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
_(uint8_t, UINT8, Byte) /* 0 */ \
_(int8_t, INT8, Char) /* 1 */ \
_(int16_t, INT16, Short) /* 2 */ \
_(int, INT32, Int) /* 3 */ \
_(int64_t, INT64, Long) /* 4 */ \
_(at::Half, FLOAT16, Half) /* 5 */ \
_(float, FLOAT32, Float) /* 6 */ \
_(double, FLOAT64, Double) /* 7 */ \
_(c10::complex<at::Half>, ComplexHalf, ComplexHalf) /* 8 */ \
_(c10::complex<float>, COMPLEX64, ComplexFloat) /* 9 */ \
_(c10::complex<double>, COMPLEX128, ComplexDouble) /* 10 */ \
_(bool, BOOL, Bool) /* 11 */ \
_(c10::qint8, QInt8, QInt8) /* 12 */ \
_(c10::quint8, QUInt8, QUInt8) /* 13 */ \
_(c10::qint32, QInt32, QInt32) /* 14 */ \
_(at::BFloat16, BFLOAT16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4, QUInt2x4) /* 17 */ \
_(c10::bits1x8, Bits1x8, Bits1x8) /* 18 */ \
_(c10::bits2x4, Bits2x4, Bits2x4) /* 19 */ \
_(c10::bits4x2, Bits4x2, Bits4x2) /* 20 */ \
_(c10::bits8, Bits8, Bits8) /* 21 */ \
_(c10::bits16, Bits16, Bits16) /* 22 */ \
_(c10::Float8_e5m2, FLOAT8_E5M2, Float8_e5m2) /* 23 */ \
_(c10::Float8_e4m3fn, FLOAT8_E4M3FN, Float8_e4m3fn) /* 24 */ \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
_(uint16_t, UINT16, UInt16) /* 27 */ \
_(uint32_t, UINT32, UInt32) /* 28 */ \
_(uint64_t, UINT64, UInt64) /* 29 */ \
_(c10::dummy_uint1_7_t<1>, UInt1, UInt1) /* 30 */ \
_(c10::dummy_uint1_7_t<2>, UInt2, UInt2) /* 31 */ \
_(c10::dummy_uint1_7_t<3>, UInt3, UInt3) /* 32 */ \
_(c10::dummy_uint1_7_t<4>, UInt4, UInt4) /* 33 */ \
_(c10::dummy_uint1_7_t<5>, UInt5, UInt5) /* 34 */ \
_(c10::dummy_uint1_7_t<6>, UInt6, UInt6) /* 35 */ \
_(c10::dummy_uint1_7_t<7>, UInt7, UInt7) /* 36 */ \
_(c10::dummy_int1_7_t<1>, Int1, Int1) /* 37 */ \
_(c10::dummy_int1_7_t<2>, Int2, Int2) /* 38 */ \
_(c10::dummy_int1_7_t<3>, Int3, Int3) /* 39 */ \
_(c10::dummy_int1_7_t<4>, Int4, Int4) /* 40 */ \
_(c10::dummy_int1_7_t<5>, Int5, Int5) /* 41 */ \
_(c10::dummy_int1_7_t<6>, Int6, Int6) /* 42 */ \
_(c10::dummy_int1_7_t<7>, Int7, Int7) /* 43 */ \
_(c10::Float8_e8m0fnu, Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \
_(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */

#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \
_(uint8_t, Byte) \
Expand All @@ -87,7 +112,8 @@ struct dummy_int1_7_t {};
_(c10::complex<double>, ComplexDouble) \
_(bool, Bool) \
_(at::BFloat16, BFloat16) \
_(at::Float8_e5m2, Float8_e5m2)
_(c10::Float8_e5m2, Float8_e5m2) \
_(c10::Float8_e4m3fn, Float8_e4m3fn)

#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
_(uint8_t, Byte) \
Expand Down Expand Up @@ -225,21 +251,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
#undef DEFINE_CONSTANT

constexpr ScalarType kComplexHalf = ScalarType::ComplexHalf;
constexpr ScalarType kQInt8 = ScalarType::QInt8;
constexpr ScalarType kQUInt8 = ScalarType::QUInt8;
constexpr ScalarType kQInt32 = ScalarType::QInt32;
constexpr ScalarType kQUInt4x2 = ScalarType::QUInt4x2;
constexpr ScalarType kQUInt2x4 = ScalarType::QUInt2x4;
constexpr ScalarType kBits1x8 = ScalarType::Bits1x8;
constexpr ScalarType kBits2x4 = ScalarType::Bits2x4;
constexpr ScalarType kBits4x2 = ScalarType::Bits4x2;
constexpr ScalarType kBits8 = ScalarType::Bits8;
constexpr ScalarType kBits16 = ScalarType::Bits16;
constexpr ScalarType kFloat8_e5m2fnuz = ScalarType::Float8_e5m2fnuz;
constexpr ScalarType kFloat8_e4m3fnuz = ScalarType::Float8_e4m3fnuz;
constexpr ScalarType kFloat8_e8m0fnu = ScalarType::Float8_e8m0fnu;
constexpr ScalarType kFloat4_e2m1fn_x2 = ScalarType::Float4_e2m1fn_x2;
constexpr ScalarType kUndefined = ScalarType::Undefined;

#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
Expand Down Expand Up @@ -298,36 +309,6 @@ inline const char* toString(ScalarType t) {

switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
case ScalarType::ComplexHalf:
return "ComplexHalf";
case ScalarType::QInt8:
return "QInt8";
case ScalarType::QUInt8:
return "QUInt8";
case ScalarType::QInt32:
return "QInt32";
case ScalarType::QUInt4x2:
return "QUInt4x2";
case ScalarType::QUInt2x4:
return "QUInt2x4";
case ScalarType::Bits1x8:
return "Bits1x8";
case ScalarType::Bits2x4:
return "Bits2x4";
case ScalarType::Bits4x2:
return "Bits4x2";
case ScalarType::Bits8:
return "Bits8";
case ScalarType::Bits16:
return "Bits16";
case ScalarType::Float8_e5m2fnuz:
return "Float8_e5m2fnuz";
case ScalarType::Float8_e4m3fnuz:
return "Float8_e4m3fnuz";
case ScalarType::Float8_e8m0fnu:
return "Float8_e8m0fnu";
case ScalarType::Float4_e2m1fn_x2:
return "Float4_e2m1fn_x2";
case ScalarType::Undefined:
return "Undefined";
default:
Expand All @@ -343,25 +324,6 @@ inline size_t elementSize(ScalarType t) {

switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE)
case ScalarType::ComplexHalf:
return sizeof(at::Half) * 2;
case ScalarType::QInt8:
case ScalarType::QUInt8:
case ScalarType::QUInt4x2:
case ScalarType::QUInt2x4:
case ScalarType::Bits1x8:
case ScalarType::Bits2x4:
case ScalarType::Bits4x2:
case ScalarType::Bits8:
case ScalarType::Float8_e5m2fnuz:
case ScalarType::Float8_e4m3fnuz:
case ScalarType::Float8_e8m0fnu:
case ScalarType::Float4_e2m1fn_x2:
return 1;
case ScalarType::QInt32:
return 4;
case ScalarType::Bits16:
return 2;
default:
TORCH_CHECK(false, "Unknown ScalarType");
}
Expand Down
65 changes: 65 additions & 0 deletions paddle/phi/api/include/compat/c10/util/Float4_e2m1fn_x2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// The file has been adapted from pytorch project
// Licensed under BSD-style license -
// https://github.com/pytorch/pytorch/blob/main/LICENSE

#pragma once

#include <cstdint>

namespace c10 {

/// Defines the Float4_e2m1fn_x2 type (4-bit floating-point, two elements packed
/// into one byte). This is the FP4 dtype from the OCP MX format spec
/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
/// Section 5.3.3)
///
/// Given two high precision values val0 and val1, here is the
/// binary configuration of their packed representation, from MSB to LSB:
///
/// original value | val1 : val0
/// ========================================
/// bit index (MSB==7, LSB==0) | 7654 : 3210
/// sign/exponent/mantissa | seem : seem

struct alignas(1) Float4_e2m1fn_x2 {
uint8_t val_;
Float4_e2m1fn_x2() = default;
explicit constexpr Float4_e2m1fn_x2(uint8_t val) : val_(val) {}
};

/// Comparison operators
inline bool operator==(const Float4_e2m1fn_x2& a, const Float4_e2m1fn_x2& b) {
return a.val_ == b.val_;
}

inline bool operator!=(const Float4_e2m1fn_x2& a, const Float4_e2m1fn_x2& b) {
return a.val_ != b.val_;
}

} // namespace c10

namespace at {
using c10::Float4_e2m1fn_x2;
using c10::operator!=;
using c10::operator==;
} // namespace at

namespace torch {
using c10::Float4_e2m1fn_x2;
using c10::operator!=;
using c10::operator==;
} // namespace torch
40 changes: 40 additions & 0 deletions paddle/phi/api/include/compat/c10/util/Float8_e4m3fnuz.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// The file has been adapted from pytorch project
// Licensed under BSD-style license -
// https://github.com/pytorch/pytorch/blob/main/LICENSE

#pragma once

#include <cstdint>

namespace c10 {

struct Float8_e4m3fnuz {
constexpr Float8_e4m3fnuz() = default;
explicit constexpr Float8_e4m3fnuz(uint8_t value) : x(value) {}

uint8_t x{0};
};

} // namespace c10

namespace at {
using c10::Float8_e4m3fnuz;
} // namespace at

namespace torch {
using c10::Float8_e4m3fnuz;
} // namespace torch
40 changes: 40 additions & 0 deletions paddle/phi/api/include/compat/c10/util/Float8_e5m2fnuz.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// The file has been adapted from pytorch project
// Licensed under BSD-style license -
// https://github.com/pytorch/pytorch/blob/main/LICENSE

#pragma once

#include <cstdint>

namespace c10 {

struct Float8_e5m2fnuz {
constexpr Float8_e5m2fnuz() = default;
explicit constexpr Float8_e5m2fnuz(uint8_t value) : x(value) {}

uint8_t x{0};
};

} // namespace c10

namespace at {
using c10::Float8_e5m2fnuz;
} // namespace at

namespace torch {
using c10::Float8_e5m2fnuz;
} // namespace torch
Loading
Loading