Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[291] Nanobind for QueryAnswer and DASNode #299

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
44 changes: 44 additions & 0 deletions src/hyperon_das_query_engine/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
load("@nanobind_bazel//:build_defs.bzl", "nanobind_extension")
load("@rules_python//python:packaging.bzl", "py_wheel")

py_library(
name = "hyperon_das_query_engine",
srcs = ["__init__.py"],
data = [":hyperon_das_query_engine_ext"],
visibility = ["//visibility:public"],
)

nanobind_extension(
name = "hyperon_das_query_engine_ext",
srcs = ["hyperon_das_query_engine_ext.cc"],
deps = [
"//query_engine:query_engine_lib"
],
)

py_wheel(
name = "hyperon_das_query_engine_wheel",
abi = "abi3",
author = "Andre Senna",
author_email = "[email protected]",
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
],
description_content_type = "text/markdown",
description_file = "README_hyperon_das_query_engine.md",
distribution = "hyperon_das_query_engine",
platform = "manylinux_2_28_x86_64",
python_requires = ">=3.10",
python_tag = "cp310",
stamp = 1,
summary = "DAS Node python package",
version = "$(das_query_engine_VERSION)", # must be defined when calling `bazel build` with `--define=DAS_NODE_VERSION=<version>`
deps = [
":hyperon_das_query_engine",
":hyperon_das_query_engine_ext",
],
)
1 change: 1 addition & 0 deletions src/hyperon_das_query_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .hyperon_das_query_engine_ext import *
85 changes: 85 additions & 0 deletions src/hyperon_das_query_engine/hyperon_das_query_engine_ext.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
//TODO: remove unused imports
// #include <nanobind/stl/vector.h>
// #include <nanobind/stl/shared_ptr.h>
#include <nanobind/trampoline.h>
#include <type_traits>

#include "query_engine/DASNode.h"
#include "query_engine/HandlesAnswer.h"
#include "query_engine/QueryAnswer.h"

#include "query_engine/query_element/RemoteIterator.h"

namespace nb = nanobind;
using namespace nb::literals;

using namespace std;
using namespace query_engine;

NB_MODULE(hyperon_das_query_engine_ext, m) {
//QueryAnswer.h bindings
nb::class_<QueryAnswer>(m, "QueryAnswer")
.def("tokenize", &QueryAnswer::tokenize)
.def("untokenize", &QueryAnswer::untokenize, "tokens"_a)
.def("to_string", &QueryAnswer::to_string);

//HandlesAnswer.h
nb::class_<HandlesAnswer, QueryAnswer>(m, "HandlesAnswer")
.def(nb::init<>())
.def(nb::init<double>(), "importance"_a)
.def(nb::init<const char*, double>(), "handle"_a, "importance"_a)
//TODO: nanobind is failing to bind the handles field
// error: invalid conversion from 'const char* const*' to 'char'
// .def_ro("handles", &HandlesAnswer::handles)
.def_ro("handles_size", &HandlesAnswer::handles_size)
.def_ro("importance", &HandlesAnswer::importance)
.def("add_handle", &HandlesAnswer::add_handle, "handle"_a)
.def("merge", &HandlesAnswer::merge, "other"_a, "merge_handles"_a)
.def_static("copy", &HandlesAnswer::copy, "base"_a)
;

// RemoteIterator.h
nb::class_<RemoteIterator<HandlesAnswer>>(m, "RemoteIterator")
.def(nb::init<string>(), "local_id"_a)
.def_ro("is_terminal", &RemoteIterator<HandlesAnswer>::is_terminal)
.def("graceful_shutdown", &RemoteIterator<HandlesAnswer>::graceful_shutdown)
.def("setup_buffers", &RemoteIterator<HandlesAnswer>::setup_buffers)
.def("finished", &RemoteIterator<HandlesAnswer>::finished)
.def("pop", &RemoteIterator<HandlesAnswer>::pop)
;

//DASNode.h
nb::class_<DASNode>(m, "DASNode")
.def(nb::init<string>(), "node_id"_a)
.def(nb::init<string, string>(), "node_id"_a, "server_id"_a)
.def_ro_static("PATTERN_MATCHING_QUERY", &DASNode::PATTERN_MATCHING_QUERY)
.def_ro_static("COUNTING_QUERY", &DASNode::COUNTING_QUERY)
.def("pattern_matcher_query", &DASNode::pattern_matcher_query,
"tokens"_a,
"context"_a = "",
"update_attention_broker"_a = false)
.def("count_query", &DASNode::count_query,
"tokens"_a,
"context"_a = "",
"update_attention_broker"_a = false)
.def("next_query_id", &DASNode::next_query_id)
.def("message_factory", &DASNode::message_factory,
"command"_a,
"args"_a)

// inherited from DistributedAlgorithmNode
.def("join_network", &DASNode::join_network)
.def("is_leader", &DASNode::is_leader)
.def("leader_id", &DASNode::leader_id)
.def("has_leader", &DASNode::has_leader)
.def("add_peer", &DASNode::add_peer, "peer_id"_a)
.def("node_id", &DASNode::node_id)
.def("broadcast", &DASNode::broadcast, "command"_a, "args"_a)
.def("send", &DASNode::send, "command"_a, "args"_a, "recipient"_a)
.def("node_joined_network", &DASNode::node_joined_network,
"node_id"_a)
.def("cast_leadership_vote", &DASNode::cast_leadership_vote);

}
9 changes: 9 additions & 0 deletions src/tests/python/unit/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,12 @@ py_test(
],
)

