Skip to content

Commit a90fb57

Browse files
committed
better eps handling in torch
1 parent 1bfbeb5 commit a90fb57

File tree

3 files changed

+41
-37
lines changed

3 files changed

+41
-37
lines changed

src/klay/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Sequence
66

77

8-
def to_torch_module(self: Circuit, semiring: str = "log", probabilistic: bool = False):
8+
def to_torch_module(self: Circuit, semiring: str = "log", probabilistic: bool = False, eps: float = 0):
99
"""
1010
Convert the circuit into a PyTorch module.
1111
@@ -15,12 +15,14 @@ def to_torch_module(self: Circuit, semiring: str = "log", probabilistic: bool =
1515
If enabled, construct a probabilistic circuit instead of an arithmetic circuit.
1616
This means the inputs to a sum node are multiplied by a probability, and
1717
we can interpret sum nodes as latent Categorical variables.
18+
:param eps:
19+
Epsilon used by log semiring for numerical stability.
1820
"""
1921
from .torch import CircuitModule, ProbabilisticCircuitModule
2022
indices = self._get_indices()
2123
if probabilistic:
22-
return ProbabilisticCircuitModule(*indices, semiring=semiring)
23-
return CircuitModule(*indices, semiring=semiring)
24+
return ProbabilisticCircuitModule(*indices, semiring=semiring, eps=eps)
25+
return CircuitModule(*indices, semiring=semiring, eps=eps)
2426

2527

2628
def to_jax_function(self: Circuit, semiring: str = "log"):

src/klay/torch/__init__.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,53 +5,54 @@
55
from .utils import unroll_ixs
66

77

8-
def _create_layers(sum_layer, prod_layer, ixs_in, ixs_out):
8+
def _create_layers(sum_layer, prod_layer, ixs_in, ixs_out, eps):
99
layers = []
1010
for i, (ix_in, ix_out) in enumerate(zip(ixs_in, ixs_out)):
1111
ix_in = torch.as_tensor(ix_in, dtype=torch.long)
1212
ix_out = torch.as_tensor(ix_out, dtype=torch.long)
1313
ix_out = unroll_ixs(ix_out)
1414
layer = prod_layer if i % 2 == 0 else sum_layer
15-
layers.append(layer(ix_in, ix_out))
15+
layers.append(layer(ix_in, ix_out, eps))
1616
return nn.Sequential(*layers)
1717

1818

1919
class CircuitModule(nn.Module):
20-
def __init__(self, ixs_in, ixs_out, semiring='real'):
20+
def __init__(self, ixs_in, ixs_out, semiring: str = 'real', eps: float = 0):
2121
super(CircuitModule, self).__init__()
2222
self.semiring = semiring
23+
self._eps = 0
24+
2325
self.sum_layer, self.prod_layer, self.zero, self.one, self.negate = \
2426
get_semiring(semiring, self.is_probabilistic())
25-
self.layers = _create_layers(self.sum_layer, self.prod_layer, ixs_in, ixs_out)
27+
self.layers = _create_layers(self.sum_layer, self.prod_layer, ixs_in, ixs_out, eps)
2628

27-
def forward(self, x_pos, x_neg=None, eps=0):
28-
x = self.encode_input(x_pos, x_neg, eps)
29+
def forward(self, x_pos, x_neg=None):
30+
x = self.encode_input(x_pos, x_neg)
2931
return self.layers(x)
3032

31-
def encode_input(self, pos, neg, eps):
33+
def encode_input(self, pos, neg):
3234
if neg is None:
33-
neg = self.negate(pos, eps)
35+
neg = self.negate(pos, self._eps)
3436
x = torch.stack([pos, neg], dim=1).flatten()
35-
units = torch.tensor([self.zero, self.one], dtype=torch.float32, device=pos.device)
37+
units = torch.tensor([self.zero, self.one], dtype=pos.dtype, device=pos.device)
3638
return torch.cat([units, x])
3739

3840
def sparsity(self, nb_vars: int) -> float:
39-
sparse_params = sum(len(l.ix_out) for l in self.layers)
40-
layer_widths = [nb_vars] + [l.out_shape[0] for l in self.layers]
41+
sparse_params = sum(len(layer.ix_out) for layer in self.layers)
42+
layer_widths = [nb_vars] + [layer.out_shape[0] for layer in self.layers]
4143
dense_params = sum(layer_widths[i] * layer_widths[i + 1] for i in range(len(layer_widths) - 1))
4244
return sparse_params / dense_params
4345

44-
def to_pc(self, x_pos, x_neg=None, eps=0):
46+
def to_pc(self, x_pos, x_neg=None):
4547
""" Converts the circuit into a probabilistic circuit."""
4648
assert self.semiring == "log" or self.semiring == "real"
4749
pc = ProbabilisticCircuitModule([], [], self.semiring)
48-
print("Making PC", pc.sum_layer, pc.sum_layer)
4950
layers = []
5051

