|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#pragma once |
| 5 | + |
| 6 | +#include <cassert> |
| 7 | +#include <type_traits> |
| 8 | +#include "core/common/common.h" |
| 9 | +#include <gsl/gsl> |
| 10 | + |
| 11 | +namespace onnxruntime { |
| 12 | + |
| 13 | +template <bool Signed> |
| 14 | +struct Int2Traits; |
| 15 | + |
| 16 | +template <> |
| 17 | +struct Int2Traits<true> { |
| 18 | + using UnpackedType = int8_t; |
| 19 | + static constexpr int8_t min_val = -2; |
| 20 | + static constexpr int8_t max_val = 1; |
| 21 | +}; |
| 22 | + |
| 23 | +template <> |
| 24 | +struct Int2Traits<false> { |
| 25 | + using UnpackedType = uint8_t; |
| 26 | + static constexpr uint8_t min_val = 0; |
| 27 | + static constexpr uint8_t max_val = 3; |
| 28 | +}; |
| 29 | + |
| 30 | +/// <summary> |
| 31 | +/// Stores 4 packed 2-bit elements in 1 byte. |
| 32 | +/// Packing follows ONNX spec: x0 | (x1 << 2) | (x2 << 4) | (x3 << 6) |
| 33 | +/// </summary> |
| 34 | +/// <typeparam name="Signed">Set to true if signed int2, or false if unsigned uint2.</typeparam> |
| 35 | +template <bool Signed> |
| 36 | +struct Int2x4Base { |
| 37 | + using UnpackedType = typename Int2Traits<Signed>::UnpackedType; |
| 38 | + static constexpr UnpackedType min_val = Int2Traits<Signed>::min_val; |
| 39 | + static constexpr UnpackedType max_val = Int2Traits<Signed>::max_val; |
| 40 | + |
| 41 | + std::byte bits_{}; |
| 42 | + |
| 43 | + Int2x4Base() = default; |
| 44 | + |
| 45 | + explicit Int2x4Base(std::byte bits) { |
| 46 | + bits_ = bits; |
| 47 | + } |
| 48 | + |
| 49 | + Int2x4Base(UnpackedType val0, UnpackedType val1, UnpackedType val2, UnpackedType val3) { |
| 50 | + bits_ = static_cast<std::byte>( |
| 51 | + (val0 & 0x3) | |
| 52 | + ((val1 & 0x3) << 2) | |
| 53 | + ((val2 & 0x3) << 4) | |
| 54 | + ((val3 & 0x3) << 6)); |
| 55 | + } |
| 56 | + |
| 57 | + static inline int8_t SignExtendLower2Bits(std::byte bits) { |
| 58 | + // Sign-extend lower 2-bits by left shifting and then doing an arithmetic right shift. |
| 59 | + constexpr uint8_t shift = (sizeof(int32_t) * 8) - 2; |
| 60 | + return static_cast<int8_t>((static_cast<int32_t>(bits) << shift) >> shift); |
| 61 | + } |
| 62 | + |
| 63 | + inline UnpackedType GetElem(size_t index) const { |
| 64 | + assert(index <= 3); |
| 65 | + const uint8_t shift = 2 * static_cast<uint8_t>(index); |
| 66 | + const std::byte val = (bits_ >> shift) & std::byte{0x3}; |
| 67 | + |
| 68 | + if constexpr (Signed) { |
| 69 | + return SignExtendLower2Bits(val); |
| 70 | + } else { |
| 71 | + return static_cast<UnpackedType>(val); |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + inline void SetElem(size_t index, UnpackedType val) { |
| 76 | + assert(index <= 3); |
| 77 | + const uint8_t shift = 2 * static_cast<uint8_t>(index); |
| 78 | + const std::byte clear_mask = ~(std::byte{0x3} << shift); |
| 79 | + |
| 80 | + bits_ &= clear_mask; // Clear 2-bit element to 0 |
| 81 | + bits_ |= static_cast<std::byte>((val & 0x3) << shift); // Set 2-bit element to val |
| 82 | + } |
| 83 | + |
| 84 | + inline std::byte ToBits() const { |
| 85 | + return bits_; |
| 86 | + } |
| 87 | + |
| 88 | + /// <summary> |
| 89 | + /// Calculates the number of packed byte units needed to store the given number of 2-bit elements. |
| 90 | + /// Each byte stores 4 x 2-bit elements. |
| 91 | + /// </summary> |
| 92 | + static size_t CalcNumInt2Quads(size_t num_int2_elems) { |
| 93 | + return (num_int2_elems + 3) / 4; |
| 94 | + } |
| 95 | + |
| 96 | + /// <summary> |
| 97 | + /// Copy a source buffer of 2-bit elements (packed) into a destination buffer of 8-bit elements (unpacked). |
| 98 | + /// </summary> |
| 99 | + /// <param name="dst">Destination buffer to store unpacked 8-bit elements</param> |
| 100 | + /// <param name="src">Source buffer with 2-bit elements</param> |
| 101 | + /// <returns>True on success</returns> |
| 102 | + static bool Unpack(gsl::span<UnpackedType> dst, gsl::span<const Int2x4Base<Signed>> src) { |
| 103 | + if (CalcNumInt2Quads(dst.size()) != src.size()) { |
| 104 | + return false; |
| 105 | + } |
| 106 | + |
| 107 | + if (src.empty()) { |
| 108 | + return true; |
| 109 | + } |
| 110 | + |
| 111 | + for (size_t i = 0; i < dst.size(); i++) { |
| 112 | + size_t byte_idx = i >> 2; // i / 4 |
| 113 | + size_t elem_idx = i & 0x3; // i % 4 |
| 114 | + dst[i] = src[byte_idx].GetElem(elem_idx); |
| 115 | + } |
| 116 | + |
| 117 | + return true; |
| 118 | + } |
| 119 | + |
| 120 | + /// <summary> |
| 121 | + /// Copy a source buffer of 8-bit elements (unpacked) into a destination buffer of 2-bit elements (packed). |
| 122 | + /// </summary> |
| 123 | + /// <param name="dst">Destination buffer to store packed 2-bit elements</param> |
| 124 | + /// <param name="src">Source buffer with 8-bit elements</param> |
| 125 | + /// <returns>True on success</returns> |
| 126 | + static bool Pack(gsl::span<Int2x4Base<Signed>> dst, gsl::span<const UnpackedType> src) { |
| 127 | + if (CalcNumInt2Quads(src.size()) != dst.size()) { |
| 128 | + return false; |
| 129 | + } |
| 130 | + |
| 131 | + if (src.empty()) { |
| 132 | + return true; |
| 133 | + } |
| 134 | + |
| 135 | + size_t src_i = 0; |
| 136 | + size_t dst_i = 0; |
| 137 | + const size_t full_quads = src.size() / 4; |
| 138 | + |
| 139 | + // Process complete groups of 4 elements |
| 140 | + for (; dst_i < full_quads; dst_i++) { |
| 141 | + dst[dst_i] = Int2x4Base<Signed>(src[src_i], src[src_i + 1], src[src_i + 2], src[src_i + 3]); |
| 142 | + src_i += 4; |
| 143 | + } |
| 144 | + |
| 145 | + // Handle remaining elements (1-3) |
| 146 | + if (src_i < src.size()) { |
| 147 | + UnpackedType vals[4] = {0, 0, 0, 0}; |
| 148 | + size_t remaining = src.size() - src_i; |
| 149 | + for (size_t j = 0; j < remaining; j++) { |
| 150 | + vals[j] = src[src_i + j]; |
| 151 | + } |
| 152 | + dst[dst_i] = Int2x4Base<Signed>(vals[0], vals[1], vals[2], vals[3]); |
| 153 | + } |
| 154 | + |
| 155 | + return true; |
| 156 | + } |
| 157 | + |
| 158 | + /// <summary> |
| 159 | + /// Returns hierarchical indices for a packed int2 element from the given element index. |
| 160 | + /// |
| 161 | + /// Usage: |
| 162 | + /// Int2x4* data = ...; |
| 163 | + /// auto indices = GetTensorElemIndices(5); // 6th int2 element |
| 164 | + /// int8_t elem = data[indices.first].GetElem(indices.second); |
| 165 | + /// </summary> |
| 166 | + /// <param name="index">Index of 2-bit element</param> |
| 167 | + /// <returns>Pair of (byte_index, element_index_within_byte)</returns> |
| 168 | + static inline std::pair<size_t, size_t> GetTensorElemIndices(size_t index) { |
| 169 | + return {index >> 2, index & 0x3}; |
| 170 | + } |
| 171 | +}; |
| 172 | + |
| 173 | +using Int2x4 = Int2x4Base<true>; |
| 174 | +using UInt2x4 = Int2x4Base<false>; |
| 175 | +static_assert(sizeof(Int2x4) == sizeof(std::byte)); |
| 176 | +static_assert(sizeof(UInt2x4) == sizeof(std::byte)); |
| 177 | + |
| 178 | +} // namespace onnxruntime |
0 commit comments