Skip to content

Commit aadb7fc

Browse files
committed
lint error
1 parent 02b9fc8 commit aadb7fc

File tree

1 file changed

+30
-37
lines changed

1 file changed

+30
-37
lines changed

tests/test_ops/test_three_interpolate.py

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -72,28 +72,28 @@ def test_three_interpolate(dtype, device):
7272
], [
7373
2.2060e-01, 3.4110e-01, 3.4110e-01, 2.2060e-01, 2.2060e-01, 2.1380e-01
7474
]],
75-
[[
76-
8.1773e-01, 9.5440e-01, 2.4532e+00,
77-
8.1773e-01, 8.1773e-01, 1.1359e+00
78-
],
79-
[
80-
8.4689e-01, 1.9176e+00, 1.4715e+00,
81-
8.4689e-01, 8.4689e-01, 1.3079e+00
82-
],
83-
[
84-
6.9473e-01, 2.7440e-01, 2.0842e+00,
85-
6.9473e-01, 6.9473e-01, 7.8619e-01
86-
],
87-
[
88-
7.6789e-01, 1.5063e+00, 1.6209e+00,
89-
7.6789e-01, 7.6789e-01, 1.1562e+00
90-
],
91-
[
92-
3.8760e-01, 1.0300e-02, 8.3569e-09,
93-
3.8760e-01, 3.8760e-01, 1.9723e-01
94-
]]],
95-
dtype=dtype,
96-
device=device)
75+
[[
76+
8.1773e-01, 9.5440e-01, 2.4532e+00,
77+
8.1773e-01, 8.1773e-01, 1.1359e+00
78+
],
79+
[
80+
8.4689e-01, 1.9176e+00, 1.4715e+00,
81+
8.4689e-01, 8.4689e-01, 1.3079e+00
82+
],
83+
[
84+
6.9473e-01, 2.7440e-01, 2.0842e+00,
85+
6.9473e-01, 6.9473e-01, 7.8619e-01
86+
],
87+
[
88+
7.6789e-01, 1.5063e+00, 1.6209e+00,
89+
7.6789e-01, 7.6789e-01, 1.1562e+00
90+
],
91+
[
92+
3.8760e-01, 1.0300e-02, 8.3569e-09,
93+
3.8760e-01, 3.8760e-01, 1.9723e-01
94+
]]],
95+
dtype=dtype,
96+
device=device)
9797

9898
assert torch.allclose(output, expected_output, 1e-3, 1e-4)
9999

@@ -148,24 +148,16 @@ def torch_type_trans(dtype):
148148
return np.float64
149149

150150

151-
@pytest.mark.parametrize('dtype', [
152-
torch.half,
153-
torch.float
154-
])
151+
@pytest.mark.parametrize('dtype', [torch.half, torch.float])
155152
@pytest.mark.parametrize('device', [
156153
pytest.param(
157154
'npu',
158155
marks=pytest.mark.skipif(
159156
not IS_NPU_AVAILABLE, reason='requires NPU support'))
160157
])
161-
@pytest.mark.parametrize('shape', [
162-
(2, 5, 6, 6),
163-
(10, 10, 10, 10),
164-
(20, 21, 13, 4),
165-
(2, 10, 2, 18),
166-
(10, 602, 910, 200),
167-
(600, 100, 300, 101)
168-
])
158+
@pytest.mark.parametrize('shape', [(2, 5, 6, 6), (10, 10, 10, 10),
159+
(20, 21, 13, 4), (2, 10, 2, 18),
160+
(10, 602, 910, 200), (600, 100, 300, 101)])
169161
def test_three_interpolate_npu_dynamic_shape(dtype, device, shape):
170162
bs = shape[0]
171163
cs = shape[1]
@@ -175,13 +167,14 @@ def test_three_interpolate_npu_dynamic_shape(dtype, device, shape):
175167
features = np.random.uniform(-10.0, 10.0,
176168
(bs, cs, ms).astype(torch_type_trans(dtype)))
177169
idx = np.random.uniform(0, ms, size=(bs, ns, 3), dtype=np.int32)
178-
weight = np.random.uniform(-10.0, 10.0 (bs, ns, 3)
179-
).astype(torch_type_trans(dtype))
170+
weight = np.random.uniform(-10.0,
171+
10.0 (bs, ns,
172+
3)).astype(torch_type_trans(dtype))
180173

181174
features_npu = torch.tensor(features, dtype=dtype).to(device)
182175
idx_npu = torch.tensor(idx, dtype=torch.int32).to(device)
183176
weight_npu = torch.tensor(weight, dtype=dtype).to(device)
184177

185178
expected_output = three_interpolate_forward_gloden(features, idx, weight)
186179
output = three_interpolate(features_npu, idx_npu, weight_npu)
187-
assert np.allclose(output.cpu().numpy(), expected_output, 1e-3, 1e-4)
180+
assert np.allclose(output.cpu().numpy(), expected_output, 1e-3, 1e-4)

0 commit comments

Comments
 (0)