Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 32 additions & 30 deletions onnxruntime/core/providers/webgpu/math/gemm_packed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,37 +104,39 @@ Status ApplyGemmPacked(const Tensor* a,
uint32_t dispatch_z = 1;
uint32_t split_dim_inner = 1;

const SplitKConfig& split_k_config = context.GetSplitKConfig();
// Currently we require the components for Y must also be a multiple of 4 when Split-K is used.
const bool output_is_vec4 = output_components == 4;
// The parameter `is_channel_last` is not used for GEMM.
const bool need_split_k = split_k_config.UseSplitK(
is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, /*is_gemm*/ true, /*is_channels_last*/ true, M, N, K);
if (need_split_k) {
const Tensor* bias = nullptr;
uint32_t output_components_in_fill_bias_program = 4;
if (need_handle_bias) {
bias = c;
output_components_in_fill_bias_program = c_components;
// Current Split-K implementation relies on atomic operations, which are not deterministic.
if (!context.KernelContext().GetUseDeterministicCompute()) {
const SplitKConfig& split_k_config = context.GetSplitKConfig();
// Currently we require the components for Y must also be a multiple of 4 when Split-K is used.
const bool output_is_vec4 = output_components == 4;
// The parameter `is_channel_last` is not used for GEMM.
const bool need_split_k = split_k_config.UseSplitK(is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, /*is_gemm*/ true, /*is_channels_last*/ true, M, N, K);
if (need_split_k) {
const Tensor* bias = nullptr;
uint32_t output_components_in_fill_bias_program = 4;
if (need_handle_bias) {
bias = c;
output_components_in_fill_bias_program = c_components;
}
const TensorShape output_shape = TensorShape{M, N / output_components_in_fill_bias_program};

auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(
bias, y, /*is_gemm*/ true, beta, output_components_in_fill_bias_program, c_is_scalar, output_shape);
ORT_RETURN_IF_ERROR(context.RunProgram(fill_bias_program));

// When Split-K is used, `bias` will be handled in `MatMulFillBiasOrZeroBeforeSplitKProgram`
// instead of here.
need_handle_bias = false;

// With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the
// number of splits along `dim_inner`.
split_dim_inner = split_k_config.GetSplitDimInner();
dispatch_z = (K + split_dim_inner - 1) / split_dim_inner;

// The output should be declared in atomic types in `MatMulProgram` for the use of atomic
// built-in functions.
output.is_atomic = true;
}
const TensorShape output_shape = TensorShape{M, N / output_components_in_fill_bias_program};

auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(
bias, y, /*is_gemm*/ true, beta, output_components_in_fill_bias_program, c_is_scalar, output_shape);
ORT_RETURN_IF_ERROR(context.RunProgram(fill_bias_program));

// When Split-K is used, `bias` will be handled in `MatMulFillBiasOrZeroBeforeSplitKProgram`
// instead of here.
need_handle_bias = false;

// With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the
// number of splits along `dim_inner`.
split_dim_inner = split_k_config.GetSplitDimInner();
dispatch_z = (K + split_dim_inner - 1) / split_dim_inner;

// The output should be declared in atomic types in `MatMulProgram` for the use of atomic
// built-in functions.
output.is_atomic = true;
}

GemmProgram program{transA, transB, alpha, need_handle_bias, need_handle_matmul, c_components, c_is_scalar, output_components, is_vec4, split_dim_inner};
Expand Down
49 changes: 26 additions & 23 deletions onnxruntime/core/providers/webgpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,29 +245,32 @@ Status ComputeMatMul(ComputeContext* context,
bool use_bias_in_matmul = has_bias;
uint32_t split_dim_inner = 1;

const SplitKConfig& split_k_config = context->GetSplitKConfig();
const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, /*is_gemm*/ false, is_channels_last, dim_a_outer, dim_b_outer, dim_inner);
if (need_split_k) {
ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1.");
ORT_ENFORCE(is_vec4, "Split-K MatMul only supports bias in vec4 format.");
ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format.");

// Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled.
const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, /*is_gemm*/ false, /*beta*/ 1.0f, /*bias_components*/ 4, /*bias_is_scalar*/ false, output_shape_temp);
ORT_RETURN_IF_ERROR(context->RunProgram(fill_bias_program));

// `bias` has been handled in the execution of `fill_bias_program` so we don't need to set
// `bias` again in `MatMulProgram`.
use_bias_in_matmul = false;

// With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the
// number of splits along `dim_inner`.
split_dim_inner = split_k_config.GetSplitDimInner();
dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner;

// The output should be declared in atomic types in `MatMulProgram` for the use of atomic
// built-in functions.
output.is_atomic = true;
// Current Split-K implementation relies on atomic operations, which are not deterministic.
if (!context->KernelContext().GetUseDeterministicCompute()) {
const SplitKConfig& split_k_config = context->GetSplitKConfig();
const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, /*is_gemm*/ false, is_channels_last, dim_a_outer, dim_b_outer, dim_inner);
if (need_split_k) {
ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1.");
ORT_ENFORCE(is_vec4, "Split-K MatMul only supports bias in vec4 format.");
ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format.");

// Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled.
const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, /*is_gemm*/ false, /*beta*/ 1.0f, /*bias_components*/ 4, /*bias_is_scalar*/ false, output_shape_temp);
ORT_RETURN_IF_ERROR(context->RunProgram(fill_bias_program));

// `bias` has been handled in the execution of `fill_bias_program` so we don't need to set
// `bias` again in `MatMulProgram`.
use_bias_in_matmul = false;

// With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the
// number of splits along `dim_inner`.
split_dim_inner = split_k_config.GetSplitDimInner();
dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner;

// The output should be declared in atomic types in `MatMulProgram` for the use of atomic
// built-in functions.
output.is_atomic = true;
}
}

MatMulProgram matmul_program{activation, use_bias_in_matmul, is_vec4, elements_per_thread, is_channels_last, split_dim_inner};
Expand Down
Loading