Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9f9d974
Use thread pool for request handlers.
ehpor Feb 11, 2026
469e5ed
Make functions const.
ehpor Feb 11, 2026
9a60d07
Add networked message broker classes.
ehpor Feb 11, 2026
b223e1c
Add Python bindings for remote message broker classes.
ehpor Feb 11, 2026
23906df
Add PeerConfig constructor.
ehpor Feb 11, 2026
c89e13b
Add tests for remote message broker.
ehpor Feb 11, 2026
466d07b
Add network message broker design file.
ehpor Feb 11, 2026
a4af0ef
Return null message after final message has been published.
ehpor Feb 11, 2026
1750843
Set temporary frame id.
ehpor Feb 11, 2026
2032130
Make destructor of abstract class virtual.
ehpor Feb 11, 2026
91825a3
Fix test layout.
ehpor Feb 11, 2026
d25c7ef
Temporarily add debug prints.
ehpor Feb 11, 2026
5edc48f
Consolidate RemoteMessageBroker functionality into a single file.
ehpor Feb 11, 2026
5a3d2ec
Make message structs for consistent (de)serialization.
ehpor Feb 12, 2026
08c4b98
Add debug print statements.
ehpor Feb 12, 2026
76f0809
Have tests use fixtures to initialize their brokers.
ehpor Feb 12, 2026
56f2d84
Let fixture create its own header memory.
ehpor Feb 12, 2026
b2606b2
Explicitly use a LocalMessageBroker to back the RemoteBrokerServer.
ehpor Feb 14, 2026
1ac3bd9
Refactor out unnecessary functions.
ehpor Feb 14, 2026
11737de
Make GetNextMessage and TryGetNextMessage non-public.
ehpor Feb 14, 2026
8f65fed
Fix GetSize() functions for messages.
ehpor Feb 14, 2026
9b46316
Add array info to messages.
ehpor Feb 14, 2026
92aa25e
Remove unnecessary sleeps.
ehpor Feb 14, 2026
69ed5e1
Fix remote message broker tests.
ehpor Feb 14, 2026
19e6475
Don't log socket identity since it can contain non-ascii characters.
ehpor Feb 14, 2026
57dc379
Don't manually shut down the service.
ehpor Feb 14, 2026
a2777c9
Remove some debug print statements.
ehpor Feb 14, 2026
8819127
Update plan.
ehpor Feb 14, 2026
76f1a3f
Wait for next message in chunks.
ehpor Feb 14, 2026
8ea3f49
Publish received messages on LocalMessageBroker.
ehpor Feb 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 126 additions & 16 deletions catkit2/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "Uuid.h"
#include "ArrayView.h"
#include "ProcessStats.h"
#include "RemoteMessageBroker.h"

#include "testbed.pb.h"

Expand Down Expand Up @@ -900,6 +901,8 @@ PYBIND11_MODULE(catkit_bindings, m)
})
.def_property_readonly("filename", &SharedMemory::GetFileName);

py::class_<MessageBroker, std::shared_ptr<MessageBroker>>(m, "MessageBroker");

py::class_<LocalMemory, Memory, std::shared_ptr<LocalMemory>>(m, "LocalMemory")
.def_static("create", [](size_t num_bytes)
{
Expand Down Expand Up @@ -1087,7 +1090,7 @@ PYBIND11_MODULE(catkit_bindings, m)
return py::none();
});

