@@ -22,19 +22,9 @@ def get_tensors(input_shape, other_shape, output_shape, require_input_grad, requ
2222 input = torch .randint (- 2 , 3 , input_shape , dtype = cpu_dtype )
2323 other = torch .randint (- 2 , 3 , other_shape , dtype = cpu_dtype )
2424
25- tt_input = (
26- ttl .tensor .Tensor (input .reshape (- 1 ).tolist (), input_shape , npu_dtype , cpu_layout )
27- .pad_to_tile (1 )
28- .to (npu_layout )
29- .to (device )
30- )
25+ tt_input = ttl .tensor .Tensor (input , npu_dtype ).pad_to_tile (1 ).to (npu_layout ).to (device )
3126
32- tt_other = (
33- ttl .tensor .Tensor (other .reshape (- 1 ).tolist (), other_shape , npu_dtype , cpu_layout )
34- .pad_to_tile (float ("nan" ))
35- .to (npu_layout )
36- .to (device )
37- )
27+ tt_other = ttl .tensor .Tensor (other , npu_dtype ).pad_to_tile (float ("nan" )).to (npu_layout ).to (device )
3828
3929 torch_input = input .reshape (- 1 ) if is_1d else input
4030 torch_other = other .reshape (- 1 ) if is_1d else other
@@ -43,36 +33,19 @@ def get_tensors(input_shape, other_shape, output_shape, require_input_grad, requ
4333 output_grad = tt_output_grad = torch_output_grad = tt_input_grad = tt_other_grad = None
4434 if require_input_grad or require_other_grad :
4535 output_grad = torch .randint (- 2 , 3 , output_shape , dtype = cpu_dtype )
46- tt_output_grad = (
47- ttl .tensor .Tensor (output_grad .reshape (- 1 ).tolist (), output_shape , npu_dtype , cpu_layout )
48- .pad_to_tile (float ("nan" ))
49- .to (npu_layout )
50- .to (device )
51- )
36+ tt_output_grad = ttl .tensor .Tensor (output_grad , npu_dtype ).pad_to_tile (float ("nan" )).to (npu_layout ).to (device )
5237 torch_output_grad = output_grad [0 ][0 ][0 ][0 ] if is_1d else output_grad
5338
5439 if require_input_grad :
5540 input_grad = torch .full (input_shape , float ("nan" ), dtype = cpu_dtype )
56- tt_input_grad = (
57- ttl .tensor .Tensor (
58- input_grad .flatten ().tolist (),
59- input_shape ,
60- npu_dtype ,
61- cpu_layout ,
62- )
63- .pad_to_tile (float ("nan" ))
64- .to (npu_layout )
65- .to (device )
66- )
41+ tt_input_grad = ttl .tensor .Tensor (input_grad , npu_dtype ).pad_to_tile (float ("nan" )).to (npu_layout ).to (device )
6742
6843 if require_other_grad :
6944 other_grad = torch .full (other_shape , float ("nan" ), dtype = cpu_dtype )
7045 tt_other_grad = (
7146 ttl .tensor .Tensor (
72- other_grad .flatten ().tolist (),
73- other_shape ,
47+ other_grad ,
7448 npu_dtype ,
75- cpu_layout ,
7649 )
7750 .pad_to_tile (float ("nan" ))
7851 .to (npu_layout )
0 commit comments