Skip to content

Commit cdef4d4

Browse files
authored
Use log1p(x) instead of log(1+x) (#1286)
This function is more accurate than torch.log() for small values of input - https://pytorch.org/docs/stable/generated/torch.log1p.html Found with https://github.com/pytorch-labs/torchfix/
1 parent a308b4e commit cdef4d4

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

mnist_forward_forward/main.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,8 @@ def train(self, x_pos, x_neg):
7272
for i in range(self.num_epochs):
7373
g_pos = self.forward(x_pos).pow(2).mean(1)
7474
g_neg = self.forward(x_neg).pow(2).mean(1)
75-
loss = torch.log(
76-
1
77-
+ torch.exp(
75+
loss = torch.log1p(
76+
torch.exp(
7877
torch.cat([-g_pos + self.threshold, g_neg - self.threshold])
7978
)
8079
).mean()

0 commit comments

Comments
 (0)