Skip to content

Commit 253ece8

Browse files
committed
probabilistic circuits in torch
1 parent ec25b3e commit 253ece8

File tree

2 files changed

+78
-40
lines changed

2 files changed

+78
-40
lines changed

src/klay/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,20 @@
44
from collections.abc import Sequence
55

66

7-
def to_torch_module(self: Circuit, semiring: str = "log"):
7+
def to_torch_module(self: Circuit, semiring: str = "log", probabilistic: bool = False):
88
"""
99
Convert the circuit into a PyTorch module.
1010
1111
:param semiring:
1212
The semiring in which the circuit should be evaluated. Supported options are ("log", "real", "mpe", "godel").
13+
:param probabilistic:
14+
If true, construct a probabilistic circuit instead of an arithmetic circuit.
15+
This means the inputs to a sum node are multiplied by a probability, and
16+
we can interpret sum nodes as latent Categorical variables.
1317
"""
1418
from .backends import torch_backend
1519
indices = self._get_indices()
16-
return torch_backend.KnowledgeModule(*indices, semiring=semiring)
20+
return torch_backend.KnowledgeModule(*indices, semiring=semiring, probabilistic=probabilistic)
1721

1822

1923
def to_jax_function(self: Circuit, semiring: str = "log"):

src/klay/backends/torch_backend.py

Lines changed: 72 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22

33
import torch
4+
from torch import nn
45

56
CUTOFF = -math.log(2)
67

@@ -33,11 +34,11 @@ def unroll_csr(csr):
3334
return ixs.repeat_interleave(repeats=deltas)
3435

3536

36-
class KnowledgeModule(torch.nn.Module):
37-
def __init__(self, pointers, csrs, semiring='real'):
37+
class KnowledgeModule(nn.Module):
38+
def __init__(self, pointers, csrs, semiring='real', probabilistic=False):
3839
super(KnowledgeModule, self).__init__()
3940
layers = []
40-
sum_layer, prod_layer, self.zero, self.one, self.negate = get_semiring(semiring)
41+
sum_layer, prod_layer, self.zero, self.one, self.negate = get_semiring(semiring, probabilistic)
4142
for i, (ptrs, csr) in enumerate(zip(pointers, csrs)):
4243
ptrs = torch.as_tensor(ptrs)
4344
csr = torch.as_tensor(csr, dtype=torch.long)
@@ -46,85 +47,118 @@ def __init__(self, pointers, csrs, semiring='real'):
4647
layers.append(prod_layer(ptrs, csr))
4748
else:
4849
layers.append(sum_layer(ptrs, csr))
49-
self.layers = torch.nn.Sequential(*layers)
50+
self.layers = nn.Sequential(*layers)
5051

5152
def forward(self, weights, neg_weights=None, eps=0):
5253
if neg_weights is None:
5354
neg_weights = self.negate(weights, eps)
5455
x = encode_input(weights, neg_weights, self.zero, self.one)
5556
return self.layers(x)
5657

57-
def sparsity(self, nb_vars):
58+
def sparsity(self, nb_vars: int) -> float:
5859
sparse_params = sum(len(l.csr) for l in self.layers)
5960
layer_widths = [nb_vars] + [l.out_shape[0] for l in self.layers]
6061
dense_params = sum(layer_widths[i] * layer_widths[i+1] for i in range(len(layer_widths) - 1))
6162
return sparse_params / dense_params
6263

6364

64-
class KnowledgeLayer(torch.nn.Module):
65+
class KnowledgeLayer(nn.Module):
6566
def __init__(self, ptrs, csr):
6667
super().__init__()
6768
self.register_buffer('ptrs', ptrs)
6869
self.register_buffer('csr', csr)
6970
self.out_shape = (self.csr[-1].item() + 1,)
7071

72+
def _scatter_reduce(self, src: torch.Tensor, reduce: str):
73+
output = torch.empty(self.out_shape, dtype=src.dtype, device=src.device)
74+
output = torch.scatter_reduce(output, 0, index=self.csr, src=src, reduce=reduce, include_self=False)
75+
return output
76+
77+
def _logsumexp_scatter_reduce(self, x: torch.Tensor, epsilon: float):
78+
with torch.no_grad():
79+
max_output = torch.empty(self.out_shape, dtype=x.dtype, device=x.device)
80+
max_output = torch.scatter_reduce(max_output, 0, index=self.csr, src=x, reduce="amax", include_self=False)
81+
x = x - max_output[self.csr]
82+
x.nan_to_num_(nan=0., posinf=float('inf'), neginf=float('-inf'))
83+
x = torch.exp(x)
84+
85+
output = torch.full(self.out_shape, epsilon, dtype=x.dtype, device=x.device)
86+
output = torch.scatter_add(output, 0, index=self.csr, src=x)
87+
output = torch.log(output) + max_output
88+
return output
89+
90+
91+
92+
class ProbabilisticKnowledgeLayer(KnowledgeLayer):
93+
def __init__(self, ptrs, csr):
94+
super().__init__(ptrs, csr)
95+
self.weights = nn.Parameter(torch.randn_like(ptrs))
96+
7197

