Skip to content

Commit 95b13c3

Browse files
dsharletgxnnpack-bot
authored andcommitted
Rewrite fully connected and batch matrix multiply tests
- Independent of operator API - Uncovered bugs in bf16 packing - Add error check for XNN_FLAG_TRANSPOSE_WEIGHTS in qc4w, which currently just crashes via calling a NULL function pointer (not obvious even under sanitizers). - This reveals an issue, F32-QC4W and QD8-F32-QC4W actually have different weights formats (int4 vs. uint4). PiperOrigin-RevId: 743573333
1 parent 2688239 commit 95b13c3

File tree

8 files changed

+1046
-5421
lines changed

8 files changed

+1046
-5421
lines changed

src/datatype.c

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,60 @@ bool xnn_datatype_is_quantized(enum xnn_datatype t) {
8888
return false;
8989
}
9090

91+
bool xnn_datatype_is_channelwise_quantized(enum xnn_datatype t) {
92+
switch (t) {
93+
case xnn_datatype_qcint8:
94+
case xnn_datatype_qcint32:
95+
case xnn_datatype_qcint4:
96+
return true;
97+
case xnn_datatype_invalid:
98+
case xnn_datatype_qint8:
99+
case xnn_datatype_pqint8:
100+
case xnn_datatype_quint8:
101+
case xnn_datatype_qint32:
102+
case xnn_datatype_qdint8:
103+
case xnn_datatype_qduint8:
104+
case xnn_datatype_qpint8:
105+
case xnn_datatype_qbint4:
106+
case xnn_datatype_fp32:
107+
case xnn_datatype_fp16:
108+
case xnn_datatype_bf16:
109+
case xnn_datatype_int32:
110+
case xnn_datatype_pfp16:
111+
case xnn_datatype_pfp32:
112+
return false;
113+
}
114+
XNN_UNREACHABLE;
115+
return false;
116+
}
117+
118+
bool xnn_datatype_is_blockwise_quantized(enum xnn_datatype t) {
119+
switch (t) {
120+
case xnn_datatype_qbint4:
121+
return true;
122+
case xnn_datatype_invalid:
123+
case xnn_datatype_qint8:
124+
case xnn_datatype_pqint8:
125+
case xnn_datatype_quint8:
126+
case xnn_datatype_qint32:
127+
case xnn_datatype_qcint8:
128+
case xnn_datatype_qcint32:
129+
case xnn_datatype_qcint4:
130+
case xnn_datatype_qdint8:
131+
case xnn_datatype_qduint8:
132+
case xnn_datatype_qpint8:
133+
case xnn_datatype_fp32:
134+
case xnn_datatype_fp16:
135+
case xnn_datatype_bf16:
136+
case xnn_datatype_int32:
137+
case xnn_datatype_pfp16:
138+
case xnn_datatype_pfp32:
139+
return false;
140+
}
141+
XNN_UNREACHABLE;
142+
return false;
143+
}
144+
91145

