Skip to content

Commit 4c850b7

Browse files
Copilotguschmue
andcommitted
Merge bias validation conditions and simplify error message
Co-authored-by: guschmue <22941064+guschmue@users.noreply.github.com>
1 parent 7c55b81 commit 4c850b7

File tree

2 files changed

+8
-19
lines changed

2 files changed

+8
-19
lines changed

js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -178,17 +178,11 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvTranspose
178178
throw new Error('FILTER_IN_CHANNEL should be equal to DATA_CHANNEL');
179179
}
180180

181+
const featureMaps = inputs[1].dims[1] * attributes.group;
182+
181183
// if bias is provided it should be 1D and the number of elements should be equal to the number of feature maps
182-
if (inputs.length === 3) {
183-
if (inputs[2].dims.length !== 1) {
184-
throw new Error('invalid bias: bias must be 1D tensor');
185-
}
186-
const featureMaps = inputs[1].dims[1] * attributes.group;
187-
if (inputs[2].dims[0] !== featureMaps) {
188-
throw new Error(
189-
`invalid bias: bias size (${inputs[2].dims[0]}) must be equal to output channels (${featureMaps})`,
190-
);
191-
}
184+
if (inputs.length === 3 && (inputs[2].dims.length !== 1 || inputs[2].dims[0] !== featureMaps)) {
185+
throw new Error('invalid bias');
192186
}
193187

194188
const spatialRank = inputs[0].dims.length - 2;

onnxruntime/core/providers/webgpu/nn/conv_transpose.cc

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,12 @@ Status ConvTranspose<is_channels_last>::ComputeInternal(ComputeContext& context)
5757

5858
bool has_bias = context.InputCount() > 2;
5959
const auto* bias = has_bias ? context.Input<Tensor>(2) : nullptr;
60+
6061
// Validate bias shape if provided
61-
if (has_bias) {
62-
const auto& bias_shape = bias->Shape();
63-
if (bias_shape.NumDimensions() != 1) {
64-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid bias: bias must be 1D tensor");
65-
}
66-
if (bias_shape[0] != num_output_channels) {
67-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid bias: bias size (", bias_shape[0],
68-
") must be equal to output channels (", num_output_channels, ")");
69-
}
62+
if (has_bias && (bias->Shape().NumDimensions() != 1 || bias->Shape()[0] != num_output_channels)) {
63+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid bias");
7064
}
65+
7166
if (input_shape.NumDimensions() == 3 && filter_shape.NumDimensions() == 3) {
7267
// ConvTranspose1D
7368
TensorShapeVector input_shape_vector = input_shape.AsShapeVector();

0 commit comments

Comments
 (0)