Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxruntime/core/providers/webgpu/math/gemm_packed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ Status ApplyGemmPacked(const Tensor* a,
const bool output_is_vec4 = output_components == 4;
// The parameter `is_channel_last` is not used for GEMM.
const bool need_split_k = split_k_config.UseSplitK(
&context,
is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, /*is_gemm*/ true, /*is_channels_last*/ true, M, N, K);
if (need_split_k) {
const Tensor* bias = nullptr;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webgpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ Status ComputeMatMul(ComputeContext* context,
uint32_t split_dim_inner = 1;

const SplitKConfig& split_k_config = context->GetSplitKConfig();
const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, /*is_gemm*/ false, is_channels_last, dim_a_outer, dim_b_outer, dim_inner);
const bool need_split_k = split_k_config.UseSplitK(context, is_vec4, activation.activation_kind_, batch_size, /*is_gemm*/ false, is_channels_last, dim_a_outer, dim_b_outer, dim_inner);
if (need_split_k) {
ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1.");
ORT_ENFORCE(is_vec4, "Split-K MatMul only supports bias in vec4 format.");
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "core/providers/webgpu/webgpu_utils.h"

#include <sstream>
#include "core/providers/webgpu/compute_context.h"
#include "core/providers/webgpu/shader_variable.h"

namespace onnxruntime {
Expand Down Expand Up @@ -71,6 +72,7 @@ uint32_t SplitKConfig::GetMaxDimInnerWithSplitK() const {
}

bool SplitKConfig::UseSplitK(
ComputeContext* context,
bool is_vec4,
ActivationKind activation_kind,
uint64_t batch_size,
Expand All @@ -79,6 +81,11 @@ bool SplitKConfig::UseSplitK(
uint32_t dim_a_outer,
uint32_t dim_b_outer,
uint32_t dim_inner) const {
// Current Split-K implementation relies on atomic operations, which are not deterministic.
if (context->KernelContext().GetUseDeterministicCompute()) {
return false;
}

if (!enable_split_k_) {
return false;
}
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
namespace onnxruntime {
namespace webgpu {

class ComputeContext;
class ShaderVariableHelper;

template <typename T>
Expand Down Expand Up @@ -106,6 +107,7 @@ class SplitKConfig {
explicit SplitKConfig(const wgpu::AdapterInfo& adapter_info);

bool UseSplitK(
ComputeContext* context,
bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, bool is_gemm,
bool is_channels_last, uint32_t dim_a_outer,
uint32_t dim_b_outer, uint32_t dim_inner) const;
Expand Down
Loading