Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -545,15 +545,12 @@ func.func @dynamic_iota_broadcast_dim0_i64(%arg0 : tensor<2xi64>) -> tensor<5x?x
func.return %0 : tensor<5x?xi32>
}

// Index-typed shapes are skipped (run shape-legalize-to-stablehlo first).
// CHECK-LABEL: @dynamic_iota_broadcast_dim1_index
func.func @dynamic_iota_broadcast_dim1_index(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> {
// CHECK-NEXT: [[CAST:%.+]] = arith.index_cast %arg0 : tensor<2xindex> to tensor<2xi64>
// CHECK-NEXT: [[SLICE:%.+]] = stablehlo.slice [[CAST]] [1:2] : (tensor<2xi64>) -> tensor<1xi64>
// CHECK-NEXT: [[IOTA:%.+]] = stablehlo.dynamic_iota [[SLICE]], dim = 0 : (tensor<1xi64>) -> tensor<?xi32>
// CHECK-NEXT: [[BROADCAST:%.+]] = stablehlo.dynamic_broadcast_in_dim [[IOTA]], %arg0, dims = [1] : (tensor<?xi32>, tensor<2xindex>) -> tensor<5x?xi32>
// CHECK-NEXT: stablehlo.dynamic_iota %arg0
%0 = "stablehlo.dynamic_iota"(%arg0) <{iota_dimension = 1 : i64}> : (tensor<2xindex>) -> tensor<5x?xi32>

// CHECK: return [[BROADCAST]]
// CHECK-NEXT: return
func.return %0 : tensor<5x?xi32>
}

Expand Down Expand Up @@ -1920,6 +1917,72 @@ func.func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> {

// -----

/////////
// ScatterOp

// CHECK-LABEL: @scatter_empty_indices
func.func @scatter_empty_indices(%input: tensor<3x4xf32>, %updates: tensor<0x4xf32>) -> tensor<3x4xf32> {
%indices = stablehlo.constant dense<> : tensor<0xi32>
// CHECK-NOT: stablehlo.scatter
// CHECK: return %arg0
%0 = "stablehlo.scatter"(%input, %indices, %updates) ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
stablehlo.return %arg1 : tensor<f32>
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [1],
inserted_window_dims = [],
scatter_dims_to_operand_dims = [0],
index_vector_dim = 1
>,
indices_are_sorted = true,
unique_indices = true
} : (tensor<3x4xf32>, tensor<0xi32>, tensor<0x4xf32>) -> tensor<3x4xf32>
func.return %0 : tensor<3x4xf32>
}

// CHECK-LABEL: @scatter_zero_extent_updates
func.func @scatter_zero_extent_updates(%input: tensor<3x4xf32>, %indices: tensor<3xi32>) -> tensor<3x4xf32> {
%updates = stablehlo.constant dense<> : tensor<3x0xf32>
// CHECK-NOT: stablehlo.scatter
// CHECK: return %arg0
%0 = "stablehlo.scatter"(%input, %indices, %updates) ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
stablehlo.return %arg1 : tensor<f32>
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [1],
inserted_window_dims = [],
scatter_dims_to_operand_dims = [0],
index_vector_dim = 1
>,
indices_are_sorted = true,
unique_indices = true
} : (tensor<3x4xf32>, tensor<3xi32>, tensor<3x0xf32>) -> tensor<3x4xf32>
func.return %0 : tensor<3x4xf32>
}

// CHECK-LABEL: @scatter_nonempty_not_simplified
func.func @scatter_nonempty_not_simplified(%input: tensor<3x4xf32>, %indices: tensor<2xi32>, %updates: tensor<2x4xf32>) -> tensor<3x4xf32> {
// CHECK: stablehlo.scatter
%0 = "stablehlo.scatter"(%input, %indices, %updates) ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
stablehlo.return %arg1 : tensor<f32>
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [1],
inserted_window_dims = [],
scatter_dims_to_operand_dims = [0],
index_vector_dim = 1
>,
indices_are_sorted = true,
unique_indices = true
} : (tensor<3x4xf32>, tensor<2xi32>, tensor<2x4xf32>) -> tensor<3x4xf32>
func.return %0 : tensor<3x4xf32>
}

