@@ -4899,6 +4899,33 @@ def ElementwiseBitwiseAndModule_basic(module, tu: TestUtils):
48994899# ==============================================================================
49004900
49014901
4902+ class ElementwiseBitwiseAndBoolModule (torch .nn .Module ):
4903+ def __init__ (self ):
4904+ super ().__init__ ()
4905+
4906+ @export
4907+ @annotate_args (
4908+ [
4909+ None ,
4910+ ([- 1 , - 1 ], torch .bool , True ),
4911+ ([- 1 , - 1 ], torch .bool , True ),
4912+ ]
4913+ )
4914+ def forward (self , x , y ):
4915+ return torch .bitwise_and (x , y )
4916+
4917+
4918+ @register_test_case (module_factory = lambda : ElementwiseBitwiseAndBoolModule ())
4919+ def ElementwiseBitwiseAndBoolModule_basic (module , tu : TestUtils ):
4920+ module .forward (
4921+ tu .randint (3 , 4 , low = 0 , high = 2 ).to (torch .bool ),
4922+ tu .randint (3 , 4 , low = 0 , high = 2 ).to (torch .bool ),
4923+ )
4924+
4925+
4926+ # ==============================================================================
4927+
4928+
49024929class ElementwiseBitwiseAndStaticShapeModule (torch .nn .Module ):
49034930 def __init__ (self ):
49044931 super ().__init__ ()
@@ -4953,6 +4980,33 @@ def ElementwiseBitwiseOrModule_basic(module, tu: TestUtils):
49534980# ==============================================================================
49544981
49554982
4983+ class ElementwiseBitwiseOrBoolModule (torch .nn .Module ):
4984+ def __init__ (self ):
4985+ super ().__init__ ()
4986+
4987+ @export
4988+ @annotate_args (
4989+ [
4990+ None ,
4991+ ([- 1 , - 1 ], torch .bool , True ),
4992+ ([- 1 , - 1 ], torch .bool , True ),
4993+ ]
4994+ )
4995+ def forward (self , x , y ):
4996+ return torch .bitwise_or (x , y )
4997+
4998+
4999+ @register_test_case (module_factory = lambda : ElementwiseBitwiseOrBoolModule ())
5000+ def ElementwiseBitwiseOrBoolModule_basic (module , tu : TestUtils ):
5001+ module .forward (
5002+ tu .randint (3 , 4 , low = 0 , high = 2 ).to (torch .bool ),
5003+ tu .randint (3 , 4 , low = 0 , high = 2 ).to (torch .bool ),
5004+ )
5005+
5006+
5007+ # ==============================================================================
5008+
5009+
49565010class ElementwiseBitwiseOrStaticShapeModule (torch .nn .Module ):
49575011 def __init__ (self ):
49585012 super ().__init__ ()
0 commit comments