Skip to content

Commit 983e5b0

Browse files
committed
Add elementwise python e2e tests
Change-Id: I8770a509a395145fae78cbcc93320279a4c25de9 Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
1 parent e7d5e2d commit 983e5b0

File tree

1 file changed

+81
-0
lines changed

1 file changed

+81
-0
lines changed

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4899,6 +4899,33 @@ def ElementwiseBitwiseAndModule_basic(module, tu: TestUtils):
48994899
# ==============================================================================
49004900

49014901

4902+
class ElementwiseBitwiseAndBoolModule(torch.nn.Module):
4903+
def __init__(self):
4904+
super().__init__()
4905+
4906+
@export
4907+
@annotate_args(
4908+
[
4909+
None,
4910+
([-1, -1], torch.bool, True),
4911+
([-1, -1], torch.bool, True),
4912+
]
4913+
)
4914+
def forward(self, x, y):
4915+
return torch.bitwise_and(x, y)
4916+
4917+
4918+
@register_test_case(module_factory=lambda: ElementwiseBitwiseAndBoolModule())
4919+
def ElementwiseBitwiseAndBoolModule_basic(module, tu: TestUtils):
4920+
module.forward(
4921+
tu.randint(3, 4, low=0, high=2).to(torch.bool),
4922+
tu.randint(3, 4, low=0, high=2).to(torch.bool),
4923+
)
4924+
4925+
4926+
# ==============================================================================
4927+
4928+
49024929
class ElementwiseBitwiseAndStaticShapeModule(torch.nn.Module):
49034930
def __init__(self):
49044931
super().__init__()
@@ -4953,6 +4980,33 @@ def ElementwiseBitwiseOrModule_basic(module, tu: TestUtils):
49534980
# ==============================================================================
49544981

49554982

4983+
class ElementwiseBitwiseOrBoolModule(torch.nn.Module):
4984+
def __init__(self):
4985+
super().__init__()
4986+
4987+
@export
4988+
@annotate_args(
4989+
[
4990+
None,
4991+
([-1, -1], torch.bool, True),
4992+
([-1, -1], torch.bool, True),
4993+
]
4994+
)
4995+
def forward(self, x, y):
4996+
return torch.bitwise_or(x, y)
4997+
4998+
4999+
@register_test_case(module_factory=lambda: ElementwiseBitwiseOrBoolModule())
5000+
def ElementwiseBitwiseOrBoolModule_basic(module, tu: TestUtils):
5001+
module.forward(
5002+
tu.randint(3, 4, low=0, high=2).to(torch.bool),
5003+
tu.randint(3, 4, low=0, high=2).to(torch.bool),
5004+
)
5005+
5006+
5007+
# ==============================================================================
5008+
5009+
49565010
class ElementwiseBitwiseOrStaticShapeModule(torch.nn.Module):
49575011
def __init__(self):
49585012
super().__init__()
@@ -5102,6 +5156,33 @@ def ElementwiseBitwiseXorModule_basic(module, tu: TestUtils):
51025156
# ==============================================================================
51035157

51045158

5159+
class ElementwiseBitwiseXorBoolModule(torch.nn.Module):
5160+
def __init__(self):
5161+
super().__init__()
5162+
5163+
@export
5164+
@annotate_args(
5165+
[
5166+
None,
5167+
([-1, -1], torch.bool, True),
5168+
([-1, -1], torch.bool, True),
5169+
]
5170+
)
5171+
def forward(self, x, y):
5172+
return torch.bitwise_xor(x, y)
5173+
5174+
5175+
@register_test_case(module_factory=lambda: ElementwiseBitwiseXorBoolModule())
5176+
def ElementwiseBitwiseXorBoolModule_basic(module, tu: TestUtils):
5177+
module.forward(
5178+
tu.randint(3, 4, low=0, high=2).to(torch.bool),
5179+
tu.randint(3, 4, low=0, high=2).to(torch.bool),
5180+
)
5181+
5182+
5183+
# ==============================================================================
5184+
5185+
51055186
class ElementwiseBitwiseXorStaticShapeModule(torch.nn.Module):
51065187
def __init__(self):
51075188
super().__init__()

0 commit comments

Comments
 (0)