Skip to content

Commit d89a2a0

Browse files
committed
fix: torch.logsumexp frontend
1 parent 4fbc67b commit d89a2a0

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

ivy/functional/frontends/torch/reduction_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,12 @@ def dist(input, other, p=2):
9393
return ivy.vector_norm(ivy.subtract(input, other), ord=p)
9494

9595

96+
@with_unsupported_dtypes({"2.2 and below": ("complex",)}, "torch")
9697
@numpy_to_torch_style_args
9798
@to_ivy_arrays_and_back
9899
def logsumexp(input, dim, keepdim=False, *, out=None):
100+
if ivy.is_int_dtype(input):
101+
input = ivy.astype(input, ivy.float32)
99102
c = ivy.max(input, axis=dim, keepdims=True)
100103
if ivy.get_num_dims(c) > 0:
101104
c = ivy.where(ivy.isinf(c), ivy.zeros_like(c), c)

0 commit comments

Comments
 (0)