Skip to content

Commit 2ec6c08

Browse files
committed
reflect review comments
1 parent 8e89d10 commit 2ec6c08

File tree

3 files changed

+23
-0
lines changed

3 files changed

+23
-0
lines changed

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "contrib_ops/webgpu/quantization/matmul_nbits_common.h"
55
#include <sstream>
66
#include "core/common/common.h"
7+
#include "core/providers/webgpu/webgpu_context.h"
78

89
namespace onnxruntime {
910
namespace contrib {
@@ -54,6 +55,12 @@ fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> )"
5455
return ss.str();
5556
}
5657

58+
bool HasDP4ADeviceSupport(int context_id) {
59+
auto& ctx = onnxruntime::webgpu::WebGpuContextFactory::GetContext(context_id);
60+
return ctx.DeviceHasFeature(wgpu::FeatureName::Subgroups) &&
61+
ctx.AdapterInfo().vendor != std::string_view{"apple"};
62+
}
63+
5764
} // namespace webgpu
5865
} // namespace contrib
5966
} // namespace onnxruntime

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ namespace webgpu {
2121
std::string GenerateZeroPointReadingCode(uint32_t nbits, bool has_zero_points,
2222
const std::string& output_type = "output_element_t");
2323

24+
/// Returns true when the default WebGPU device supports the DP4A kernel path
25+
/// (Subgroups feature present and non-Apple vendor).
26+
/// \p context_id is the WebGpuContext slot (0 for the default context).
27+
bool HasDP4ADeviceSupport(int context_id = 0);
28+
2429
} // namespace webgpu
2530
} // namespace contrib
2631
} // namespace onnxruntime

onnxruntime/test/contrib_ops/matmul_2bits_test.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
#include "core/session/ort_env.h"
2828
#include "core/util/qmath.h"
2929
#include "core/providers/webgpu/webgpu_provider_options.h"
30+
#ifdef USE_WEBGPU
31+
#include "contrib_ops/webgpu/quantization/matmul_nbits_common.h"
32+
#endif
3033

3134
extern std::unique_ptr<Ort::Env> ort_env;
3235

@@ -543,7 +546,15 @@ TEST(MatMul2BitsWebGpu, Float32_ZeroPoint_LargerK) {
543546
// DP4A path tests (accuracy_level=4) — exercises the 1024-entry LUT / dequantization
544547
// path for 2-bit weights with zero_points.
545548
// DP4A constraints: accuracy_level==4, block_size%32==0, K%128==0, N%16==0.
549+
// Skipped when the adapter lacks Subgroups support or is Apple (Metal),
550+
// because the DP4A kernel would silently fall back to the default path.
546551
TEST(MatMul2BitsWebGpu, Float32_ZeroPoint_DP4A) {
552+
// Ensure the WebGPU context is initialized so we can query adapter capabilities.
553+
auto ep = DefaultWebGpuExecutionProvider();
554+
if (!contrib::webgpu::HasDP4ADeviceSupport(ep->GetDeviceId())) {
555+
GTEST_SKIP() << "DP4A requires Subgroups support on a non-Apple adapter";
556+
}
557+
547558
TestOptions2Bits opts{};
548559
opts.accuracy_level = 4;
549560
opts.has_zero_point = true;

0 commit comments

Comments
 (0)