Skip to content

Commit 3d6bab0

Browse files
committed
[TorchToTosa] Avoid i1 gather by casting through i8
TOSA gather does not accept i1 tensors. When gather element type is i1, cast inputs to i8, perform the gather (including gather-nd paths), then cast back to i1. Signed-off-by: Cathal Corbett <cathal.corbett@arm.com> Change-Id: I8e3034612c2fabec7c9e75d8295a863860a674c2
1 parent 64ca81a commit 3d6bab0

File tree

3 files changed

+129
-13
lines changed

3 files changed

+129
-13
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4545,14 +4545,33 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewriteImpl(
45454545
.value();
45464546

45474547
SmallVector<int64_t> intermediateOutShape = {1, numIndices, weightShape[1]};
4548-
auto gatherOp = tosa::GatherOp::create(
4549-
rewriter, op->getLoc(),
4550-
RankedTensorType::get(makeShapeLLVMCompatible(intermediateOutShape),
4551-
weightType.getElementType()),
4552-
reshapedWeight, castIndices);
4548+
auto gatherElemTy = weightType.getElementType();
4549+
auto gatherTy = RankedTensorType::get(
4550+
makeShapeLLVMCompatible(intermediateOutShape), gatherElemTy);
4551+
Value gatherResult;
4552+
if (auto intTy = dyn_cast<IntegerType>(gatherElemTy);
4553+
intTy && intTy.getWidth() == 1) {
4554+
auto i8Ty = rewriter.getI8Type();
4555+
auto reshapedWeightI8 =
4556+
tosa::tosaCastTensorToType(
4557+
rewriter, reshapedWeight,
4558+
RankedTensorType::get(makeShapeLLVMCompatible(newWeightShape),
4559+
i8Ty))
4560+
.value();
4561+
auto gatherTyI8 = RankedTensorType::get(
4562+
makeShapeLLVMCompatible(intermediateOutShape), i8Ty);
4563+
auto gatheredI8 = tosa::GatherOp::create(rewriter, op->getLoc(), gatherTyI8,
4564+
reshapedWeightI8, castIndices);
4565+
gatherResult =
4566+
tosa::tosaCastTensorToType(rewriter, gatheredI8, gatherTy).value();
4567+
} else {
4568+
gatherResult = tosa::GatherOp::create(rewriter, op->getLoc(), gatherTy,
4569+
reshapedWeight, castIndices)
4570+
.getResult();
4571+
}
45534572

45544573
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
4555-
op, outType, gatherOp,
4574+
op, outType, gatherResult,
45564575
tosa::getTosaConstShape(rewriter, op->getLoc(),
45574576
makeShapeTorchCompatible(outType.getShape())));
45584577

@@ -4868,9 +4887,24 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewriteImpl(
48684887
// Duplicate the 1-D index vector across the batch dimension so that we can
48694888
// use a single tosa.gather to materialize the strided slice.
48704889
auto gatherTy = RankedTensorType::get({N, W, C}, elemTy);
4871-
Value gathered =
4872-
tosa::GatherOp::create(rewriter, loc, gatherTy, reshaped, idxNW)
4873-
.getResult();
4890+
Value gathered;
4891+
if (auto intTy = dyn_cast<IntegerType>(elemTy);
4892+
intTy && intTy.getWidth() == 1) {
4893+
auto i8Ty = rewriter.getI8Type();
4894+
auto reshapedI8 =
4895+
tosa::tosaCastTensorToType(
4896+
rewriter, reshaped,
4897+
RankedTensorType::get(makeShapeLLVMCompatible(nkcShape), i8Ty))
4898+
.value();
4899+
auto gatherTyI8 = RankedTensorType::get({N, W, C}, i8Ty);
4900+
auto gatheredI8 =
4901+
tosa::GatherOp::create(rewriter, loc, gatherTyI8, reshapedI8, idxNW);
4902+
gathered =
4903+
tosa::tosaCastTensorToType(rewriter, gatheredI8, gatherTy).value();
4904+
} else {
4905+
gathered = tosa::GatherOp::create(rewriter, loc, gatherTy, reshaped, idxNW)
4906+
.getResult();
4907+
}
48744908

48754909
SmallVector<int64_t> outShape = inputShape;
48764910
outShape[dim] = W;

lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -420,18 +420,40 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
420420
// Now the gather op itself
421421
// %9 = "tosa.gather"(%2, %7) : (tensor<1x12x1xf32>, tensor<1x8xi32>) ->
422422
// tensor<1x8x1xf32>
423+
auto resultElemTy = resultType.getElementType();
424+
Value valuesForGather = tosaValuesReshapeOp.getResult();
425+
Type gatherElemTy = resultElemTy;
426+
if (auto intTy = dyn_cast<IntegerType>(resultElemTy);
427+
intTy && intTy.getWidth() == 1) {
428+
auto i8Ty = rewriter.getI8Type();
429+
valuesForGather = tosa::tosaCastTensorToType(
430+
rewriter, valuesForGather,
431+
GetTypeFromTensorShape(tosaValuesShape, i8Ty))
432+
.value();
433+
gatherElemTy = i8Ty;
434+
}
435+
423436
auto tosaGatherOp = tosa::CreateOpAndInfer<tosa::GatherOp>(
424437
rewriter, op->getLoc(),
425-
GetTypeFromTensorShape(tosaGatherResultShape,
426-
resultType.getElementType()),
427-
tosaValuesReshapeOp.getResult(), tosaIndicesReshapeOp.getResult());
438+
GetTypeFromTensorShape(tosaGatherResultShape, gatherElemTy),
439+
valuesForGather, tosaIndicesReshapeOp.getResult());
428440