// -----

/////////
// SortOp

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -506,12 +505,16 @@ struct DynamicIotaOpToBroadcast

Value iotaShape = iota.getOutputShape();
auto iotaShapeType = cast<ShapedType>(iotaShape.getType());
if (iotaShapeType.getElementType().isIndex())
return rewriter.notifyMatchFailure(
iota, "index-typed shapes not supported; run "
"shape-legalize-to-stablehlo first");

auto iotaShapeI64Type =
RankedTensorType::get(iotaShapeType.getShape(), rewriter.getI64Type());
Value iotaShapeI64;
if (iotaShapeType.getElementType().isIndex()) {
iotaShapeI64 = arith::IndexCastOp::create(rewriter, iotaLoc,
iotaShapeI64Type, iotaShape);
if (iotaShapeType.getElementType().isInteger(64)) {
iotaShapeI64 = iotaShape;
} else {
iotaShapeI64 = stablehlo::ConvertOp::create(rewriter, iotaLoc,
iotaShapeI64Type, iotaShape);
Expand Down Expand Up @@ -1267,6 +1270,55 @@ struct SliceOpConcatSimplify : public SimplifyOpRewritePattern<SliceOp> {
}
};

//////////////////////////////////
// ScatterOp
/////////////////////////////////

// Pattern: scatter(inputs, indices, updates) -> inputs
// [when scatter_indices has 0 scatter points]
struct ScatterOpEmptyIndices : public SimplifyOpRewritePattern<ScatterOp> {
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;

LogicalResult matchAndRewrite(ScatterOp op,
PatternRewriter& rewriter) const override {
auto indicesType = cast<ShapedType>(op.getScatterIndices().getType());
if (!indicesType.hasStaticShape()) return failure();

auto scatterDimNums = op.getScatterDimensionNumbers();
int64_t indexVectorDim = scatterDimNums.getIndexVectorDim();

bool hasZeroScatterPoints = false;
for (int64_t i = 0; i < indicesType.getRank(); ++i) {
if (i != indexVectorDim && indicesType.getDimSize(i) == 0) {
hasZeroScatterPoints = true;
break;
}
}
if (!hasZeroScatterPoints) return failure();

rewriter.replaceOp(op, op.getInputs());
return success();
}
};

// Pattern: scatter(inputs, indices, updates) -> inputs
// [when all updates have zero extent]
struct ScatterOpZeroExtentUpdates : public SimplifyOpRewritePattern<ScatterOp> {
using SimplifyOpRewritePattern::SimplifyOpRewritePattern;

LogicalResult matchAndRewrite(ScatterOp op,
PatternRewriter& rewriter) const override {
bool allZeroExtent = llvm::all_of(op.getUpdates(), [](Value update) {
auto type = cast<ShapedType>(update.getType());
return type.hasStaticShape() && type.getNumElements() == 0;
});
if (!allZeroExtent) return failure();

rewriter.replaceOp(op, op.getInputs());
return success();
}
};

//////////////////////////////////
// SortOp
/////////////////////////////////
Expand Down Expand Up @@ -1670,6 +1722,7 @@ void populateStablehloCanonicalizationPatterns(
GatherOpCanon, IotaOpBroadcast, PadOpBroadcastEmptyTensor,
RealDynamicSliceOpToDynamicSlice, ReduceOpEmptyCanon,
ReduceOpNoopVariableReturn, ReduceOpUnusedResultCanon, SelectOpCanon,
ScatterOpEmptyIndices, ScatterOpZeroExtentUpdates,
SliceOpConcatSimplify, SortOpDropUnusedArgs, SortOpSetDimension,
TransposeIsReshape, TupleIsRepacking, WhileOpImplicitCapture>(
context, options, benefit);
Expand Down
Loading