Skip to content

Commit dd7c1c0

Browse files
committed
use smaller tensors for the dtype unittests
1 parent 931194f commit dd7c1c0

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

tests/test_image_interpolation_2d.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ def test_insert_into_image_nearest_interp_2d():
161161
[torch.float32, torch.float64, torch.complex64, torch.complex128]
162162
)
163163
def test_insert_into_image_2d_type_consistency(dtype):
164-
image = torch.rand((28, 28), dtype=dtype)
165-
coords = torch.tensor(np.random.uniform(low=0, high=27, size=(3, 4, 2)))
164+
image = torch.rand((4, 4), dtype=dtype)
165+
coords = torch.tensor(np.random.uniform(low=0, high=3, size=(3, 4, 2)))
166166
values = torch.rand(size=(3, 4), dtype=dtype)
167167
# cast the dtype to corresponding float for weights
168168
weights = torch.zeros_like(image, dtype=torch.float64)
@@ -180,8 +180,8 @@ def test_insert_into_image_2d_type_consistency(dtype):
180180

181181

182182
def test_insert_into_image_3d_type_error():
183-
image = torch.rand((28, 28), dtype=torch.complex64)
184-
coords = torch.tensor(np.random.uniform(low=0, high=27, size=(3, 4, 2)))
183+
image = torch.rand((4, 4), dtype=torch.complex64)
184+
coords = torch.tensor(np.random.uniform(low=0, high=3, size=(3, 4, 2)))
185185
values = torch.rand(size=(3, 4), dtype=torch.complex128)
186186
# cast the dtype to corresponding float for weights
187187
with pytest.raises(ValueError):

tests/test_image_interpolation_3d.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ def test_insert_multiple_values_into_multichannel_image_2d_nearest():
162162
[torch.float32, torch.float64, torch.complex64, torch.complex128]
163163
)
164164
def test_insert_into_image_2d_type_consistency(dtype):
165-
image = torch.rand((28, 28, 28), dtype=dtype)
166-
coords = torch.tensor(np.random.uniform(low=0, high=27, size=(3, 4, 5, 3)))
165+
image = torch.rand((4, 4, 4), dtype=dtype)
166+
coords = torch.tensor(np.random.uniform(low=0, high=3, size=(3, 4, 5, 3)))
167167
values = torch.rand(size=(3, 4, 5), dtype=dtype)
168168
# cast the dtype to corresponding float for weights
169169
weights = torch.zeros_like(image, dtype=torch.float64)
@@ -181,8 +181,8 @@ def test_insert_into_image_2d_type_consistency(dtype):
181181

182182

183183
def test_insert_into_image_3d_type_error():
184-
image = torch.rand((28, 28, 28), dtype=torch.complex64)
185-
coords = torch.tensor(np.random.uniform(low=0, high=27, size=(3, 4, 5, 3)))
184+
image = torch.rand((4, 4, 4), dtype=torch.complex64)
185+
coords = torch.tensor(np.random.uniform(low=0, high=3, size=(3, 4, 5, 3)))
186186
values = torch.rand(size=(3, 4, 5), dtype=torch.complex128)
187187
# cast the dtype to corresponding float for weights
188188
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)