Skip to content

Commit 1705d43

Browse files
nikhilJain17claudereeselevine
authored
[ggml-webgpu] Handle buffer overlap / buffer aliasing for concat operator (#24000)
* Only run webgpu CI on my fork * Add webgpu only workflow * handle buffer overlap case for concat operator * restore build-webgpu.yml Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Run clang-format * Update ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: Reese Levine <reeselevine1@gmail.com>
1 parent 3b3da01 commit 1705d43

3 files changed

Lines changed: 79 additions & 33 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -448,15 +448,19 @@ struct ggml_webgpu_upscale_pipeline_key_hash {
448448
/** Concat **/
449449

450450
struct ggml_webgpu_concat_pipeline_key {
451-
int type;
451+
int type;
452+
bool src_overlap;
452453

453-
bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; }
454+
bool operator==(const ggml_webgpu_concat_pipeline_key & other) const {
455+
return type == other.type && src_overlap == other.src_overlap;
456+
}
454457
};
455458

456459
struct ggml_webgpu_concat_pipeline_key_hash {
457460
size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const {
458461
size_t seed = 0;
459462
ggml_webgpu_hash_combine(seed, key.type);
463+
ggml_webgpu_hash_combine(seed, key.src_overlap);
460464
return seed;
461465
}
462466
};
@@ -2634,6 +2638,7 @@ class ggml_webgpu_shader_lib {
26342638
webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
26352639
ggml_webgpu_concat_pipeline_key key = {};
26362640
key.type = context.dst->type;
2641+
key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
26372642

26382643
auto it = concat_pipelines.find(key);
26392644
if (it != concat_pipelines.end()) {
@@ -2656,11 +2661,17 @@ class ggml_webgpu_shader_lib {
26562661
GGML_ABORT("Unsupported type for concat shader");
26572662
}
26582663

2664+
if (key.src_overlap) {
2665+
defines.push_back("SRC_OVERLAP");
2666+
variant += "_src_overlap";
2667+
}
2668+
26592669
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
26602670

26612671
auto processed = preprocessor.preprocess(wgsl_concat, defines);
2662-
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
2672+
auto decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>();
26632673
decisions->wg_size = context.max_wg_size;
2674+
decisions->src_overlap = key.src_overlap;
26642675
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
26652676
pipeline.context = decisions;
26662677
concat_pipelines[key] = pipeline;

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 46 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2310,42 +2310,59 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx,
23102310
uint32_t ne = (uint32_t) ggml_nelements(dst);
23112311
uint32_t dim = (uint32_t) dst->op_params[0];
23122312

2313-
std::vector<uint32_t> params = {
2314-
ne,
2315-
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
2316-
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
2317-
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
2318-
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
2319-
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
2320-
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
2321-
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
2322-
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
2323-
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
2324-
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
2325-
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
2326-
(uint32_t) dst->ne[0],
2327-
(uint32_t) dst->ne[1],
2328-
(uint32_t) dst->ne[2],
2329-
(uint32_t) dst->ne[3],
2330-
dim,
2331-
(uint32_t) src0->ne[dim]
2332-
};
2333-
2334-
std::vector<wgpu::BindGroupEntry> entries = {
2335-
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0),
2336-
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1),
2337-
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst),
2338-
};
2339-
23402313
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
23412314
shader_lib_ctx.src0 = src0;
23422315
shader_lib_ctx.src1 = src1;
23432316
shader_lib_ctx.dst = dst;
23442317
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
23452318

23462319
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
2347-
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
2348-
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
2320+
auto * decisions = static_cast<ggml_webgpu_binary_shader_decisions *>(pipeline.context.get());
2321+
2322+
uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type));
2323+
uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type));
2324+
size_t merged_offset = 0;
2325+
size_t merged_size = 0;
2326+
if (decisions->src_overlap) {
2327+
const ggml_webgpu_merged_binding_range merged_range =
2328+
ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 });
2329+
merged_offset = merged_range.offset;
2330+
merged_size = merged_range.size;
2331+
offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range);
2332+
offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range);
2333+
}
2334+
2335+
std::vector<uint32_t> params = { ne,
2336+
offset_src0,
2337+
offset_src1,
2338+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
2339+
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
2340+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
2341+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
2342+
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
2343+
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
2344+
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
2345+
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
2346+
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
2347+
(uint32_t) dst->ne[0],
2348+
(uint32_t) dst->ne[1],
2349+
(uint32_t) dst->ne[2],
2350+
(uint32_t) dst->ne[3],
2351+
dim,
2352+
(uint32_t) src0->ne[dim] };
2353+
2354+
std::vector<wgpu::BindGroupEntry> entries = {};
2355+
if (decisions->src_overlap) {
2356+
entries.push_back(
2357+
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size));
2358+
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
2359+
} else {
2360+
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0));
2361+
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1));
2362+
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
2363+
}
2364+
2365+
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
23492366
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
23502367
}
23512368

ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,16 @@ struct Params {
3131
#define DataType i32
3232
#endif
3333

34+
#ifdef SRC_OVERLAP
35+
@group(0) @binding(0)
36+
var<storage, read_write> merged_src: array<DataType>;
37+
38+
@group(0) @binding(1)
39+
var<storage, read_write> dst: array<DataType>;
40+
41+
@group(0) @binding(2)
42+
var<uniform> params: Params;
43+
#else
3444
@group(0) @binding(0)
3545
var<storage, read_write> src0: array<DataType>;
3646

@@ -42,7 +52,7 @@ var<storage, read_write> dst: array<DataType>;
4252

4353
@group(0) @binding(3)
4454
var<uniform> params: Params;
45-
55+
#endif
4656
@compute @workgroup_size(WG_SIZE)
4757
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
4858

@@ -62,14 +72,22 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
6272
ni[1] * params.stride_src0_1 +
6373
ni[2] * params.stride_src0_2 +
6474
ni[3] * params.stride_src0_3;
75+
#ifdef SRC_OVERLAP
76+
dst[params.offset_dst + gid.x] = merged_src[params.offset_src0 + src_i];
77+
#else
6578
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i];
79+
#endif
6680
} else {
6781
ni[params.dim] -= params.src0_nedim;
6882
let src_i = ni[0] * params.stride_src1_0 +
6983
ni[1] * params.stride_src1_1 +
7084
ni[2] * params.stride_src1_2 +
7185
ni[3] * params.stride_src1_3;
86+
#ifdef SRC_OVERLAP
87+
dst[params.offset_dst + gid.x] = merged_src[params.offset_src1 + src_i];
88+
#else
7289
dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i];
90+
#endif
7391
}
7492
}
7593
}

0 commit comments

Comments
 (0)