diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py index fb081e1cd406a..8657cc660b96c 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py @@ -899,7 +899,7 @@ def test_jax_svd( if compute_uv: with BackendHandler.update_backend(backend_fw) as ivy_backend: - ret = [ivy_backend.to_numpy(x) for x in ret] + ret = [ivy_backend.to_numpy(x).astype(np.float64) for x in ret] frontend_ret = [np.asarray(x) for x in frontend_ret] u, s, vh = ret @@ -915,10 +915,11 @@ def test_jax_svd( ) else: with BackendHandler.update_backend(backend_fw) as ivy_backend: - ret = ivy_backend.to_numpy(ret) + ret = ivy_backend.to_numpy(ret).astype(np.float64) + frontend_ret = np.asarray(frontend_ret) assert_all_close( ret_np=ret, - ret_from_gt_np=np.asarray(frontend_ret[0]), + ret_from_gt_np=frontend_ret, rtol=1e-2, atol=1e-2, backend=backend_fw,