@@ -2310,42 +2310,59 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx,
23102310 uint32_t ne = (uint32_t ) ggml_nelements (dst);
23112311 uint32_t dim = (uint32_t ) dst->op_params [0 ];
23122312
2313- std::vector<uint32_t > params = {
2314- ne,
2315- (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, src0) / ggml_type_size (src0->type )),
2316- (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, src1) / ggml_type_size (src1->type )),
2317- (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, dst) / ggml_type_size (dst->type )),
2318- (uint32_t ) (src0->nb [0 ] / ggml_type_size (src0->type )),
2319- (uint32_t ) (src0->nb [1 ] / ggml_type_size (src0->type )),
2320- (uint32_t ) (src0->nb [2 ] / ggml_type_size (src0->type )),
2321- (uint32_t ) (src0->nb [3 ] / ggml_type_size (src0->type )),
2322- (uint32_t ) (src1->nb [0 ] / ggml_type_size (src1->type )),
2323- (uint32_t ) (src1->nb [1 ] / ggml_type_size (src1->type )),
2324- (uint32_t ) (src1->nb [2 ] / ggml_type_size (src1->type )),
2325- (uint32_t ) (src1->nb [3 ] / ggml_type_size (src1->type )),
2326- (uint32_t ) dst->ne [0 ],
2327- (uint32_t ) dst->ne [1 ],
2328- (uint32_t ) dst->ne [2 ],
2329- (uint32_t ) dst->ne [3 ],
2330- dim,
2331- (uint32_t ) src0->ne [dim]
2332- };
2333-
2334- std::vector<wgpu::BindGroupEntry> entries = {
2335- ggml_webgpu_make_tensor_bind_group_entry (ctx, 0 , src0),
2336- ggml_webgpu_make_tensor_bind_group_entry (ctx, 1 , src1),
2337- ggml_webgpu_make_tensor_bind_group_entry (ctx, 2 , dst),
2338- };
2339-
23402313 ggml_webgpu_shader_lib_context shader_lib_ctx = {};
23412314 shader_lib_ctx.src0 = src0;
23422315 shader_lib_ctx.src1 = src1;
23432316 shader_lib_ctx.dst = dst;
23442317 shader_lib_ctx.max_wg_size = ctx->global_ctx ->capabilities .limits .maxComputeInvocationsPerWorkgroup ;
23452318
23462319 webgpu_pipeline pipeline = ctx->shader_lib ->get_concat_pipeline (shader_lib_ctx);
2347- auto * decisions = static_cast <ggml_webgpu_generic_shader_decisions *>(pipeline.context .get ());
2348- uint32_t wg_x = CEIL_DIV (ne, decisions->wg_size );
2320+ auto * decisions = static_cast <ggml_webgpu_binary_shader_decisions *>(pipeline.context .get ());
2321+
2322+ uint32_t offset_src0 = (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, src0) / ggml_type_size (src0->type ));
2323+ uint32_t offset_src1 = (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, src1) / ggml_type_size (src1->type ));
2324+ size_t merged_offset = 0 ;
2325+ size_t merged_size = 0 ;
2326+ if (decisions->src_overlap ) {
2327+ const ggml_webgpu_merged_binding_range merged_range =
2328+ ggml_webgpu_tensor_merged_binding_range (ctx, { src0, src1 });
2329+ merged_offset = merged_range.offset ;
2330+ merged_size = merged_range.size ;
2331+ offset_src0 = ggml_webgpu_tensor_merged_element_offset (src0, merged_range);
2332+ offset_src1 = ggml_webgpu_tensor_merged_element_offset (src1, merged_range);
2333+ }
2334+
2335+ std::vector<uint32_t > params = { ne,
2336+ offset_src0,
2337+ offset_src1,
2338+ (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, dst) / ggml_type_size (dst->type )),
2339+ (uint32_t ) (src0->nb [0 ] / ggml_type_size (src0->type )),
2340+ (uint32_t ) (src0->nb [1 ] / ggml_type_size (src0->type )),
2341+ (uint32_t ) (src0->nb [2 ] / ggml_type_size (src0->type )),
2342+ (uint32_t ) (src0->nb [3 ] / ggml_type_size (src0->type )),
2343+ (uint32_t ) (src1->nb [0 ] / ggml_type_size (src1->type )),
2344+ (uint32_t ) (src1->nb [1 ] / ggml_type_size (src1->type )),
2345+ (uint32_t ) (src1->nb [2 ] / ggml_type_size (src1->type )),
2346+ (uint32_t ) (src1->nb [3 ] / ggml_type_size (src1->type )),
2347+ (uint32_t ) dst->ne [0 ],
2348+ (uint32_t ) dst->ne [1 ],
2349+ (uint32_t ) dst->ne [2 ],
2350+ (uint32_t ) dst->ne [3 ],
2351+ dim,
2352+ (uint32_t ) src0->ne [dim] };
2353+
2354+ std::vector<wgpu::BindGroupEntry> entries = {};
2355+ if (decisions->src_overlap ) {
2356+ entries.push_back (
2357+ ggml_webgpu_make_bind_group_entry (0 , ggml_webgpu_tensor_buf (src0), merged_offset, merged_size));
2358+ entries.push_back (ggml_webgpu_make_tensor_bind_group_entry (ctx, 1 , dst));
2359+ } else {
2360+ entries.push_back (ggml_webgpu_make_tensor_bind_group_entry (ctx, 0 , src0));
2361+ entries.push_back (ggml_webgpu_make_tensor_bind_group_entry (ctx, 1 , src1));
2362+ entries.push_back (ggml_webgpu_make_tensor_bind_group_entry (ctx, 2 , dst));
2363+ }
2364+
2365+ uint32_t wg_x = CEIL_DIV (ne, decisions->wg_size );
23492366 return ggml_backend_webgpu_build (ctx, pipeline, params, entries, wg_x);
23502367}
23512368
0 commit comments