429441
// Finally, reshape back to the original output shape of [Indices,
430442
// ParamChannels]. %10 = "tosa.reshape"(%9) {new_shape = [1, 4, 2]} :
431443
// (tensor<1x8x1xf32>) -> tensor<1x4x2xf32> %11 = torch_c.from_builtin_tensor
432444
// %10 : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32>
445+
Value gatherResult = tosaGatherOp.getResult();
446+
if (auto intTy = dyn_cast<IntegerType>(resultElemTy);
447+
intTy && intTy.getWidth() == 1) {
448+
gatherResult =
449+
tosa::tosaCastTensorToType(
450+
rewriter, gatherResult,
451+
GetTypeFromTensorShape(tosaGatherResultShape, resultElemTy))
452+
.value();
453+
}
454+
433455
return tosa::CreateOpAndInfer<tosa::ReshapeOp>(
434-
rewriter, op->getLoc(), resultType, tosaGatherOp.getResult(),
456+
rewriter, op->getLoc(), resultType, gatherResult,
435457
tosa::getTosaConstShape(rewriter, op->getLoc(),
436458
resultType.getShape()))
437459
.getResult();

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,6 +1349,47 @@ func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.v
13491349
return %0 : !torch.vtensor<[1,4,2],f32>
13501350
}
13511351

