@@ -36,6 +36,53 @@ def test__Proxy_set_backend(self):
3636 array = xp .arange (5 )
3737 self .assertIsInstance (array , np .ndarray )
3838
39+ def test__Proxy_get_float_dtype (self ):
40+
41+ from array_api_compat import numpy as apc_np
42+
43+ xp = _config ._Proxy ("numpy" )
44+ xp .set_backend (apc_np )
45+
46+ # Test default float dtype (NumPy)
47+ dtype_default = xp .get_float_dtype ()
48+ self .assertIn (
49+ dtype_default ,
50+ ("float64" , "numpy.float64" ), # API compat may return either
51+ )
52+
53+ # Test explicit float32
54+ dtype_32 = xp .get_float_dtype ("float32" )
55+ self .assertIn (dtype_32 , ("float32" , "numpy.float32" ))
56+
57+ # Test explicit float64
58+ dtype_32 = xp .get_float_dtype ("float64" )
59+ self .assertIn (dtype_32 , ("float64" , "numpy.float64" ))
60+
61+ if _config .TORCH_AVAILABLE :
62+ from array_api_compat import torch as apc_torch
63+
64+ xp .set_backend (apc_torch )
65+
66+ # Test default float dtype (PyTorch)
67+ dtype_default = xp .get_float_dtype ()
68+ self .assertIn (
69+ str (dtype_default ),
70+ ("float32" , "torch.float32" ),
71+ )
72+
73+ # Test explicit float32
74+ dtype_32 = xp .get_float_dtype ("float32" )
75+ self .assertIn (
76+ str (dtype_32 ),
77+ ("float32" , "torch.float32" ),
78+ )
79+
80+ # Switch bact to NumPy.
81+ xp .set_backend (apc_np )
82+
83+ dtype_default = xp .get_float_dtype ()
84+ self .assertIn (dtype_default , ("float64" , "numpy.float64" ))
85+
3986
4087if __name__ == "__main__" :
4188 unittest .main ()
0 commit comments