Skip to content

Commit c3e7938

Browse files
committed
rename backend indices
1 parent 5dd46b3 commit c3e7938

File tree

2 files changed

+54
-54
lines changed

2 files changed

+54
-54
lines changed

src/klay/backends/jax_backend.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,49 +43,49 @@ def encode_input_real(pos, neg):
4343

4444

4545

46-
def create_knowledge_layer(pointers, csrs, semiring):
47-
pointers = [np.array(ptrs) for ptrs in pointers]
48-
num_segments = [len(csr) - 1 for csr in csrs] # needed for the jit
49-
csrs = [unroll_csr(np.array(csr, dtype=np.int32)) for csr in csrs]
46+
def create_knowledge_layer(pointers, ix_outs, semiring):
47+
ixs_in = [np.array(ix_in) for ix_in in pointers]
48+
num_segments = [len(ix_out) - 1 for ix_out in ix_outs] # needed for the jit
49+
ixs_out = [unroll_ix_out(np.array(ix_out, dtype=np.int32)) for ix_out in ix_outs]
5050
sum_layer, prod_layer = get_semiring(semiring)
5151
encode_input = {'log': encode_input_log, 'real': encode_input_real}[semiring]
5252

5353
@jax.jit
5454
def wrapper(pos, neg=None):
5555
x = encode_input(pos, neg)
56-
for i, (ptrs, csr) in enumerate(zip(pointers, csrs)):
56+
for i, (ix_in, ix_out) in enumerate(zip(ixs_in, ixs_out)):
5757
if i % 2 == 0:
58-
x = prod_layer(num_segments[i], ptrs, csr, x)
58+
x = prod_layer(num_segments[i], ix_in, ix_out, x)
5959
else:
60-
x = sum_layer(num_segments[i], ptrs, csr, x)
60+
x = sum_layer(num_segments[i], ix_in, ix_out, x)
6161
return x
6262

6363
return wrapper
6464

6565

66-
def unroll_csr(csr):
67-
deltas = np.diff(csr)
66+
def unroll_ix_out(ix_out):
67+
deltas = np.diff(ix_out)
6868
ixs = np.arange(len(deltas), dtype=jnp.int32)
6969
return np.repeat(ixs, repeats=deltas)
7070

7171

72-
def log_sum_layer(num_segments, ptrs, csr, x):
73-
x = x[ptrs]
74-
x_max = segment_max(stop_gradient(x), csr, indices_are_sorted=True, num_segments=num_segments)
75-
x = x - x_max[csr]
72+
def log_sum_layer(num_segments, ix_in, ix_out, x):
73+
x = x[ix_in]
74+
x_max = segment_max(stop_gradient(x), ix_out, indices_are_sorted=True, num_segments=num_segments)
75+
x = x - x_max[ix_out]
7676
x = jnp.nan_to_num(x, copy=False, nan=0.0, posinf=float('inf'), neginf=float('-inf'))
7777
x = jnp.exp(x)
78-
x = segment_sum(x, csr, indices_are_sorted=True, num_segments=num_segments)
78+
x = segment_sum(x, ix_out, indices_are_sorted=True, num_segments=num_segments)
7979
x = jnp.log(x + EPSILON) + x_max
8080
return x
8181

8282

83-
def sum_layer(num_segments, ptrs, csr, x):
84-
return segment_sum(x[ptrs], csr, num_segments=num_segments, indices_are_sorted=True)
83+
def sum_layer(num_segments, ix_in, ix_out, x):
84+
return segment_sum(x[ix_in], ix_out, num_segments=num_segments, indices_are_sorted=True)
8585

8686

87-
def prod_layer(num_segments, ptrs, csr, x):
88-
return segment_prod(x[ptrs], csr, num_segments=num_segments, indices_are_sorted=True)
87+
def prod_layer(num_segments, ix_in, ix_out, x):
88+
return segment_prod(x[ix_in], ix_out, num_segments=num_segments, indices_are_sorted=True)
8989

9090

9191
def get_semiring(name: str):

src/klay/backends/torch_backend.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,26 @@ def encode_input(pos, neg, zero, one):
2828
return torch.cat([constants, result])
2929

3030

