Skip to content

Commit 5dd46b3

Browse files
committed
pc sampling on real+log
1 parent bb7e23f commit 5dd46b3

File tree

1 file changed

+28
-25
lines changed

1 file changed

+28
-25
lines changed

src/klay/backends/torch_backend.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import math
2-
from functools import reduce
32

43
import torch
54
from torch import nn
@@ -97,9 +96,9 @@ def _safe_exp(self, x: torch.Tensor):
9796
x.nan_to_num_(nan=0., posinf=float('inf'), neginf=float('-inf'))
9897
return torch.exp(x), max_output
9998

100-
def _logsumexp_scatter(self, x: torch.Tensor, epsilon: float):
99+
def _logsumexp_scatter(self, x: torch.Tensor, eps: float):
101100
x, max_output = self._safe_exp(x)
102-
output = torch.full(self.out_shape, epsilon, dtype=x.dtype, device=x.device)
101+
output = torch.full(self.out_shape, eps, dtype=x.dtype, device=x.device)
103102
output = torch.scatter_add(output, 0, index=self.csr, src=x)
104103
output = torch.log(output) + max_output
105104
return output
@@ -111,6 +110,24 @@ def __init__(self, ptrs, csr):
111110
super().__init__(ptrs, csr)
112111
self.weights = nn.Parameter(torch.randn_like(ptrs, dtype=torch.float32))
113112

113+
def get_edge_weights(self):
114+
exp_weights, _ = self._safe_exp(self.weights)
115+
norm = self._scatter_forward(exp_weights, "sum")
116+
return exp_weights / norm[self.csr]
117+
118+
def get_log_edge_weights(self, eps):
119+
norm = self._logsumexp_scatter(self.weights, eps)
120+
return self.weights - norm[self.csr]
121+
122+
def sample_pc(self, y, eps=10e-16):
123+
weights = self.get_log_edge_weights(eps)
124+
noise = -(-torch.log(torch.rand_like(weights) + eps) + eps).log()
125+
gumbels = weights + noise
126+
samples = self._scatter_forward(gumbels, "amax")
127+
samples = samples[self.csr] == gumbels
128+
samples &= y[self.csr].to(torch.bool)
129+
return self._scatter_backward(samples, "sum") > 0
130+
114131

115132
class SumLayer(KnowledgeLayer):
116133
def forward(self, x):
@@ -124,6 +141,9 @@ class ProdLayer(KnowledgeLayer):
124141
def forward(self, x):
125142
return self._scatter_forward(x[self.ptrs], "prod")
126143

144+
def sample_pc(self, y):
145+
return self._scatter_backward(y[self.csr], "sum") > 0
146+
127147

128148
class MinLayer(KnowledgeLayer):
129149
def forward(self, x):
@@ -136,37 +156,20 @@ def forward(self, x):
136156

137157

138158
class LogSumLayer(KnowledgeLayer):
139-
def forward(self, x, epsilon=10e-16):
140-
return self._logsumexp_scatter(x[self.ptrs], epsilon)
159+
def forward(self, x, eps=10e-16):
160+
return self._logsumexp_scatter(x[self.ptrs], eps)
141161

142162

143163
class ProbabilisticSumLayer(ProbabilisticKnowledgeLayer):
144164
def forward(self, x):
145165
x = self.get_edge_weights() * x[self.ptrs]
146166
return self._scatter_forward(x, "sum")
147167

148-
def get_edge_weights(self):
149-
exp_weights, _ = self._safe_exp(self.weights)
150-
norm = self._scatter_forward(exp_weights, "sum")
151-
return exp_weights / norm[self.csr]
152-
153168

154169
class ProbabilisticLogSumLayer(ProbabilisticKnowledgeLayer):
155-
def forward(self, x, epsilon=10e-16):
156-
x = self.get_edge_weights(epsilon) + x[self.ptrs]
157-
return self._logsumexp_scatter(x, epsilon)
158-
159-
def get_edge_weights(self, epsilon):
160-
norm = self._logsumexp_scatter(self.weights, epsilon)
161-
return self.weights - norm[self.csr]
162-
163-
def sample_pc(self, y, epsilon=10e-16):
164-
weights = self.get_edge_weights(epsilon)
165-
gumbels = weights - (-torch.rand_like(weights).log()).log()
166-
samples = self._scatter_forward(gumbels, "amax")
167-
samples = samples[self.csr] == gumbels
168-
samples &= y[self.csr].to(torch.bool)
169-
return self._scatter_backward(samples, "sum") > 0
170+
def forward(self, x, eps=10e-16):
171+
x = self.get_log_edge_weights(eps) + x[self.ptrs]
172+
return self._logsumexp_scatter(x, eps)
170173

171174

172175
def get_semiring(name: str, probabilistic: bool):

0 commit comments

Comments
 (0)