@@ -31,7 +31,7 @@ def test_add_1D_tensor_and_scalar(device, scalar, size):
3131@pytest .mark .parametrize ("h" , [2 * 32 ])
3232@pytest .mark .parametrize ("w" , [4 * 32 ])
3333def test_add_scalar (device , s , h , w ):
34- torch_input_tensor = torch .rand ((1 , 1 , h , w ), dtype = torch .bfloat16 )
34+ torch_input_tensor = torch .rand ((h , w ), dtype = torch .bfloat16 )
3535 torch_output_tensor = torch_input_tensor + s
3636
3737 input_tensor = ttnn .from_torch (torch_input_tensor )
@@ -49,7 +49,7 @@ def test_add_scalar(device, s, h, w):
4949@pytest .mark .parametrize ("h" , [1 ])
5050@pytest .mark .parametrize ("w" , [4 ])
5151def test_add_scalar_and_alpha (device , alpha , scalar_input_tensor_b , h , w ):
52- torch_input_tensor = torch .rand ((1 , 1 , h , w ), dtype = torch .bfloat16 )
52+ torch_input_tensor = torch .rand ((h , w ), dtype = torch .bfloat16 )
5353 torch_output_tensor = torch .add (torch_input_tensor , scalar_input_tensor_b , alpha = alpha )
5454
5555 input_tensor = ttnn .from_torch (torch_input_tensor )
@@ -65,8 +65,8 @@ def test_add_scalar_and_alpha(device, alpha, scalar_input_tensor_b, h, w):
6565@pytest .mark .parametrize ("h" , [32 ])
6666@pytest .mark .parametrize ("w" , [2 * 32 ])
6767def test_add (device , h , w ):
68- torch_a = torch .rand ((1 , 1 , h , w ), dtype = torch .bfloat16 )
69- torch_b = torch .rand ((1 , 1 , h , w ), dtype = torch .bfloat16 )
68+ torch_a = torch .rand ((h , w ), dtype = torch .bfloat16 )
69+ torch_b = torch .rand ((h , w ), dtype = torch .bfloat16 )
7070 torch_output = torch .add (torch_a , torch_b )
7171
7272 a = ttnn .from_torch (torch_a )
@@ -106,7 +106,7 @@ def test_add_4D(device, n, c, h, w):
106106@pytest .mark .parametrize ("w" , [2 * 32 ])
107107@pytest .mark .parametrize ("scalar" , [0.42 ])
108108def test_add_scalar (device , h , w , scalar ):
109- torch_a = torch .rand ((1 , 1 , h , w ), dtype = torch .bfloat16 )
109+ torch_a = torch .rand ((h , w ), dtype = torch .bfloat16 )
110110 torch_output = scalar + torch_a
111111
112112 a = ttnn .from_torch (torch_a )
0 commit comments