Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,10 @@ LogicalResult inferAllToAllOp(
int64_t splitDimension, int64_t concatDimension, int64_t splitCount,
DenseIntElementsAttr replicaGroups,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
if (operands.empty())
return emitOptionalError(location,
"must have at least 1 operand");

// all_to_all_c5, all_to_all_c7, all_to_all_i5
if (failed(verifyReplicaGroups(location, replicaGroups,
/*allGroupsMustHaveSameSize=*/true,
Expand Down Expand Up @@ -3631,6 +3635,10 @@ LogicalResult verifyAllGatherOp(std::optional<Location> location,
DenseIntElementsAttr replicaGroups,
int64_t channelId, bool useGlobalDeviceIds,
ValueRange results) {
if (operands.empty())
return emitOptionalError(location,
"must have at least 1 operand");

// all_gather_i3, all_gather_c2, all_gather_c4
if (failed(verifyReplicaGroups(location, replicaGroups,
/*allGroupsMustHaveSameSize=*/true,
Expand Down Expand Up @@ -3696,6 +3704,10 @@ LogicalResult verifyAllReduceOp(std::optional<Location> location,
DenseIntElementsAttr replicaGroups,
int64_t channelId, bool useGlobalDeviceIds,
Region& computation) {
if (operands.empty())
return emitOptionalError(location,
"must have at least 1 operand");

// TODO(#498): AllReduceOp does not have rank-2 replicaGroups.
// all_reduce_c1...all_reduce_c3
if (failed(verifyReplicaGroups(location, replicaGroups,
Expand Down
39 changes: 39 additions & 0 deletions stablehlo/tests/verify_collective_ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file

// -----

func.func @all_reduce_empty_operands() {
// expected-error@+1 {{must have at least 1 operand}}
"stablehlo.all_reduce"() ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%0 = stablehlo.add %arg0, %arg1 : tensor<f32>
stablehlo.return %0 : tensor<f32>
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : () -> ()
func.return
}

// -----

func.func @all_gather_empty_operands() {
// expected-error@+1 {{must have at least 1 operand}}
"stablehlo.all_gather"() {
all_gather_dim = 0 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : () -> ()
func.return
}

// -----

func.func @all_to_all_empty_operands() {
// expected-error@+1 {{must have at least 1 operand}}
"stablehlo.all_to_all"() {
split_dimension = 0 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : () -> ()
func.return
}