Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ FailureOr<Value> getZeroPointValue(PatternRewriter &rewriter, Operation *op,
// Check if a shaped type has any dimension with size 0.
bool typeHasZeroDim(ShapedType type);

// Check if a type is i1 or a shaped type with i1 element type.
bool isI1Type(Type type);

// Compute scale/offset/border parameters for TOSA resize on one dimension.
void computeResizeParams(int inputSize, int outputSize, bool alignCorners,
tosa::ResizeMode mode, int &scaleN, int &scaleD,
Expand Down
21 changes: 21 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,27 @@ class ConvertAtenCompareOp : public TorchToTosaOpConversionPattern<AtenOpT> {
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
if (isBitwiseOp) {
// TOSA bitwise ops do not support i1. Use logical ops for bool tensors.
if (tosa::isI1Type(resultTy)) {
auto lhsBool =
tosa::tosaCastTensorToType(rewriter, lhs, resultTy).value();
auto rhsBool =
tosa::tosaCastTensorToType(rewriter, rhsTensor, resultTy).value();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these casts necessary? Won't the operand types already be same as the result type for the ops under consideration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I traced the lowering in ConvertAtenCompareOp: resultTy comes from the type converter, but lhs/rhsTensor are still the original operands at that point, before any promotion. So I don't believe we can prove operands are already i1 when resultTy is i1. The casts are therefore defensive to guarantee i1 for Logical*Op.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@catcor01 Can you please clarify this question? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sahas3 apologies, I did not realize my comment has been pending since yesterday.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries :)

if constexpr (std::is_same<AtenOpT, AtenBitwiseAndTensorOp>() ||
std::is_same<AtenOpT, AtenBitwiseAndScalarOp>()) {
rewriter.replaceOpWithNewOp<tosa::LogicalAndOp>(op, resultTy, lhsBool,
rhsBool);
return success();
} else if constexpr (std::is_same<AtenOpT, AtenBitwiseOrTensorOp>()) {
rewriter.replaceOpWithNewOp<tosa::LogicalOrOp>(op, resultTy, lhsBool,
rhsBool);
return success();
} else if constexpr (std::is_same<AtenOpT, AtenBitwiseXorTensorOp>()) {
rewriter.replaceOpWithNewOp<tosa::LogicalXorOp>(op, resultTy, lhsBool,
rhsBool);
return success();
}
}
lhs = tosa::tosaCastTensorToType(rewriter, lhs, resultTy).value();
rhsTensor =
tosa::tosaCastTensorToType(rewriter, rhsTensor, resultTy).value();
Expand Down
8 changes: 8 additions & 0 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,14 @@ bool typeHasZeroDim(ShapedType type) {
return llvm::any_of(outShape, [](int64_t dim) { return dim == 0; });
}

bool isI1Type(Type type) {
if (auto shapedTy = dyn_cast<ShapedType>(type))
type = shapedTy.getElementType();
if (auto intTy = dyn_cast<IntegerType>(type))
return intTy.getWidth() == 1;
return false;
}

void computeResizeParams(int inputSize, int outputSize, bool alignCorners,
tosa::ResizeMode mode, int &scaleN, int &scaleD,
int &offset, int &border) {
Expand Down
54 changes: 54 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4899,6 +4899,33 @@ def ElementwiseBitwiseAndModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseBitwiseAndBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.bool, True),
([-1, -1], torch.bool, True),
]
)
def forward(self, x, y):
return torch.bitwise_and(x, y)


@register_test_case(module_factory=lambda: ElementwiseBitwiseAndBoolModule())
def ElementwiseBitwiseAndBoolModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(3, 4, low=0, high=2).to(torch.bool),
tu.randint(3, 4, low=0, high=2).to(torch.bool),
)


# ==============================================================================


class ElementwiseBitwiseAndStaticShapeModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -4953,6 +4980,33 @@ def ElementwiseBitwiseOrModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseBitwiseOrBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.bool, True),
([-1, -1], torch.bool, True),
]
)
def forward(self, x, y):
return torch.bitwise_or(x, y)


@register_test_case(module_factory=lambda: ElementwiseBitwiseOrBoolModule())
def ElementwiseBitwiseOrBoolModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(3, 4, low=0, high=2).to(torch.bool),
tu.randint(3, 4, low=0, high=2).to(torch.bool),
)


# ==============================================================================


class ElementwiseBitwiseOrStaticShapeModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
17 changes: 17 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,23 @@ func.func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>

// -----

// CHECK-LABEL: func.func @torch.aten.bitwise_and.Tensor$bool(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],i1>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK-NOT: tosa.bitwise_and
// CHECK: %[[VAL_4:.*]] = tosa.logical_and %[[VAL_2]], %[[VAL_3]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
// CHECK: }
func.func @torch.aten.bitwise_and.Tensor$bool(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}

// -----

// CHECK-LABEL: func.func @torch.aten.log2$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ func.func @torch.aten.bitwise_and.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],

// -----

// CHECK-LABEL: torch.aten.bitwise_xor.Tensor$bool
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi1>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xi1>
// CHECK: %[[VAL_2:.*]] = tosa.logical_xor %[[VAL_0]], %[[VAL_1]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
func.func @torch.aten.bitwise_xor.Tensor$bool(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.bitwise_xor.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}

// -----

// CHECK-LABEL: func.func @torch.aten.div.Tensor$mixed_type_fp(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xi32>) -> tensor<?x?xf32> {
Expand Down
Loading