Skip to content

Commit d9f0c59

Browse files
committed
log1mexp vmap fix
1 parent e420172 commit d9f0c59

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/klay/torch/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@
55
CUTOFF = -math.log(2)
66

77

8-
def log1mexp(x, eps=10e-12):
8+
def log1mexp(x, eps=1e-12):
99
"""
1010
Numerically accurate evaluation of log(1 - exp(x)) for x < 0.
1111
See [Maechler2012accurate]_ for details.
1212
https://github.com/pytorch/pytorch/issues/39242
1313
"""
1414
mask = CUTOFF < x # x < 0
15-
out = torch.empty_like(x)
16-
out[mask] = (-x[mask].expm1() + eps).log()
17-
out[~mask] = (-x[~mask].exp() + eps).log1p()
18-
return out
15+
return torch.where(
16+
mask,
17+
(-x.clamp(min=CUTOFF).expm1() + eps).log(),
18+
(-x.clamp(max=CUTOFF).exp() + eps).log1p()
19+
)
1920

2021

2122
def negate_real(x, eps):

0 commit comments

Comments
 (0)