diff --git a/paddle/phi/api/include/compat/c10/core/ScalarType.h b/paddle/phi/api/include/compat/c10/core/ScalarType.h index e72a013c70f697..a040f353a5816a 100644 --- a/paddle/phi/api/include/compat/c10/core/ScalarType.h +++ b/paddle/phi/api/include/compat/c10/core/ScalarType.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// #The file has been adapted from pytorch project -// #Licensed under BSD-style license - +// The file has been adapted from pytorch project +// Licensed under BSD-style license - // https://github.com/pytorch/pytorch/blob/main/LICENSE #pragma once @@ -21,10 +21,20 @@ #include #include #include +#include #include +#include #include +#include +#include #include +#include #include +#include +#include +#include +#include +#include #include #include "paddle/common/macros.h" @@ -41,38 +51,53 @@ struct dummy_uint1_7_t {}; template struct dummy_int1_7_t {}; -#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ - _(uint8_t, UINT8, Byte) /* 0 */ \ - _(int8_t, INT8, Char) /* 1 */ \ - _(int16_t, INT16, Short) /* 2 */ \ - _(int, INT32, Int) /* 3 */ \ - _(int64_t, INT64, Long) /* 4 */ \ - _(at::Half, FLOAT16, Half) \ - _(float, FLOAT32, Float) /* 6 */ \ - _(double, FLOAT64, Double) /* 7 */ \ - _(c10::complex, COMPLEX64, ComplexFloat) /* 9 */ \ - _(c10::complex, COMPLEX128, ComplexDouble) /* 10 */ \ - _(bool, BOOL, Bool) /* 11 */ \ - _(at::BFloat16, BFLOAT16, BFloat16) /* 15 */ \ - _(c10::Float8_e5m2, FLOAT8_E5M2, Float8_e5m2) /* 23 */ \ - _(c10::Float8_e4m3fn, FLOAT8_E4M3FN, Float8_e4m3fn) /* 24 */ \ - _(uint16_t, UINT16, UInt16) /* 27 */ \ - _(uint32_t, UINT32, UInt32) /* 28 */ \ - _(uint64_t, UINT64, UInt64) /* 29 */ \ - _(c10::dummy_uint1_7_t<1>, UInt1, UInt1) /* 30 */ \ - _(c10::dummy_uint1_7_t<2>, UInt2, UInt2) /* 31 */ \ - _(c10::dummy_uint1_7_t<3>, UInt3, UInt3) /* 32 */ \ - _(c10::dummy_uint1_7_t<4>, UInt4, UInt4) /* 33 */ \ - _(c10::dummy_uint1_7_t<5>, UInt5, UInt5) /* 34 */ \ - _(c10::dummy_uint1_7_t<6>, UInt6, UInt6) /* 35 */ \ - _(c10::dummy_uint1_7_t<7>, UInt7, UInt7) /* 36 */ \ - _(c10::dummy_int1_7_t<1>, Int1, Int1) /* 37 */ \ - _(c10::dummy_int1_7_t<2>, Int2, Int2) /* 38 */ \ - _(c10::dummy_int1_7_t<3>, Int3, Int3) /* 39 */ \ - _(c10::dummy_int1_7_t<4>, Int4, Int4) /* 40 */ \ - _(c10::dummy_int1_7_t<5>, Int5, Int5) /* 41 */ \ - _(c10::dummy_int1_7_t<6>, Int6, Int6) /* 42 */ \ - _(c10::dummy_int1_7_t<7>, Int7, Int7) /* 43 */ +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ + _(uint8_t, UINT8, Byte) /* 0 */ \ + _(int8_t, INT8, Char) /* 1 */ \ + _(int16_t, INT16, Short) /* 2 */ \ + _(int, INT32, Int) /* 3 */ \ + _(int64_t, INT64, Long) /* 4 */ \ + _(at::Half, FLOAT16, Half) /* 5 */ \ + _(float, FLOAT32, Float) /* 6 */ \ + _(double, FLOAT64, Double) /* 7 */ \ + _(c10::complex, ComplexHalf, ComplexHalf) /* 8 */ \ + _(c10::complex, COMPLEX64, ComplexFloat) /* 9 */ \ + _(c10::complex, COMPLEX128, ComplexDouble) /* 10 */ \ + _(bool, BOOL, Bool) /* 11 */ \ + _(c10::qint8, QInt8, QInt8) /* 12 */ \ + _(c10::quint8, QUInt8, QUInt8) /* 13 */ \ + _(c10::qint32, QInt32, QInt32) /* 14 */ \ + _(at::BFloat16, BFLOAT16, BFloat16) /* 15 */ \ + _(c10::quint4x2, QUInt4x2, QUInt4x2) /* 16 */ \ + _(c10::quint2x4, QUInt2x4, QUInt2x4) /* 17 */ \ + _(c10::bits1x8, Bits1x8, Bits1x8) /* 18 */ \ + _(c10::bits2x4, Bits2x4, Bits2x4) /* 19 */ \ + _(c10::bits4x2, Bits4x2, Bits4x2) /* 20 */ \ + _(c10::bits8, Bits8, Bits8) /* 21 */ \ + _(c10::bits16, Bits16, Bits16) /* 22 */ \ + _(c10::Float8_e5m2, FLOAT8_E5M2, Float8_e5m2) /* 23 */ \ + _(c10::Float8_e4m3fn, FLOAT8_E4M3FN, Float8_e4m3fn) /* 24 */ \ + _(c10::Float8_e5m2fnuz, Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \ + _(c10::Float8_e4m3fnuz, Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \ + _(uint16_t, UINT16, UInt16) /* 27 */ \ + _(uint32_t, UINT32, UInt32) /* 28 */ \ + _(uint64_t, UINT64, UInt64) /* 29 */ \ + _(c10::dummy_uint1_7_t<1>, UInt1, UInt1) /* 30 */ \ + _(c10::dummy_uint1_7_t<2>, UInt2, UInt2) /* 31 */ \ + _(c10::dummy_uint1_7_t<3>, UInt3, UInt3) /* 32 */ \ + _(c10::dummy_uint1_7_t<4>, UInt4, UInt4) /* 33 */ \ + _(c10::dummy_uint1_7_t<5>, UInt5, UInt5) /* 34 */ \ + _(c10::dummy_uint1_7_t<6>, UInt6, UInt6) /* 35 */ \ + _(c10::dummy_uint1_7_t<7>, UInt7, UInt7) /* 36 */ \ + _(c10::dummy_int1_7_t<1>, Int1, Int1) /* 37 */ \ + _(c10::dummy_int1_7_t<2>, Int2, Int2) /* 38 */ \ + _(c10::dummy_int1_7_t<3>, Int3, Int3) /* 39 */ \ + _(c10::dummy_int1_7_t<4>, Int4, Int4) /* 40 */ \ + _(c10::dummy_int1_7_t<5>, Int5, Int5) /* 41 */ \ + _(c10::dummy_int1_7_t<6>, Int6, Int6) /* 42 */ \ + _(c10::dummy_int1_7_t<7>, Int7, Int7) /* 43 */ \ + _(c10::Float8_e8m0fnu, Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \ + _(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */ #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \ _(uint8_t, Byte) \ @@ -87,7 +112,8 @@ struct dummy_int1_7_t {}; _(c10::complex, ComplexDouble) \ _(bool, Bool) \ _(at::BFloat16, BFloat16) \ - _(at::Float8_e5m2, Float8_e5m2) + _(c10::Float8_e5m2, Float8_e5m2) \ + _(c10::Float8_e4m3fn, Float8_e4m3fn) #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \ _(uint8_t, Byte) \ @@ -225,21 +251,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT) #undef DEFINE_CONSTANT -constexpr ScalarType kComplexHalf = ScalarType::ComplexHalf; -constexpr ScalarType kQInt8 = ScalarType::QInt8; -constexpr ScalarType kQUInt8 = ScalarType::QUInt8; -constexpr ScalarType kQInt32 = ScalarType::QInt32; -constexpr ScalarType kQUInt4x2 = ScalarType::QUInt4x2; -constexpr ScalarType kQUInt2x4 = ScalarType::QUInt2x4; -constexpr ScalarType kBits1x8 = ScalarType::Bits1x8; -constexpr ScalarType kBits2x4 = ScalarType::Bits2x4; -constexpr ScalarType kBits4x2 = ScalarType::Bits4x2; -constexpr ScalarType kBits8 = ScalarType::Bits8; -constexpr ScalarType kBits16 = ScalarType::Bits16; -constexpr ScalarType kFloat8_e5m2fnuz = ScalarType::Float8_e5m2fnuz; -constexpr ScalarType kFloat8_e4m3fnuz = ScalarType::Float8_e4m3fnuz; -constexpr ScalarType kFloat8_e8m0fnu = ScalarType::Float8_e8m0fnu; -constexpr ScalarType kFloat4_e2m1fn_x2 = ScalarType::Float4_e2m1fn_x2; constexpr ScalarType kUndefined = ScalarType::Undefined; #define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \ @@ -298,36 +309,6 @@ inline const char* toString(ScalarType t) { switch (t) { AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) - case ScalarType::ComplexHalf: - return "ComplexHalf"; - case ScalarType::QInt8: - return "QInt8"; - case ScalarType::QUInt8: - return "QUInt8"; - case ScalarType::QInt32: - return "QInt32"; - case ScalarType::QUInt4x2: - return "QUInt4x2"; - case ScalarType::QUInt2x4: - return "QUInt2x4"; - case ScalarType::Bits1x8: - return "Bits1x8"; - case ScalarType::Bits2x4: - return "Bits2x4"; - case ScalarType::Bits4x2: - return "Bits4x2"; - case ScalarType::Bits8: - return "Bits8"; - case ScalarType::Bits16: - return "Bits16"; - case ScalarType::Float8_e5m2fnuz: - return "Float8_e5m2fnuz"; - case ScalarType::Float8_e4m3fnuz: - return "Float8_e4m3fnuz"; - case ScalarType::Float8_e8m0fnu: - return "Float8_e8m0fnu"; - case ScalarType::Float4_e2m1fn_x2: - return "Float4_e2m1fn_x2"; case ScalarType::Undefined: return "Undefined"; default: @@ -343,25 +324,6 @@ inline size_t elementSize(ScalarType t) { switch (t) { AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE) - case ScalarType::ComplexHalf: - return sizeof(at::Half) * 2; - case ScalarType::QInt8: - case ScalarType::QUInt8: - case ScalarType::QUInt4x2: - case ScalarType::QUInt2x4: - case ScalarType::Bits1x8: - case ScalarType::Bits2x4: - case ScalarType::Bits4x2: - case ScalarType::Bits8: - case ScalarType::Float8_e5m2fnuz: - case ScalarType::Float8_e4m3fnuz: - case ScalarType::Float8_e8m0fnu: - case ScalarType::Float4_e2m1fn_x2: - return 1; - case ScalarType::QInt32: - return 4; - case ScalarType::Bits16: - return 2; default: TORCH_CHECK(false, "Unknown ScalarType"); } diff --git a/paddle/phi/api/include/compat/c10/util/Float4_e2m1fn_x2.h b/paddle/phi/api/include/compat/c10/util/Float4_e2m1fn_x2.h new file mode 100644 index 00000000000000..01b73c873eeb81 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/Float4_e2m1fn_x2.h @@ -0,0 +1,65 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The file has been adapted from pytorch project +// Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include + +namespace c10 { + +/// Defines the Float4_e2m1fn_x2 type (4-bit floating-point, two elements packed +/// into one byte). This is the FP4 dtype from the OCP MX format spec +/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, +/// Section 5.3.3) +/// +/// Given two high precision values val0 and val1, here is the +/// binary configuration of their packed representation, from MSB to LSB: +/// +/// original value | val1 : val0 +/// ======================================== +/// bit index (MSB==7, LSB==0) | 7654 : 3210 +/// sign/exponent/mantissa | seem : seem + +struct alignas(1) Float4_e2m1fn_x2 { + uint8_t val_; + Float4_e2m1fn_x2() = default; + explicit constexpr Float4_e2m1fn_x2(uint8_t val) : val_(val) {} +}; + +/// Comparison operators +inline bool operator==(const Float4_e2m1fn_x2& a, const Float4_e2m1fn_x2& b) { + return a.val_ == b.val_; +} + +inline bool operator!=(const Float4_e2m1fn_x2& a, const Float4_e2m1fn_x2& b) { + return a.val_ != b.val_; +} + +} // namespace c10 + +namespace at { +using c10::Float4_e2m1fn_x2; +using c10::operator!=; +using c10::operator==; +} // namespace at + +namespace torch { +using c10::Float4_e2m1fn_x2; +using c10::operator!=; +using c10::operator==; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/util/Float8_e4m3fnuz.h b/paddle/phi/api/include/compat/c10/util/Float8_e4m3fnuz.h new file mode 100644 index 00000000000000..4ac3aa8c4716e2 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/Float8_e4m3fnuz.h @@ -0,0 +1,40 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The file has been adapted from pytorch project +// Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include + +namespace c10 { + +struct Float8_e4m3fnuz { + constexpr Float8_e4m3fnuz() = default; + explicit constexpr Float8_e4m3fnuz(uint8_t value) : x(value) {} + + uint8_t x{0}; +}; + +} // namespace c10 + +namespace at { +using c10::Float8_e4m3fnuz; +} // namespace at + +namespace torch { +using c10::Float8_e4m3fnuz; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/util/Float8_e5m2fnuz.h b/paddle/phi/api/include/compat/c10/util/Float8_e5m2fnuz.h new file mode 100644 index 00000000000000..e04e6fe28ae7f5 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/Float8_e5m2fnuz.h @@ -0,0 +1,40 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The file has been adapted from pytorch project +// Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include + +namespace c10 { + +struct Float8_e5m2fnuz { + constexpr Float8_e5m2fnuz() = default; + explicit constexpr Float8_e5m2fnuz(uint8_t value) : x(value) {} + + uint8_t x{0}; +}; + +} // namespace c10 + +namespace at { +using c10::Float8_e5m2fnuz; +} // namespace at + +namespace torch { +using c10::Float8_e5m2fnuz; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/util/Float8_e8m0fnu.h b/paddle/phi/api/include/compat/c10/util/Float8_e8m0fnu.h new file mode 100644 index 00000000000000..97ecd5100362cf --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/Float8_e8m0fnu.h @@ -0,0 +1,40 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The file has been adapted from pytorch project +// Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include + +namespace c10 { + +struct Float8_e8m0fnu { + constexpr Float8_e8m0fnu() = default; + explicit constexpr Float8_e8m0fnu(uint8_t value) : x(value) {} + + uint8_t x{0}; +}; + +} // namespace c10 + +namespace at { +using c10::Float8_e8m0fnu; +} // namespace at + +namespace torch { +using c10::Float8_e8m0fnu; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/util/bits.h b/paddle/phi/api/include/compat/c10/util/bits.h new file mode 100644 index 00000000000000..8b3e0ae9843cb7 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/bits.h @@ -0,0 +1,76 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The file has been adapted from pytorch project +// Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include + +namespace c10 { + +struct bits1x8 { + constexpr bits1x8() = default; + explicit constexpr bits1x8(uint8_t value) : val_(value) {} + + uint8_t val_{0}; +}; + +struct bits2x4 { + constexpr bits2x4() = default; + explicit constexpr bits2x4(uint8_t value) : val_(value) {} + + uint8_t val_{0}; +}; + +struct bits4x2 { + constexpr bits4x2() = default; + explicit constexpr bits4x2(uint8_t value) : val_(value) {} + + uint8_t val_{0}; +}; + +struct bits8 { + constexpr bits8() = default; + explicit constexpr bits8(uint8_t value) : val_(value) {} + + uint8_t val_{0}; +}; + +struct bits16 { + constexpr bits16() = default; + explicit constexpr bits16(uint16_t value) : val_(value) {} + + uint16_t val_{0}; +}; + +} // namespace c10 + +namespace at { +using c10::bits16; +using c10::bits1x8; +using c10::bits2x4; +using c10::bits4x2; +using c10::bits8; +} // namespace at + +namespace torch { +using c10::bits16; +using c10::bits1x8; +using c10::bits2x4; +using c10::bits4x2; +using c10::bits8; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/util/qint32.h b/paddle/phi/api/include/compat/c10/util/qint32.h new file mode 100644 index 00000000000000..02ede07de30a81 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/qint32.h @@ -0,0 +1,43 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The file has been adapted from pytorch project +// Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include + +namespace c10 { + +/** + * qint32 is for signed 32 bit quantized Tensors + */ +struct alignas(4) qint32 { + using underlying = int32_t; + int32_t val_; + qint32() = default; + explicit constexpr qint32(int32_t val) : val_(val) {} +}; + +} // namespace c10 + +namespace at { +using c10::qint32; +} // namespace at + +namespace torch { +using c10::qint32; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/util/qint8.h b/paddle/phi/api/include/compat/c10/util/qint8.h new file mode 100644 index 00000000000000..17e25d868e63d5 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/qint8.h @@ -0,0 +1,45 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The file has been adapted from pytorch project +// Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include + +namespace c10 { + +/** + * This is the data type for quantized Tensors. Right now we only have + * qint8 which is for 8 bit Tensors, and qint32 for 32 bit int Tensors, + * we might have 4 bit, 2 bit or 1 bit data types in the future. + */ +struct alignas(1) qint8 { + using underlying = int8_t; + int8_t val_; + qint8() = default; + explicit constexpr qint8(int8_t val) : val_(val) {} +}; + +} // namespace c10 + +namespace at { +using c10::qint8; +} // namespace at + +namespace torch { +using c10::qint8; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/util/quint2x4.h b/paddle/phi/api/include/compat/c10/util/quint2x4.h new file mode 100644 index 00000000000000..ee0dd619642e21 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/quint2x4.h @@ -0,0 +1,44 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The file has been adapted from pytorch project +// Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include + +namespace c10 { + +/** + * quint2x4 is for un-signed 2 bit quantized Tensors that are packed to byte + * boundary. + */ +struct alignas(1) quint2x4 { + using underlying = uint8_t; + uint8_t val_; + quint2x4() = default; + explicit constexpr quint2x4(uint8_t val) : val_(val) {} +}; + +} // namespace c10 + +namespace at { +using c10::quint2x4; +} // namespace at + +namespace torch { +using c10::quint2x4; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/util/quint4x2.h b/paddle/phi/api/include/compat/c10/util/quint4x2.h new file mode 100644 index 00000000000000..a9af00de07efe6 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/quint4x2.h @@ -0,0 +1,44 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The file has been adapted from pytorch project +// Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include + +namespace c10 { + +/** + * quint4x2 is for un-signed 4 bit quantized Tensors that are packed to byte + * boundary. + */ +struct alignas(1) quint4x2 { + using underlying = uint8_t; + uint8_t val_; + quint4x2() = default; + explicit constexpr quint4x2(uint8_t val) : val_(val) {} +}; + +} // namespace c10 + +namespace at { +using c10::quint4x2; +} // namespace at + +namespace torch { +using c10::quint4x2; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/util/quint8.h b/paddle/phi/api/include/compat/c10/util/quint8.h new file mode 100644 index 00000000000000..780f42ba6a6378 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/quint8.h @@ -0,0 +1,43 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The file has been adapted from pytorch project +// Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include + +namespace c10 { + +/** + * quint8 is for unsigned 8 bit quantized Tensors + */ +struct alignas(1) quint8 { + using underlying = uint8_t; + uint8_t val_; + quint8() = default; + explicit constexpr quint8(uint8_t val) : val_(val) {} +}; + +} // namespace c10 + +namespace at { +using c10::quint8; +} // namespace at + +namespace torch { +using c10::quint8; +} // namespace torch