Skip to content

Commit 8a7a797

Browse files
authored
Print tensor for new packed type of 2 bits (#27064)
### Description To fix a build error for dump node inputs and outputs build option.
1 parent 39f966e commit 8a7a797

File tree

2 files changed

+123
-111
lines changed

2 files changed

+123
-111
lines changed

onnxruntime/core/framework/print_tensor_statistics_utils.h

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <cmath>
66
#include "core/framework/print_tensor_utils.h"
7+
#include "core/framework/int2.h"
78

89
namespace onnxruntime {
910
namespace utils {
@@ -94,36 +95,38 @@ void PrintCommonStats(const T* data, size_t count, TensorStatisticsData& tensor_
9495
}
9596
}
9697

97-
#define DEF_PRINT_COMMON_STATS_4BIT(FOUR_BIT_TYPE) \
98-
template <> \
99-
inline void PrintCommonStats<FOUR_BIT_TYPE>( \
100-
const FOUR_BIT_TYPE* data, size_t count, TensorStatisticsData&) { \
101-
using UnpackedType = typename FOUR_BIT_TYPE::UnpackedType; \
102-
UnpackedType min = data[0].GetElem(0); \
103-
UnpackedType max = min; \
104-
for (size_t i = 1; i < count; i++) { \
105-
auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(i); \
106-
auto value = data[indices.first].GetElem(indices.second); \
107-
if (value > max) { \
108-
max = value; \
109-
} \
110-
if (value < min) { \
111-
min = value; \
112-
} \
113-
} \
114-
\
115-
std::cout << "Min="; \
116-
PrintValue(min); \
117-
\
118-
std::cout << ",Max="; \
119-
PrintValue(max); \
98+
#define DEF_PRINT_COMMON_STATS_PACKED(PACKED_TYPE) \
99+
template <> \
100+
inline void PrintCommonStats<PACKED_TYPE>( \
101+
const PACKED_TYPE* data, size_t count, TensorStatisticsData&) { \
102+
using UnpackedType = typename PACKED_TYPE::UnpackedType; \
103+
UnpackedType min = data[0].GetElem(0); \
104+
UnpackedType max = min; \
105+
for (size_t i = 1; i < count; i++) { \
106+
auto indices = PACKED_TYPE::GetTensorElemIndices(i); \
107+
auto value = data[indices.first].GetElem(indices.second); \
108+
if (value > max) { \
109+
max = value; \
110+
} \
111+
if (value < min) { \
112+
min = value; \
113+
} \
114+
} \
115+
\
116+
std::cout << "Min="; \
117+
PrintValue(min); \
118+
\
119+
std::cout << ",Max="; \
120+
PrintValue(max); \
120121
}
121122

122-
DEF_PRINT_COMMON_STATS_4BIT(Int4x2)
123-
DEF_PRINT_COMMON_STATS_4BIT(UInt4x2)
123+
DEF_PRINT_COMMON_STATS_PACKED(Int4x2)
124+
DEF_PRINT_COMMON_STATS_PACKED(UInt4x2)
124125
#if !defined(DISABLE_FLOAT4_TYPES)
125-
DEF_PRINT_COMMON_STATS_4BIT(Float4E2M1x2)
126+
DEF_PRINT_COMMON_STATS_PACKED(Float4E2M1x2)
126127
#endif
128+
DEF_PRINT_COMMON_STATS_PACKED(Int2x4)
129+
DEF_PRINT_COMMON_STATS_PACKED(UInt2x4)
127130

128131
template <typename T>
129132
void PrintHalfStats(const T* data, size_t count) {

0 commit comments

Comments
 (0)