92146
size_t xnn_datatype_log2_size_bits(enum xnn_datatype t) {
93147
switch (t) {

src/operators/fully-connected-nc.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1826,6 +1826,13 @@ enum xnn_status xnn_create_fully_connected_nc_f32_qc4w(
18261826
return xnn_status_invalid_parameter;
18271827
}
18281828

1829+
if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
1830+
xnn_log_error(
1831+
"failed to create %s operator with XNN_FLAG_TRANSPOSE_WEIGHTS: not supported",
1832+
xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32_qc4w));
1833+
return xnn_status_unsupported_parameter;
1834+
}
1835+
18291836
for (size_t output_channel = 0; output_channel < output_channels; output_channel++) {
18301837
if (kernel_scale[output_channel] <= 0.0f || !isnormal(kernel_scale[output_channel])) {
18311838
xnn_log_error(

src/xnnpack/buffer.h

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ class NumericLimits<xnn_bfloat16> {
7777
static xnn_bfloat16 max_identity() { return -infinity(); }
7878
};
7979

80-
template <typename T>
81-
class NumericLimits<quantized<T>> {
80+
template <typename T, typename Kind>
81+
class NumericLimits<quantized<T, Kind>> {
8282
public:
8383
static quantized<T> min() { return {std::numeric_limits<T>::lowest()}; }
8484
static quantized<T> max() { return {std::numeric_limits<T>::max()}; }
@@ -359,7 +359,7 @@ class Tensor {
359359
}
360360
size_t size() const {
361361
assert(is_contiguous());
362-
return data_->size();
362+
return end_ - begin_;
363363
}
364364
T* begin() { return data(); }
365365
T* end() { return end_; }
@@ -374,12 +374,18 @@ class Tensor {
374374
// Tensor, they do not affect the memory addressed by the Tensor. To realize
375375
// the effect of these operations, make a copy with `deep_copy`.
376376

377-
// Reorder the dimensions to extents = {extent(i) for i in perm}, and similar
378-
// for strides.
379-
Tensor<T, Alignment> transpose(const std::vector<size_t>& perm) const {
377+
// Reorder the dimensions in `dims`. Dimensions not in dims maintain their
378+
// relative ordering.
379+
Tensor<T, Alignment> transpose(std::vector<size_t> perm) const {
380+
// Sort idx to get the new locations
381+
std::vector<size_t> sorted = perm;
382+
std::sort(sorted.begin(), sorted.end());
383+
380384
Tensor<T, Alignment> result(*this);
381-
result.extents_ = permute(perm, extents_);
382-
result.strides_ = permute(perm, strides_);
385+
for (size_t i = 0; i < sorted.size(); i++) {
386+
result.extents_[sorted[i]] = extent(perm[i]);
387+
result.strides_[sorted[i]] = stride(perm[i]);
388+
}
383389
return result;
384390
}
385391

@@ -429,14 +435,16 @@ class Tensor {
429435

430436
Tensor<T, Alignment> result(*this);
431437
std::vector<size_t> offsets(rank());
438+
std::vector<size_t> maxs(rank());
432439
for (size_t i = 0; i < rank(); ++i) {
433440
offsets[i] = begins[i] < 0 ? extents_[i] + begins[i] : begins[i];
434441
result.extents_[i] =
435442
(ends[i] <= 0 ? extents_[i] + ends[i] : ends[i]) - offsets[i];
443+
maxs[i] = result.extents_[i] - 1;
436444
}
437445

438446
result.begin_ = begin_ + flat_offset(offsets);
439-
result.end_ = result.begin_ + result.flat_offset(result.extents_);
447+
result.end_ = result.begin_ + result.flat_offset(maxs) + 1;
440448

441449
return result;
442450
}
@@ -460,6 +468,18 @@ class Tensor {
460468
return slice(dim, at, at + 1);
461469
}
462470

471+
// Slice the leading dimensions at the indices of `at`.
472+
Tensor<T, Alignment> slice_leading(std::vector<size_t> at) const {
473+
std::vector<int64_t> begins(rank());
474+
std::vector<int64_t> ends(rank());
475+
std::copy(at.begin(), at.end(), begins.begin());
476+
std::copy(at.begin(), at.end(), ends.begin());
477+
for (size_t i = 0; i < at.size(); ++i) {
478+
ends[i] += 1;
479+
}
480+
return slice(begins, ends);
481+
}
482+
463483
// Remove `pre` elements from the beginning of each dimension, and `post`
464484
// elements from the end of each dimension.
465485
Tensor<T, Alignment> crop_padding(const index_type& pre,
@@ -730,7 +750,9 @@ xnn_quantization_params random_quantization(xnn_datatype datatype, Rng& rng,
730750
std::uniform_real_distribution<float> scale_dist{min_scale, max_scale};
731751
switch (datatype) {
732752
case xnn_datatype_qint8:
733-
// int8 quantization assumes zero point is 0.
753+
case xnn_datatype_qcint8:
754+
case xnn_datatype_qcint4:
755+
// signed integer quantization assumes zero point is 0.
734756
return {0, scale_dist(rng)};
735757
case xnn_datatype_quint8:
736758
return {u8_dist(rng), scale_dist(rng)};

src/xnnpack/datatype.h

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ bool xnn_datatype_is_integral(enum xnn_datatype t);
2828

2929
// Returns true if the datatype is a quantized real datatype.
3030
bool xnn_datatype_is_quantized(enum xnn_datatype t);
31+
bool xnn_datatype_is_channelwise_quantized(enum xnn_datatype t);
32+
bool xnn_datatype_is_blockwise_quantized(enum xnn_datatype t);
3133

3234
// Returns the size of an element of the datatype.
3335
size_t xnn_datatype_log2_size_bits(enum xnn_datatype t);
@@ -46,18 +48,22 @@ const char* xnn_datatype_to_string(enum xnn_datatype type);
4648

4749
namespace xnnpack {
4850

51+
struct channelwise {};
52+
4953
// We need a type that distinguishes an intX_t from a quantized intX_t. We can't
5054
// do arithmetic on these, because we don't know the quantization parameters.
51-
template <typename T>
55+
template <typename T, typename Kind = void>
5256
struct quantized {
5357
T value;
5458
using type = T;
5559

5660
operator T() const { return value; }
61+
// Forward operator[] in case T is a sub-byte packed value.
62+
auto operator[](size_t i) const { return value[i]; }
5763

5864
quantized() = default;
5965
quantized(T t) : value(t) {}
60-
quantized<T>& operator=(T t) {
66+
quantized<T, Kind>& operator=(T t) {
6167
value = t;
6268
return *this;
6369
}
@@ -66,26 +72,26 @@ struct quantized {
6672
template <typename T>
6773
struct is_quantized : std::false_type {};
6874

69-
template <typename T>
70-
struct is_quantized<quantized<T>> : std::true_type {};
75+
template <typename T, typename Kind>
76+
struct is_quantized<quantized<T, Kind>> : std::true_type {};
7177

7278
template <typename T>
7379
struct unwrap_quantized {
7480
using type = T;
7581
};
7682

77-
template <>
78-
struct unwrap_quantized<quantized<int8_t>> {
83+
template <typename Kind>
84+
struct unwrap_quantized<quantized<int8_t, Kind>> {
7985
using type = int8_t;
8086
};
8187

82-
template <>
83-
struct unwrap_quantized<quantized<uint8_t>> {
88+
template <typename Kind>
89+
struct unwrap_quantized<quantized<uint8_t, Kind>> {
8490
using type = uint8_t;
8591
};
8692

87-
template <>
88-
struct unwrap_quantized<quantized<int32_t>> {
93+
template <typename Kind>
94+
struct unwrap_quantized<quantized<int32_t, Kind>> {
8995
using type = int32_t;
9096
};
9197

@@ -97,6 +103,9 @@ xnn_datatype xnn_datatype_of() {
97103
return xnn_datatype_quint8;
98104
} else if (std::is_same<T, xnnpack::quantized<int8_t>>::value) {
99105
return xnn_datatype_qint8;
106+
} else if (std::is_same<
107+
T, xnnpack::quantized<int8_t, xnnpack::channelwise>>::value) {
108+
return xnn_datatype_qcint8;
100109
} else if (std::is_same<T, xnnpack::quantized<int32_t>>::value) {
101110
return xnn_datatype_qint32;
102111
} else if (std::is_same<T, xnn_float16>::value) {

0 commit comments

Comments
 (0)