Skip to content

Commit f3e95d0

Browse files
committed
updated to support int8 asymmetric weight quantization
1 parent 060fb11 commit f3e95d0

File tree

4 files changed

+32
-14
lines changed

4 files changed

+32
-14
lines changed

src/plugins/intel_gpu/src/graph/impls/onednn/moe_gemm_onednn.cpp

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,18 @@ struct moe_gemm_onednn : typed_primitive_onednn_impl<moe_gemm> {
6565
dnnl::memory::dim d0 = wei_scales_shape[0];
6666
dnnl::memory::dim d1 = wei_scales_shape[1];
6767
dnnl::memory::dim d2 = wei_scales_shape[2];
68+
dnnl::memory::dims wei_scales_dims = (moe_cfg.weight_group_size == -1) ? dnnl::memory::dims{d0, d2} : dnnl::memory::dims{d0, d1, d2};
69+
dnnl::memory::format_tag wei_scales_fmt = (moe_cfg.weight_group_size == -1) ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::abc;
6870
dnnl::memory::desc wei_scales_md(
69-
{d0, d1, d2}, convert_data_type(wei_scales.get_layout().data_type), dnnl::memory::format_tag::abc);
71+
wei_scales_dims, convert_data_type(wei_scales.get_layout().data_type), wei_scales_fmt);
7072
dnnl::memory wei_scales_mem = dnnl::ocl_interop::make_memory(wei_scales_md, onednn_engine, dnnl::ocl_interop::memory_kind::usm,
7173
reinterpret_cast<uint8_t*>(wei_scales.buffer_ptr()));
7274
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_mem});
7375

7476
if (!moe_cfg.is_weight_symmetric_quantized) {
7577
auto& wei_zp = instance.input_memory(moe_cfg.weight_zp_idx);
7678
dnnl::memory::desc wei_zp_md(
77-
{d0, d1, d2}, convert_data_type(wei_zp.get_layout().data_type), dnnl::memory::format_tag::abc);
79+
wei_scales_dims, convert_data_type(wei_zp.get_layout().data_type), wei_scales_fmt);
7880
dnnl::memory wei_zp_mem = dnnl::ocl_interop::make_memory(wei_zp_md, onednn_engine, dnnl::ocl_interop::memory_kind::usm,
7981
reinterpret_cast<uint8_t*>(wei_zp.buffer_ptr()));
8082
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, wei_zp_mem});
@@ -157,16 +159,30 @@ struct moe_gemm_onednn : typed_primitive_onednn_impl<moe_gemm> {
157159
auto moe_cfg = MoEGemmImplementationManager::get_moe_cfg(impl_params);
158160

159161
if (moe_cfg.is_weight_quantized) {
160-
attr->set_scales(DNNL_ARG_WEIGHTS,
161-
(1 << 0) | (1 << 1) | (1 << 2),
162-
{moe_cfg.weight_group_size, 1},
163-
convert_data_type(impl_params.get_input_layout(moe_cfg.weight_scale_idx).data_type));
162+
if (moe_cfg.weight_group_size == -1) {
163+
attr->set_scales(DNNL_ARG_WEIGHTS,
164+
(1 << 0) | (1 << 2),
165+
{},
166+
convert_data_type(impl_params.get_input_layout(moe_cfg.weight_scale_idx).data_type));
167+
} else {
168+
attr->set_scales(DNNL_ARG_WEIGHTS,
169+
(1 << 0) | (1 << 1) | (1 << 2),
170+
{moe_cfg.weight_group_size, 1},
171+
convert_data_type(impl_params.get_input_layout(moe_cfg.weight_scale_idx).data_type));
172+
}
164173

165174
if (!moe_cfg.is_weight_symmetric_quantized) {
166-
attr->set_zero_points(DNNL_ARG_WEIGHTS,
167-
(1 << 0) | (1 << 1) | (1 << 2),
168-
{moe_cfg.weight_group_size, 1},
169-
convert_data_type(impl_params.get_input_layout(moe_cfg.weight_zp_idx).data_type));
175+
if (moe_cfg.weight_group_size == -1) {
176+
attr->set_zero_points(DNNL_ARG_WEIGHTS,
177+
(1 << 0) | (1 << 2),
178+
{},
179+
convert_data_type(impl_params.get_input_layout(moe_cfg.weight_zp_idx).data_type));
180+
} else {
181+
attr->set_zero_points(DNNL_ARG_WEIGHTS,
182+
(1 << 0) | (1 << 1) | (1 << 2),
183+
{moe_cfg.weight_group_size, 1},
184+
convert_data_type(impl_params.get_input_layout(moe_cfg.weight_zp_idx).data_type));
185+
}
170186
}
171187
}
172188

src/plugins/intel_gpu/src/graph/impls/onednn/moe_gemm_onednn.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ struct MoEGemmImplementationManager : public ImplementationManager {
160160
// weight scales : [#experts, num_groups, ofm, 1]
161161
auto scale_group_dim = 1;
162162
auto num_scale_groups = (weight_shape.size() == 4) ? params.input_layouts[moe_cfg.weight_scale_idx].get_shape()[scale_group_dim] : 1;
163-
moe_cfg.weight_group_size = k / num_scale_groups;
163+
moe_cfg.weight_group_size = (num_scale_groups == 1) ? -1 : (k / num_scale_groups);
164164
if (static_cast<int32_t>(params.input_layouts.size()) > moe_cfg.weight_zp_idx) {
165165
moe_cfg.is_weight_symmetric_quantized = false;
166166
} else {

src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ ConvertMOEToMOECompressed::ConvertMOEToMOECompressed(bool is_pa) {
262262
config.hidden_size = weight_shape[2];
263263
if (weight_shape.size() == 4) config.hidden_size *= weight_shape[3];
264264
config.inter_size = weight_shape[1];
265-
config.group_size = (weight_shape.size() == 3) ? config.hidden_size : scale_shape[3];
265+
config.group_size = (scale_shape.size() == 3) ? std::numeric_limits<size_t>::max() : scale_shape[3];
266266
config.top_k = topk_shape.rbegin()->get_length();
267267
config.out_type = ov::element::dynamic;
268268
config.has_batch_dim = is_pa ? 0 : 1;
@@ -272,7 +272,9 @@ ConvertMOEToMOECompressed::ConvertMOEToMOECompressed(bool is_pa) {
272272
args.push_back(pattern_map.at(topk_indices_gemm2_m));
273273
// params for up
274274
args.push_back(pattern_map.at(compressed_weights_m_up));
275-
auto transposed_index = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 2, 1, 3});
275+
auto transposed_index = (config.group_size == std::numeric_limits<size_t>::max()) ?
276+
std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{3}, std::vector<int64_t>{0, 2, 1}) :
277+
std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 2, 1, 3});
276278
{
277279
auto scale = std::make_shared<ov::op::v1::Transpose>(pattern_map.at(mul_const_m_up), transposed_index);
278280
args.push_back(scale);

0 commit comments

Comments
 (0)