7298
class SumLayer(KnowledgeLayer):
7399
def forward(self, x):
74-
output = torch.zeros(self.out_shape, dtype=x.dtype, device=x.device)
75-
output = torch.scatter_add(output, 0, index=self.csr, src=x[self.ptrs])
76-
return output
100+
return self._scatter_reduce(x[self.ptrs], "sum")
77101

78102

79103
class ProdLayer(KnowledgeLayer):
80104
def forward(self, x):
81-
output = torch.empty(self.out_shape, dtype=x.dtype, device=x.device)
82-
output = torch.scatter_reduce(output, 0, index=self.csr, src=x[self.ptrs], reduce="prod", include_self=False)
83-
return output
105+
return self._scatter_reduce(x[self.ptrs], "prod")
84106

85107

86108
class MinLayer(KnowledgeLayer):
87109
def forward(self, x):
88-
output = torch.empty(self.out_shape, dtype=x.dtype, device=x.device)
89-
output = torch.scatter_reduce(output, 0, index=self.csr, src=x[self.ptrs], reduce="amin", include_self=False)
90-
return output
110+
return self._scatter_reduce(x[self.ptrs], "amin")
91111

92112

93113
class MaxLayer(KnowledgeLayer):
94114
def forward(self, x):
95-
output = torch.empty(self.out_shape, dtype=x.dtype, device=x.device)
96-
output = torch.scatter_reduce(output, 0, index=self.csr, src=x[self.ptrs], reduce="amax", include_self=False)
97-
return output
115+
return self._scatter_reduce(x[self.ptrs], "amax")
98116

99117

100118
class LogSumLayer(KnowledgeLayer):
101119
def forward(self, x, epsilon=10e-16):
102-
x = x[self.ptrs]
103-
with torch.no_grad():
104-
max_output = torch.empty(self.out_shape, dtype=x.dtype, device=x.device)
105-
max_output = torch.scatter_reduce(max_output, 0, index=self.csr, src=x, reduce="amax", include_self=False)
106-
x = x - max_output[self.csr]
107-
x.nan_to_num_(nan=0., posinf=float('inf'), neginf=float('-inf'))
108-
x = torch.exp(x)
120+
return self._logsumexp_scatter_reduce(x[self.ptrs], epsilon)
109121

110-
output = torch.full(self.out_shape, epsilon, dtype=x.dtype, device=x.device)
111-
output = torch.scatter_add(output, 0, index=self.csr, src=x)
112-
output = torch.log(output) + max_output
113-
return output
122+
123+
class ProbabilisticSumLayer(ProbabilisticKnowledgeLayer):
124+
def forward(self, x):
125+
x = self.get_edge_weights() * x[self.ptrs]
126+
return self._scatter_reduce(x, "sum")
127+
128+
def get_edge_weights(self):
129+
exp_weights = torch.exp(self.weights)
130+
norm = self._scatter_reduce(exp_weights, "sum")
131+
return exp_weights / norm
132+
133+
134+
class ProbabilisticLogSumLayer(ProbabilisticKnowledgeLayer):
135+
def forward(self, x, epsilon=10e-16):
136+
x = self.get_edge_weights(epsilon) + x[self.ptrs]
137+
return self._logsumexp_scatter_reduce(x, epsilon)
138+
139+
def get_edge_weights(self, epsilon):
140+
norm = self._logsumexp_scatter_reduce(self.weights, epsilon)
141+
return self.weights - norm
114142

115143

116-
def get_semiring(name: str):
144+
def get_semiring(name: str, probabilistic: bool):
117145
"""
118146
For a given semiring, returns the sum and product layer,
119147
the zero and one elements, and a negation function.
120148
"""
121-
if name == "real":
122-
return SumLayer, ProdLayer, 0, 1, negate_real
123-
elif name == "log":
124-
return LogSumLayer, SumLayer, float('-inf'), 0, log1mexp
125-
elif name == "mpe":
126-
return MaxLayer, ProdLayer, 0, 1, negate_real
127-
elif name == "godel":
128-
return MaxLayer, MinLayer, 0, 1, negate_real
149+
if probabilistic:
150+
if name == "real":
151+
return ProbabilisticSumLayer, ProdLayer, 0, 1, negate_real
152+
if name == "log":
153+
return ProbabilisticLogSumLayer, SumLayer, float('-inf'), 0, log1mexp
154+
raise ValueError(f"Unknown probabilistic semiring {name}")
129155
else:
156+
if name == "real":
157+
return SumLayer, ProdLayer, 0, 1, negate_real
158+
elif name == "log":
159+
return LogSumLayer, SumLayer, float('-inf'), 0, log1mexp
160+
elif name == "mpe":
161+
return MaxLayer, ProdLayer, 0, 1, negate_real
162+
elif name == "godel":
163+
return MaxLayer, MinLayer, 0, 1, negate_real
130164
raise ValueError(f"Unknown semiring {name}")

0 commit comments

Comments
 (0)