Skip to content

Commit 6ff2026

Browse files
committed
update
1 parent a8bf8b7 commit 6ff2026

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

torch_scatter/composite/logsumexp.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
from torch_scatter.utils import broadcast
66

77

8-
def scatter_logsumexp(src: torch.Tensor,
9-
index: torch.Tensor,
10-
dim: int = -1,
11-
out: Optional[torch.Tensor] = None,
12-
dim_size: Optional[int] = None,
13-
eps: float = 1e-12) -> torch.Tensor:
8+
def scatter_logsumexp(
9+
src: torch.Tensor,
10+
index: torch.Tensor,
11+
dim: int = -1,
12+
out: Optional[torch.Tensor] = None,
13+
dim_size: Optional[int] = None,
14+
eps: float = 1e-12,
15+
) -> torch.Tensor:
1416
if not torch.is_floating_point(src):
1517
raise ValueError('`scatter_logsumexp` can only be computed over '
1618
'tensors with floating point data types.')
@@ -48,6 +50,7 @@ def scatter_logsumexp(src: torch.Tensor,
4850

4951
if orig_out is None:
5052
return out.nan_to_num_(neginf=0.0)
51-
else:
52-
mask = ~out.isfinite()
53-
out[mask] = orig_out[mask]
53+
54+
mask = ~out.isfinite()
55+
out[mask] = orig_out[mask]
56+
return out

0 commit comments

Comments
 (0)