@@ -535,24 +535,32 @@ Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const
535535 shader.AddInput (" input_a" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
536536 shader.AddOutput (" output" , ShaderUsage::UseUniform);
537537 shader.AddOutput (" scales" , ShaderUsage::UseUniform);
538-
538+ shader.AdditionalImplementation () << R"ADDNL_FN(
539+ fn readInput(offset: u32) -> input_a_value_t
540+ {
541+ if (offset > uniforms.input_size) {
542+ return input_a_value_t(0);
543+ }
544+ return input_a[offset];
545+ }
546+ )ADDNL_FN" ;
539547 shader.MainFunctionBody () << R"MAIN_FN(
540548 var local_a : array<vec4<input_a_element_t>, 32>;
541549 var max_value:vec4<input_a_element_t> = vec4<input_a_element_t>(0);
542550 for (var idx:u32=0;idx<32;idx+=1)
543551 {
544- local_a[idx] = input_a[workgroup_id.x *32 + idx] ;
552+ local_a[idx] = readInput(workgroup_idx *32 + idx) ;
545553 max_value = max(max_value, abs(local_a[idx]));
546554 }
547555 var scale = max(max_value.x, max_value.y);
548556 scale = max(scale, max_value.z);
549557 scale = max(scale, max_value.w);
550558 for (var idx:u32=0;idx<32;idx+=1)
551559 {
552- output[workgroup_id.x *32+idx] = pack4x8snorm(vec4<f32>(local_a[idx]/scale));
560+ output[workgroup_idx *32+idx] = pack4x8snorm(vec4<f32>(local_a[idx]/scale));
553561 }
554562 // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
555- scales[workgroup_id.x ] = scale/127;
563+ scales[workgroup_idx ] = scale/127;
556564)MAIN_FN" ;
557565 return Status::OK ();
558566}
@@ -828,7 +836,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
828836 Tensor a_scale = context.CreateGPUTensor (a->DataType (), a_scales_dims);
829837 quantize_program.AddInputs ({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(kVec4Components )}})
830838 .AddOutputs ({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape (), gsl::narrow<int >(1 )},
831- {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape (), gsl::narrow<int >(1 )}});
839+ {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape (), gsl::narrow<int >(1 )}})
840+ .AddUniformVariable ({static_cast <uint32_t >(M * K / kVec4Components )});
832841 ORT_RETURN_IF_ERROR (context.RunProgram (quantize_program));
833842
834843 constexpr uint32_t kTileSize = 64 ;
0 commit comments