Skip to content

Commit e420172

Browse files
authored
Numerical stability issues in gradient for log1mexp (#18)
1 parent a90fb57 commit e420172

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

src/klay/torch/utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,17 @@
55
CUTOFF = -math.log(2)
66

77

8-
def log1mexp(x, eps):
8+
def log1mexp(x, eps=10e-12):
99
"""
1010
Numerically accurate evaluation of log(1 - exp(x)) for x < 0.
1111
See [Maechler2012accurate]_ for details.
1212
https://github.com/pytorch/pytorch/issues/39242
1313
"""
1414
mask = CUTOFF < x # x < 0
15-
return torch.where(
16-
mask,
17-
(-x.expm1() + eps).log(),
18-
(-x.exp() + eps).log1p(),
19-
)
15+
out = torch.empty_like(x)
16+
out[mask] = (-x[mask].expm1() + eps).log()
17+
out[~mask] = (-x[~mask].exp() + eps).log1p()
18+
return out
2019

2120

2221
def negate_real(x, eps):

tests/test_gradient_stability.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""
2+
Test for numerical stability issues in backward pass with log probabilities.
3+
4+
This test identifies cases where forward pass produces finite values but
5+
backward pass introduces NaNs, particularly when inputs are close to -inf
6+
or when log probabilities approach 0 (probability = 1).
7+
"""
8+
import torch
9+
10+
from klay.torch.utils import log1mexp
11+
12+
13+
def test_log1mexp_gradient_stability():
14+
test_cases = [-1e-10, -0.01, -0.1, -1.0, -10.0, -100.0, -1000.0]
15+
16+
for x in test_cases:
17+
x = torch.tensor(x, dtype=torch.float32).requires_grad_(True)
18+
out = log1mexp(x)
19+
20+
assert torch.isfinite(out), f"Output is not finite for {x}."
21+
out.backward()
22+
assert torch.isfinite(x.grad), f"Gradient is not finite for {x}."
23+
24+
25+
if __name__ == "__main__":
26+
test_log1mexp_gradient_stability()

0 commit comments

Comments
 (0)