Skip to content
Open
8 changes: 6 additions & 2 deletions src/s_tir/transform/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// Emit warp shuffle calls.
PrimExpr WarpShuffle(const Op& op, ffi::Optional<Buffer> mask_buffer, PrimExpr val,
PrimExpr delta_or_lane) {
// WebGPU's WGSL requires u32 for subgroupShuffle lane/delta arguments.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you can move this to DispatchWebGPUShuffle by casting call->args[2] to UInt(32)

if (target_->kind->name == "webgpu") {
delta_or_lane = cast(DataType::UInt(32, delta_or_lane.dtype().lanes()), delta_or_lane);
}
ffi::Array<PrimExpr> indices = {0};
PrimExpr mask;
if (mask_buffer.defined()) {
Expand All @@ -742,11 +746,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
bool IsWarpReduction(const std::vector<DataType>& types, int group_extent, int reduce_extent,
int contiguous_reduce_extent) {
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") &&
(target_->kind->name != "metal")) {
(target_->kind->name != "metal") && (target_->kind->name != "webgpu")) {
return false;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a check here of the following form:

if (target_->kind->name == "webgpu" && !supports_subgroups_) {
  return false;
}

This is to avoid scenarios where a target such as {"kind":"webgpu","thread_warp_size":32,"supports_subgroups":false} would still emit subgroup ops, but the WGSL would not contain enable subgroups;.

need_warp_shuffle_mask_ = target_->kind->name != "metal";
need_warp_shuffle_mask_ = target_->kind->name != "metal" && target_->kind->name != "webgpu";
Comment on lines 748 to +753
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve maintainability, consider using std::unordered_set for checking the target kind. This makes it easier to add or remove supported targets in the future.

Suggested change
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") &&
(target_->kind->name != "metal")) {
(target_->kind->name != "metal") && (target_->kind->name != "webgpu")) {
return false;
}
need_warp_shuffle_mask_ = target_->kind->name != "metal";
need_warp_shuffle_mask_ = target_->kind->name != "metal" && target_->kind->name != "webgpu";
const std::unordered_set<std::string> supported_targets = {"cuda", "rocm", "metal", "webgpu"};
if (!supported_targets.count(target_->kind->name)) {
return false;
}
const std::unordered_set<std::string> no_mask_targets = {"metal", "webgpu"};
need_warp_shuffle_mask_ = !no_mask_targets.count(target_->kind->name);


// rocm only supports 32 bit operands for shuffling at the moment
if ((target_->kind->name == "rocm") &&
Expand Down
7 changes: 6 additions & 1 deletion src/target/source/codegen_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ std::string CodeGenWebGPU::Finish() {
if (enable_fp16_) {
header_stream << "enable f16;\n\n";
}
if (enable_subgroups_) {
header_stream << "enable subgroups;\n\n";
}
return header_stream.str() + decl_stream.str() + this->fwd_decl_stream.str() + stream.str();
}

Expand All @@ -120,7 +123,9 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
}
}

CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {}
CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {
enable_subgroups_ = target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false));
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable subgroups; is currently controlled only by the supports_subgroups target attr, but subgroup shuffle ops can be emitted whenever thread_warp_size is set > 1 (e.g. if a user sets -thread-warp-size=32 directly on the WebGPU target). In that case, the generated WGSL would contain subgroupShuffle* calls without the required enable subgroups; directive. To avoid this inconsistent state, consider deriving enable_subgroups_ from thread_warp_size > 1 as well, or emitting a clear error if subgroup ops are encountered while supports_subgroups is false.

Suggested change
enable_subgroups_ = target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false));
Bool supports_subgroups = target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false));
Optional<Integer> thread_warp_size = target_->GetAttr<Integer>("thread_warp_size");
bool warp_uses_subgroups =
thread_warp_size.defined() && thread_warp_size.value()->value > 1;
if (warp_uses_subgroups && !supports_subgroups) {
LOG(FATAL) << "WebGPU target has thread_warp_size=" << thread_warp_size.value()->value
<< " but does not support subgroups. Either enable the 'supports_subgroups' "
<< "target attribute or set thread_warp_size <= 1.";
}
enable_subgroups_ = supports_subgroups || warp_uses_subgroups;

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add the following check here:

