Skip to content

Commit da03617

Browse files
committed
Update test__config.py
1 parent 6371e1d commit da03617

1 file changed

Lines changed: 47 additions & 0 deletions

File tree

deeptrack/tests/backend/test__config.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,53 @@ def test__Proxy_set_backend(self):
3636
array = xp.arange(5)
3737
self.assertIsInstance(array, np.ndarray)
3838

39+
def test__Proxy_get_float_dtype(self):
40+
41+
from array_api_compat import numpy as apc_np
42+
43+
xp = _config._Proxy("numpy")
44+
xp.set_backend(apc_np)
45+
46+
# Test default float dtype (NumPy)
47+
dtype_default = xp.get_float_dtype()
48+
self.assertIn(
49+
dtype_default,
50+
("float64", "numpy.float64"), # API compat may return either
51+
)
52+
53+
# Test explicit float32
54+
dtype_32 = xp.get_float_dtype("float32")
55+
self.assertIn(dtype_32, ("float32", "numpy.float32"))
56+
57+
# Test explicit float64
58+
dtype_32 = xp.get_float_dtype("float64")
59+
self.assertIn(dtype_32, ("float64", "numpy.float64"))
60+
61+
if _config.TORCH_AVAILABLE:
62+
from array_api_compat import torch as apc_torch
63+
64+
xp.set_backend(apc_torch)
65+
66+
# Test default float dtype (PyTorch)
67+
dtype_default = xp.get_float_dtype()
68+
self.assertIn(
69+
str(dtype_default),
70+
("float32", "torch.float32"),
71+
)
72+
73+
# Test explicit float32
74+
dtype_32 = xp.get_float_dtype("float32")
75+
self.assertIn(
76+
str(dtype_32),
77+
("float32", "torch.float32"),
78+
)
79+
80+
# Switch bact to NumPy.
81+
xp.set_backend(apc_np)
82+
83+
dtype_default = xp.get_float_dtype()
84+
self.assertIn(dtype_default, ("float64", "numpy.float64"))
85+
3986

4087
if __name__ == "__main__":
4188
unittest.main()

0 commit comments

Comments
 (0)