py::class_<LocalMessageBroker, std::shared_ptr<LocalMessageBroker>>(m, "LocalMessageBroker")
py::class_<LocalMessageBroker, MessageBroker, std::shared_ptr<LocalMessageBroker>>(m, "LocalMessageBroker")
.def_static("create", [](std::shared_ptr<Memory> header, std::vector<std::shared_ptr<Memory>> memory_blocks)
{
auto stream = StructStream(header);
Expand All @@ -1102,12 +1105,6 @@ PYBIND11_MODULE(catkit_bindings, m)

return std::shared_ptr<LocalMessageBroker>(std::move(broker));
})
.def("prepare_message", [](std::shared_ptr<LocalMessageBroker> broker, const std::string& topic, size_t payload_size, std::uint8_t memory_block_id)
{
auto message = broker->PrepareMessage(topic, payload_size, memory_block_id);

return message;
}, py::arg("topic"), py::arg("payload_size"), py::arg("memory_block_id") = 0)
.def("prepare_message", [](std::shared_ptr<LocalMessageBroker> broker, const std::string& topic, size_t payload_size, py::object trace_id, std::uint8_t memory_block_id)
{
if (trace_id.is_none())
Expand Down Expand Up @@ -1168,26 +1165,23 @@ PYBIND11_MODULE(catkit_bindings, m)
broker->PublishArray(topic, array_view, py::cast<Uuid>(trace_id), memory_block_id);
}
}, py::arg("topic"), py::arg("array"), py::arg("trace_id") = py::none(), py::arg("memory_block_id") = 0)
.def("try_get_message", [](std::shared_ptr<LocalMessageBroker> broker, std::string_view topic, size_t frame_id) -> py::object
.def("get_current_message", [](std::shared_ptr<LocalMessageBroker> broker, std::string_view topic) -> py::object
{
auto res = broker->TryGetMessage(topic, frame_id);
auto res = broker->GetCurrentMessage(topic);
if (res)
return py::cast(res.value());

return py::none();
}, py::arg("topic"), py::arg("frame_id"))
.def("get_current_message", [](std::shared_ptr<LocalMessageBroker> broker, std::string_view topic) -> py::object
}, py::arg("topic"))
.def("get_current_message_id", [](std::shared_ptr<LocalMessageBroker> broker, std::string_view topic) -> py::object
{
auto res = broker->GetCurrentMessage(topic);
auto res = broker->GetCurrentMessageId(topic);

if (res)
return py::cast(res.value());

return py::none();
}, py::arg("topic"))
.def("is_message_available", &LocalMessageBroker::IsMessageAvailable)
.def("will_message_be_available", &LocalMessageBroker::WillMessageBeAvailable)
.def("get_newest_message_id", &LocalMessageBroker::GetNewestMessageId)
.def("get_oldest_message_id", &LocalMessageBroker::GetOldestMessageId)
.def("get_message_rate", &LocalMessageBroker::GetMessageRate)
.def("get_all_message_topics", &LocalMessageBroker::GetAllMessageTopics)
.def("print_debug_info", &LocalMessageBroker::PrintDebugInfo)
Expand Down Expand Up @@ -1292,6 +1286,122 @@ PYBIND11_MODULE(catkit_bindings, m)
.def_property_readonly("memory_usage", &ProcessStats::GetMemoryUsage)
.def_property_readonly("cpu_usage", &ProcessStats::GetCpuUsage);

py::class_<PeerConfig>(m, "PeerConfig")
.def(py::init<>())
.def(py::init<std::string, std::string, int>(),
py::arg("name"),
py::arg("host"),
py::arg("port"))
.def_readwrite("name", &PeerConfig::name)
.def_readwrite("host", &PeerConfig::host)
.def_readwrite("port", &PeerConfig::port);

py::class_<RemoteBrokerServer>(m, "RemoteBrokerServer")
.def(py::init<std::shared_ptr<LocalMessageBroker>, uint16_t, int>(),
py::arg("broker"),
py::arg("port"),
py::arg("num_workers") = 4)
.def("start", &RemoteBrokerServer::Start)
.def("stop", &RemoteBrokerServer::Stop, py::call_guard<py::gil_scoped_release>())
.def_property_readonly("is_running", &RemoteBrokerServer::IsRunning);