31-
def unroll_csr(csr):
32-
deltas = torch.diff(csr)
33-
ixs = torch.arange(len(deltas), dtype=torch.long, device=csr.device)
31+
def unroll_ixs(ixs):
32+
deltas = torch.diff(ixs)
33+
ixs = torch.arange(len(deltas), dtype=torch.long, device=ixs.device)
3434
return ixs.repeat_interleave(repeats=deltas)
3535

3636

3737
class KnowledgeModule(nn.Module):
38-
def __init__(self, pointers, csrs, semiring='real', probabilistic=False):
38+
def __init__(self, ixs_in, ixs_out, semiring='real', probabilistic=False):
3939
super(KnowledgeModule, self).__init__()
4040
layers = []
4141
self.probabilistic = probabilistic
4242
sum_layer, prod_layer, self.zero, self.one, self.negate = get_semiring(semiring, probabilistic)
43-
for i, (ptrs, csr) in enumerate(zip(pointers, csrs)):
44-
ptrs = torch.as_tensor(ptrs)
45-
csr = torch.as_tensor(csr, dtype=torch.long)
46-
csr = unroll_csr(csr)
43+
for i, (ix_in, ix_out) in enumerate(zip(ixs_in, ixs_out)):
44+
ix_in = torch.as_tensor(ix_in)
45+
ix_out = torch.as_tensor(ix_out, dtype=torch.long)
46+
ix_out = unroll_ixs(ix_out)
4747
if i % 2 == 0:
48-
layers.append(prod_layer(ptrs, csr))
48+
layers.append(prod_layer(ix_in, ix_out))
4949
else:
50-
layers.append(sum_layer(ptrs, csr))
50+
layers.append(sum_layer(ix_in, ix_out))
5151
self.layers = nn.Sequential(*layers)
5252

5353
def forward(self, weights, neg_weights=None, eps=0):
@@ -57,7 +57,7 @@ def forward(self, weights, neg_weights=None, eps=0):
5757
return self.layers(x)
5858

5959
def sparsity(self, nb_vars: int) -> float:
60-
sparse_params = sum(len(l.csr) for l in self.layers)
60+
sparse_params = sum(len(l.ix_out) for l in self.layers)
6161
layer_widths = [nb_vars] + [l.out_shape[0] for l in self.layers]
6262
dense_params = sum(layer_widths[i] * layer_widths[i+1] for i in range(len(layer_widths) - 1))
6363
return sparse_params / dense_params
@@ -71,104 +71,104 @@ def sample_pc(self):
7171

7272

7373
class KnowledgeLayer(nn.Module):
74-
def __init__(self, ptrs, csr):
74+
def __init__(self, ix_in, ix_out):
7575
super().__init__()
76-
self.register_buffer('ptrs', ptrs)
77-
self.register_buffer('csr', csr)
78-
self.out_shape = (self.csr[-1].item() + 1,)
79-
self.in_shape = (self.ptrs.max() + 1,)
76+
self.register_buffer('ix_in', ix_in)
77+
self.register_buffer('ix_out', ix_out)
78+
self.out_shape = (self.ix_out[-1].item() + 1,)
79+
self.in_shape = (self.ix_in.max().item() + 1,)
8080

8181
def _scatter_forward(self, x: torch.Tensor, reduce: str):
8282
output = torch.empty(self.out_shape, dtype=x.dtype, device=x.device)
83-
output = torch.scatter_reduce(output, 0, index=self.csr, src=x, reduce=reduce, include_self=False)
83+
output = torch.scatter_reduce(output, 0, index=self.ix_out, src=x, reduce=reduce, include_self=False)
8484
return output
8585

8686
def _scatter_backward(self, x: torch.Tensor, reduce: str):
8787
output = torch.empty(self.in_shape, dtype=x.dtype, device=x.device)
88-
output = torch.scatter_reduce(output, 0, index=self.ptrs, src=x, reduce=reduce, include_self=False)
88+
output = torch.scatter_reduce(output, 0, index=self.ix_in, src=x, reduce=reduce, include_self=False)
8989
return output
9090

9191

9292
def _safe_exp(self, x: torch.Tensor):
9393
with torch.no_grad():
9494
max_output = self._scatter_forward(x, "amax")
95-
x = x - max_output[self.csr]
95+
x = x - max_output[self.ix_out]
9696
x.nan_to_num_(nan=0., posinf=float('inf'), neginf=float('-inf'))
9797
return torch.exp(x), max_output
9898

