@@ -4899,6 +4899,33 @@ def ElementwiseBitwiseAndModule_basic(module, tu: TestUtils):
48994899# ==============================================================================
49004900
49014901
4902+ class ElementwiseBitwiseAndBoolModule (torch .nn .Module ):
4903+ def __init__ (self ):
4904+ super ().__init__ ()
4905+
4906+ @export
4907+ @annotate_args (
4908+ [
4909+ None ,
4910+ ([- 1 , - 1 ], torch .bool , True ),
4911+ ([- 1 , - 1 ], torch .bool , True ),
4912+ ]
4913+ )
4914+ def forward (self , x , y ):
4915+ return torch .bitwise_and (x , y )
4916+
4917+
4918+ @register_test_case (module_factory = lambda : ElementwiseBitwiseAndBoolModule ())
4919+ def ElementwiseBitwiseAndBoolModule_basic (module , tu : TestUtils ):
4920+ module .forward (
4921+ tu .randint (3 , 4 , low = 0 , high = 2 ).to (torch .bool ),
4922+ tu .randint (3 , 4 , low = 0 , high = 2 ).to (torch .bool ),
4923+ )
4924+
4925+
4926+ # ==============================================================================
4927+
4928+
49024929class ElementwiseBitwiseAndStaticShapeModule (torch .nn .Module ):
49034930 def __init__ (self ):
49044931 super ().__init__ ()
@@ -4953,6 +4980,33 @@ def ElementwiseBitwiseOrModule_basic(module, tu: TestUtils):
49534980# ==============================================================================
49544981
49554982
4983+ class ElementwiseBitwiseOrBoolModule (torch .nn .Module ):
4984+ def __init__ (self ):
4985+ super ().__init__ ()
4986+
4987+ @export
4988+ @annotate_args (
4989+ [
4990+ None ,
4991+ ([- 1 , - 1 ], torch .bool , True ),
4992+ ([- 1 , - 1 ], torch .bool , True ),
4993+ ]
4994+ )
4995+ def forward (self , x , y ):
4996+ return torch .bitwise_or (x , y )
4997+
4998+
4999+ @register_test_case (module_factory = lambda : ElementwiseBitwiseOrBoolModule ())
5000+ def ElementwiseBitwiseOrBoolModule_basic (module , tu : TestUtils ):
5001+ module .forward (
5002+ tu .randint (3 , 4 , low = 0 , high = 2 ).to (torch .bool ),
5003+ tu .randint (3 , 4 , low = 0 , high = 2 ).to (torch .bool ),
5004+ )
5005+
5006+
5007+ # ==============================================================================
5008+
5009+
49565010class ElementwiseBitwiseOrStaticShapeModule (torch .nn .Module ):
49575011 def __init__ (self ):
49585012 super ().__init__ ()
@@ -5102,6 +5156,33 @@ def ElementwiseBitwiseXorModule_basic(module, tu: TestUtils):
51025156# ==============================================================================
51035157
51045158
5159+ class ElementwiseBitwiseXorBoolModule (torch .nn .Module ):
5160+ def __init__ (self ):
5161+ super ().__init__ ()
5162+
5163+ @export
5164+ @annotate_args (
5165+ [
5166+ None ,
5167+ ([- 1 , - 1 ], torch .bool , True ),
5168+ ([- 1 , - 1 ], torch .bool , True ),
5169+ ]
5170+ )
5171+ def forward (self , x , y ):
5172+ return torch .bitwise_xor (x , y )
5173+
5174+
5175+ @register_test_case (module_factory = lambda : ElementwiseBitwiseXorBoolModule ())
5176+ def ElementwiseBitwiseXorBoolModule_basic (module , tu : TestUtils ):
5177+ module .forward (
5178+ tu .randint (3 , 4 , low = 0 , high = 2 ).to (torch .bool ),
5179+ tu .randint (3 , 4 , low = 0 , high = 2 ).to (torch .bool ),
5180+ )
5181+
5182+
5183+ # ==============================================================================
5184+
5185+
51055186class ElementwiseBitwiseXorStaticShapeModule (torch .nn .Module ):
51065187 def __init__ (self ):
51075188 super ().__init__ ()
0 commit comments