Skip to content

Commit 738d5a6

Browse files
committed
Add precision-dependent epsilon configuration
- Add EPS global constant for numerical stability operations - Add DEFAULT_EPS_VALUES_PROB and DEFAULT_EPS_VALUES_LOGPROB dicts with precision-specific defaults (float16, bfloat16, float32, float64) - Add set_eps(eps) function to configure epsilon globally - Remove eps parameters from all functions (log1mexp, negate_real, layers, forward methods) - use global EPS constant instead - Export set_eps, EPS, and default value dicts from klay.torch
1 parent 1bfbeb5 commit 738d5a6

File tree

3 files changed

+58
-27
lines changed

3 files changed

+58
-27
lines changed

src/klay/torch/__init__.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22
from torch import nn
33

44
from .layers import ProbabilisticCircuitLayer, get_semiring
5-
from .utils import unroll_ixs
5+
from .utils import (
6+
unroll_ixs,
7+
set_eps,
8+
EPS,
9+
DEFAULT_EPS_VALUES_PROB,
10+
DEFAULT_EPS_VALUES_LOGPROB,
11+
)
612

713

814
def _create_layers(sum_layer, prod_layer, ixs_in, ixs_out):
@@ -24,13 +30,13 @@ def __init__(self, ixs_in, ixs_out, semiring='real'):
2430
get_semiring(semiring, self.is_probabilistic())
2531
self.layers = _create_layers(self.sum_layer, self.prod_layer, ixs_in, ixs_out)
2632

27-
def forward(self, x_pos, x_neg=None, eps=0):
28-
x = self.encode_input(x_pos, x_neg, eps)
33+
def forward(self, x_pos, x_neg=None):
34+
x = self.encode_input(x_pos, x_neg)
2935
return self.layers(x)
3036

31-
def encode_input(self, pos, neg, eps):
37+
def encode_input(self, pos, neg):
3238
if neg is None:
33-
neg = self.negate(pos, eps)
39+
neg = self.negate(pos)
3440
x = torch.stack([pos, neg], dim=1).flatten()
3541
units = torch.tensor([self.zero, self.one], dtype=torch.float32, device=pos.device)
3642
return torch.cat([units, x])
@@ -41,14 +47,14 @@ def sparsity(self, nb_vars: int) -> float:
4147
dense_params = sum(layer_widths[i] * layer_widths[i + 1] for i in range(len(layer_widths) - 1))
4248
return sparse_params / dense_params
4349

44-
def to_pc(self, x_pos, x_neg=None, eps=0):
50+
def to_pc(self, x_pos, x_neg=None):
4551
""" Converts the circuit into a probabilistic circuit."""
4652
assert self.semiring == "log" or self.semiring == "real"
4753
pc = ProbabilisticCircuitModule([], [], self.semiring)
4854
print("Making PC", pc.sum_layer, pc.sum_layer)
4955
layers = []
5056

51-
x = self.encode_input(x_pos, x_neg, eps)
57+
x = self.encode_input(x_pos, x_neg)
5258
for i, layer in enumerate(self.layers):
5359
if isinstance(layer, self.sum_layer):
5460
new_layer = pc.sum_layer(layer.ix_in, layer.ix_out)
@@ -76,7 +82,7 @@ def sample(self):
7682
return y[2::2]
7783

7884
def condition(self, x_pos, x_neg):
79-
x = self.encode_input(x_pos, x_neg, None)
85+
x = self.encode_input(x_pos, x_neg)
8086
for layer in self.layers:
8187
x = layer.condition(x) \
8288
if isinstance(layer, ProbabilisticCircuitLayer) \

src/klay/torch/layers.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from torch import nn
33

4-
from .utils import negate_real, log1mexp
4+
from .utils import negate_real, log1mexp, EPS
55

66

77
class CircuitLayer(nn.Module):
@@ -14,7 +14,7 @@ def __init__(self, ix_in, ix_out):
1414

1515
def _scatter_forward(self, x: torch.Tensor, reduce: str, **kwargs):
1616
if reduce == "logsumexp":
17-
return self._scatter_logsumexp_forward(x, **kwargs)
17+
return self._scatter_logsumexp_forward(x)
1818
output = torch.empty(self.out_shape, dtype=x.dtype, device=x.device)
1919
output = torch.scatter_reduce(output, 0, index=self.ix_out, src=x, reduce=reduce, include_self=False)
2020
return output
@@ -31,9 +31,9 @@ def _safe_exp(self, x: torch.Tensor):
3131
x.nan_to_num_(nan=0., posinf=float('inf'), neginf=float('-inf'))
3232
return torch.exp(x), max_output
3333