py::class_<RemoteMessageBroker, MessageBroker, std::shared_ptr<RemoteMessageBroker>>(m, "RemoteMessageBroker")
.def(py::init<std::shared_ptr<LocalMessageBroker>, std::string, std::vector<PeerConfig>>(),
py::arg("local_broker"),
py::arg("local_machine_name"),
py::arg("peers"))
.def("prepare_message", [](std::shared_ptr<RemoteMessageBroker> broker, const std::string& topic, size_t payload_size, py::object trace_id, uint8_t memory_block_id)
{
if (trace_id.is_none())
{
return broker->PrepareMessage(topic, payload_size, memory_block_id);
}
else
{
return broker->PrepareMessage(topic, payload_size, py::cast<Uuid>(trace_id), memory_block_id);
}
}, py::arg("topic"), py::arg("payload_size"), py::arg("trace_id") = py::none(), py::arg("memory_block_id") = 0)
.def("publish_message", [](std::shared_ptr<RemoteMessageBroker> broker, Message& message, bool is_final)
{
broker->PublishMessage(message, is_final);
}, py::arg("message"), py::arg("is_final") = true)
.def("publish_data", [](std::shared_ptr<RemoteMessageBroker> broker, std::string topic, py::bytes data, py::object trace_id, uint8_t memory_block_id)
{
if (trace_id.is_none())
{
broker->PublishData(topic, PyBytes_AsString(data.ptr()), PyBytes_Size(data.ptr()), memory_block_id);
}
else
{
broker->PublishData(topic, PyBytes_AsString(data.ptr()), PyBytes_Size(data.ptr()), py::cast<Uuid>(trace_id), memory_block_id);
}
}, py::arg("topic"), py::arg("data"), py::arg("trace_id") = py::none(), py::arg("memory_block_id") = 0)
.def("publish_array", [](std::shared_ptr<RemoteMessageBroker> broker, std::string topic, py::array array, py::object trace_id, uint8_t memory_block_id)
{
ArrayInfo info;

auto dtype = array.dtype();
info.data_type = dtype.kind();
info.item_size = dtype.itemsize();
info.byte_order = dtype.byteorder();

if (array.ndim() > MAX_NUM_DIMENSIONS)
throw std::runtime_error("Array dimension is too large.");

info.ndim = array.ndim();

for (size_t i = 0; i < info.ndim; ++i)
{
info.shape[i] = array.shape()[i];
info.strides[i] = array.strides()[i];
}

if (!info.IsCContiguous() && !info.IsFContiguous())
throw std::runtime_error("Array has to be either C or F contiguous.");

// All checks are complete. Let's copy/submit the raw data.
const ArrayView array_view{info, array.mutable_data()};
if (trace_id.is_none())
{
broker->PublishArray(topic, array_view, memory_block_id);
}
else
{
broker->PublishArray(topic, array_view, py::cast<Uuid>(trace_id), memory_block_id);
}
}, py::arg("topic"), py::arg("array"), py::arg("trace_id") = py::none(), py::arg("memory_block_id") = 0)
.def("get_current_message", [](std::shared_ptr<RemoteMessageBroker> broker, std::string_view topic) -> py::object
{
auto res = broker->GetCurrentMessage(topic);
if (res)
return py::cast(res.value());

return py::none();
}, py::arg("topic"))
.def("get_current_message_id", [](std::shared_ptr<LocalMessageBroker>broker, std::string_view topic) -> py::object
{
auto res = broker->GetCurrentMessageId(topic);

if (res)
return py::cast(res.value());

return py::none();
}, py::arg("topic"))
.def("get_message_rate", &RemoteMessageBroker::GetMessageRate)
.def("get_all_message_topics", &RemoteMessageBroker::GetAllMessageTopics)
.def("subscribe", [](std::shared_ptr<RemoteMessageBroker> broker, std::string topic, py::object preferred_next_frame_id, MessageSubscriptionMode mode)
{
// Check if the starting frame ID is a number or None.
if (preferred_next_frame_id.is_none())
{
return broker->Subscribe(topic, mode);
}
else
{
return broker->Subscribe(topic, py::cast<std::uint64_t>(preferred_next_frame_id), mode);
}
}, py::arg("topic"), py::arg("preferred_next_frame_id") = py::none(), py::arg("mode") = MessageSubscriptionMode::NewestOnly);

