Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ FailureOr<Value> getZeroPointValue(PatternRewriter &rewriter, Operation *op,
// Check if a shaped type has any dimension with size 0.
bool typeHasZeroDim(ShapedType type);

// Check if a type is i1 or a shaped type with i1 element type.
bool isI1Type(Type type);

// Compute scale/offset/border parameters for TOSA resize on one dimension.
void computeResizeParams(int inputSize, int outputSize, bool alignCorners,
tosa::ResizeMode mode, int &scaleN, int &scaleD,
Expand Down
21 changes: 21 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,27 @@ class ConvertAtenCompareOp : public TorchToTosaOpConversionPattern<AtenOpT> {
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
if (isBitwiseOp) {
// TOSA bitwise ops do not support i1. Use logical ops for bool tensors.
if (tosa::isI1Type(resultTy)) {
auto lhsBool =
tosa::tosaCastTensorToType(rewriter, lhs, resultTy).value();
auto rhsBool =
tosa::tosaCastTensorToType(rewriter, rhsTensor, resultTy).value();
Comment on lines +753 to +756
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these casts necessary? Won't the operand types already be same as the result type for the ops under consideration?

if constexpr (std::is_same<AtenOpT, AtenBitwiseAndTensorOp>() ||
std::is_same<AtenOpT, AtenBitwiseAndScalarOp>()) {
rewriter.replaceOpWithNewOp<tosa::LogicalAndOp>(op, resultTy, lhsBool,
rhsBool);
return success();
} else if constexpr (std::is_same<AtenOpT, AtenBitwiseOrTensorOp>()) {
rewriter.replaceOpWithNewOp<tosa::LogicalOrOp>(op, resultTy, lhsBool,
rhsBool);
return success();
} else if constexpr (std::is_same<AtenOpT, AtenBitwiseXorTensorOp>()) {
rewriter.replaceOpWithNewOp<tosa::LogicalXorOp>(op, resultTy, lhsBool,
rhsBool);
return success();
}
}
lhs = tosa::tosaCastTensorToType(rewriter, lhs, resultTy).value();
rhsTensor =
tosa::tosaCastTensorToType(rewriter, rhsTensor, resultTy).value();
Expand Down
8 changes: 8 additions & 0 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,14 @@ bool typeHasZeroDim(ShapedType type) {
return llvm::any_of(outShape, [](int64_t dim) { return dim == 0; });
}

bool isI1Type(Type type) {
if (auto shapedTy = dyn_cast<ShapedType>(type))
type = shapedTy.getElementType();
if (auto intTy = dyn_cast<IntegerType>(type))
return intTy.getWidth() == 1;
return false;
}

void computeResizeParams(int inputSize, int outputSize, bool alignCorners,
tosa::ResizeMode mode, int &scaleN, int &scaleD,
int &offset, int &border) {
Expand Down
54 changes: 54 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4899,6 +4899,33 @@ def ElementwiseBitwiseAndModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseBitwiseAndBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.bool, True),
([-1, -1], torch.bool, True),
]
)
def forward(self, x, y):
return torch.bitwise_and(x, y)


@register_test_case(module_factory=lambda: ElementwiseBitwiseAndBoolModule())
def ElementwiseBitwiseAndBoolModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(3, 4, low=0, high=2).to(torch.bool),
tu.randint(3, 4, low=0, high=2).to(torch.bool),
)


# ==============================================================================


class ElementwiseBitwiseAndStaticShapeModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -4953,6 +4980,33 @@ def ElementwiseBitwiseOrModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseBitwiseOrBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.bool, True),
([-1, -1], torch.bool, True),
]
)
def forward(self, x, y):
return torch.bitwise_or(x, y)


@register_test_case(module_factory=lambda: ElementwiseBitwiseOrBoolModule())
def ElementwiseBitwiseOrBoolModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(3, 4, low=0, high=2).to(torch.bool),
tu.randint(3, 4, low=0, high=2).to(torch.bool),
)


# ==============================================================================


class ElementwiseBitwiseOrStaticShapeModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
17 changes: 17 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,23 @@ func.func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>

// -----

// CHECK-LABEL: func.func @torch.aten.bitwise_and.Tensor$bool(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],i1>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK-NOT: tosa.bitwise_and
// CHECK: %[[VAL_4:.*]] = tosa.logical_and %[[VAL_2]], %[[VAL_3]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
// CHECK: }
func.func @torch.aten.bitwise_and.Tensor$bool(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}

// -----

// CHECK-LABEL: func.func @torch.aten.log2$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ func.func @torch.aten.bitwise_and.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],

// -----

// CHECK-LABEL: torch.aten.bitwise_xor.Tensor$bool
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi1>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xi1>
// CHECK: %[[VAL_2:.*]] = tosa.logical_xor %[[VAL_0]], %[[VAL_1]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
func.func @torch.aten.bitwise_xor.Tensor$bool(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.bitwise_xor.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}

// -----

// CHECK-LABEL: func.func @torch.aten.div.Tensor$mixed_type_fp(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xi32>) -> tensor<?x?xf32> {
Expand Down
Loading