51-
x = self.encode_input(x_pos, x_neg, eps)
52+
x = self.encode_input(x_pos, x_neg)
5253
for i, layer in enumerate(self.layers):
5354
if isinstance(layer, self.sum_layer):
54-
new_layer = pc.sum_layer(layer.ix_in, layer.ix_out)
55+
new_layer = pc.sum_layer(layer.ix_in, layer.ix_out, layer._eps)
5556
weights = x.log() if self.semiring == "real" else x
5657
new_layer.weights.data = weights[new_layer.ix_in]
5758
else:
@@ -76,7 +77,7 @@ def sample(self):
7677
return y[2::2]
7778

7879
def condition(self, x_pos, x_neg):
79-
x = self.encode_input(x_pos, x_neg, None)
80+
x = self.encode_input(x_pos, x_neg)
8081
for layer in self.layers:
8182
x = layer.condition(x) \
8283
if isinstance(layer, ProbabilisticCircuitLayer) \

src/klay/torch/layers.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,17 @@
55

66

77
class CircuitLayer(nn.Module):
8-
def __init__(self, ix_in, ix_out):
8+
def __init__(self, ix_in, ix_out, eps):
99
super().__init__()
1010
self.register_buffer('ix_in', ix_in)
1111
self.register_buffer('ix_out', ix_out)
1212
self.out_shape = (self.ix_out[-1].item() + 1,)
1313
self.in_shape = (self.ix_in.max().item() + 1,)
14+
self._eps = eps
1415

15-
def _scatter_forward(self, x: torch.Tensor, reduce: str, **kwargs):
16+
def _scatter_forward(self, x: torch.Tensor, reduce: str):
1617
if reduce == "logsumexp":
17-
return self._scatter_logsumexp_forward(x, **kwargs)
18+
return self._scatter_logsumexp_forward(x)
1819
output = torch.empty(self.out_shape, dtype=x.dtype, device=x.device)
1920
output = torch.scatter_reduce(output, 0, index=self.ix_out, src=x, reduce=reduce, include_self=False)
2021
return output
@@ -31,9 +32,9 @@ def _safe_exp(self, x: torch.Tensor):
3132
x.nan_to_num_(nan=0., posinf=float('inf'), neginf=float('-inf'))
3233
return torch.exp(x), max_output
3334

34-
def _scatter_logsumexp_forward(self, x: torch.Tensor, eps: float):
35+
def _scatter_logsumexp_forward(self, x: torch.Tensor):
3536
x, max_output = self._safe_exp(x)
36-
output = torch.full(self.out_shape, eps, dtype=x.dtype, device=x.device)
37+
output = torch.full(self.out_shape, self._eps, dtype=x.dtype, device=x.device)
3738
output = torch.scatter_add(output, 0, index=self.ix_out, src=x)
3839
output = torch.log(output) + max_output
3940
return output
@@ -63,13 +64,13 @@ def forward(self, x):
6364

6465

6566
class LogSumLayer(CircuitLayer):
66-
def forward(self, x, eps=10e-16):
67-
return self._scatter_forward(x[self.ix_in], "logsumexp", eps=eps)
67+
def forward(self, x):
68+
return self._scatter_forward(x[self.ix_in], "logsumexp")
6869

6970

7071
class ProbabilisticCircuitLayer(CircuitLayer):
71-
def __init__(self, ix_in, ix_out):
72-
super().__init__(ix_in, ix_out)
72+
def __init__(self, ix_in, ix_out, eps):
73+
super().__init__(ix_in, ix_out, eps)
7374
self.weights = nn.Parameter(torch.randn_like(ix_in, dtype=torch.float32))
7475

7576
def get_edge_weights(self):
@@ -79,15 +80,15 @@ def get_edge_weights(self):
7980

8081
def renorm_weights(self, x):
8182
with torch.no_grad():
82-
self.weights.data = self.get_log_edge_weights(0) + x
83+
self.weights.data = self.get_log_edge_weights() + x
8384

84-
def get_log_edge_weights(self, eps):
85-
norm = self._scatter_logsumexp_forward(self.weights, eps)
85+
def get_log_edge_weights(self):
86+
norm = self._scatter_logsumexp_forward(self.weights)
8687
return self.weights - norm[self.ix_out]
8788

88-
def sample(self, y, eps=10e-16):
89-
weights = self.get_log_edge_weights(eps)
90-
noise = -(-torch.log(torch.rand_like(weights) + eps) + eps).log()
89+
def sample(self, y):
90+
weights = self.get_log_edge_weights()
91+
noise = -(-torch.log(torch.rand_like(weights) + self._eps) + self._eps).log()
9192
gumbels = weights + noise
9293
samples = self._scatter_forward(gumbels, "amax")
9394
samples = samples[self.ix_out] == gumbels
@@ -107,9 +108,9 @@ def condition(self, x):
107108

108109

109110
class ProbabilisticLogSumLayer(ProbabilisticCircuitLayer):
110-
def forward(self, x, eps=10e-16):
111-
x = self.get_log_edge_weights(eps) + x[self.ix_in]
112-
return self._scatter_logsumexp_forward(x, eps)
111+
def forward(self, x):
112+
x = self.get_log_edge_weights() + x[self.ix_in]
113+
return self._scatter_logsumexp_forward(x)
113114

114115
def condition(self, x):
115116
y = self.forward(x)

0 commit comments

Comments
 (0)