Skip to content

[js/webgpu] ConvTranspose1D slower on Webgpu than Wasm #23273

@gianlourbano

Description

@gianlourbano

Describe the issue

ConvTranpose1D with input shapes [8, 4098, 435], weights [4096, 1, 4098] strides 1024 and padding 0 appears to be slower on WebGPU than Wasm, with timings:

EP timing (m1 macbook pro)
wasm 6s
webgpu (latest chrome) 30s
webgpu (canary chrome) 18s

canary faster due to this bug

To reproduce

Simple torch script to generate the conv and convert it to onnx

import torch

class ConvTest (torch.nn.Module):
    def __init__(self, weight, stride, padding = 0):
        super(ConvTest, self).__init__()
        self.weight = weight
        self.stride = stride
        self.padding = padding
    
    def forward(self, x):
        return torch.nn.functional.conv_transpose1d(x, self.weight, stride=self.stride, padding=self.padding)

convtest = ConvTest(weight = torch.randn(4098, 1, 4096), stride = 1024)

input = torch.randn(8, 4098,  435)

torch.onnx.export(
    convtest,
    (input,),
    "convtest.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=20,
    dynamo=True,
    do_constant_folding=True,
    keep_initializers_as_inputs=True,
    # report=True,
    external_data=None,
    # verify=True
)

To test in browser:

       const session = await ort.InferenceSession.create("/convtest.onnx", {
            executionProviders: ["webgpu"],
            // logSeverityLevel: 0
        });

        const wgpu_profile = []

        ort.env.webgpu.profiling = {
            mode: "default",
            ondata: (data) => {
                wgpu_profile.push(data);
            }
        }

        const input_dims = [8, 4098, 435];
        const size = 8 * 4098 * 435;

        const no_chunks = 1;
        const chunks = [];

        for (let i = 0; i < no_chunks; i++) {
            const chunk = new Float32Array(size);
            chunks.push(chunk);
        }

        for(let i = 0; i < no_chunks; i++) {
            console.time("onnx step " + i);
            const input = new ort.Tensor("float32", chunks[i], input_dims);
            const output = await session.run({input});
            console.timeEnd("onnx step " + i);
        }

        await session.release();

        wgpu_profile.sort((a, b) => (a.endTime-a.startTime) - (b.endTime-b.startTime));

        wgpu_profile.forEach((kernel) => {
            console.log(`${kernel.kernelType} (${kernel.kernelName}) took ${(kernel.endTime - kernel.startTime) / 1000 / 1000} ms`);
        })

Urgency

Urgent

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.21.0-dev.20241224-2d05c4bcd9

Execution Provider

'webgpu' (WebGPU), 'wasm'/'cpu' (WebAssembly CPU)

Metadata

Metadata

Assignees

No one assigned

    Labels

    ep:WebGPUort-web webgpu providerplatform:webissues related to ONNX Runtime web; typically submitted using templatestaleissues that have not been addressed in a while; categorized by a bot

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions