@@ -72,28 +72,28 @@ def test_three_interpolate(dtype, device):
72
72
], [
73
73
2.2060e-01 , 3.4110e-01 , 3.4110e-01 , 2.2060e-01 , 2.2060e-01 , 2.1380e-01
74
74
]],
75
- [[
76
- 8.1773e-01 , 9.5440e-01 , 2.4532e+00 ,
77
- 8.1773e-01 , 8.1773e-01 , 1.1359e+00
78
- ],
79
- [
80
- 8.4689e-01 , 1.9176e+00 , 1.4715e+00 ,
81
- 8.4689e-01 , 8.4689e-01 , 1.3079e+00
82
- ],
83
- [
84
- 6.9473e-01 , 2.7440e-01 , 2.0842e+00 ,
85
- 6.9473e-01 , 6.9473e-01 , 7.8619e-01
86
- ],
87
- [
88
- 7.6789e-01 , 1.5063e+00 , 1.6209e+00 ,
89
- 7.6789e-01 , 7.6789e-01 , 1.1562e+00
90
- ],
91
- [
92
- 3.8760e-01 , 1.0300e-02 , 8.3569e-09 ,
93
- 3.8760e-01 , 3.8760e-01 , 1.9723e-01
94
- ]]],
95
- dtype = dtype ,
96
- device = device )
75
+ [[
76
+ 8.1773e-01 , 9.5440e-01 , 2.4532e+00 ,
77
+ 8.1773e-01 , 8.1773e-01 , 1.1359e+00
78
+ ],
79
+ [
80
+ 8.4689e-01 , 1.9176e+00 , 1.4715e+00 ,
81
+ 8.4689e-01 , 8.4689e-01 , 1.3079e+00
82
+ ],
83
+ [
84
+ 6.9473e-01 , 2.7440e-01 , 2.0842e+00 ,
85
+ 6.9473e-01 , 6.9473e-01 , 7.8619e-01
86
+ ],
87
+ [
88
+ 7.6789e-01 , 1.5063e+00 , 1.6209e+00 ,
89
+ 7.6789e-01 , 7.6789e-01 , 1.1562e+00
90
+ ],
91
+ [
92
+ 3.8760e-01 , 1.0300e-02 , 8.3569e-09 ,
93
+ 3.8760e-01 , 3.8760e-01 , 1.9723e-01
94
+ ]]],
95
+ dtype = dtype ,
96
+ device = device )
97
97
98
98
assert torch .allclose (output , expected_output , 1e-3 , 1e-4 )
99
99
@@ -148,24 +148,16 @@ def torch_type_trans(dtype):
148
148
return np .float64
149
149
150
150
151
- @pytest .mark .parametrize ('dtype' , [
152
- torch .half ,
153
- torch .float
154
- ])
151
+ @pytest .mark .parametrize ('dtype' , [torch .half , torch .float ])
155
152
@pytest .mark .parametrize ('device' , [
156
153
pytest .param (
157
154
'npu' ,
158
155
marks = pytest .mark .skipif (
159
156
not IS_NPU_AVAILABLE , reason = 'requires NPU support' ))
160
157
])
161
- @pytest .mark .parametrize ('shape' , [
162
- (2 , 5 , 6 , 6 ),
163
- (10 , 10 , 10 , 10 ),
164
- (20 , 21 , 13 , 4 ),
165
- (2 , 10 , 2 , 18 ),
166
- (10 , 602 , 910 , 200 ),
167
- (600 , 100 , 300 , 101 )
168
- ])
158
+ @pytest .mark .parametrize ('shape' , [(2 , 5 , 6 , 6 ), (10 , 10 , 10 , 10 ),
159
+ (20 , 21 , 13 , 4 ), (2 , 10 , 2 , 18 ),
160
+ (10 , 602 , 910 , 200 ), (600 , 100 , 300 , 101 )])
169
161
def test_three_interpolate_npu_dynamic_shape (dtype , device , shape ):
170
162
bs = shape [0 ]
171
163
cs = shape [1 ]
@@ -175,13 +167,14 @@ def test_three_interpolate_npu_dynamic_shape(dtype, device, shape):
175
167
features = np .random .uniform (- 10.0 , 10.0 ,
176
168
(bs , cs , ms ).astype (torch_type_trans (dtype )))
177
169
idx = np .random .uniform (0 , ms , size = (bs , ns , 3 ), dtype = np .int32 )
178
- weight = np .random .uniform (- 10.0 , 10.0 (bs , ns , 3 )
179
- ).astype (torch_type_trans (dtype ))
170
+ weight = np .random .uniform (- 10.0 ,
171
+ 10.0 (bs , ns ,
172
+ 3 )).astype (torch_type_trans (dtype ))
180
173
181
174
features_npu = torch .tensor (features , dtype = dtype ).to (device )
182
175
idx_npu = torch .tensor (idx , dtype = torch .int32 ).to (device )
183
176
weight_npu = torch .tensor (weight , dtype = dtype ).to (device )
184
177
185
178
expected_output = three_interpolate_forward_gloden (features , idx , weight )
186
179
output = three_interpolate (features_npu , idx_npu , weight_npu )
187
- assert np .allclose (output .cpu ().numpy (), expected_output , 1e-3 , 1e-4 )
180
+ assert np .allclose (output .cpu ().numpy (), expected_output , 1e-3 , 1e-4 )
0 commit comments