|
4 | 4 |
|
5 | 5 | #include <cmath> |
6 | 6 | #include "core/framework/print_tensor_utils.h" |
| 7 | +#include "core/framework/int2.h" |
7 | 8 |
|
8 | 9 | namespace onnxruntime { |
9 | 10 | namespace utils { |
@@ -94,36 +95,38 @@ void PrintCommonStats(const T* data, size_t count, TensorStatisticsData& tensor_ |
94 | 95 | } |
95 | 96 | } |
96 | 97 |
|
97 | | -#define DEF_PRINT_COMMON_STATS_4BIT(FOUR_BIT_TYPE) \ |
98 | | - template <> \ |
99 | | - inline void PrintCommonStats<FOUR_BIT_TYPE>( \ |
100 | | - const FOUR_BIT_TYPE* data, size_t count, TensorStatisticsData&) { \ |
101 | | - using UnpackedType = typename FOUR_BIT_TYPE::UnpackedType; \ |
102 | | - UnpackedType min = data[0].GetElem(0); \ |
103 | | - UnpackedType max = min; \ |
104 | | - for (size_t i = 1; i < count; i++) { \ |
105 | | - auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(i); \ |
106 | | - auto value = data[indices.first].GetElem(indices.second); \ |
107 | | - if (value > max) { \ |
108 | | - max = value; \ |
109 | | - } \ |
110 | | - if (value < min) { \ |
111 | | - min = value; \ |
112 | | - } \ |
113 | | - } \ |
114 | | - \ |
115 | | - std::cout << "Min="; \ |
116 | | - PrintValue(min); \ |
117 | | - \ |
118 | | - std::cout << ",Max="; \ |
119 | | - PrintValue(max); \ |
| 98 | +#define DEF_PRINT_COMMON_STATS_PACKED(PACKED_TYPE) \ |
| 99 | + template <> \ |
| 100 | + inline void PrintCommonStats<PACKED_TYPE>( \ |
| 101 | + const PACKED_TYPE* data, size_t count, TensorStatisticsData&) { \ |
| 102 | + using UnpackedType = typename PACKED_TYPE::UnpackedType; \ |
| 103 | + UnpackedType min = data[0].GetElem(0); \ |
| 104 | + UnpackedType max = min; \ |
| 105 | + for (size_t i = 1; i < count; i++) { \ |
| 106 | + auto indices = PACKED_TYPE::GetTensorElemIndices(i); \ |
| 107 | + auto value = data[indices.first].GetElem(indices.second); \ |
| 108 | + if (value > max) { \ |
| 109 | + max = value; \ |
| 110 | + } \ |
| 111 | + if (value < min) { \ |
| 112 | + min = value; \ |
| 113 | + } \ |
| 114 | + } \ |
| 115 | + \ |
| 116 | + std::cout << "Min="; \ |
| 117 | + PrintValue(min); \ |
| 118 | + \ |
| 119 | + std::cout << ",Max="; \ |
| 120 | + PrintValue(max); \ |
120 | 121 | } |
121 | 122 |
|
122 | | -DEF_PRINT_COMMON_STATS_4BIT(Int4x2) |
123 | | -DEF_PRINT_COMMON_STATS_4BIT(UInt4x2) |
| 123 | +DEF_PRINT_COMMON_STATS_PACKED(Int4x2) |
| 124 | +DEF_PRINT_COMMON_STATS_PACKED(UInt4x2) |
124 | 125 | #if !defined(DISABLE_FLOAT4_TYPES) |
125 | | -DEF_PRINT_COMMON_STATS_4BIT(Float4E2M1x2) |
| 126 | +DEF_PRINT_COMMON_STATS_PACKED(Float4E2M1x2) |
126 | 127 | #endif |
| 128 | +DEF_PRINT_COMMON_STATS_PACKED(Int2x4) |
| 129 | +DEF_PRINT_COMMON_STATS_PACKED(UInt2x4) |
127 | 130 |
|
128 | 131 | template <typename T> |
129 | 132 | void PrintHalfStats(const T* data, size_t count) { |
|
0 commit comments