Skip to content

Commit 143c53b

Browse files
committed
fix: incorrect implementation of ivy.sum with jax backend
1 parent 83f4610 commit 143c53b

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

ivy/functional/backends/jax/statistical.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,9 @@ def sum(
115115
axis = tuple(axis) if isinstance(axis, list) else axis
116116
if ivy.is_bool_dtype(x):
117117
if jax.config.jax_enable_x64:
118-
return jnp.sum(a=x, axis=axis, dtype=ivy.as_native_dtype("int64"), keepdims=keepdims).astype(dtype)
119-
return jnp.sum(a=x, axis=axis, dtype=ivy.as_native_dtype("int32"), keepdims=keepdims).astype(dtype)
118+
dtype = ivy.as_native_dtype("int64")
119+
else:
120+
dtype = ivy.as_native_dtype("int32")
120121
return jnp.sum(a=x, axis=axis, dtype=dtype, keepdims=keepdims)
121122

122123

ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,8 @@ def test_torch_sum(
972972
dtype=castable_dtype,
973973
atol=1e-02,
974974
rtol=1e-02,
975+
test_dtypes=False,
976+
test_values=input_dtype[0] != "bool",
975977
)
976978

977979

0 commit comments

Comments
 (0)