Skip to content

Commit 2c3e8bf

Browse files
committed
fix: support int inputs to ivy.cos
1 parent d6c4d78 commit 2c3e8bf

File tree

4 files changed

+7
-2
lines changed

4 files changed

+7
-2
lines changed

ivy/functional/backends/jax/elementwise.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def ceil(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
152152

153153

154154
def cos(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
155+
if ivy.is_int_dtype(x.dtype):
156+
x = jnp.astype(x, jnp.float32)
155157
return jnp.cos(x)
156158

157159

ivy/functional/backends/numpy/elementwise.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ def ceil(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray:
210210

211211
@_scalar_output_to_0d_array
212212
def cos(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray:
213+
if ivy.is_int_dtype(x.dtype):
214+
x = np.astype(x, np.float32)
213215
return np.cos(x, out=out)
214216

215217

ivy/functional/backends/tensorflow/elementwise.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,14 @@ def ceil(
202202
return tf.math.ceil(x)
203203

204204

205-
@with_unsupported_dtypes({"2.15.0 and below": ("integer",)}, backend_version)
206205
def cos(
207206
x: Union[tf.Tensor, tf.Variable],
208207
/,
209208
*,
210209
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
211210
) -> Union[tf.Tensor, tf.Variable]:
211+
if ivy.is_int_dtype(x.dtype):
212+
x = tf.cast(x, tf.float32)
212213
return tf.cos(x)
213214

214215

ivy_tests/test_ivy/test_frontends/test_torch/test_pointwise_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,7 @@ def test_torch_copysign(
10271027
@handle_frontend_test(
10281028
fn_tree="torch.cos",
10291029
dtype_and_x=helpers.dtype_and_values(
1030-
available_dtypes=helpers.get_dtypes("float"),
1030+
available_dtypes=helpers.get_dtypes("numeric"),
10311031
),
10321032
)
10331033
def test_torch_cos(

0 commit comments

Comments
 (0)