We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent e420172 commit d9f0c59Copy full SHA for d9f0c59
src/klay/torch/utils.py
@@ -5,17 +5,18 @@
5
CUTOFF = -math.log(2)
6
7
8
-def log1mexp(x, eps=10e-12):
+def log1mexp(x, eps=1e-12):
9
"""
10
Numerically accurate evaluation of log(1 - exp(x)) for x < 0.
11
See [Maechler2012accurate]_ for details.
12
https://github.com/pytorch/pytorch/issues/39242
13
14
mask = CUTOFF < x # x < 0
15
- out = torch.empty_like(x)
16
- out[mask] = (-x[mask].expm1() + eps).log()
17
- out[~mask] = (-x[~mask].exp() + eps).log1p()
18
- return out
+ return torch.where(
+ mask,
+ (-x.clamp(min=CUTOFF).expm1() + eps).log(),
+ (-x.clamp(max=CUTOFF).exp() + eps).log1p()
19
+ )
20
21
22
def negate_real(x, eps):
0 commit comments