@@ -27,24 +27,49 @@ TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components
2727
2828SplitKConfig::SplitKConfig (const wgpu::AdapterInfo& adapter_info) {
2929 if (adapter_info.vendor == std::string_view{" intel" }) {
30- if (adapter_info.architecture == std::string_view{" xe-2lpg" } ||
31- adapter_info.architecture == std::string_view{" xe-2hpg" } ||
32- adapter_info.architecture == std::string_view{" xe-lpg" } ||
33- adapter_info.architecture == std::string_view{" gen-12hp" }) {
30+ // Disable Split-K on old Intel GPUs.
31+ if (adapter_info.architecture == std::string_view{" gen-7" } ||
32+ adapter_info.architecture == std::string_view{" gen-8" } ||
33+ adapter_info.architecture == std::string_view{" gen-9" } ||
34+ adapter_info.architecture == std::string_view{" gen-11" }) {
35+ enable_split_k_ = false ;
36+ } else if (adapter_info.architecture == std::string_view{" xe-2lpg" } ||
37+ adapter_info.architecture == std::string_view{" xe-2hpg" } ||
38+ adapter_info.architecture == std::string_view{" gen-12hp" }) {
39+ // Below thresholds are only verified on Intel discreate GPUs and Lunar Lake iGPUs.
3440 enable_split_k_ = true ;
3541
36- // Below thresholds are only verified on the above Intel GPUs without any regressions. The
37- // proper value of `max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_` may be
38- // reduced when we support a larger `dim_inner` because larger `dim_inner` will bring more
39- // atomic calls for each output value.
4042 split_dim_inner_ = 256 ;
4143 min_dim_inner_with_split_k_ = split_dim_inner_ * 2 ;
42- max_dim_inner_with_split_k_ = split_dim_inner_ * 9 ;
43- max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35 .0f ;
44+
45+ configs_per_dim_inner_range_.emplace_back (768 , 52 .0f );
46+ configs_per_dim_inner_range_.emplace_back (2304 , 35 .0f );
47+ configs_per_dim_inner_range_.emplace_back (3072 , 21 .5f );
48+ configs_per_dim_inner_range_.emplace_back (4096 , 16 .0f );
49+ } else {
50+ // Below are the default thresholds on newer Intel GPUs. These values are chosen on
51+ // Intel "gen-12lp" GPU with 32EUs.
52+ enable_split_k_ = true ;
53+
54+ split_dim_inner_ = 256 ;
55+ min_dim_inner_with_split_k_ = split_dim_inner_ * 2 ;
56+
57+ configs_per_dim_inner_range_.emplace_back (768 , 20 .0f );
58+ configs_per_dim_inner_range_.emplace_back (1792 , 13 .0f );
59+ configs_per_dim_inner_range_.emplace_back (3072 , 8 .0f );
60+ configs_per_dim_inner_range_.emplace_back (4096 , 6 .0f );
4461 }
4562 }
4663}
4764
65+ SplitKConfig::ConfigAtRange::ConfigAtRange (uint32_t max_dim_inner, float rate)
66+ : max_dim_inner_with_rate(max_dim_inner), max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner(rate) {}
67+
68+ uint32_t SplitKConfig::GetMaxDimInnerWithSplitK () const {
69+ assert (!configs_per_dim_inner_range_.empty ());
70+ return configs_per_dim_inner_range_.back ().max_dim_inner_with_rate ;
71+ }
72+
4873bool SplitKConfig::UseSplitK (
4974 bool is_vec4,
5075 ActivationKind activation_kind,
@@ -71,11 +96,20 @@ bool SplitKConfig::UseSplitK(
7196 // Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and
7297 // `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and
7398 // `dim_inner)` as the metric to decide whether to use Split-K or not.
74- use_split_k &= (dim_inner >= min_dim_inner_with_split_k_);
75- use_split_k &= (dim_inner <= max_dim_inner_with_split_k_);
76- use_split_k &= ((dim_a_outer * dim_b_outer * 1 .0f / dim_inner) <= max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_);
99+ use_split_k &= dim_inner >= min_dim_inner_with_split_k_;
100+ use_split_k &= dim_inner <= GetMaxDimInnerWithSplitK ();
77101
78- return use_split_k;
102+ if (!use_split_k) {
103+ return false ;
104+ }
105+
106+ const float rate = dim_a_outer * dim_b_outer * 1 .0f / dim_inner;
107+ for (const auto & config_at_range : configs_per_dim_inner_range_) {
108+ if (dim_inner <= config_at_range.max_dim_inner_with_rate ) {
109+ return rate <= config_at_range.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner ;
110+ }
111+ }
112+ return false ;
79113}
80114
81115uint32_t SplitKConfig::GetSplitDimInner () const {
0 commit comments