Skip to content

Commit 88be963

Browse files
committed
#3003: updated ttnn tests
1 parent de44c2c commit 88be963

File tree

10 files changed

+36
-50
lines changed

10 files changed

+36
-50
lines changed

tests/ttnn/unit_tests/experimental/test_exp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
def test_exp(device, h, w):
1717
torch.manual_seed(0)
1818

19-
torch_input_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
19+
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
2020
torch_output_tensor = torch.exp(torch_input_tensor)
2121

2222
input_tensor = ttnn.from_torch(torch_input_tensor)

tests/ttnn/unit_tests/experimental/test_layer_norm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
def test_layer_norm(device, h, w):
1919
torch.manual_seed(0)
2020

21-
torch_input_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
21+
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
2222
torch_output_tensor = torch.nn.functional.layer_norm(torch_input_tensor, normalized_shape=[w])
2323

2424
input_tensor = ttnn.from_torch(torch_input_tensor)
@@ -37,7 +37,7 @@ def test_layer_norm(device, h, w):
3737
def test_layer_norm_with_weight_and_bias(device, h, w):
3838
torch.manual_seed(0)
3939

40-
torch_input_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
40+
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
4141
torch_weight = torch.rand((w,), dtype=torch.bfloat16)
4242
torch_bias = torch.rand((w,), dtype=torch.bfloat16)
4343
torch_output_tensor = torch.nn.functional.layer_norm(
@@ -66,8 +66,8 @@ def test_layer_norm_with_weight_and_bias(device, h, w):
6666
def test_layer_norm_with_weight_bias_and_residual_input(device, h, w):
6767
torch.manual_seed(0)
6868

69-
torch_input_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
70-
torch_residual_input_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
69+
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
70+
torch_residual_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
7171
torch_weight = torch.rand((w,), dtype=torch.bfloat16)
7272
torch_bias = torch.rand((w,), dtype=torch.bfloat16)
7373
torch_output_tensor = torch.nn.functional.layer_norm(

tests/ttnn/unit_tests/test_add.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_add_1D_tensor_and_scalar(device, scalar, size):
3131
@pytest.mark.parametrize("h", [2 * 32])
3232
@pytest.mark.parametrize("w", [4 * 32])
3333
def test_add_scalar(device, s, h, w):
34-
torch_input_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
34+
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
3535
torch_output_tensor = torch_input_tensor + s
3636

3737
input_tensor = ttnn.from_torch(torch_input_tensor)
@@ -49,7 +49,7 @@ def test_add_scalar(device, s, h, w):
4949
@pytest.mark.parametrize("h", [1])
5050
@pytest.mark.parametrize("w", [4])
5151
def test_add_scalar_and_alpha(device, alpha, scalar_input_tensor_b, h, w):
52-
torch_input_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
52+
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
5353
torch_output_tensor = torch.add(torch_input_tensor, scalar_input_tensor_b, alpha=alpha)
5454

5555
input_tensor = ttnn.from_torch(torch_input_tensor)
@@ -65,8 +65,8 @@ def test_add_scalar_and_alpha(device, alpha, scalar_input_tensor_b, h, w):
6565
@pytest.mark.parametrize("h", [32])
6666
@pytest.mark.parametrize("w", [2 * 32])
6767
def test_add(device, h, w):
68-
torch_a = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
69-
torch_b = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
68+
torch_a = torch.rand((h, w), dtype=torch.bfloat16)
69+
torch_b = torch.rand((h, w), dtype=torch.bfloat16)
7070
torch_output = torch.add(torch_a, torch_b)
7171

7272
a = ttnn.from_torch(torch_a)
@@ -106,7 +106,7 @@ def test_add_4D(device, n, c, h, w):
106106
@pytest.mark.parametrize("w", [2 * 32])
107107
@pytest.mark.parametrize("scalar", [0.42])
108108
def test_add_scalar(device, h, w, scalar):
109-
torch_a = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
109+
torch_a = torch.rand((h, w), dtype=torch.bfloat16)
110110
torch_output = scalar + torch_a
111111

112112
a = ttnn.from_torch(torch_a)
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
@pytest.mark.parametrize("h", [32])
1313
@pytest.mark.parametrize("w", [2 * 32])
14-
def test_free(device, h, w):
15-
torch_input_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
14+
def test_deallocate(device, h, w):
15+
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
1616

1717
input_tensor = ttnn.from_torch(torch_input_tensor)
1818

@@ -25,7 +25,7 @@ def test_free(device, h, w):
2525

2626
# Create a reference to the same storage by using reshape which will create a new flyweight
2727
# (If reshape operation changes, then this test might need to be updated)
28-
output_tensor_reference = ttnn.reshape(output_tensor, (1, 1, h, w))
28+
output_tensor_reference = ttnn.reshape(output_tensor, (h, w))
2929

3030
ttnn.deallocate(output_tensor)
3131
with pytest.raises(RuntimeError) as exception:

tests/ttnn/unit_tests/test_dump_and_load.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
def test_dump_and_load(tmp_path, h, w):
1717
file_name = tmp_path / pathlib.Path("tensor.bin")
1818

19-
torch_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
19+
torch_tensor = torch.rand((h, w), dtype=torch.bfloat16)
2020
tt_tensor = ttnn.from_torch(torch_tensor)
2121
ttnn.dump_tensor(file_name, tt_tensor)
2222

@@ -30,7 +30,7 @@ def test_dump_and_load(tmp_path, h, w):
3030
def test_dump_and_load_tilized(tmp_path, h, w):
3131
file_name = tmp_path / pathlib.Path("tensor.bin")
3232

33-
torch_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
33+
torch_tensor = torch.rand((h, w), dtype=torch.bfloat16)
3434
tt_tensor = ttnn.from_torch(torch_tensor)
3535
tt_tensor = ttnn.to_layout(tt_tensor, ttnn.TILE_LAYOUT)
3636
ttnn.dump_tensor(file_name, tt_tensor)

tests/ttnn/unit_tests/test_slicing.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

tests/ttnn/unit_tests/test_softmax.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414

1515

1616
@skip_for_wormhole_b0()
17-
@pytest.mark.parametrize("h", [32])
18-
@pytest.mark.parametrize("w", [2 * 32])
19-
def test_softmax(device, h, w):
17+
@pytest.mark.parametrize("batch_size", [1, 16])
18+
@pytest.mark.parametrize("h", [32, 64])
19+
@pytest.mark.parametrize("w", [32, 64])
20+
def test_softmax(device, batch_size, h, w):
2021
torch.manual_seed(0)
2122

22-
torch_input_tensor = torch_random((1, 16, 4, 4), -10, 10, dtype=torch.bfloat16)
23+
torch_input_tensor = torch_random((batch_size, h, w), -10, 10, dtype=torch.bfloat16)
2324
torch_output_tensor = F.softmax(torch_input_tensor, dim=-1, dtype=torch.bfloat16)
2425
input_tensor = ttnn.from_torch(torch_input_tensor)
2526
input_tensor = ttnn.to_device(input_tensor, device)

tests/ttnn/unit_tests/test_to_and_from_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
@pytest.mark.parametrize("h", [7])
1313
@pytest.mark.parametrize("w", [3])
1414
def test_to_and_from_4D(h, w):
15-
torch_input = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
15+
torch_input = torch.rand((h, w), dtype=torch.bfloat16)
1616
tt_output = ttnn.from_torch(torch_input)
1717
torch_output = ttnn.to_torch(tt_output)
1818
assert torch.allclose(torch_output, torch_input)

ttnn/core.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,7 +1009,7 @@ def ttnn_reshape(ttl_input_tensor, shape):
10091009
ttl_input_tensor, shape
10101010
)
10111011

1012-
if len(input_tensor.shape) == 4 and len(shape) == 4:
1012+
if input_tensor.is_on_device and len(input_tensor.shape) == 4 and len(shape) == 4:
10131013
w, z, y, x = shape
10141014
return Tensor(ttl.tensor.reshape(ttl_input_tensor, w, z, y, x))
10151015
else:
@@ -1063,7 +1063,7 @@ def permute(input_tensor: Tensor, order: Tuple[int, ...]) -> Tensor:
10631063

10641064
ttl_input_tensor = input_tensor._tensor
10651065

1066-
if len(input_tensor.shape) == 4:
1066+
if input_tensor.is_on_device and len(input_tensor.shape) == 4:
10671067
return Tensor(ttl.tensor.permute(ttl_input_tensor, order))
10681068
else:
10691069

@@ -1099,15 +1099,20 @@ def softmax(input_tensor: Tensor, dim: int, memory_config: MemoryConfig = DRAM_M
10991099
11001100
"""
11011101

1102-
rank = len(input_tensor.shape)
1102+
input_shape = tuple(input_tensor.shape)
1103+
rank = len(input_shape)
11031104
if dim < 0:
11041105
dim = rank + dim
11051106
if dim != rank - 1:
11061107
raise RuntimeError("Softmax can only operate on the last dimension.")
11071108

1109+
input_tensor = _reshape_to_4D(input_tensor)
1110+
11081111
ttl_input_tensor = input_tensor._tensor
11091112
ttl_output_tensor = ttl.tensor.softmax(ttl_input_tensor, output_mem_config=memory_config)
1110-
return Tensor(ttl_output_tensor)
1113+
output_tensor = Tensor(ttl_output_tensor)
1114+
output_tensor = reshape(output_tensor, input_shape)
1115+
return output_tensor
11111116

11121117

11131118
def embedding(

ttnn/experimental.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515

1616

1717
def exp(input_tensor: Tensor) -> Tensor:
18+
original_shape = tuple(input_tensor.shape)
19+
input_tensor = _reshape_to_4D(input_tensor)
1820
ttl_input_tensor = input_tensor._tensor
19-
output_tensor = ttl.tensor.exp(ttl_input_tensor)
20-
return Tensor(output_tensor)
21+
ttl_output_tensor = ttl.tensor.exp(ttl_input_tensor)
22+
output_tensor = Tensor(ttl_output_tensor)
23+
output_tensor = reshape(output_tensor, original_shape)
24+
return output_tensor
2125

2226

2327
def gelu(input_tensor: Tensor, fast_and_approx=True) -> Tensor:

0 commit comments

Comments
 (0)