Skip to content

Commit 526a229

Browse files
committed
Create test_pytorch_features.py
1 parent 01f97bd commit 526a229

1 file changed

Lines changed: 94 additions & 0 deletions

File tree

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# pylint: disable=C0115:missing-class-docstring
2+
# pylint: disable=C0116:missing-function-docstring
3+
# pylint: disable=C0103:invalid-name
4+
5+
import unittest
6+
7+
import numpy as np
8+
import torch
9+
10+
from deeptrack.pytorch import features
11+
12+
13+
class TestTorchFeatures(unittest.TestCase):
14+
15+
def test_ToTensor_numpy(self):
16+
f = features.ToTensor()
17+
x = np.ones((4, 5), dtype=np.float32)
18+
y = f(x)
19+
20+
self.assertIsInstance(y, torch.Tensor)
21+
self.assertEqual(tuple(y.shape), (4, 5))
22+
23+
def test_ToTensor_torch_tensor_passthrough(self):
24+
f = features.ToTensor()
25+
x = torch.ones((4, 5), dtype=torch.float32)
26+
y = f(x)
27+
28+
self.assertIsInstance(y, torch.Tensor)
29+
self.assertTrue(torch.equal(x, y))
30+
self.assertEqual(x.dtype, y.dtype)
31+
32+
def test_ToTensor_numpy_negative_stride(self):
33+
f = features.ToTensor()
34+
x = np.arange(12).reshape(3, 4)[:, ::-1]
35+
y = f(x)
36+
37+
self.assertIsInstance(y, torch.Tensor)
38+
self.assertEqual(tuple(y.shape), (3, 4))
39+
40+
def test_ToTensor_scalar_add_dim(self):
41+
f = features.ToTensor(add_dim_to_number=True)
42+
y = f(3.0)
43+
44+
self.assertIsInstance(y, torch.Tensor)
45+
self.assertEqual(tuple(y.shape), (1,))
46+
47+
def test_ToTensor_scalar_no_add_dim(self):
48+
f = features.ToTensor(add_dim_to_number=False)
49+
y = f(3.0)
50+
51+
self.assertIsInstance(y, float)
52+
53+
def test_ToTensor_permute_always(self):
54+
f = features.ToTensor(permute_mode="always")
55+
x = np.zeros((10, 11, 3), dtype=np.float32)
56+
y = f(x)
57+
58+
self.assertEqual(tuple(y.shape), (3, 10, 11))
59+
60+
def test_ToTensor_permute_never(self):
61+
f = features.ToTensor(permute_mode="never")
62+
x = np.zeros((10, 11, 3), dtype=np.float32)
63+
y = f(x)
64+
65+
self.assertEqual(tuple(y.shape), (10, 11, 3))
66+
67+
def test_ToTensor_permute_numpy_only(self):
68+
f = features.ToTensor(permute_mode="numpy")
69+
x_np = np.zeros((10, 11, 3), dtype=np.float32)
70+
y_np = f(x_np)
71+
72+
x_torch = torch.zeros((10, 11, 3), dtype=torch.float32)
73+
y_torch = f(x_torch)
74+
75+
self.assertEqual(tuple(y_np.shape), (3, 10, 11))
76+
self.assertEqual(tuple(y_torch.shape), (10, 11, 3))
77+
78+
def test_ToTensor_permute_numpy_and_not_int(self):
79+
f = features.ToTensor(permute_mode="numpy_and_not_int")
80+
81+
x_float = np.zeros((10, 11, 3), dtype=np.float32)
82+
y_float = f(x_float)
83+
self.assertEqual(tuple(y_float.shape), (3, 10, 11))
84+
85+
x_int = np.zeros((10, 11, 3), dtype=np.int32)
86+
y_int = f(x_int)
87+
self.assertEqual(tuple(y_int.shape), (10, 11, 3))
88+
89+
def test_ToTensor_dtype(self):
90+
f = features.ToTensor(dtype=torch.float64)
91+
x = np.ones((2, 2), dtype=np.float32)
92+
y = f(x)
93+
94+
self.assertEqual(y.dtype, torch.float64)

0 commit comments

Comments
 (0)