1352+
// -----
1353+
// CHECK-LABEL: func.func @torch.aten.gather$bool(
1354+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],i1>,
1355+
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],i1> {
1356+
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,4,2],si64> -> tensor<1x4x2xi64>
1357+
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,4,3],i1> -> tensor<1x4x3xi1>
1358+
// CHECK: %[[VAL_4:.*]] = torch.constant.int -1
1359+
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
1360+
// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32>
1361+
// CHECK: %[[VAL_7:.*]] = tosa.const_shape {values = dense<[1, 4, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
1362+
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_7]] : (tensor<1x4x2xi32>, !tosa.shape<4>) -> tensor<1x4x2x1xi32>
1363+
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32>
1364+
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{values = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]]]]> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32>
1365+
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32>
1366+
// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[1, 12, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
1367+
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_12]] : (tensor<1x4x3xi1>, !tosa.shape<3>) -> tensor<1x12x1xi1>
1368+
// CHECK: %[[VAL_14:.*]] = tosa.cast %[[VAL_13]] : (tensor<1x12x1xi1>) -> tensor<1x12x1xi8>
1369+
// CHECK: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[8, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
1370+
// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_11]], %[[VAL_15]] : (tensor<1x4x2x3xi32>, !tosa.shape<2>) -> tensor<8x3xi32>
1371+
// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{values = dense<[12, 3, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
1372+
// CHECK: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
1373+
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_17]], %[[VAL_18]] : (tensor<3xi32>, !tosa.shape<2>) -> tensor<1x3xi32>
1374+
// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
1375+
// CHECK: %[[VAL_21:.*]] = tosa.mul %[[VAL_16]], %[[VAL_19]], %[[VAL_20]] : (tensor<8x3xi32>, tensor<1x3xi32>, tensor<1xi8>) -> tensor<8x3xi32>
1376+
// CHECK: %[[VAL_22:.*]] = tosa.reduce_sum %[[VAL_21]] {axis = 1 : i32} : (tensor<8x3xi32>) -> tensor<8x1xi32>
1377+
// CHECK: %[[VAL_23:.*]] = tosa.const_shape {values = dense<[1, 8]> : tensor<2xindex>} : () -> !tosa.shape<2>
1378+
// CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_22]], %[[VAL_23]] : (tensor<8x1xi32>, !tosa.shape<2>) -> tensor<1x8xi32>
1379+
// CHECK: %[[VAL_25:.*]] = tosa.gather %[[VAL_14]], %[[VAL_24]] : (tensor<1x12x1xi8>, tensor<1x8xi32>) -> tensor<1x8x1xi8>
1380+
// CHECK: %[[VAL_26:.*]] = tosa.cast %[[VAL_25]] : (tensor<1x8x1xi8>) -> tensor<1x8x1xi1>
1381+
// CHECK: %[[VAL_27:.*]] = tosa.const_shape {values = dense<[1, 4, 2]> : tensor<3xindex>} : () -> !tosa.shape<3>
1382+
// CHECK: %[[VAL_28:.*]] = tosa.reshape %[[VAL_26]], %[[VAL_27]] : (tensor<1x8x1xi1>, !tosa.shape<3>) -> tensor<1x4x2xi1>
1383+
// CHECK: %[[VAL_29:.*]] = torch_c.from_builtin_tensor %[[VAL_28]] : tensor<1x4x2xi1> -> !torch.vtensor<[1,4,2],i1>
1384+
// CHECK: return %[[VAL_29]] : !torch.vtensor<[1,4,2],i1>
1385+
// CHECK: }
1386+
func.func @torch.aten.gather$bool(%arg0: !torch.vtensor<[1,4,3],i1>, %arg1: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],i1> {
1387+
%int-1 = torch.constant.int -1
1388+
%false = torch.constant.bool false
1389+
%0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[1,4,3],i1>, !torch.int, !torch.vtensor<[1,4,2],si64>, !torch.bool -> !torch.vtensor<[1,4,2],i1>
1390+
return %0 : !torch.vtensor<[1,4,2],i1>
1391+
}
1392+
13521393
// -----
13531394
// CHECK-LABEL: func.func @torch.aten.add$int(
13541395
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,2],si32>,
@@ -1422,6 +1463,25 @@ func.func @torch.aten.slice.negative_start(%arg0: !torch.vtensor<[4,65,256],f32>
14221463
return %0 : !torch.vtensor<[4,16,256],f32>
14231464
}
14241465

1466+
// -----
1467+
// CHECK-LABEL: func.func @torch.aten.slice.bool_strided(
1468+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,64,1],i1>) -> !torch.vtensor<[1,32,1],i1> {
1469+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,64,1],i1> -> tensor<1x64x1xi1>
1470+
// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x64x1xi1>) -> tensor<1x64x1xi8>
1471+
// CHECK: %[[VAL_3:.*]] = tosa.gather %[[VAL_2]], %{{.*}} : (tensor<1x64x1xi8>, tensor<1x32xi32>) -> tensor<1x32x1xi8>
1472+
// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_3]] : (tensor<1x32x1xi8>) -> tensor<1x32x1xi1>
1473+
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x32x1xi1> -> !torch.vtensor<[1,32,1],i1>
1474+
// CHECK: return %[[VAL_5]] : !torch.vtensor<[1,32,1],i1>
1475+
// CHECK: }
1476+
func.func @torch.aten.slice.bool_strided(%arg0: !torch.vtensor<[1,64,1],i1>) -> !torch.vtensor<[1,32,1],i1> {
1477+
%int1 = torch.constant.int 1
1478+
%int0 = torch.constant.int 0
1479+
%int64 = torch.constant.int 64
1480+
%int2 = torch.constant.int 2
1481+
%0 = torch.aten.slice.Tensor %arg0, %int1, %int0, %int64, %int2 : !torch.vtensor<[1,64,1],i1>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,32,1],i1>
1482+
return %0 : !torch.vtensor<[1,32,1],i1>
1483+
}
1484+
14251485
// -----
14261486
// CHECK-LABEL: func.func @torch.aten.clamp.min_none(
14271487
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {

0 commit comments

Comments
 (0)