#ifdef VERSION_INFO
m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);
#else
Expand Down
9 changes: 7 additions & 2 deletions catkit_core/Client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
using namespace std;
using namespace zmq;

const int SOCKET_TIMEOUT = 60000; // milliseconds.
#define DEBUG_PRINT(msg) std::cerr << "[DEBUG] " << __func__ << ":" << __LINE__ << " - " << msg << std::endl
#define ERROR_PRINT(msg) std::cerr << "[ERROR] " << __func__ << ":" << __LINE__ << " - " << msg << std::endl

const int SOCKET_TIMEOUT = 60000;

Client::Client(std::string host, int port)
: m_Host(host), m_Port(port)
Expand All @@ -29,6 +32,7 @@ Client::~Client()

string Client::MakeRequest(const string &what, const string &request)
{
DEBUG_PRINT("what: " << what << ", request_size: " << request.size() << ", host: " << m_Host << ":" << m_Port);
auto socket = GetSocket();

zmq::multipart_t request_msg;
Expand Down Expand Up @@ -59,6 +63,7 @@ string Client::MakeRequest(const string &what, const string &request)

std::string reply_type = reply_msg.popstr();
std::string reply_data = reply_msg.popstr();
DEBUG_PRINT("reply_type: " << reply_type << ", reply_data: " << reply_data);

if (reply_type == "OK")
{
Expand Down Expand Up @@ -119,4 +124,4 @@ Client::socket_ptr Client::GetSocket()
{
this->m_Sockets.emplace(ptr);
});
}
}
59 changes: 13 additions & 46 deletions catkit_core/LocalMessageBroker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ Message LocalMessageBroker::PrepareMessageImpl(std::string_view topic, size_t pa

if (allocator == nullptr)
{
throw std::runtime_error("Invalid device ID.");
throw std::runtime_error("Invalid device ID: " + std::to_string(memory_block_id));
}

DEBUG_PRINT("Gotten allocator.");
Expand Down Expand Up @@ -344,15 +344,15 @@ Message LocalMessageBroker::PrepareMessageImpl(std::string_view topic, size_t pa

DEBUG_PRINT("Header set");

return Message(header, payload, INVALID_FRAME_ID, false);
return Message(header, payload, INVALID_FRAME_ID);
}

Message LocalMessageBroker::PublishMessage(Message message, bool is_final)
{
DEBUG_PRINT("Publishing message.");

if (message.m_HasBeenPublished)
throw std::runtime_error("Message has already been published.");
if (message.m_Header == nullptr || message.m_Payload == nullptr)
throw std::runtime_error("Message is invalid. Use PrepareMessage() to create a valid message.");

// Set the timestamp.
message.m_Header->producer_timestamp = GetTimeStamp();
Expand Down Expand Up @@ -475,19 +475,14 @@ Message LocalMessageBroker::PublishMessage(Message message, bool is_final)
DEBUG_PRINT("Copied message header.");
}

message.m_HasBeenPublished = is_final;

return message;
}

std::optional<Message> LocalMessageBroker::TryGetMessage(std::string_view topic, size_t frame_id)
{
auto topic_header = GetTopicHeader(topic);

if (!topic_header->IsMessageAvailable(frame_id))
return std::nullopt;

return FetchMessage(topic_header, frame_id);
if (is_final)
{
return Message(nullptr, nullptr, 0);
}
else
{
return message;
}
}

Message LocalMessageBroker::FetchMessage(TopicHeader* topic_header, size_t frame_id)
Expand All @@ -498,7 +493,7 @@ Message LocalMessageBroker::FetchMessage(TopicHeader* topic_header, size_t frame
auto memory = GetMemory(header->payload_info.memory_block_id);
auto payload = memory->GetAddress(offset);

return Message(header, payload, frame_id, true);
return Message(header, payload, frame_id);
}

