Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
TensorType destType);

// Create a tosa.gather op. Casts i1 inputs to i8 internally if needed.
Value createGatherOp(PatternRewriter &rewriter, Location loc,
RankedTensorType resultType, Value input, Value indices);

// Creates a TOSA operation and performs shape inference on the individual
// op. This allows shape inference during the framework to TOSA lowering.
template <typename TosaOp, typename... Args>
Expand Down Expand Up @@ -119,6 +123,9 @@ FailureOr<Value> getZeroPointValue(PatternRewriter &rewriter, Operation *op,
// Check if a shaped type has any dimension with size 0.
bool typeHasZeroDim(ShapedType type);

// Check if a type is i1 or a shaped type with i1 element type.
bool isI1Type(Type type);

// Compute scale/offset/border parameters for TOSA resize on one dimension.
void computeResizeParams(int inputSize, int outputSize, bool alignCorners,
tosa::ResizeMode mode, int &scaleN, int &scaleD,
Expand Down
15 changes: 7 additions & 8 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4545,14 +4545,14 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewriteImpl(
.value();

SmallVector<int64_t> intermediateOutShape = {1, numIndices, weightShape[1]};
auto gatherOp = tosa::GatherOp::create(
rewriter, op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(intermediateOutShape),
weightType.getElementType()),
reshapedWeight, castIndices);
auto gatherElemTy = weightType.getElementType();
auto gatherTy = RankedTensorType::get(
makeShapeLLVMCompatible(intermediateOutShape), gatherElemTy);
Value gatherResult = tosa::createGatherOp(rewriter, op->getLoc(), gatherTy,
reshapedWeight, castIndices);

rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, outType, gatherOp,
op, outType, gatherResult,
tosa::getTosaConstShape(rewriter, op->getLoc(),
makeShapeTorchCompatible(outType.getShape())));

Expand Down Expand Up @@ -4869,8 +4869,7 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewriteImpl(
// use a single tosa.gather to materialize the strided slice.
auto gatherTy = RankedTensorType::get({N, W, C}, elemTy);
Value gathered =
tosa::GatherOp::create(rewriter, loc, gatherTy, reshaped, idxNW)
.getResult();
tosa::createGatherOp(rewriter, loc, gatherTy, reshaped, idxNW);

SmallVector<int64_t> outShape = inputShape;
outShape[dim] = W;
Expand Down
16 changes: 6 additions & 10 deletions lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,18 +422,14 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
// Now the gather op itself
// %9 = "tosa.gather"(%2, %7) : (tensor<1x12x1xf32>, tensor<1x8xi32>) ->
// tensor<1x8x1xf32>
auto tosaGatherOp = tosa::CreateOpAndInfer<tosa::GatherOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(tosaGatherResultShape,
resultType.getElementType()),
tosaValuesReshapeOp.getResult(), tosaIndicesReshapeOp.getResult());
auto gatherTy = GetTypeFromTensorShape(tosaGatherResultShape,
resultType.getElementType());
Value gatherResult = tosa::createGatherOp(rewriter, op->getLoc(), gatherTy,
tosaValuesReshapeOp.getResult(),
tosaIndicesReshapeOp.getResult());

// Finally, reshape back to the original output shape of [Indices,
// ParamChannels]. %10 = "tosa.reshape"(%9) {new_shape = [1, 4, 2]} :
// (tensor<1x8x1xf32>) -> tensor<1x4x2xf32> %11 = torch_c.from_builtin_tensor
// %10 : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32>
return tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(), resultType, tosaGatherOp.getResult(),
rewriter, op->getLoc(), resultType, gatherResult,
tosa::getTosaConstShape(rewriter, op->getLoc(),
resultType.getShape()))
.getResult();
Expand Down
27 changes: 27 additions & 0 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,25 @@ std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
return tosa::CastOp::create(rewriter, op->getLoc(), castedSrcType, src);
}

// Create a tosa.gather op. Casts i1 inputs to i8 internally if needed.
Value createGatherOp(PatternRewriter &rewriter, Location loc,
RankedTensorType resultType, Value input, Value indices) {
if (tosa::isI1Type(resultType)) {
auto i8Ty = rewriter.getI8Type();
auto inputTy = cast<RankedTensorType>(input.getType());
auto inputI8Ty = inputTy.clone(i8Ty);
auto inputI8 =
tosa::tosaCastTensorToType(rewriter, input, inputI8Ty).value();
auto gatherI8Ty = resultType.clone(i8Ty);
auto gatheredI8 =
tosa::GatherOp::create(rewriter, loc, gatherI8Ty, inputI8, indices);
return tosa::tosaCastTensorToType(rewriter, gatheredI8, resultType).value();
}

return tosa::GatherOp::create(rewriter, loc, resultType, input, indices)
.getResult();
}

// Template instantiation
template std::optional<Value>
getConstTensor<bool>(PatternRewriter &, Operation *, ArrayRef<bool> vec,
Expand Down Expand Up @@ -586,6 +605,14 @@ bool typeHasZeroDim(ShapedType type) {
return llvm::any_of(outShape, [](int64_t dim) { return dim == 0; });
}

bool isI1Type(Type type) {
if (auto shapedTy = dyn_cast<ShapedType>(type))
type = shapedTy.getElementType();
if (auto intTy = dyn_cast<IntegerType>(type))
return intTy.getWidth() == 1;
return false;
}

void computeResizeParams(int inputSize, int outputSize, bool alignCorners,
tosa::ResizeMode mode, int &scaleN, int &scaleD,
int &offset, int &border) {
Expand Down
27 changes: 27 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,6 +1792,33 @@ def GatherModule_basic(module, tu: TestUtils):
# ==============================================================================


class GatherBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.bool, True),
([-1, -1, -1], torch.int64, True),
]
)
def forward(self, tensor, indices):
return torch.gather(tensor, 2, indices)


