Skip to content

Commit de5baf7

Browse files
committed
Added fix. Does not seem sensitive to eps value.
1 parent 5cbf371 commit de5baf7

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

src/klay/torch/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
CUTOFF = -math.log(2)
66

77

8-
def log1mexp(x, eps):
8+
def log1mexp(x, eps=1e-12):
99
"""
1010
Numerically accurate evaluation of log(1 - exp(x)) for x < 0.
1111
See [Maechler2012accurate]_ for details.
@@ -14,11 +14,10 @@ def log1mexp(x, eps):
1414
mask = CUTOFF < x # x < 0
1515
return torch.where(
1616
mask,
17-
(-x.expm1() + eps).log(),
18-
(-x.exp() + eps).log1p(),
17+
(torch.clamp(-x.expm1(), min=eps)).log(),
18+
(torch.clamp(-x.exp(), min=eps)).log1p(),
1919
)
2020

21-
2221
def negate_real(x, eps):
2322
return 1 - x
2423

0 commit comments

Comments
 (0)