Skip to content

Commit ccd416e

Browse files
committed
Skip out-of-range broadcast axes instead of erroring
Broadcasting a dimension past the input's rank is a no-op, so match broadcast_like and skip assigning the bitset when the converted slinky dim falls outside the input, rather than validating and returning an error. This also avoids the out-of-range axes_set index.
1 parent fe9e001 commit ccd416e

2 files changed

Lines changed: 12 additions & 6 deletions

File tree

ynnpack/subgraph/broadcast.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,12 @@ ynn_status ynn_define_broadcast(ynn_subgraph_t subgraph, size_t num_axes,
3232

3333
ynn::axes_set axes_set;
3434
for (size_t i = 0; i < num_axes; ++i) {
35-
YNN_RETURN_IF_ERROR(
36-
validate_axis("broadcast", "input", input.rank(), axes[i]));
37-
axes_set[axis_to_slinky_dim(input.rank(), axes[i])] = true;
35+
const int axis = axis_to_slinky_dim(input.rank(), axes[i]);
36+
if (axis < input.rank()) {
37+
// Dimensions past the input's rank are implicit broadcasts; broadcasting
38+
// a broadcast dimension is a no-op, so skip it.
39+
axes_set[axis] = true;
40+
}
3841
}
3942

4043
ynn_node node;

ynnpack/subgraph/test/errors.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,16 @@ TEST(Errors, broadcast_axis_out_of_bounds) {
5454
SubgraphBuilder subgraph(1);
5555
subgraph.AddInput(ynn_type_fp32, 3, in_id);
5656

57-
// An axis far outside [-rank, rank) maps to a slinky dim past the end of the
58-
// axes_set bitset; broadcast must reject it instead of indexing out of range.
57+
// An axis far outside [-rank, rank) maps to a slinky dim past the input's
58+
// dimensions. Such a dimension is an implicit broadcast, so broadcasting it
59+
// is a no-op: it must be skipped rather than indexing the axes_set bitset out
60+
// of range.
5961
const int32_t axes[] = {-100};
6062
uint32_t output_id = YNN_INVALID_VALUE_ID;
6163
EXPECT_EQ(ynn_define_broadcast(subgraph.GetSubgraph(), /*num_axes=*/1, axes,
6264
in_id, &output_id, /*flags=*/0),
63-
ynn_status_invalid_parameter);
65+
ynn_status_success);
66+
EXPECT_EQ(output_id, in_id);
6467
}
6568

6669
} // namespace ynn

0 commit comments

Comments
 (0)