Skip to content

Commit d2d6440

Browse files
committed
#3364: Use ttl.tensor.Tensor(torch_tensor, data_type) in pytest script
1 parent 6a37935 commit d2d6440

File tree

1 file changed

+5
-32
lines changed

1 file changed

+5
-32
lines changed

tests/tt_eager/python_api_testing/unit_testing/test_moreh_matmul.py

Lines changed: 5 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,9 @@ def get_tensors(input_shape, other_shape, output_shape, require_input_grad, requ
2222
input = torch.randint(-2, 3, input_shape, dtype=cpu_dtype)
2323
other = torch.randint(-2, 3, other_shape, dtype=cpu_dtype)
2424

25-
tt_input = (
26-
ttl.tensor.Tensor(input.reshape(-1).tolist(), input_shape, npu_dtype, cpu_layout)
27-
.pad_to_tile(1)
28-
.to(npu_layout)
29-
.to(device)
30-
)
25+
tt_input = ttl.tensor.Tensor(input, npu_dtype).pad_to_tile(1).to(npu_layout).to(device)
3126

32-
tt_other = (
33-
ttl.tensor.Tensor(other.reshape(-1).tolist(), other_shape, npu_dtype, cpu_layout)
34-
.pad_to_tile(float("nan"))
35-
.to(npu_layout)
36-
.to(device)
37-
)
27+
tt_other = ttl.tensor.Tensor(other, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)
3828

3929
torch_input = input.reshape(-1) if is_1d else input
4030
torch_other = other.reshape(-1) if is_1d else other
@@ -43,36 +33,19 @@ def get_tensors(input_shape, other_shape, output_shape, require_input_grad, requ
4333
output_grad = tt_output_grad = torch_output_grad = tt_input_grad = tt_other_grad = None
4434
if require_input_grad or require_other_grad:
4535
output_grad = torch.randint(-2, 3, output_shape, dtype=cpu_dtype)
46-
tt_output_grad = (
47-
ttl.tensor.Tensor(output_grad.reshape(-1).tolist(), output_shape, npu_dtype, cpu_layout)
48-
.pad_to_tile(float("nan"))
49-
.to(npu_layout)
50-
.to(device)
51-
)
36+
tt_output_grad = ttl.tensor.Tensor(output_grad, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)
5237
torch_output_grad = output_grad[0][0][0][0] if is_1d else output_grad
5338

5439
if require_input_grad:
5540
input_grad = torch.full(input_shape, float("nan"), dtype=cpu_dtype)
56-
tt_input_grad = (
57-
ttl.tensor.Tensor(
58-
input_grad.flatten().tolist(),
59-
input_shape,
60-
npu_dtype,
61-
cpu_layout,
62-
)
63-
.pad_to_tile(float("nan"))
64-
.to(npu_layout)
65-
.to(device)
66-
)
41+
tt_input_grad = ttl.tensor.Tensor(input_grad, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)
6742

6843
if require_other_grad:
6944
other_grad = torch.full(other_shape, float("nan"), dtype=cpu_dtype)
7045
tt_other_grad = (
7146
ttl.tensor.Tensor(
72-
other_grad.flatten().tolist(),
73-
other_shape,
47+
other_grad,
7448
npu_dtype,
75-
cpu_layout,
7649
)
7750
.pad_to_tile(float("nan"))
7851
.to(npu_layout)

0 commit comments

Comments
 (0)