@register_test_case(module_factory=lambda: GatherBoolModule())
def GatherBoolModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(2, 3, 4, high=2).to(torch.bool),
torch.tensor([[[1, 2, 3], [1, 2, 3]]]),
)


# ==============================================================================


class GatherNegativeDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
62 changes: 62 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,47 @@ func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.v
return %0 : !torch.vtensor<[1,4,2],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.gather$bool(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],i1>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,4,2],si64> -> tensor<1x4x2xi64>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,4,3],i1> -> tensor<1x4x3xi1>
// CHECK: %[[VAL_4:.*]] = torch.constant.int -1
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32>
// CHECK: %[[VAL_7:.*]] = tosa.const_shape {values = dense<[1, 4, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_7]] : (tensor<1x4x2xi32>, !tosa.shape<4>) -> tensor<1x4x2x1xi32>
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32>
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{values = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]]]]> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32>
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32>
// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[1, 12, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_12]] : (tensor<1x4x3xi1>, !tosa.shape<3>) -> tensor<1x12x1xi1>
// CHECK: %[[VAL_14:.*]] = tosa.const_shape {values = dense<[8, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_11]], %[[VAL_14]] : (tensor<1x4x2x3xi32>, !tosa.shape<2>) -> tensor<8x3xi32>
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<[12, 3, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK: %[[VAL_17:.*]] = tosa.const_shape {values = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_16]], %[[VAL_17]] : (tensor<3xi32>, !tosa.shape<2>) -> tensor<1x3xi32>
// CHECK: %[[VAL_19:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[VAL_20:.*]] = tosa.mul %[[VAL_15]], %[[VAL_18]], %[[VAL_19]] : (tensor<8x3xi32>, tensor<1x3xi32>, tensor<1xi8>) -> tensor<8x3xi32>
// CHECK: %[[VAL_21:.*]] = tosa.reduce_sum %[[VAL_20]] {axis = 1 : i32} : (tensor<8x3xi32>) -> tensor<8x1xi32>
// CHECK: %[[VAL_22:.*]] = tosa.const_shape {values = dense<[1, 8]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: %[[VAL_23:.*]] = tosa.reshape %[[VAL_21]], %[[VAL_22]] : (tensor<8x1xi32>, !tosa.shape<2>) -> tensor<1x8xi32>
// CHECK: %[[VAL_24:.*]] = tosa.cast %[[VAL_13]] : (tensor<1x12x1xi1>) -> tensor<1x12x1xi8>
// CHECK: %[[VAL_25:.*]] = tosa.gather %[[VAL_24]], %[[VAL_23]] : (tensor<1x12x1xi8>, tensor<1x8xi32>) -> tensor<1x8x1xi8>
// CHECK: %[[VAL_26:.*]] = tosa.cast %[[VAL_25]] : (tensor<1x8x1xi8>) -> tensor<1x8x1xi1>
// CHECK: %[[VAL_27:.*]] = tosa.const_shape {values = dense<[1, 4, 2]> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: %[[VAL_28:.*]] = tosa.reshape %[[VAL_26]], %[[VAL_27]] : (tensor<1x8x1xi1>, !tosa.shape<3>) -> tensor<1x4x2xi1>
// CHECK: %[[VAL_29:.*]] = torch_c.from_builtin_tensor %[[VAL_28]] : tensor<1x4x2xi1> -> !torch.vtensor<[1,4,2],i1>
// CHECK: return %[[VAL_29]] : !torch.vtensor<[1,4,2],i1>
// CHECK: }
func.func @torch.aten.gather$bool(%arg0: !torch.vtensor<[1,4,3],i1>, %arg1: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],i1> {
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[1,4,3],i1>, !torch.int, !torch.vtensor<[1,4,2],si64>, !torch.bool -> !torch.vtensor<[1,4,2],i1>
return %0 : !torch.vtensor<[1,4,2],i1>
}

// -----
// CHECK-LABEL: func.func @torch.aten.add$int(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,2],si32>,
Expand Down Expand Up @@ -1422,6 +1463,27 @@ func.func @torch.aten.slice.negative_start(%arg0: !torch.vtensor<[4,65,256],f32>
return %0 : !torch.vtensor<[4,16,256],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.slice.bool_strided(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,64,1],i1>) -> !torch.vtensor<[1,32,1],i1> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,64,1],i1> -> tensor<1x64x1xi1>
// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x64x1xi1>) -> tensor<1x64x1xi8>
// CHECK: %[[VAL_3:.*]] = tosa.gather %[[VAL_2]], %{{.*}} : (tensor<1x64x1xi8>, tensor<1x32xi32>) -> tensor<1x32x1xi8>
// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_3]] : (tensor<1x32x1xi8>) -> tensor<1x32x1xi1>
// CHECK: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[1, 32, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_5]] : (tensor<1x32x1xi1>, !tosa.shape<3>) -> tensor<1x32x1xi1>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x32x1xi1> -> !torch.vtensor<[1,32,1],i1>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,32,1],i1>
// CHECK: }
func.func @torch.aten.slice.bool_strided(%arg0: !torch.vtensor<[1,64,1],i1>) -> !torch.vtensor<[1,32,1],i1> {
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int64 = torch.constant.int 64
%int2 = torch.constant.int 2
%0 = torch.aten.slice.Tensor %arg0, %int1, %int0, %int64, %int2 : !torch.vtensor<[1,64,1],i1>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,32,1],i1>
return %0 : !torch.vtensor<[1,32,1],i1>
}

// -----
// CHECK-LABEL: func.func @torch.aten.clamp.min_none(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {
Expand Down
Loading