Skip to content

Commit ec25b3e

Browse files
committed
clean up docstrings a bit
1 parent 8fcde1c commit ec25b3e

File tree

2 files changed

+31
-14
lines changed

2 files changed

+31
-14
lines changed

src/circuit.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -455,14 +455,14 @@ nb::class_<Circuit>(m, "Circuit")
455455
.def("add_SDD_from_file", &Circuit::add_SDD_from_file, "filename"_a, "true_lits"_a = std::vector<int>(), "false_lits"_a = std::vector<int>())
456456
.def("add_D4_from_file", &Circuit::add_D4_from_file, "filename"_a, "true_lits"_a = std::vector<int>(), "false_lits"_a = std::vector<int>())
457457
.def("_get_indices", &Circuit::get_indices)
458-
.def("nb_nodes", &Circuit::nb_nodes, "number of nodes in the circuit")
459-
.def("nb_root_nodes", &Circuit::nb_root_nodes, "number of root nodes in the circuit")
460-
.def("true_node", &Circuit::true_node, "adds a true node to the circuit, and returns a pointer")
461-
.def("false_node", &Circuit::false_node, "adds a false node to the circuit, and returns a pointer")
462-
.def("literal_node", &Circuit::literal_node, "adds a literal node to the circuit ,and returns a pointer")
463-
.def("or_node", &Circuit::or_node, "children"_a, "adds an or node to the circuit, and returns a pointer")
464-
.def("and_node", &Circuit::and_node, "children"_a, "adds an and node to the circuit, and returns a pointer")
465-
.def("set_root", &Circuit::set_root, "root"_a, "marks a node pointer as root")
458+
.def("nb_nodes", &Circuit::nb_nodes, "Number of nodes in the circuit.")
459+
.def("nb_root_nodes", &Circuit::nb_root_nodes, "Number of root nodes in the circuit.")
460+
.def("true_node", &Circuit::true_node, "Adds a true node to the circuit, and returns it as a pointer.")
461+
.def("false_node", &Circuit::false_node, "Adds a false node to the circuit, and returns it as a pointer.")
462+
.def("literal_node", &Circuit::literal_node, "Adds a literal node to the circuit, and returns it as a pointer.", "literal"_a)
463+
.def("or_node", &Circuit::or_node, "children"_a, "Adds an :code:`or` node to the circuit, and returns it as a pointer.")
464+
.def("and_node", &Circuit::and_node, "children"_a, "Adds an :code:`and` node to the circuit, and returns it as a pointer.")
465+
.def("set_root", &Circuit::set_root, "root"_a, "Marks a node pointer as root.")
466466
.def("remove_unused_nodes", &Circuit::remove_unused_nodes, "Removes unused non-root nodes from the circuit.\nCareful! This invalidates any NodePtr refering to an unused node (i.e., a node not conneected to a root node).");
467467

468468
m.def("to_dot_file", &to_dot_file, "circuit"_a, "filename"_a, "Write the given circuit as dot format to a file");

src/klay/__init__.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,42 @@
11
# noinspection PyUnresolvedReferences
22
from .nanobind_ext import Circuit
33

4+
from collections.abc import Sequence
45

5-
def to_torch_module(circuit: Circuit, semiring: str = "log"):
6+
7+
def to_torch_module(self: Circuit, semiring: str = "log"):
8+
"""
9+
Convert the circuit into a PyTorch module.
10+
11+
:param semiring:
12+
The semiring in which the circuit should be evaluated. Supported options are ("log", "real", "mpe", "godel").
13+
"""
614
from .backends import torch_backend
7-
indices = circuit._get_indices()
15+
indices = self._get_indices()
816
return torch_backend.KnowledgeModule(*indices, semiring=semiring)
917

1018

11-
def to_jax_function(circuit: Circuit, semiring: str = "log"):
19+
def to_jax_function(self: Circuit, semiring: str = "log"):
20+
"""
21+
Convert the circuit into a Jax function.
22+
23+
:param semiring:
24+
The semiring in which the circuit should be evaluated. Supported options are ("log", "real").
25+
"""
1226
from .backends import jax_backend
13-
indices = circuit._get_indices()
27+
indices = self._get_indices()
1428
return jax_backend.create_knowledge_layer(*indices, semiring=semiring)
1529

1630

17-
def add_sdd(circuit: Circuit, sdd: "SddNode", **kwargs):
31+
def add_sdd(self: Circuit, sdd: "SddNode", true_lits: Sequence[int] = (), false_lits: Sequence[int] = ()):
32+
"""
33+
Add an SDD to the Circuit.
34+
"""
1835
import os
1936
from pathlib import Path
2037

2138
sdd.save(bytes(Path("tmp.sdd")))
22-
circuit.add_SDD_from_file("tmp.sdd", **kwargs)
39+
self.add_SDD_from_file("tmp.sdd", true_lits, false_lits)
2340
os.remove("tmp.sdd")
2441

2542

0 commit comments

Comments
 (0)