-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[WebGPU] Add gating logic for subgroup shuffle primitives #18823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
097d05f
f119bbd
b1e3688
07d011c
3298e94
b1139a9
e9697fe
d95827a
397ac1b
89d6142
9a3edc9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -723,6 +723,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { | |||||||||||||||||||||||||||||||
| // Emit warp shuffle calls. | ||||||||||||||||||||||||||||||||
| PrimExpr WarpShuffle(const Op& op, ffi::Optional<Buffer> mask_buffer, PrimExpr val, | ||||||||||||||||||||||||||||||||
| PrimExpr delta_or_lane) { | ||||||||||||||||||||||||||||||||
| // WebGPU's WGSL requires u32 for subgroupShuffle lane/delta arguments. | ||||||||||||||||||||||||||||||||
| if (target_->kind->name == "webgpu") { | ||||||||||||||||||||||||||||||||
| delta_or_lane = cast(DataType::UInt(32, delta_or_lane.dtype().lanes()), delta_or_lane); | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| ffi::Array<PrimExpr> indices = {0}; | ||||||||||||||||||||||||||||||||
| PrimExpr mask; | ||||||||||||||||||||||||||||||||
| if (mask_buffer.defined()) { | ||||||||||||||||||||||||||||||||
|
|
@@ -742,11 +746,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { | |||||||||||||||||||||||||||||||
| bool IsWarpReduction(const std::vector<DataType>& types, int group_extent, int reduce_extent, | ||||||||||||||||||||||||||||||||
| int contiguous_reduce_extent) { | ||||||||||||||||||||||||||||||||
| if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") && | ||||||||||||||||||||||||||||||||
| (target_->kind->name != "metal")) { | ||||||||||||||||||||||||||||||||
| (target_->kind->name != "metal") && (target_->kind->name != "webgpu")) { | ||||||||||||||||||||||||||||||||
| return false; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a check here of the following form: This is to avoid scenarios where a target such as |
||||||||||||||||||||||||||||||||
| need_warp_shuffle_mask_ = target_->kind->name != "metal"; | ||||||||||||||||||||||||||||||||
| need_warp_shuffle_mask_ = target_->kind->name != "metal" && target_->kind->name != "webgpu"; | ||||||||||||||||||||||||||||||||
|
Comment on lines
748
to
+753
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To improve maintainability, consider using
Suggested change
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| // rocm only supports 32 bit operands for shuffling at the moment | ||||||||||||||||||||||||||||||||
| if ((target_->kind->name == "rocm") && | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -107,6 +107,9 @@ std::string CodeGenWebGPU::Finish() { | |||||||||||||||||||||||
| if (enable_fp16_) { | ||||||||||||||||||||||||
| header_stream << "enable f16;\n\n"; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| if (enable_subgroups_) { | ||||||||||||||||||||||||
| header_stream << "enable subgroups;\n\n"; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| return header_stream.str() + decl_stream.str() + this->fwd_decl_stream.str() + stream.str(); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
@@ -120,7 +123,9 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) { | |||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {} | ||||||||||||||||||||||||
| CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) { | ||||||||||||||||||||||||
| enable_subgroups_ = target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false)); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| enable_subgroups_ = target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false)); | |
| Bool supports_subgroups = target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false)); | |
| Optional<Integer> thread_warp_size = target_->GetAttr<Integer>("thread_warp_size"); | |
| bool warp_uses_subgroups = | |
| thread_warp_size.defined() && thread_warp_size.value()->value > 1; | |
| if (warp_uses_subgroups && !supports_subgroups) { | |
| LOG(FATAL) << "WebGPU target has thread_warp_size=" << thread_warp_size.value()->value | |
| << " but does not support subgroups. Either enable the 'supports_subgroups' " | |
| << "target attribute or set thread_warp_size <= 1."; | |
| } | |
| enable_subgroups_ = supports_subgroups || warp_uses_subgroups; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add the following check here:
Bool supports_subgroups = target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false));
int64_t thread_warp_size = target_->GetAttr<Integer>("thread_warp_size", 1).value()->value;
if (thread_warp_size > 1 && !supports_subgroups) {
LOG(FATAL) << "WebGPU target has thread_warp_size=" << thread_warp_size
<< " but supports_subgroups is false.";
}
enable_subgroups_ = supports_subgroups;
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -427,8 +427,28 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) | |||||||||||||
| // Tags | ||||||||||||||
| .set_default_keys({"vulkan", "gpu"}); | ||||||||||||||
|
|
||||||||||||||
| /*! | ||||||||||||||
| * \brief Update WebGPU target attributes based on subgroup support. | ||||||||||||||
| * When supports_subgroups is true, set thread_warp_size to 32 so that | ||||||||||||||
| * TIR lowering uses warp-level shuffle reductions instead of shared memory. | ||||||||||||||
| */ | ||||||||||||||
| ffi::Map<ffi::String, ffi::Any> UpdateWebGPUAttrs(ffi::Map<ffi::String, ffi::Any> target) { | ||||||||||||||
| if (target.count("supports_subgroups")) { | ||||||||||||||
| bool subgroups = Downcast<Bool>(target.at("supports_subgroups")); | ||||||||||||||
| if (subgroups) { | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a comment stating the following:
|
||||||||||||||
| target.Set("thread_warp_size", int64_t(32)); | ||||||||||||||
| } | ||||||||||||||
|
Comment on lines
+438
to
+440
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This implementation unconditionally sets
Suggested change
|
||||||||||||||
| } | ||||||||||||||
| return target; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) | ||||||||||||||
| .add_attr_option<int64_t>("max_num_threads", refl::DefaultValue(256)) | ||||||||||||||
| .add_attr_option<bool>("supports_subgroups", refl::DefaultValue(false)) | ||||||||||||||
| // thread_warp_size=1: is_subwarp_reduction and is_multiwarp_reduction returns false, so no | ||||||||||||||
| // subgroup ops are emitted. | ||||||||||||||
| .add_attr_option<int64_t>("thread_warp_size", refl::DefaultValue(1)) | ||||||||||||||
| .set_target_canonicalizer(UpdateWebGPUAttrs) | ||||||||||||||
| .set_default_keys({"webgpu", "gpu"}); | ||||||||||||||
|
|
||||||||||||||
| TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) | ||||||||||||||
|
|
||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe you can move this to DispatchWebGPUShuffle by casting call->args[2] to UInt(32)