9999
def _logsumexp_scatter(self, x: torch.Tensor, eps: float):
100100
x, max_output = self._safe_exp(x)
101101
output = torch.full(self.out_shape, eps, dtype=x.dtype, device=x.device)
102-
output = torch.scatter_add(output, 0, index=self.csr, src=x)
102+
output = torch.scatter_add(output, 0, index=self.ix_out, src=x)
103103
output = torch.log(output) + max_output
104104
return output
105105

106106

107107

108108
class ProbabilisticKnowledgeLayer(KnowledgeLayer):
109-
def __init__(self, ptrs, csr):
110-
super().__init__(ptrs, csr)
111-
self.weights = nn.Parameter(torch.randn_like(ptrs, dtype=torch.float32))
109+
def __init__(self, ix_in, ix_out):
110+
super().__init__(ix_in, ix_out)
111+
self.weights = nn.Parameter(torch.randn_like(ix_in, dtype=torch.float32))
112112

113113
def get_edge_weights(self):
114114
exp_weights, _ = self._safe_exp(self.weights)
115115
norm = self._scatter_forward(exp_weights, "sum")
116-
return exp_weights / norm[self.csr]
116+
return exp_weights / norm[self.ix_out]
117117

118118
def get_log_edge_weights(self, eps):
119119
norm = self._logsumexp_scatter(self.weights, eps)
120-
return self.weights - norm[self.csr]
120+
return self.weights - norm[self.ix_out]
121121

122122
def sample_pc(self, y, eps=10e-16):
123123
weights = self.get_log_edge_weights(eps)
124124
noise = -(-torch.log(torch.rand_like(weights) + eps) + eps).log()
125125
gumbels = weights + noise
126126
samples = self._scatter_forward(gumbels, "amax")
127-
samples = samples[self.csr] == gumbels
128-
samples &= y[self.csr].to(torch.bool)
127+
samples = samples[self.ix_out] == gumbels
128+
samples &= y[self.ix_out].to(torch.bool)
129129
return self._scatter_backward(samples, "sum") > 0
130130

131131

132132
class SumLayer(KnowledgeLayer):
133133
def forward(self, x):
134-
return self._scatter_forward(x[self.ptrs], "sum")
134+
return self._scatter_forward(x[self.ix_in], "sum")
135135

136136
def sample_pc(self, y):
137-
return self._scatter_backward(y[self.csr], "sum") > 0
137+
return self._scatter_backward(y[self.ix_out], "sum") > 0
138138

139139

140140
class ProdLayer(KnowledgeLayer):
141141
def forward(self, x):
142-
return self._scatter_forward(x[self.ptrs], "prod")
142+
return self._scatter_forward(x[self.ix_in], "prod")
143143

144144
def sample_pc(self, y):
145-
return self._scatter_backward(y[self.csr], "sum") > 0
145+
return self._scatter_backward(y[self.ix_out], "sum") > 0
146146

147147

148148
class MinLayer(KnowledgeLayer):
149149
def forward(self, x):
150-
return self._scatter_forward(x[self.ptrs], "amin")
150+
return self._scatter_forward(x[self.ix_in], "amin")
151151

152152

153153
class MaxLayer(KnowledgeLayer):
154154
def forward(self, x):
155-
return self._scatter_forward(x[self.ptrs], "amax")
155+
return self._scatter_forward(x[self.ix_in], "amax")
156156

157157

158158
class LogSumLayer(KnowledgeLayer):
159159
def forward(self, x, eps=10e-16):
160-
return self._logsumexp_scatter(x[self.ptrs], eps)
160+
return self._logsumexp_scatter(x[self.ix_in], eps)
161161

162162

163163
class ProbabilisticSumLayer(ProbabilisticKnowledgeLayer):
164164
def forward(self, x):
165-
x = self.get_edge_weights() * x[self.ptrs]
165+
x = self.get_edge_weights() * x[self.ix_in]
166166
return self._scatter_forward(x, "sum")
167167

168168

169169
class ProbabilisticLogSumLayer(ProbabilisticKnowledgeLayer):
170170
def forward(self, x, eps=10e-16):
171-
x = self.get_log_edge_weights(eps) + x[self.ptrs]
171+
x = self.get_log_edge_weights(eps) + x[self.ix_in]
172172
return self._logsumexp_scatter(x, eps)
173173

174174

0 commit comments

Comments
 (0)