File tree Expand file tree Collapse file tree 2 files changed +8
-7
lines changed
experiments/visual_sudoku Expand file tree Collapse file tree 2 files changed +8
-7
lines changed Original file line number Diff line number Diff line change @@ -96,7 +96,7 @@ def get_circuit(grid_size: int):
9696
9797
9898def nll_loss (preds , targets ):
99- neg_preds = klay .torch .log1mexp (preds , eps = 1e-7 )
99+ neg_preds = klay .torch .utils . log1mexp (preds )
100100 nll = - torch .where (targets , preds , neg_preds )
101101 return nll .mean ()
102102
Original file line number Diff line number Diff line change 55CUTOFF = - math .log (2 )
66
77
8- def log1mexp (x , eps = 10e -12 ):
8+ def log1mexp (x , eps = 1e -12 ):
99 """
1010 Numerically accurate evaluation of log(1 - exp(x)) for x < 0.
1111 See [Maechler2012accurate]_ for details.
1212 https://github.com/pytorch/pytorch/issues/39242
1313 """
14- mask = CUTOFF < x # x < 0
15- out = torch .empty_like (x )
16- out [mask ] = (- x [mask ].expm1 () + eps ).log ()
17- out [~ mask ] = (- x [~ mask ].exp () + eps ).log1p ()
18- return out
14+ mask = CUTOFF < x
15+ return torch .where (
16+ mask ,
17+ (- x .clamp (min = CUTOFF ).expm1 () + eps ).log (),
18+ (- x .clamp (max = CUTOFF ).exp () + eps ).log1p ()
19+ )
1920
2021
2122def negate_real (x , eps ):
You can’t perform that action at this time.
0 commit comments