Skip to content

Commit 4a858a8

Browse files
authored
[webgpu] Support Split-K in more situations (#26806)
### Description This patch supports more `dim_inner` (up to 4096) for `Split-K` to optimize more models. This patch also enables `Split-K` on `gen-12lp`. ### Motivation and Context With this PR we can achieve about 30% improvement on `jina-clip-v1-text-fp16` and 20% improvement on `jina-embeddings-v2-base-code-fp16` on Lunar Lake iGPUs.
1 parent 541d5da commit 4a858a8

File tree

2 files changed

+57
-16
lines changed

2 files changed

+57
-16
lines changed

onnxruntime/core/providers/webgpu/webgpu_utils.cc

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,49 @@ TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components
2727

2828
SplitKConfig::SplitKConfig(const wgpu::AdapterInfo& adapter_info) {
2929
if (adapter_info.vendor == std::string_view{"intel"}) {
30-
if (adapter_info.architecture == std::string_view{"xe-2lpg"} ||
31-
adapter_info.architecture == std::string_view{"xe-2hpg"} ||
32-
adapter_info.architecture == std::string_view{"xe-lpg"} ||
33-
adapter_info.architecture == std::string_view{"gen-12hp"}) {
30+
// Disable Split-K on old Intel GPUs.
31+
if (adapter_info.architecture == std::string_view{"gen-7"} ||
32+
adapter_info.architecture == std::string_view{"gen-8"} ||
33+
adapter_info.architecture == std::string_view{"gen-9"} ||
34+
adapter_info.architecture == std::string_view{"gen-11"}) {
35+
enable_split_k_ = false;
36+
} else if (adapter_info.architecture == std::string_view{"xe-2lpg"} ||
37+
adapter_info.architecture == std::string_view{"xe-2hpg"} ||
38+
adapter_info.architecture == std::string_view{"gen-12hp"}) {
39+
// Below thresholds are only verified on Intel discreate GPUs and Lunar Lake iGPUs.
3440
enable_split_k_ = true;
3541

36-
// Below thresholds are only verified on the above Intel GPUs without any regressions. The
37-
// proper value of `max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_` may be
38-
// reduced when we support a larger `dim_inner` because larger `dim_inner` will bring more
39-
// atomic calls for each output value.
4042
split_dim_inner_ = 256;
4143
min_dim_inner_with_split_k_ = split_dim_inner_ * 2;
42-
max_dim_inner_with_split_k_ = split_dim_inner_ * 9;
43-
max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f;
44+
45+
configs_per_dim_inner_range_.emplace_back(768, 52.0f);
46+
configs_per_dim_inner_range_.emplace_back(2304, 35.0f);
47+
configs_per_dim_inner_range_.emplace_back(3072, 21.5f);
48+
configs_per_dim_inner_range_.emplace_back(4096, 16.0f);
49+
} else {
50+
// Below are the default thresholds on newer Intel GPUs. These values are chosen on
51+
// Intel "gen-12lp" GPU with 32EUs.
52+
enable_split_k_ = true;
53+
54+
split_dim_inner_ = 256;
55+
min_dim_inner_with_split_k_ = split_dim_inner_ * 2;
56+
57+
configs_per_dim_inner_range_.emplace_back(768, 20.0f);
58+
configs_per_dim_inner_range_.emplace_back(1792, 13.0f);
59+
configs_per_dim_inner_range_.emplace_back(3072, 8.0f);
60+
configs_per_dim_inner_range_.emplace_back(4096, 6.0f);
4461
}
4562
}
4663
}
4764

65+
SplitKConfig::ConfigAtRange::ConfigAtRange(uint32_t max_dim_inner, float rate)
66+
: max_dim_inner_with_rate(max_dim_inner), max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner(rate) {}
67+
68+
uint32_t SplitKConfig::GetMaxDimInnerWithSplitK() const {
69+
assert(!configs_per_dim_inner_range_.empty());
70+
return configs_per_dim_inner_range_.back().max_dim_inner_with_rate;
71+
}
72+
4873
bool SplitKConfig::UseSplitK(
4974
bool is_vec4,
5075
ActivationKind activation_kind,
@@ -71,11 +96,20 @@ bool SplitKConfig::UseSplitK(
7196
// Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and
7297
// `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and
7398
// `dim_inner)` as the metric to decide whether to use Split-K or not.
74-
use_split_k &= (dim_inner >= min_dim_inner_with_split_k_);
75-
use_split_k &= (dim_inner <= max_dim_inner_with_split_k_);
76-
use_split_k &= ((dim_a_outer * dim_b_outer * 1.0f / dim_inner) <= max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_);
99+
use_split_k &= dim_inner >= min_dim_inner_with_split_k_;
100+
use_split_k &= dim_inner <= GetMaxDimInnerWithSplitK();
77101

78-
return use_split_k;
102+
if (!use_split_k) {
103+
return false;
104+
}
105+
106+
const float rate = dim_a_outer * dim_b_outer * 1.0f / dim_inner;
107+
for (const auto& config_at_range : configs_per_dim_inner_range_) {
108+
if (dim_inner <= config_at_range.max_dim_inner_with_rate) {
109+
return rate <= config_at_range.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner;
110+
}
111+
}
112+
return false;
79113
}
80114

81115
uint32_t SplitKConfig::GetSplitDimInner() const {

onnxruntime/core/providers/webgpu/webgpu_utils.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,15 @@ class SplitKConfig {
116116
bool enable_split_k_ = false;
117117
uint32_t split_dim_inner_ = 0;
118118
uint32_t min_dim_inner_with_split_k_ = 0;
119-
uint32_t max_dim_inner_with_split_k_ = 0;
120-
float max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 0.0f;
119+
120+
uint32_t GetMaxDimInnerWithSplitK() const;
121+
122+
struct ConfigAtRange {
123+
ConfigAtRange(uint32_t max_dim_inner, float rate);
124+
uint32_t max_dim_inner_with_rate = 0;
125+
float max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner = 0.0f;
126+
};
127+
std::vector<ConfigAtRange> configs_per_dim_inner_range_;
121128
};
122129

123130
/**

0 commit comments

Comments
 (0)