std::uint64_t LocalMessageBroker::GetNextMessageId(TopicHeader *topic_header, size_t preferred_next_frame_id, MessageSubscriptionMode mode)
Expand Down Expand Up @@ -590,34 +585,6 @@ std::shared_ptr<HybridPoolAllocator> LocalMessageBroker::GetAllocator(uint8_t me
return m_Allocators[memory_block_id];
}

bool LocalMessageBroker::IsMessageAvailable(std::string_view topic, size_t frame_id)
{
auto topic_header = GetTopicHeader(topic);

return topic_header->IsMessageAvailable(frame_id);
}

bool LocalMessageBroker::WillMessageBeAvailable(std::string_view topic, size_t frame_id)
{
auto topic_header = GetTopicHeader(topic);

return topic_header->WillMessageBeAvailable(frame_id);
}

size_t LocalMessageBroker::GetNewestMessageId(std::string_view topic)
{
auto topic_header = GetTopicHeader(topic);

return topic_header->GetNewestMessageId();
}

size_t LocalMessageBroker::GetOldestMessageId(std::string_view topic)
{
auto topic_header = GetTopicHeader(topic);

return topic_header->GetOldestMessageId();
}

double LocalMessageBroker::GetMessageRate(std::string_view topic)
{
auto topic_header = GetTopicHeader(topic);
Expand Down
28 changes: 7 additions & 21 deletions catkit_core/LocalMessageBroker.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,37 +86,23 @@ class LocalMessageBroker : public Shareable, public MessageBroker
// Get the newest message for a topic.
virtual std::optional<Message> GetCurrentMessage(std::string_view topic) override;

// Get the next message for a topic.
virtual std::optional<Message> GetNextMessage(std::string_view topic, size_t preferred_next_frame_id, MessageSubscriptionMode mode = MessageSubscriptionMode::NewestOnly, double timeout_in_seconds = -1, EventWaitMethod wait_type = EventWaitMethod::Default, void (*error_check)() = nullptr) override;

// Try to get the next message for a topic.
virtual std::optional<Message> TryGetNextMessage(std::string_view topic, size_t preferred_next_frame_id, MessageSubscriptionMode mode = MessageSubscriptionMode::NewestOnly) override;

// Get the message rate for a topic.
virtual double GetMessageRate(std::string_view topic) override;

// Get the message topics for all messages in this broker.
virtual std::vector<std::string> GetAllMessageTopics() override;

// Try to get a message by topic and frame ID.
std::optional<Message> TryGetMessage(std::string_view topic, size_t frame_id);

// Check for message availability.
bool IsMessageAvailable(std::string_view topic, size_t frame_id);

// Check if a message will be available in the future.
bool WillMessageBeAvailable(std::string_view topic, size_t frame_id);

// Get the newest message ID for a topic.
size_t GetNewestMessageId(std::string_view topic);

// Get the oldest message ID for a topic.
size_t GetOldestMessageId(std::string_view topic);

ShareableType GetType() const override;

virtual void PrintDebugInfo() const override;

protected:
// Get the next message for a topic.
virtual std::optional<Message> GetNextMessage(std::string_view topic, size_t preferred_next_frame_id, MessageSubscriptionMode mode = MessageSubscriptionMode::NewestOnly, double timeout_in_seconds = -1, EventWaitMethod wait_type = EventWaitMethod::Default, void (*error_check)() = nullptr) override;

// Try to get the next message for a topic.
virtual std::optional<Message> TryGetNextMessage(std::string_view topic, size_t preferred_next_frame_id, MessageSubscriptionMode mode = MessageSubscriptionMode::NewestOnly) override;

private:
Message FetchMessage(TopicHeader *topic_header, size_t frame_id);
std::uint64_t GetNextMessageId(TopicHeader *topic_header, size_t preferred_next_frame_id, MessageSubscriptionMode mode);
Expand Down
Loading
Loading