@@ -65,16 +65,18 @@ struct moe_gemm_onednn : typed_primitive_onednn_impl<moe_gemm> {
6565 dnnl::memory::dim d0 = wei_scales_shape[0 ];
6666 dnnl::memory::dim d1 = wei_scales_shape[1 ];
6767 dnnl::memory::dim d2 = wei_scales_shape[2 ];
68+ dnnl::memory::dims wei_scales_dims = (moe_cfg.weight_group_size == -1 ) ? dnnl::memory::dims{d0, d2} : dnnl::memory::dims{d0, d1, d2};
69+ dnnl::memory::format_tag wei_scales_fmt = (moe_cfg.weight_group_size == -1 ) ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::abc;
6870 dnnl::memory::desc wei_scales_md (
69- {d0, d1, d2}, convert_data_type (wei_scales.get_layout ().data_type ), dnnl::memory::format_tag::abc );
71+ wei_scales_dims, convert_data_type (wei_scales.get_layout ().data_type ), wei_scales_fmt );
7072 dnnl::memory wei_scales_mem = dnnl::ocl_interop::make_memory (wei_scales_md, onednn_engine, dnnl::ocl_interop::memory_kind::usm,
7173 reinterpret_cast <uint8_t *>(wei_scales.buffer_ptr ()));
7274 args.insert ({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_mem});
7375
7476 if (!moe_cfg.is_weight_symmetric_quantized ) {
7577 auto & wei_zp = instance.input_memory (moe_cfg.weight_zp_idx );
7678 dnnl::memory::desc wei_zp_md (
77- {d0, d1, d2}, convert_data_type (wei_zp.get_layout ().data_type ), dnnl::memory::format_tag::abc );
79+ wei_scales_dims, convert_data_type (wei_zp.get_layout ().data_type ), wei_scales_fmt );
7880 dnnl::memory wei_zp_mem = dnnl::ocl_interop::make_memory (wei_zp_md, onednn_engine, dnnl::ocl_interop::memory_kind::usm,
7981 reinterpret_cast <uint8_t *>(wei_zp.buffer_ptr ()));
8082 args.insert ({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, wei_zp_mem});
@@ -157,16 +159,30 @@ struct moe_gemm_onednn : typed_primitive_onednn_impl<moe_gemm> {
157159 auto moe_cfg = MoEGemmImplementationManager::get_moe_cfg (impl_params);
158160
159161 if (moe_cfg.is_weight_quantized ) {
160- attr->set_scales (DNNL_ARG_WEIGHTS,
161- (1 << 0 ) | (1 << 1 ) | (1 << 2 ),
162- {moe_cfg.weight_group_size , 1 },
163- convert_data_type (impl_params.get_input_layout (moe_cfg.weight_scale_idx ).data_type ));
162+ if (moe_cfg.weight_group_size == -1 ) {
163+ attr->set_scales (DNNL_ARG_WEIGHTS,
164+ (1 << 0 ) | (1 << 2 ),
165+ {},
166+ convert_data_type (impl_params.get_input_layout (moe_cfg.weight_scale_idx ).data_type ));
167+ } else {
168+ attr->set_scales (DNNL_ARG_WEIGHTS,
169+ (1 << 0 ) | (1 << 1 ) | (1 << 2 ),
170+ {moe_cfg.weight_group_size , 1 },
171+ convert_data_type (impl_params.get_input_layout (moe_cfg.weight_scale_idx ).data_type ));
172+ }
164173
165174 if (!moe_cfg.is_weight_symmetric_quantized ) {
166- attr->set_zero_points (DNNL_ARG_WEIGHTS,
167- (1 << 0 ) | (1 << 1 ) | (1 << 2 ),
168- {moe_cfg.weight_group_size , 1 },
169- convert_data_type (impl_params.get_input_layout (moe_cfg.weight_zp_idx ).data_type ));
175+ if (moe_cfg.weight_group_size == -1 ) {
176+ attr->set_zero_points (DNNL_ARG_WEIGHTS,
177+ (1 << 0 ) | (1 << 2 ),
178+ {},
179+ convert_data_type (impl_params.get_input_layout (moe_cfg.weight_zp_idx ).data_type ));
180+ } else {
181+ attr->set_zero_points (DNNL_ARG_WEIGHTS,
182+ (1 << 0 ) | (1 << 1 ) | (1 << 2 ),
183+ {moe_cfg.weight_group_size , 1 },
184+ convert_data_type (impl_params.get_input_layout (moe_cfg.weight_zp_idx ).data_type ));
185+ }
170186 }
171187 }
172188
0 commit comments