Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
TensorType destType);

// Create a tosa.gather op. Casts i1 inputs to i8 internally if needed.
std::optional<Value> createGatherOp(PatternRewriter &rewriter, Location loc,
RankedTensorType resultType, Value input,
Value indices);

// Creates a TOSA operation and performs shape inference on the individual
// op. This allows shape inference during the framework to TOSA lowering.
template <typename TosaOp, typename... Args>
Expand Down Expand Up @@ -119,6 +124,9 @@ FailureOr<Value> getZeroPointValue(PatternRewriter &rewriter, Operation *op,
// Check if a shaped type has any dimension with size 0.
bool typeHasZeroDim(ShapedType type);

// Check if a type is i1 or a shaped type with i1 element type.
bool isI1Type(Type type);

// Compute scale/offset/border parameters for TOSA resize on one dimension.
void computeResizeParams(int inputSize, int outputSize, bool alignCorners,
tosa::ResizeMode mode, int &scaleN, int &scaleD,
Expand Down
25 changes: 15 additions & 10 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4545,14 +4545,17 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewriteImpl(
.value();

SmallVector<int64_t> intermediateOutShape = {1, numIndices, weightShape[1]};
auto gatherOp = tosa::GatherOp::create(
rewriter, op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(intermediateOutShape),
weightType.getElementType()),
reshapedWeight, castIndices);
auto gatherElemTy = weightType.getElementType();
auto gatherTy = RankedTensorType::get(
makeShapeLLVMCompatible(intermediateOutShape), gatherElemTy);
auto gatherResult = tosa::createGatherOp(rewriter, op->getLoc(), gatherTy,
reshapedWeight, castIndices);
if (!gatherResult)
return rewriter.notifyMatchFailure(
op, "expected ranked tensor input for gather");

rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, outType, gatherOp,
op, outType, *gatherResult,
tosa::getTosaConstShape(rewriter, op->getLoc(),
makeShapeTorchCompatible(outType.getShape())));

Expand Down Expand Up @@ -4868,9 +4871,11 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewriteImpl(
// Duplicate the 1-D index vector across the batch dimension so that we can
// use a single tosa.gather to materialize the strided slice.
auto gatherTy = RankedTensorType::get({N, W, C}, elemTy);
Value gathered =
tosa::GatherOp::create(rewriter, loc, gatherTy, reshaped, idxNW)
.getResult();
auto gathered =
tosa::createGatherOp(rewriter, loc, gatherTy, reshaped, idxNW);
if (!gathered)
return rewriter.notifyMatchFailure(
op, "expected ranked tensor input for gather");

SmallVector<int64_t> outShape = inputShape;
outShape[dim] = W;
Expand All @@ -4879,7 +4884,7 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewriteImpl(

// Restore the original rank with the newly strided dimension size.
Value result =
tosa::ReshapeOp::create(rewriter, loc, convertedResultTy, gathered,
tosa::ReshapeOp::create(rewriter, loc, convertedResultTy, *gathered,
tosa::getTosaConstShape(rewriter, loc, outShape))
.getResult();

Expand Down
18 changes: 8 additions & 10 deletions lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,18 +422,16 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
// Now the gather op itself
// %9 = "tosa.gather"(%2, %7) : (tensor<1x12x1xf32>, tensor<1x8xi32>) ->
// tensor<1x8x1xf32>
auto tosaGatherOp = tosa::CreateOpAndInfer<tosa::GatherOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(tosaGatherResultShape,
resultType.getElementType()),
tosaValuesReshapeOp.getResult(), tosaIndicesReshapeOp.getResult());
auto gatherTy = GetTypeFromTensorShape(tosaGatherResultShape,
resultType.getElementType());
auto gatherResult = tosa::createGatherOp(rewriter, op->getLoc(), gatherTy,
tosaValuesReshapeOp.getResult(),
tosaIndicesReshapeOp.getResult());
if (!gatherResult)
return std::nullopt;

// Finally, reshape back to the original output shape of [Indices,
// ParamChannels]. %10 = "tosa.reshape"(%9) {new_shape = [1, 4, 2]} :
// (tensor<1x8x1xf32>) -> tensor<1x4x2xf32> %11 = torch_c.from_builtin_tensor
// %10 : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32>
return tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(), resultType, tosaGatherOp.getResult(),
rewriter, op->getLoc(), resultType, *gatherResult,
tosa::getTosaConstShape(rewriter, op->getLoc(),
resultType.getShape()))
.getResult();
Expand Down
30 changes: 30 additions & 0 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,28 @@ std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
return tosa::CastOp::create(rewriter, op->getLoc(), castedSrcType, src);
}

// Create a tosa.gather op. Casts i1 inputs to i8 internally if needed.
std::optional<Value> createGatherOp(PatternRewriter &rewriter, Location loc,
RankedTensorType resultType, Value input,
Value indices) {
if (tosa::isI1Type(resultType)) {
auto i8Ty = rewriter.getI8Type();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy)
return std::nullopt;
auto inputI8Ty = inputTy.clone(i8Ty);
auto inputI8 =
tosa::tosaCastTensorToType(rewriter, input, inputI8Ty).value();
auto gatherI8Ty = resultType.clone(i8Ty);
auto gatheredI8 =
tosa::GatherOp::create(rewriter, loc, gatherI8Ty, inputI8, indices);
return tosa::tosaCastTensorToType(rewriter, gatheredI8, resultType).value();
}

return tosa::GatherOp::create(rewriter, loc, resultType, input, indices)
.getResult();
}