py_test(
name = "test_das_node",
size = "small",
srcs = ["test_das_node.py"],
deps = [
"//hyperon_das_query_engine:hyperon_das_query_engine",
],
)

99 changes: 99 additions & 0 deletions src/tests/python/unit/test_das_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from unittest import TestCase

from hyperon_das_query_engine import (
DASNode,
QueryAnswer,
HandlesAnswer,
RemoteIterator,
)


class TestDASNode(TestCase):
def test_das_node_binding(self) -> None:
"""
Test if all the attributes are present on the binding of DASNode
"""

#inherited attributes
self.assertTrue(hasattr(DASNode, "add_peer"))
self.assertTrue(hasattr(DASNode, "broadcast"))
self.assertTrue(hasattr(DASNode, "cast_leadership_vote"))
self.assertTrue(hasattr(DASNode, "has_leader"))
self.assertTrue(hasattr(DASNode, "is_leader"))
self.assertTrue(hasattr(DASNode, "join_network"))
self.assertTrue(hasattr(DASNode, "message_factory"))
self.assertTrue(hasattr(DASNode, "node_id"))
self.assertTrue(hasattr(DASNode, "node_joined_network"))
self.assertTrue(hasattr(DASNode, "node_joined_network"))
self.assertTrue(hasattr(DASNode, "send"))

#own attributes
self.assertTrue(hasattr(DASNode, "PATTERN_MATCHING_QUERY"))
self.assertTrue(hasattr(DASNode, "COUNTING_QUERY"))
self.assertTrue(hasattr(DASNode, "pattern_matcher_query"))
self.assertTrue(hasattr(DASNode, "count_query"))
self.assertTrue(hasattr(DASNode, "next_query_id"))

def test_query_answer_bindings(self):
self.assertTrue(hasattr(QueryAnswer, "tokenize"))
self.assertTrue(hasattr(QueryAnswer, "untokenize"))
self.assertTrue(hasattr(QueryAnswer, "to_string"))

def test_handles_answer_inherits_query_answer(self):
self.assertTrue(issubclass(HandlesAnswer, QueryAnswer))

def test_handles_answer_bindings(self):
# self.assertTrue(hasattr(HandlesAnswer, "handles"))
self.assertTrue(hasattr(HandlesAnswer, "handles_size"))
self.assertTrue(hasattr(HandlesAnswer, "importance"))
self.assertTrue(hasattr(HandlesAnswer, "add_handle"))
self.assertTrue(hasattr(HandlesAnswer, "merge"))
self.assertTrue(hasattr(HandlesAnswer, "copy"))

# inherited attr
self.assertTrue(hasattr(HandlesAnswer, "tokenize"))
self.assertTrue(hasattr(HandlesAnswer, "untokenize"))
self.assertTrue(hasattr(HandlesAnswer, "to_string"))

def test_remote_iterator_bindings(self):
self.assertTrue(hasattr(RemoteIterator, "is_terminal"))
self.assertTrue(hasattr(RemoteIterator, "graceful_shutdown"))
self.assertTrue(hasattr(RemoteIterator, "setup_buffers"))
self.assertTrue(hasattr(RemoteIterator, "finished"))
self.assertTrue(hasattr(RemoteIterator, "pop"))

def test_das_node(self) -> None:
"""
Test das_node server and client constructors
"""
self.server_id: str = "localhost:35700"
self.client1_id: str = "localhost:35701"
self.client2_id: str = "localhost:35702"

self.server = DASNode(node_id=self.server_id)
self.client1 = DASNode(node_id=self.client1_id, server_id=self.server_id)
self.client2 = DASNode(node_id=self.client2_id, server_id=self.server_id)

# Test id assignment
self.assertEqual(self.server.node_id(), self.server_id)
self.assertEqual(self.client1.node_id(), self.client1_id)
self.assertEqual(self.client2.node_id(), self.client2_id)

# Server should be leader
self.assertTrue(self.server.has_leader())
self.assertTrue(self.client1.has_leader())
self.assertTrue(self.client2.has_leader())

self.assertTrue(self.server.is_leader())
self.assertFalse(self.client1.is_leader())
self.assertFalse(self.client2.is_leader())

self.assertEqual(self.server.leader_id(), self.server_id)
self.assertEqual(self.client1.leader_id(), self.server_id)
self.assertEqual(self.client2.leader_id(), self.server_id)

# Test id assignment
self.assertEqual(self.server.node_id(), self.server_id)
self.assertEqual(self.client1.node_id(), self.client1_id)
self.assertEqual(self.client2.node_id(), self.client2_id)

13 changes: 13 additions & 0 deletions src/tests/python/unit/test_star_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@


class TestStarNode(TestCase):
def test_star_node_binding(self):
self.assertTrue(hasattr(StarNode, "add_peer"))
self.assertTrue(hasattr(StarNode, "broadcast"))
self.assertTrue(hasattr(StarNode, "cast_leadership_vote"))
self.assertTrue(hasattr(StarNode, "has_leader"))
self.assertTrue(hasattr(StarNode, "is_leader"))
self.assertTrue(hasattr(StarNode, "join_network"))
self.assertTrue(hasattr(StarNode, "message_factory"))
self.assertTrue(hasattr(StarNode, "node_id"))
self.assertTrue(hasattr(StarNode, "node_joined_network"))
self.assertTrue(hasattr(StarNode, "node_joined_network"))
self.assertTrue(hasattr(StarNode, "send"))

def test_star_node(self):
self.server_id: str = "localhost:35700"
self.client1_id: str = "localhost:35701"
Expand Down