Skip to content

Commit 5cbf371

Browse files
committed
Add failing test
1 parent a90fb57 commit 5cbf371

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

tests/test_gradient_stability.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""
2+
Test for numerical stability issues in backward pass with log probabilities.
3+
4+
This test identifies cases where forward pass produces finite values but
5+
backward pass introduces NaNs, particularly when inputs are close to -inf
6+
or when log probabilities approach 0 (probability = 1).
7+
"""
8+
import torch
9+
from klay.torch.utils import log1mexp
10+
11+
12+
def test_log1mexp_gradient_stability():
13+
test_cases = [
14+
torch.tensor([-0.01, -0.1, -1.0, -10.0, -100.0], dtype=torch.float32),
15+
torch.tensor([-1000.0], dtype=torch.float32),
16+
torch.tensor([-1e-10], dtype=torch.float32), # Very close to 0
17+
]
18+
19+
for i, x in enumerate(test_cases):
20+
x_test = x.clone().requires_grad_(True)
21+
22+
23+
output = log1mexp(x_test)
24+
25+
assert torch.isfinite(output).all(), f"Output should be finite for case {i + 1}"
26+
27+
loss = output.sum()
28+
loss.backward()
29+
30+
assert torch.isfinite(x_test.grad).all(), f"Gradient should be finite for case {i + 1}"

0 commit comments

Comments
 (0)