// Template instantiation
template std::optional<Value>
getConstTensor<bool>(PatternRewriter &, Operation *, ArrayRef<bool> vec,
Expand Down Expand Up @@ -586,6 +608,14 @@ bool typeHasZeroDim(ShapedType type) {
return llvm::any_of(outShape, [](int64_t dim) { return dim == 0; });
}

bool isI1Type(Type type) {
if (auto shapedTy = dyn_cast<ShapedType>(type))
type = shapedTy.getElementType();
if (auto intTy = dyn_cast<IntegerType>(type))
return intTy.getWidth() == 1;
return false;
}

void computeResizeParams(int inputSize, int outputSize, bool alignCorners,
tosa::ResizeMode mode, int &scaleN, int &scaleD,
int &offset, int &border) {
Expand Down
27 changes: 27 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,6 +1792,33 @@ def GatherModule_basic(module, tu: TestUtils):
# ==============================================================================


class GatherBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.bool, True),
([-1, -1, -1], torch.int64, True),
]
)
def forward(self, tensor, indices):
return torch.gather(tensor, 2, indices)


@register_test_case(module_factory=lambda: GatherBoolModule())
def GatherBoolModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(2, 3, 4, high=2).to(torch.bool),
torch.tensor([[[1, 2, 3], [1, 2, 3]]]),
)


# ==============================================================================


class GatherNegativeDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
62 changes: 62 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,47 @@ func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.v
return %0 : !torch.vtensor<[1,4,2],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.gather$bool(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],i1>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,4,2],si64> -> tensor<1x4x2xi64>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,4,3],i1> -> tensor<1x4x3xi1>
// CHECK: %[[VAL_4:.*]] = torch.constant.int -1
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32>
// CHECK: %[[VAL_7:.*]] = tosa.const_shape {values = dense<[1, 4, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_7]] : (tensor<1x4x2xi32>, !tosa.shape<4>) -> tensor<1x4x2x1xi32>
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32>
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{values = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]]]]> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32>
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32>
// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[1, 12, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_12]] : (tensor<1x4x3xi1>, !tosa.shape<3>) -> tensor<1x12x1xi1>
// CHECK: %[[VAL_14:.*]] = tosa.const_shape {values = dense<[8, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_11]], %[[VAL_14]] : (tensor<1x4x2x3xi32>, !tosa.shape<2>) -> tensor<8x3xi32>
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<[12, 3, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK: %[[VAL_17:.*]] = tosa.const_shape {values = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_16]], %[[VAL_17]] : (tensor<3xi32>, !tosa.shape<2>) -> tensor<1x3xi32>
// CHECK: %[[VAL_19:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[VAL_20:.*]] = tosa.mul %[[VAL_15]], %[[VAL_18]], %[[VAL_19]] : (tensor<8x3xi32>, tensor<1x3xi32>, tensor<1xi8>) -> tensor<8x3xi32>
// CHECK: %[[VAL_21:.*]] = tosa.reduce_sum %[[VAL_20]] {axis = 1 : i32} : (tensor<8x3xi32>) -> tensor<8x1xi32>
// CHECK: %[[VAL_22:.*]] = tosa.const_shape {values = dense<[1, 8]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: %[[VAL_23:.*]] = tosa.reshape %[[VAL_21]], %[[VAL_22]] : (tensor<8x1xi32>, !tosa.shape<2>) -> tensor<1x8xi32>
// CHECK: %[[VAL_24:.*]] = tosa.cast %[[VAL_13]] : (tensor<1x12x1xi1>) -> tensor<1x12x1xi8>
// CHECK: %[[VAL_25:.*]] = tosa.gather %[[VAL_24]], %[[VAL_23]] : (tensor<1x12x1xi8>, tensor<1x8xi32>) -> tensor<1x8x1xi8>
// CHECK: %[[VAL_26:.*]] = tosa.cast %[[VAL_25]] : (tensor<1x8x1xi8>) -> tensor<1x8x1xi1>
// CHECK: %[[VAL_27:.*]] = tosa.const_shape {values = dense<[1, 4, 2]> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: %[[VAL_28:.*]] = tosa.reshape %[[VAL_26]], %[[VAL_27]] : (tensor<1x8x1xi1>, !tosa.shape<3>) -> tensor<1x4x2xi1>
// CHECK: %[[VAL_29:.*]] = torch_c.from_builtin_tensor %[[VAL_28]] : tensor<1x4x2xi1> -> !torch.vtensor<[1,4,2],i1>
// CHECK: return %[[VAL_29]] : !torch.vtensor<[1,4,2],i1>
// CHECK: }
func.func @torch.aten.gather$bool(%arg0: !torch.vtensor<[1,4,3],i1>, %arg1: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],i1> {
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[1,4,3],i1>, !torch.int, !torch.vtensor<[1,4,2],si64>, !torch.bool -> !torch.vtensor<[1,4,2],i1>
return %0 : !torch.vtensor<[1,4,2],i1>
}

// -----
// CHECK-LABEL: func.func @torch.aten.add$int(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,2],si32>,
Expand Down Expand Up @@ -1422,6 +1463,27 @@ func.func @torch.aten.slice.negative_start(%arg0: !torch.vtensor<[4,65,256],f32>
return %0 : !torch.vtensor<[4,16,256],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.slice.bool_strided(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,64,1],i1>) -> !torch.vtensor<[1,32,1],i1> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,64,1],i1> -> tensor<1x64x1xi1>
// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x64x1xi1>) -> tensor<1x64x1xi8>
// CHECK: %[[VAL_3:.*]] = tosa.gather %[[VAL_2]], %{{.*}} : (tensor<1x64x1xi8>, tensor<1x32xi32>) -> tensor<1x32x1xi8>
// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_3]] : (tensor<1x32x1xi8>) -> tensor<1x32x1xi1>
// CHECK: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[1, 32, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_5]] : (tensor<1x32x1xi1>, !tosa.shape<3>) -> tensor<1x32x1xi1>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x32x1xi1> -> !torch.vtensor<[1,32,1],i1>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,32,1],i1>
// CHECK: }
func.func @torch.aten.slice.bool_strided(%arg0: !torch.vtensor<[1,64,1],i1>) -> !torch.vtensor<[1,32,1],i1> {
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int64 = torch.constant.int 64
%int2 = torch.constant.int 2
%0 = torch.aten.slice.Tensor %arg0, %int1, %int0, %int64, %int2 : !torch.vtensor<[1,64,1],i1>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,32,1],i1>
return %0 : !torch.vtensor<[1,32,1],i1>
}

// -----
// CHECK-LABEL: func.func @torch.aten.clamp.min_none(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {
Expand Down
Loading