Skip to content

Commit 325ee30

Browse files
authored
[js/webgpu] Reland the optimization of ConvTranspose (#23858)
This PR fixes the errors in the ConvTranspose optimization and adds tests to ensure the correctness of the implementation.
1 parent 1872527 commit 325ee30

File tree

2 files changed

+199
-19
lines changed

2 files changed

+199
-19
lines changed

js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts

+77-19
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ export const createConvTranspose2DProgramInfo = (
4646
const inputChannelsPerGroup = wShape[2] / group;
4747
const outputChannelsPerGroup = wShape[3];
4848
const aComponents = isChannelsLast ? getMaxComponents(inputChannelsPerGroup) : 1;
49+
const packInputAs4 = isChannelsLast && outputChannelsPerGroup === 1 && inputChannelsPerGroup >= 4;
50+
const inputChannelsPerGroupInt = packInputAs4
51+
? Math.floor(inputChannelsPerGroup / 4) * 4
52+
: Math.floor(inputChannelsPerGroup / aComponents) * aComponents;
53+
const inputChannelsRemainder = inputChannelsPerGroup - inputChannelsPerGroupInt;
4954
const components = isChannelsLast ? getMaxComponents(outputChannelsPerGroup) : 1;
5055
const bComponents = isChannelsLast ? (outputChannelsPerGroup === 1 ? aComponents : components) : 1;
5156
const outputSize = ShapeUtil.size(outputShape) / components;
@@ -78,6 +83,7 @@ export const createConvTranspose2DProgramInfo = (
7883
{ type: DataType.uint32, data: dilations },
7984
{ type: DataType.uint32, data: effectiveFilterDims },
8085
{ type: DataType.int32, data: pads },
86+
{ type: DataType.uint32, data: inputChannelsPerGroupInt },
8187
{ type: DataType.uint32, data: inputChannelsPerGroup },
8288
{ type: DataType.uint32, data: outputChannelsPerGroup },
8389
...createTensorShapeVariables(inputs[0].dims, inputs[1].dims),
@@ -96,6 +102,7 @@ export const createConvTranspose2DProgramInfo = (
96102
{ name: 'dilations', type: 'u32', length: filterDims.length },
97103
{ name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length },
98104
{ name: 'pads', type: 'i32', length: pads.length },
105+
{ name: 'input_channels_per_group_int', type: 'u32' },
99106
{ name: 'input_channels_per_group', type: 'u32' },
100107
{ name: 'output_channels_per_group', type: 'u32' },
101108
];
@@ -114,16 +121,40 @@ export const createConvTranspose2DProgramInfo = (
114121

115122
const calculateResult = (): string => {
116123
let calcStr = '';
117-
if (aComponents === 1) {
118-
calcStr += `
119-
let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)};
120-
let wValue = ${w.getByOffset(`w_offset / ${bComponents}`)};
121-
dotProd = dotProd + xValue * wValue;`;
124+
if (packInputAs4) {
125+
if (aComponents === 4) {
126+
calcStr += `
127+
let xValue = ${dy.getByOffset('x_offset')};
128+
let wValue = ${w.getByOffset('w_offset')};
129+
dotProd = dotProd + dot(xValue, wValue);
130+
x_offset += 1u;
131+
w_offset += 1u;`;
132+
} else if (aComponents === 2) {
133+
calcStr += `
134+
dotProd = dotProd + dot(vec4<${dataType}>(${dy.getByOffset('x_offset')}, ${dy.getByOffset('x_offset + 1u')}), vec4<${dataType}>(${w.getByOffset('w_offset')}, ${w.getByOffset('w_offset + 1u')}));
135+
x_offset += 2u;
136+
w_offset += 2u;`;
137+
} else if (aComponents === 1) {
138+
calcStr += `
139+
dotProd = dotProd + dot(vec4<${dataType}>(${dy.getByOffset('x_offset')}, ${dy.getByOffset('x_offset + 1u')}, ${dy.getByOffset('x_offset + 2u')}, ${dy.getByOffset('x_offset + 3u')}), vec4<${dataType}>(${w.getByOffset('w_offset')}, ${w.getByOffset('w_offset + 1u')}, ${w.getByOffset('w_offset + 2u')}, ${w.getByOffset('w_offset + 3u')}));
140+
x_offset += 4u;
141+
w_offset += 4u;`;
142+
}
122143
} else {
123-
if (outputChannelsPerGroup === 1) {
144+
calcStr += `
145+
let xValue = ${
146+
isChannelsLast
147+
? dy.getByOffset(
148+
`${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents}`,
149+
)
150+
: dy.get('batch', 'inputChannel', 'idyR', 'idyC')
151+
};
152+
`;
153+
if (aComponents === 1) {
124154
calcStr += `
125-
let wValue = ${w.getByOffset(`${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)} / ${bComponents}`)};
126-
dotProd = dotProd + dot(xValue, wValue);`;
155+
let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)};
156+
let wValue = ${w.getByOffset(`w_offset / ${bComponents}`)};
157+
dotProd = dotProd + xValue * wValue;`;
127158
} else {
128159
for (let c = 0; c < aComponents; c++) {
129160
calcStr += `
@@ -134,6 +165,32 @@ export const createConvTranspose2DProgramInfo = (
134165
}
135166
return calcStr;
136167
};
168+
const calculateRemainder = (): string => {
169+
if (inputChannelsRemainder === 0) {
170+
return '';
171+
}
172+
if (!packInputAs4) {
173+
throw new Error(`packInputAs4 ${packInputAs4} is not true.`);
174+
}
175+
let calcStr = '';
176+
if (aComponents === 1) {
177+
calcStr += 'dotProd = dotProd';
178+
for (let i = 0; i < inputChannelsRemainder; i++) {
179+
calcStr += `
180+
+ ${dy.getByOffset(`x_offset + ${i}`)} * ${w.getByOffset(`w_offset + ${i}`)}`;
181+
}
182+
calcStr += ';';
183+
} else if (aComponents === 2) {
184+
if (inputChannelsRemainder !== 2) {
185+
throw new Error(`Invalid inputChannelsRemainder ${inputChannelsRemainder}.`);
186+
}
187+
calcStr += `
188+
let xValue = ${dy.getByOffset('x_offset')};
189+
let wValue = ${w.getByOffset('w_offset')};
190+
dotProd = dotProd + dot(xValue, wValue);`;
191+
}
192+
return calcStr;
193+
};
137194
const codeSnippet = `
138195
let outputIndices = ${output.offsetToIndices(`global_idx * ${components}`)};
139196
let batch = ${output.indicesGet('outputIndices', 0)};
@@ -169,7 +226,6 @@ export const createConvTranspose2DProgramInfo = (
169226
// Minimum wC >= 0 that satisfies (dyCCorner + wC) % (uniforms.strides.y) == 0
170227
wC = u32(((dyCCorner + i32(uniforms.strides.y) - 1) / i32(uniforms.strides.y)) * i32(uniforms.strides.y) - dyCCorner);
171228
}
172-
173229
for (; wC < uniforms.effective_filter_dims.y; wC = wC + 1) {
174230
if (wC % uniforms.dilations.y != 0) {
175231
continue;
@@ -182,17 +238,19 @@ export const createConvTranspose2DProgramInfo = (
182238
}
183239
let idyC: u32 = u32(dyC);
184240
var inputChannel = groupId * uniforms.input_channels_per_group;
185-
for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + ${aComponents}) {
186-
let xValue = ${
187-
isChannelsLast
188-
? dy.getByOffset(
189-
`${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents}`,
190-
)
191-
: dy.get('batch', 'inputChannel', 'idyR', 'idyC')
192-
};
241+
${
242+
packInputAs4
243+
? `
244+
var x_offset = ${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents};
245+
var w_offset = ${w.indicesToOffset(`${w.type.indices}(wRPerm, wCPerm, inputChannel, wOutChannel)`)} / ${bComponents};
246+
`
247+
: ''
248+
}
249+
for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group_int; d2 = d2 + ${packInputAs4 ? 4 : aComponents}) {
193250
${calculateResult()}
194-
inputChannel = inputChannel + ${aComponents};
251+
inputChannel = inputChannel + ${packInputAs4 ? 4 : aComponents};
195252
}
253+
${calculateRemainder()}
196254
wC = wC + uniforms.strides.y - 1;
197255
}
198256
wR = wR + uniforms.strides[0] - 1;
@@ -211,7 +269,7 @@ export const createConvTranspose2DProgramInfo = (
211269
return {
212270
name: 'ConvTranspose2D',
213271
shaderCache: {
214-
hint: `${attributes.cacheKey};${aComponents}${bComponents}${components}${outputChannelsPerGroup === 1}`,
272+
hint: `${attributes.cacheKey};${aComponents}${bComponents}${components}${packInputAs4}${inputChannelsRemainder}`,
215273
inputDependencies,
216274
},
217275
getRunData: () => ({

js/web/test/data/ops/conv-transpose.jsonc

+122
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,128 @@
348348
}
349349
]
350350
},
351+
{
352+
"name": "ConvTranspose NHWC- group - A",
353+
"operator": "ConvTranspose",
354+
"inputShapeDefinitions": "rankOnly",
355+
"opset": { "domain": "", "version": 17 },
356+
"attributes": [
357+
{ "name": "kernel_shape", "data": [1, 1], "type": "ints" },
358+
{ "name": "group", "data": 2, "type": "int" }
359+
],
360+
"cases": [
361+
{
362+
"name": "T[0]",
363+
"inputs": [
364+
{
365+
"data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0, 32.0, 34.0],
366+
"dims": [1, 2, 3, 3],
367+
"type": "float32"
368+
},
369+
{
370+
"data": [1.0, 2.0],
371+
"dims": [2, 1, 1, 1],
372+
"type": "float32"
373+
}
374+
],
375+
"outputs": [
376+
{
377+
"data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 36, 40, 44, 48, 52, 56, 60, 64, 68],
378+
"dims": [1, 2, 3, 3],
379+
"type": "float32"
380+
}
381+
]
382+
}
383+
]
384+
},
385+
{
386+
"name": "ConvTranspose NHWC- group - B",
387+
"operator": "ConvTranspose",
388+
"inputShapeDefinitions": "rankOnly",
389+
"opset": { "domain": "", "version": 17 },
390+
"attributes": [
391+
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
392+
{ "name": "group", "data": 3, "type": "int" }
393+
],
394+
"cases": [
395+
{
396+
"name": "T[0]",
397+
"inputs": [
398+
{
399+
"data": [
400+
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
401+
19.0, 20.0, 21.0, 22.0, 23.0, 0, 0, 0
402+
],
403+
"dims": [1, 3, 3, 3],
404+
"type": "float32"
405+
},
406+
{
407+
"data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
408+
"dims": [3, 1, 2, 2],
409+
"type": "float32"
410+
},
411+
{
412+
"data": [0.125, 0.25, 0.375],
413+
"dims": [3],
414+
"type": "float32"
415+
}
416+
],
417+
"outputs": [
418+
{
419+
"data": [
420+
0.125, 1.125, 4.125, 4.125, 3.125, 13.125, 23.125, 18.125, 15.125, 43.125, 53.125, 36.125, 18.125, 45.125,
421+
52.125, 32.125, 45.25, 104.25, 115.25, 66.25, 123.25, 279.25, 305.25, 172.25, 159.25, 357.25, 383.25,
422+
214.25, 105.25, 232.25, 247.25, 136.25, 162.375, 351.375, 370.375, 200.375, 387.375, 833.375, 875.375,
423+
470.375, 231.375, 494.375, 517.375, 276.375, 0.375, 0.375, 0.375, 0.375
424+
],
425+
"dims": [1, 3, 4, 4],
426+
"type": "float32"
427+
}
428+
]
429+
}
430+
]
431+
},
432+
{
433+
"name": "ConvTranspose NHWC- group - C",
434+
"operator": "ConvTranspose",
435+
"inputShapeDefinitions": "rankOnly",
436+
"opset": { "domain": "", "version": 17 },
437+
"attributes": [
438+
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
439+
{ "name": "group", "data": 3, "type": "int" }
440+
],
441+
"cases": [
442+
{
443+
"name": "T[0]",
444+
"inputs": [
445+
{
446+
"data": [
447+
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
448+
19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0
449+
],
450+
"dims": [1, 3, 3, 4],
451+
"type": "float32"
452+
},
453+
{
454+
"data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
455+
"dims": [3, 1, 2, 2],
456+
"type": "float32"
457+
}
458+
],
459+
"outputs": [
460+
{
461+
"data": [
462+
0, 1, 4, 7, 6, 4, 16, 26, 36, 26, 20, 56, 66, 76, 50, 24, 59, 66, 73, 44, 60, 137, 148, 159, 90, 164, 368,
463+
394, 420, 234, 212, 472, 498, 524, 290, 140, 307, 322, 337, 184, 216, 465, 484, 503, 270, 516, 1104, 1146,
464+
1188, 634, 596, 1272, 1314, 1356, 722, 352, 747, 770, 793, 420
465+
],
466+
"dims": [1, 3, 4, 5],
467+
"type": "float32"
468+
}
469+
]
470+
}
471+
]
472+
},
351473
{
352474
"name": "ConvTranspose with bias addition C",
353475
"operator": "ConvTranspose",

0 commit comments

Comments
 (0)