34-
def _scatter_logsumexp_forward(self, x: torch.Tensor, eps: float):
34+
def _scatter_logsumexp_forward(self, x: torch.Tensor):
3535
x, max_output = self._safe_exp(x)
36-
output = torch.full(self.out_shape, eps, dtype=x.dtype, device=x.device)
36+
output = torch.full(self.out_shape, EPS, dtype=x.dtype, device=x.device)
3737
output = torch.scatter_add(output, 0, index=self.ix_out, src=x)
3838
output = torch.log(output) + max_output
3939
return output
@@ -63,8 +63,8 @@ def forward(self, x):
6363

6464

6565
class LogSumLayer(CircuitLayer):
66-
def forward(self, x, eps=10e-16):
67-
return self._scatter_forward(x[self.ix_in], "logsumexp", eps=eps)
66+
def forward(self, x):
67+
return self._scatter_forward(x[self.ix_in], "logsumexp")
6868

6969

7070
class ProbabilisticCircuitLayer(CircuitLayer):
@@ -79,15 +79,15 @@ def get_edge_weights(self):
7979

8080
def renorm_weights(self, x):
8181
with torch.no_grad():
82-
self.weights.data = self.get_log_edge_weights(0) + x
82+
self.weights.data = self.get_log_edge_weights() + x
8383

84-
def get_log_edge_weights(self, eps):
85-
norm = self._scatter_logsumexp_forward(self.weights, eps)
84+
def get_log_edge_weights(self):
85+
norm = self._scatter_logsumexp_forward(self.weights)
8686
return self.weights - norm[self.ix_out]
8787

88-
def sample(self, y, eps=10e-16):
89-
weights = self.get_log_edge_weights(eps)
90-
noise = -(-torch.log(torch.rand_like(weights) + eps) + eps).log()
88+
def sample(self, y):
89+
weights = self.get_log_edge_weights()
90+
noise = -(-torch.log(torch.rand_like(weights) + EPS) + EPS).log()
9191
gumbels = weights + noise
9292
samples = self._scatter_forward(gumbels, "amax")
9393
samples = samples[self.ix_out] == gumbels
@@ -107,9 +107,9 @@ def condition(self, x):
107107

108108

109109
class ProbabilisticLogSumLayer(ProbabilisticCircuitLayer):
110-
def forward(self, x, eps=10e-16):
111-
x = self.get_log_edge_weights(eps) + x[self.ix_in]
112-
return self._scatter_logsumexp_forward(x, eps)
110+
def forward(self, x):
111+
x = self.get_log_edge_weights() + x[self.ix_in]
112+
return self._scatter_logsumexp_forward(x)
113113

114114
def condition(self, x):
115115
y = self.forward(x)

src/klay/torch/utils.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,35 @@
22

33
import torch
44

5+
6+
# Default epsilon values for different precisions
7+
DEFAULT_EPS_VALUES_PROB = {
8+
torch.float16: 1e-4,
9+
torch.bfloat16: 1e-3,
10+
torch.float32: 1e-8,
11+
torch.float64: 1e-15,
12+
}
13+
14+
DEFAULT_EPS_VALUES_LOGPROB = {
15+
torch.float16: 1e-4,
16+
torch.bfloat16: 1e-3,
17+
torch.float32: 1e-16,
18+
torch.float64: 1e-30,
19+
}
20+
21+
# Global epsilon constant - used for all numerical stability operations
22+
EPS = 1e-16
23+
524
CUTOFF = -math.log(2)
625

726

8-
def log1mexp(x, eps):
27+
def set_eps(eps: float):
28+
"""Set global epsilon value for numerical stability in operations."""
29+
global EPS
30+
EPS = eps
31+
32+
33+
def log1mexp(x):
934
"""
1035
Numerically accurate evaluation of log(1 - exp(x)) for x < 0.
1136
See [Maechler2012accurate]_ for details.
@@ -14,12 +39,12 @@ def log1mexp(x, eps):
1439
mask = CUTOFF < x # x < 0
1540
return torch.where(
1641
mask,
17-
(-x.expm1() + eps).log(),
18-
(-x.exp() + eps).log1p(),
42+
(-x.expm1() + EPS).log(),
43+
(-x.exp() + EPS).log1p(),
1944
)
2045

2146

22-
def negate_real(x, eps):
47+
def negate_real(x):
2348
return 1 - x
2449

2550

0 commit comments

Comments
 (0)