Skip to content

Commit 0d53e11

Browse files
Copilotguschmue
andcommitted
Add bias validation to C++ WebGPU ConvTranspose implementation
Co-authored-by: guschmue <22941064+guschmue@users.noreply.github.com>
1 parent 583f08b commit 0d53e11

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

onnxruntime/core/providers/webgpu/nn/conv_transpose.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,19 @@ Status ConvTranspose<is_channels_last>::ComputeInternal(ComputeContext& context)
5757

5858
bool has_bias = context.InputCount() > 2;
5959
const auto* bias = has_bias ? context.Input<Tensor>(2) : nullptr;
60+
61+
// Validate bias shape if provided
62+
if (has_bias) {
63+
const auto& bias_shape = bias->Shape();
64+
if (bias_shape.NumDimensions() != 1) {
65+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid bias: bias must be 1D tensor");
66+
}
67+
if (bias_shape[0] != num_output_channels) {
68+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid bias: bias size (", bias_shape[0],
69+
") must be equal to output channels (", num_output_channels, ")");
70+
}
71+
}
72+
6073
if (input_shape.NumDimensions() == 3 && filter_shape.NumDimensions() == 3) {
6174
// ConvTranspose1D
6275
TensorShapeVector input_shape_vector = input_shape.AsShapeVector();

0 commit comments

Comments
 (0)