Bool supports_subgroups = target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false));
int64_t thread_warp_size = target_->GetAttr<Integer>("thread_warp_size", 1).value()->value;
if (thread_warp_size > 1 && !supports_subgroups) {
  LOG(FATAL) << "WebGPU target has thread_warp_size=" << thread_warp_size
             << " but supports_subgroups is false.";
}
enable_subgroups_ = supports_subgroups;

}

runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_readonly_decl) {
// clear previous generated state.
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_webgpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class CodeGenWebGPU final : public CodeGenC {

// whether enable fp16
bool enable_fp16_{false};
// whether enable subgroups
bool enable_subgroups_{false};

/*! \brief the header stream for function label and enable directive if any, goes before any other
* declaration */
Expand Down
58 changes: 58 additions & 0 deletions src/target/source/intrin_rule_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,29 @@ namespace intrin {

using tir::FLowerIntrinsic;

// warp-level primitives. Follows implementation in intrin_rule_metal.cc
struct WebGPUWarpIntrinsic {
const Op operator()(DataType t, const Op& orig_op) const {
if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
return Op::Get("tir.webgpu.subgroup_shuffle");
} else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
return Op::Get("tir.webgpu.subgroup_shuffle_up");
} else {
TVM_FFI_ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
return Op::Get("tir.webgpu.subgroup_shuffle_down");
}
}
};

template <typename T>
static PrimExpr DispatchWebGPUShuffle(const PrimExpr& e) {
const CallNode* call = e.as<CallNode>();
TVM_FFI_ICHECK(call != nullptr);
TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
ffi::Array<PrimExpr> webgpu_args{{call->args[1], call->args[2]}};
return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), webgpu_args);
}

// See full list of builtin: https://www.w3.org/TR/WGSL/#builtin-functions

struct ReturnAbs {
Expand Down Expand Up @@ -113,6 +136,41 @@ TVM_REGISTER_OP("tir.trunc")
// extra dispatch
TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchFastErf);

// warp-level primitives. Follows implementation in intrin_rule_metal.cc
TVM_REGISTER_OP("tir.tvm_warp_shuffle")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic",
DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);

TVM_REGISTER_OP("tir.tvm_warp_shuffle_up")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic",
DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);

TVM_REGISTER_OP("tir.tvm_warp_shuffle_down")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic",
DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);

// Register low-level builtin ops.
TVM_REGISTER_OP("tir.webgpu.subgroup_shuffle")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("lane", "Expr", "The source thread id.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffle")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TVM_REGISTER_OP("tir.webgpu.subgroup_shuffle_up")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("delta", "Expr", "The source lane id offset to be added.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffleUp")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TVM_REGISTER_OP("tir.webgpu.subgroup_shuffle_down")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("delta", "Expr", "The source lane id offset to be subtracted.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffleDown")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

} // namespace intrin
} // namespace codegen
} // namespace tvm
20 changes: 20 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,28 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
// Tags
.set_default_keys({"vulkan", "gpu"});

/*!
* \brief Update WebGPU target attributes based on subgroup support.
* When supports_subgroups is true, set thread_warp_size to 32 so that
* TIR lowering uses warp-level shuffle reductions instead of shared memory.
*/
ffi::Map<ffi::String, ffi::Any> UpdateWebGPUAttrs(ffi::Map<ffi::String, ffi::Any> target) {
if (target.count("supports_subgroups")) {
bool subgroups = Downcast<Bool>(target.at("supports_subgroups"));
if (subgroups) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment stating the following:

  1. Runtime routing on the WebLLM side guarantees subgroup size == 32
  2. Runtime routing on the WebLLM side guarantees maxComputeInvocationsPerWorkgroup >= 1024
  3. This is intentionally constrained for the subgroup-enabled WASM variant

target.Set("thread_warp_size", int64_t(32));
}
Comment on lines +438 to +440
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This implementation unconditionally sets thread_warp_size to 32 if supports_subgroups is true. This will overwrite any user-provided value for thread_warp_size. It would be more flexible to only set this value if the user has not already specified one, allowing for overrides on devices that may have different subgroup sizes.

Suggested change
if (subgroups) {
target.Set("thread_warp_size", int64_t(32));
}
if (subgroups && !target.count("thread_warp_size")) {
target.Set("thread_warp_size", int64_t(32));
}

}
return target;
}

TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU)
.add_attr_option<int64_t>("max_num_threads", refl::DefaultValue(256))
.add_attr_option<bool>("supports_subgroups", refl::DefaultValue(false))
// thread_warp_size=1: is_subwarp_reduction and is_multiwarp_reduction returns false, so no
// subgroup ops are emitted.
.add_attr_option<int64_t>("thread_warp_size", refl::DefaultValue(1))
.set_target_canonicalizer(UpdateWebGPUAttrs)
.set_default_keys({"webgpu", "gpu"});

TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,5 +406,204 @@ def main(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32
assert "tvm_storage_sync" in After_script


def test_webgpu_warp_reduce():
transform = tvm.s_tir.transform.LowerThreadAllreduce()

@I.ir_module
class Before:
@T.prim_func(private=True)
def main(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")):
T.func_attr(
{
"target": T.target(
{
"kind": "webgpu",
"supports_subgroups": True,
"host": "llvm",
}
),
}
)
A_flat = T.Buffer(4096, data=A.data)

for i in range(128):
threadIdx_x = T.launch_thread("threadIdx.x", 32)

reduce_data = T.allocate([1], "float32", "local")
reduce = T.Buffer(1, data=reduce_data, scope="local")

with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
T.tvm_thread_allreduce(
T.uint32(1),
A_flat[0],
T.bool(True),
reduce[0],
threadIdx_x,
)
if threadIdx_x == 0:
B[i] = reduce[0]

@I.ir_module
class Expected:
@T.prim_func(private=True)
def main(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")):
T.func_attr(
{
"target": T.target(
{
"kind": "webgpu",
"supports_subgroups": True,
"host": "llvm",
}
),
}
)
A_flat = T.Buffer(4096, data=A.data)

for i in range(128):
threadIdx_x = T.launch_thread("threadIdx.x", 32)

reduce_data = T.allocate([1], "float32", "local")
reduce = T.Buffer(1, data=reduce_data, scope="local")

with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
t0_data = T.allocate([1], "float32", "local")
t0 = T.decl_buffer(1, data=t0_data, scope="local")

reduce[0] = A_flat[0]

t0[0] = T.tvm_warp_shuffle_down(0, reduce[0], T.uint32(16), 32, 32)
reduce[0] = reduce[0] + t0[0]
t0[0] = T.tvm_warp_shuffle_down(0, reduce[0], T.uint32(8), 32, 32)
reduce[0] = reduce[0] + t0[0]
t0[0] = T.tvm_warp_shuffle_down(0, reduce[0], T.uint32(4), 32, 32)
reduce[0] = reduce[0] + t0[0]
t0[0] = T.tvm_warp_shuffle_down(0, reduce[0], T.uint32(2), 32, 32)
reduce[0] = reduce[0] + t0[0]
t0[0] = T.tvm_warp_shuffle_down(0, reduce[0], T.uint32(1), 32, 32)
reduce[0] = reduce[0] + t0[0]
reduce[0] = T.tvm_warp_shuffle(t0[0], reduce[0], T.uint32(0), 32, 32)
if threadIdx_x == 0:
B[i] = reduce[0]

After = transform(Before)
tvm.ir.assert_structural_equal(After, Expected)


def test_webgpu_multi_warp_reduce():
transform = tvm.s_tir.transform.LowerThreadAllreduce()

@I.ir_module
class Before:
@T.prim_func(private=True)
def main(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32")):
T.func_attr(
{
"target": T.target(
{
"kind": "webgpu",
"max_num_threads": 1024,
"supports_subgroups": True,
"host": "llvm",
}
),
}
)
blockIdx_x = T.launch_thread("blockIdx.x", 1)
cross_thread_B = T.allocate([1], "float32", "local")
threadIdx_z = T.launch_thread("threadIdx.z", 1)
threadIdx_y = T.launch_thread("threadIdx.y", 2)
threadIdx_x = T.launch_thread("threadIdx.x", 128)
cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
A_1 = T.Buffer((256,), data=A.data)
T.tvm_thread_allreduce(
T.uint32(1),
A_1[threadIdx_y * 128 + threadIdx_x],
T.bool(True),
cross_thread_B_1[0],
threadIdx_x,
)
if threadIdx_x == 0:
B_1 = T.Buffer((2,), data=B.data)
B_1[threadIdx_y] = cross_thread_B_1[0]

@I.ir_module
class Expected:
@T.prim_func(private=True)
def main(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32")):
T.func_attr(
{
"target": T.target(
{
"kind": "webgpu",
"max_num_threads": 1024,
"supports_subgroups": True,
"host": "llvm",
}
),
}
)
blockIdx_x = T.launch_thread("blockIdx.x", 1)
red_result = T.allocate([2], "float32", "shared")
T.attr(red_result, "volatile_scope", 1)
threadIdx_z = T.launch_thread("threadIdx.z", 1)
threadIdx_y = T.launch_thread("threadIdx.y", 2)
threadIdx_x = T.launch_thread("threadIdx.x", 128)
red_result_1 = T.Buffer((2,), data=red_result, scope="shared")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
red_buf0 = T.decl_buffer([1], "float32", scope="local")
t0 = T.decl_buffer([1], "float32", scope="local")
red_buf0_1 = T.decl_buffer([1], "float32", scope="local")
t0_1 = T.decl_buffer([1], "float32", scope="local")
red_buf_staging = T.decl_buffer([8], "float32", scope="shared")
A_1 = T.Buffer((256,), data=A.data)
red_buf0_1[0] = A_1[threadIdx_y * 128 + threadIdx_x]
t0_1[0] = T.tvm_warp_shuffle_down(0, red_buf0_1[0], T.uint32(16), 32, 32)
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
t0_1[0] = T.tvm_warp_shuffle_down(0, red_buf0_1[0], T.uint32(8), 32, 32)
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
t0_1[0] = T.tvm_warp_shuffle_down(0, red_buf0_1[0], T.uint32(4), 32, 32)
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
t0_1[0] = T.tvm_warp_shuffle_down(0, red_buf0_1[0], T.uint32(2), 32, 32)
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
t0_1[0] = T.tvm_warp_shuffle_down(0, red_buf0_1[0], T.uint32(1), 32, 32)
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
if threadIdx_x % 32 == 0:
red_buf_staging[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_1[0]
T.tvm_storage_sync("shared")
if threadIdx_x < 4:
red_buf0[0] = red_buf_staging[threadIdx_y * 4 + threadIdx_x]
t0[0] = T.tvm_warp_shuffle_down(0, red_buf0[0], T.uint32(2), 32, 32)
red_buf0[0] = red_buf0[0] + t0[0]
t0[0] = T.tvm_warp_shuffle_down(0, red_buf0[0], T.uint32(1), 32, 32)
red_buf0[0] = red_buf0[0] + t0[0]
if threadIdx_x == 0:
red_result_1[threadIdx_y] = red_buf0[0]
T.tvm_storage_sync("shared")
if threadIdx_x == 0:
B_1 = T.Buffer((2,), data=B.data)
B_1[threadIdx_y] = red_result_1[threadIdx_y]

After = transform(Before)
tvm.ir.assert_structural_equal(After, Expected)


if __name__ == "__main__":
tvm.testing.main()
13 changes: 13 additions & 0 deletions tests/python/target/test_target_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,5 +426,18 @@ def test_cli_string_rejected():
Target("llvm -mcpu=cortex-a53")


def test_webgpu_target_subgroup_attrs():
"""Test WebGPU target defaults and supports_subgroups canonicalization."""
# Default: thread_warp_size=1, supports_subgroups=False
tgt_default = Target({"kind": "webgpu"})
assert tgt_default.attrs["thread_warp_size"] == 1
assert tgt_default.attrs["supports_subgroups"] == 0

# With supports_subgroups=True: thread_warp_size is set to 32
tgt_subgroups = Target({"kind": "webgpu", "supports_subgroups": True})
assert tgt_subgroups.attrs["thread_warp_size"] == 32
assert tgt_subgroups.attrs["supports_subgroups"] == 1


if __name__ == "__main__":
tvm.testing.main()
2 changes: 2 additions & 0 deletions web/emcc/webgpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
#include <iostream>
#include <string>

#include "3rdparty/tvm-ffi/src/ffi/extra/json_parser.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/json_writer.cc"
#include "../../src/runtime/file_utils.h"
#include "../../src/runtime/metadata.h"
#include "../../src/runtime/workspace_pool.h"
Expand Down
3 changes: 3 additions & 0 deletions web/src/webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ export async function detectGPUDevice(powerPreference: "low-power" | "high-perfo
if (adapter.features.has("shader-f16")) {
requiredFeatures.push("shader-f16");
}
if (adapter.features.has("subgroups")) {
requiredFeatures.push("subgroups");
}
// requestAdapterInfo() is deprecated, causing requestAdapterInfo to raise
// issue when building. However, it is still needed for older browsers, hence `as any`.
const adapterInfo = adapter.info || await (adapter as any).requestAdapterInfo();
Expand Down