Skip to content

Commit 1f6b420

Browse files
committed
log1mexp vmap fix
1 parent 8106f77 commit 1f6b420

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

experiments/visual_sudoku/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def get_circuit(grid_size: int):
9696

9797

9898
def nll_loss(preds, targets):
99-
neg_preds = klay.torch.log1mexp(preds, eps=1e-7)
99+
neg_preds = klay.torch.utils.log1mexp(preds)
100100
nll = -torch.where(targets, preds, neg_preds)
101101
return nll.mean()
102102

src/klay/torch/utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@
55
CUTOFF = -math.log(2)
66

77

8-
def log1mexp(x, eps=10e-12):
8+
def log1mexp(x, eps=1e-12):
99
"""
1010
Numerically accurate evaluation of log(1 - exp(x)) for x < 0.
1111
See [Maechler2012accurate]_ for details.
1212
https://github.com/pytorch/pytorch/issues/39242
1313
"""
14-
mask = CUTOFF < x # x < 0
15-
out = torch.empty_like(x)
16-
out[mask] = (-x[mask].expm1() + eps).log()
17-
out[~mask] = (-x[~mask].exp() + eps).log1p()
18-
return out
14+
mask = CUTOFF < x
15+
return torch.where(
16+
mask,
17+
(-x.clamp(min=CUTOFF).expm1() + eps).log(),
18+
(-x.clamp(max=CUTOFF).exp() + eps).log1p()
19+
)
1920

2021

2122
def negate_real(x, eps):

0 commit comments

Comments
 (0)