Skip to content

Commit 397fd8d

Browse files
committed
small bugfix
1 parent 578ebca commit 397fd8d

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

src/circuit.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,8 @@ m.doc() = "Layerize arithmetic circuits";
504504
nb::class_<NodePtr>(m, "NodePtr")
505505
.def("__repr__", &NodePtr::to_string)
506506
.def(nb::self == nb::self)
507-
.def("__hash__", &NodePtr::as_int);
507+
.def("__hash__", &NodePtr::as_int)
508+
.def("get_ix", [](NodePtr a) {return a.get()->ix;});
508509

509510
nb::class_<Circuit>(m, "Circuit", "Circuits are the main class added by KLay, and require no arguments to construct.\n\n:code:`circuit = klay.Circuit()` ")
510511
.def(nb::init<>())

src/klay/backends/torch_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def forward(self, x):
129129
def get_edge_weights(self):
130130
exp_weights, _ = self._safe_exp(self.weights)
131131
norm = self._scatter_reduce(exp_weights, "sum")
132-
return exp_weights / norm
132+
return exp_weights / norm[self.csr]
133133

134134

135135
class ProbabilisticLogSumLayer(ProbabilisticKnowledgeLayer):
@@ -139,7 +139,7 @@ def forward(self, x, epsilon=10e-16):
139139

140140
def get_edge_weights(self, epsilon):
141141
norm = self._logsumexp_scatter_reduce(self.weights, epsilon)
142-
return self.weights - norm
142+
return self.weights - norm[self.csr]
143143

144144

145145
def get_semiring(name: str, probabilistic: bool):

0 commit comments

Comments
 (0)