diff --git a/CMakeLists.txt b/CMakeLists.txt index 38e585a7..df912b05 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -108,6 +108,8 @@ else() option(OPENCMW_ENABLE_COVERAGE "Enable Coverage" OFF) endif() option(OPENCMW_ENABLE_CONCEPTS "Enable Concepts Builds" ${opencmw_MASTER_PROJECT}) +option(OPENCMW_DEBUG_HTTP "Enable verbose HTTP output for debugging" OFF) +option(OPENCMW_PROFILE_HTTP "Enable verbose HTTP output for profiling" OFF) # Very basic PCH example option(ENABLE_PCH "Enable Precompiled Headers" OFF) @@ -124,6 +126,14 @@ if(ENABLE_PCH) ) endif() +if(OPENCMW_DEBUG_HTTP) + target_compile_definitions(opencmw_project_options INTERFACE -DOPENCMW_DEBUG_HTTP) +endif() + +if(OPENCMW_PROFILE_HTTP) + target_compile_definitions(opencmw_project_options INTERFACE -DOPENCMW_PROFILE_HTTP) +endif() + if(OPENCMW_ENABLE_TESTING) enable_testing() message("Building Tests.") diff --git a/cmake/DependenciesNative.cmake b/cmake/DependenciesNative.cmake index b1696538..bcce833e 100644 --- a/cmake/DependenciesNative.cmake +++ b/cmake/DependenciesNative.cmake @@ -65,3 +65,29 @@ FetchContent_Declare( FetchContent_MakeAvailable(cpp-httplib zeromq openssl-source) list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/contrib) # replace contrib by extras for catch2 v3.x.x + +option(ENABLE_NGHTTP2_DEBUG "Enable verbose nghttp2 debug output" OFF) + +include(ExternalProject) +ExternalProject_Add(Nghttp2Project + GIT_REPOSITORY https://github.com/nghttp2/nghttp2 + GIT_TAG v1.65.0 + GIT_SHALLOW ON + BUILD_BYPRODUCTS ${CMAKE_BINARY_DIR}/nghttp2-install/lib/libnghttp2.a + CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX:PATH=${CMAKE_BINARY_DIR}/nghttp2-install + -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} + -DENABLE_LIB_ONLY:BOOL=ON + -DENABLE_HTTP3:BOOL=OFF + -DENABLE_DEBUG:BOOL=${ENABLE_NGHTTP2_DEBUG} + -DBUILD_STATIC_LIBS:BOOL=ON + -BUILD_SHARED_LIBS:BOOL=OFF + -DENABLE_DOC:BOOL=OFF +) + +add_library(nghttp2-static STATIC IMPORTED STATIC GLOBAL) +set_target_properties(nghttp2-static PROPERTIES + IMPORTED_LOCATION "${CMAKE_BINARY_DIR}/nghttp2-install/lib/libnghttp2.a" + INTERFACE_INCLUDE_DIRECTORIES "${CMAKE_BINARY_DIR}/nghttp2-install/include" +) +add_dependencies(nghttp2-static Nghttp2Project) diff --git a/concepts/client/CMakeLists.txt b/concepts/client/CMakeLists.txt index a361193d..33f28624 100644 --- a/concepts/client/CMakeLists.txt +++ b/concepts/client/CMakeLists.txt @@ -1,13 +1,10 @@ -if(NOT EMSCRIPTEN) - add_executable(RestSubscription_example RestSubscription_example.cpp) - target_link_libraries( - RestSubscription_example - PRIVATE core - client - opencmw_project_options - opencmw_project_warnings - assets::rest) -endif() +add_executable(LoadTest_client LoadTest_client.cpp) +target_link_libraries( + LoadTest_client + PRIVATE core + client + opencmw_project_options + opencmw_project_warnings) add_executable(RestSubscription_client RestSubscription_client.cpp) target_link_libraries( diff --git a/concepts/client/LoadTest_client.cpp b/concepts/client/LoadTest_client.cpp new file mode 100644 index 00000000..d084e0cb --- /dev/null +++ b/concepts/client/LoadTest_client.cpp @@ -0,0 +1,68 @@ +#include +#include +#include +#include + +#include +#include +#include + +#include "ClientCommon.hpp" +#include "helpers.hpp" + +using namespace std::chrono_literals; + +std::string schema() { + if (auto env = ::getenv("DISABLE_REST_HTTPS"); env != nullptr && std::string_view(env) == "1") { + return "http"; + } else { + return "https"; + } +} + + +int main() { + constexpr auto kServerPort = 8080; + constexpr auto kNClients = 80UZ; + constexpr auto kNSubscriptions = 10UZ; + constexpr auto kNUpdates = 5000UZ; + constexpr auto kIntervalMs = 40UZ; + constexpr auto kPayloadSize = 4096UZ; + + std::array, kNClients> clients; + for (std::size_t i = 0; i < clients.size(); i++) { + clients[i] = std::make_unique(opencmw::client::DefaultContentTypeHeader(opencmw::MIME::BINARY), opencmw::client::VerifyServerCertificates(false)); + } + std::atomic responseCount = 0; + + const auto start = std::chrono::system_clock::now(); + + for (std::size_t i = 0; i < kNClients; i++) { + for (std::size_t j = 0; j < kNSubscriptions; j++) { + opencmw::client::Command cmd; + cmd.command = opencmw::mdp::Command::Subscribe; + cmd.serviceName = "/loadTest"; + cmd.topic = opencmw::URI<>(fmt::format("{}://localhost:{}/loadTest?initialDelayMs=1000&topic={}&intervalMs={}&payloadSize={}&nUpdates={}", schema(), kServerPort, /*i,*/ j, kIntervalMs, kPayloadSize, kNUpdates)); + cmd.callback = [&responseCount](const auto &msg) { + responseCount++; + }; + clients[i]->request(std::move(cmd)); + } + } + + constexpr auto expectedResponses = kNClients * kNSubscriptions * kNUpdates; + + std::uint64_t counter = 0; + while (responseCount < expectedResponses) { + counter += 50; + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + if (counter % 20 == 0) { + fmt::println("Received {} of {} responses", responseCount.load(), expectedResponses); + } + } + + const auto end = std::chrono::system_clock::now(); + const auto elapsed = std::chrono::duration_cast(end - start).count(); + fmt::println("Elapsed time: {} ms", elapsed); + return 0; +} diff --git a/concepts/client/RestSubscription_client.cpp b/concepts/client/RestSubscription_client.cpp index 9d086b6c..b822fc3b 100644 --- a/concepts/client/RestSubscription_client.cpp +++ b/concepts/client/RestSubscription_client.cpp @@ -13,8 +13,11 @@ using namespace std::chrono_literals; // These are not main-local, as JS doesn't end when // C++ main ends namespace test_state { +#ifndef __EMSCRIPTEN__ +opencmw::client::RestClient client(opencmw::client::VerifyServerCertificates(false)); +#else opencmw::client::RestClient client; - +#endif std::string schema() { if (auto env = ::getenv("DISABLE_REST_HTTPS"); env != nullptr && std::string_view(env) == "1") { return "http"; @@ -39,10 +42,6 @@ auto run = rest_test_runner( } // namespace test_state int main() { -#ifndef __EMSCRIPTEN__ - opencmw::client::RestClient::CHECK_CERTIFICATES = false; -#endif - using namespace test_state; #ifndef __EMSCRIPTEN__ diff --git a/concepts/client/RestSubscription_example.cpp b/concepts/client/RestSubscription_example.cpp deleted file mode 100644 index 99cb82dc..00000000 --- a/concepts/client/RestSubscription_example.cpp +++ /dev/null @@ -1,127 +0,0 @@ -#include -#include -#include - -#include -#include -#include - -namespace detail { -class EventDispatcher { - std::mutex _mutex; - std::condition_variable _condition; - std::atomic _id{ 0 }; - std::atomic _cid{ -1 }; - std::string _message; - -public: - void wait_event(httplib::DataSink &sink) { - std::unique_lock lk(_mutex); - int id = std::atomic_load_explicit(&_id, std::memory_order_acquire); - _condition.wait(lk, [id, this] { return _cid == id; }); - if (sink.is_writable()) { - sink.write(_message.data(), _message.size()); - } - } - - void send_event(const std::string_view &message) { - std::scoped_lock lk(_mutex); - _cid = _id++; - _message = message; - _condition.notify_all(); - } -}; -} // namespace detail - -int main() { - using namespace std::chrono_literals; - opencmw::client::RestClient client; - - std::atomic updateCounter{ 0 }; - detail::EventDispatcher eventDispatcher; - httplib::Server server; - server.Get("/event", [&eventDispatcher, &updateCounter](const httplib::Request &req, httplib::Response &res) { - auto acceptType = req.headers.find("accept"); - if (acceptType == req.headers.end() || opencmw::MIME::EVENT_STREAM.typeName() != acceptType->second) { // non-SSE request -> return default response -#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16)) - res.set_content(fmt::format("update counter = {}", updateCounter.load()), opencmw::MIME::TEXT); -#else - res.set_content(fmt::format("update counter = {}", updateCounter.load()), std::string(opencmw::MIME::TEXT.typeName())); -#endif - return; - } else { - fmt::print("server received SSE request on path '{}' body = '{}'\n", req.path, req.body); -#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16)) - res.set_chunked_content_provider(opencmw::MIME::EVENT_STREAM, [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) { -#else - res.set_chunked_content_provider(std::string(opencmw::MIME::EVENT_STREAM.typeName()), [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) { -#endif - eventDispatcher.wait_event(sink); - return true; - }); - } - }); - server.Get("/endPoint", [](const httplib::Request &req, httplib::Response &res) { - fmt::print("server received request on path '{}' body = '{}'\n", req.path, req.body); - res.set_content("Hello World!", "text/plain"); - }); - client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); }); - int timeOut = 0; - while (!server.is_running() && timeOut < 10'000) { - std::this_thread::sleep_for(1ms); - timeOut += 1; - } - assert(server.is_running()); - if (!server.is_running()) { - fmt::print("couldn't start server\n"); - std::terminate(); - } - - std::atomic received(false); - opencmw::IoBuffer data; - data.put('A'); - data.put('B'); - data.put('C'); - data.put(0); - opencmw::client::Command command; - command.command = opencmw::mdp::Command::Subscribe; - command.topic = opencmw::URI("http://localhost:8080/event"); - command.data = std::move(data); - command.callback = [&received](const opencmw::mdp::Message &rep) { - fmt::print("SSE client received reply = '{}' - body size: '{}'\n", rep.data.asString(), rep.data.size()); - received.fetch_add(1, std::memory_order_relaxed); - received.notify_all(); - }; - - client.request(command); - - std::cout << "client request launched" << std::endl; - std::this_thread::sleep_for(100ms); - eventDispatcher.send_event("test-event meta data"); - std::jthread([&updateCounter, &eventDispatcher] { - while (updateCounter < 5) { - std::this_thread::sleep_for(500ms); - eventDispatcher.send_event(fmt::format("test-event {}", updateCounter++)); - } - }).join(); - - while (received.load(std::memory_order_relaxed) < 5) { - std::this_thread::sleep_for(100ms); - } - std::cout << "done waiting" << std::endl; - assert(received.load(std::memory_order_acquire) >= 5); - - command.command = opencmw::mdp::Command::Unsubscribe; - client.request(command); - std::this_thread::sleep_for(100ms); - std::cout << "done Unsubscribe" << std::endl; - client.stop(); - std::cout << "client stopped" << std::endl; - - server.stop(); - eventDispatcher.send_event(fmt::format("test-event {}", updateCounter++)); - std::cout << "server stopped" << std::endl; - std::this_thread::sleep_for(5s); - - return 0; -} diff --git a/concepts/client/dns_example.cpp b/concepts/client/dns_example.cpp index f681aa33..159d14c3 100644 --- a/concepts/client/dns_example.cpp +++ b/concepts/client/dns_example.cpp @@ -12,8 +12,9 @@ #include #endif // EMSCRIPTEN +#include + #include -#include using namespace std::chrono_literals; using namespace opencmw; @@ -32,12 +33,17 @@ using namespace opencmw::service::dns; void run_dns_server(std::string_view httpAddress, std::string_view mdpAddress) { majordomo::Broker<> broker{ "Broker", {} }; std::string rootPath{ "./" }; - auto fs = cmrc::assets::get_filesystem(); - majordomo::RestBackend rest_backend{ broker, fs, URI<>{ std::string{ httpAddress } } }; + majordomo::rest::Settings rest; + rest.handlers = { majordomo::rest::cmrcHandler("/assets/*", "", std::make_shared(cmrc::assets::get_filesystem()), "") }; + + if (const auto bound = broker.bindRest(rest); !bound) { + fmt::println(std::cerr, "failed to bind REST: {}", bound.error()); + std::exit(1); + return; + } DnsWorkerType dnsWorker{ broker, DnsHandler{} }; broker.bind(URI<>{ std::string{ mdpAddress } }, majordomo::BindOption::Router); - RunInThread restThread(rest_backend); RunInThread dnsThread(dnsWorker); RunInThread brokerThread(broker); diff --git a/concepts/majordomo/CMakeLists.txt b/concepts/majordomo/CMakeLists.txt index 0ad4a009..daec7361 100644 --- a/concepts/majordomo/CMakeLists.txt +++ b/concepts/majordomo/CMakeLists.txt @@ -38,6 +38,16 @@ target_link_libraries( assets::rest assets::testImages) +add_executable(MajordomoRest_LoadTestServer MajordomoRest_LoadTestServer.cpp) +target_include_directories(MajordomoRest_LoadTestServer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_libraries( + MajordomoRest_LoadTestServer + PRIVATE majordomo + opencmw_project_options + opencmw_project_warnings + assets::rest + assets::testImages) + if(NOT CMAKE_CXX_COMPILER_ID MATCHES diff --git a/concepts/majordomo/MajordomoRest_LoadTestServer.cpp b/concepts/majordomo/MajordomoRest_LoadTestServer.cpp new file mode 100644 index 00000000..84723fe2 --- /dev/null +++ b/concepts/majordomo/MajordomoRest_LoadTestServer.cpp @@ -0,0 +1,59 @@ +#include +#include +#include +#include + +#include "helpers.hpp" + +namespace majordomo = opencmw::majordomo; + +int main(int argc, char **argv) { + using opencmw::URI; + + std::string rootPath = "./"; + std::uint16_t port = 8080; + bool https = true; + + for (int i = 1; i < argc; i++) { + if (std::string_view(argv[i]) == "--port") { + if (i + 1 < argc) { + port = static_cast(std::stoi(argv[i + 1])); + ++i; + continue; + } + } else if (std::string_view(argv[i]) == "--http") { + https = false; + } else { + rootPath = argv[i]; + } + } + + fmt::println(std::cerr, "Starting load test server ({}) for {} on port {}", https ? "HTTPS" : "HTTP", rootPath, port); + + opencmw::majordomo::rest::Settings rest; + rest.port = port; + rest.handlers = { majordomo::rest::cmrcHandler("/assets/*", "", std::make_shared(cmrc::assets::get_filesystem()), "") }; + if (https) { + rest.certificateFilePath = "./demo_public.crt"; + rest.keyFilePath = "./demo_private.key"; + } + + majordomo::Broker broker("/Broker", testSettings()); + opencmw::query::registerTypes(majordomo::load_test::Context(), broker); + + if (const auto bound = broker.bindRest(rest); !bound) { + fmt::println("Could not bind HTTP/2 REST bridge to port {}: {}", rest.port, bound.error()); + return 1; + } + + const auto brokerRouterAddress = broker.bind(URI<>("mds://127.0.0.1:12345")); + if (!brokerRouterAddress) { + std::cerr << "Could not bind to broker address" << std::endl; + return 1; + } + + majordomo::load_test::Worker<"/loadTest"> loadTestWorker(broker); + RunInThread runLoadTest(loadTestWorker); + + broker.run(); +} diff --git a/concepts/majordomo/MajordomoRest_example.cpp b/concepts/majordomo/MajordomoRest_example.cpp index 818ae601..7f0b36e6 100644 --- a/concepts/majordomo/MajordomoRest_example.cpp +++ b/concepts/majordomo/MajordomoRest_example.cpp @@ -1,13 +1,9 @@ #include #include -#include +#include #include -#include -#include -#include #include -#include #include "helpers.hpp" @@ -18,49 +14,59 @@ int main(int argc, char **argv) { using opencmw::URI; std::string rootPath = "./"; - if (argc > 1) { - rootPath = argv[1]; + std::uint16_t port = 8080; + bool https = true; + + for (int i = 1; i < argc; i++) { + if (std::string_view(argv[i]) == "--port") { + if (i + 1 < argc) { + port = static_cast(std::stoi(argv[i + 1])); + ++i; + continue; + } + } else if (std::string_view(argv[i]) == "--http") { + https = false; + } else { + rootPath = argv[i]; + } } - std::cerr << "Starting server for " << rootPath << "\n"; + const auto scheme = https ? "https" : "http"; + auto makeExample = [scheme, port](std::string_view pathAndQuery) { + return fmt::format("{}://localhost:{}/{}", scheme, port, pathAndQuery); + }; + + fmt::println(std::cerr, "Starting {} server for {} on port {}", https ? "HTTPS" : "HTTP", rootPath, port); - std::cerr << "Open up https://localhost:8080/addressbook?contentType=text/html&ctx=FAIR.SELECTOR.ALL in your web browser\n"; - std::cerr << "Or curl -v -k one of the following:\n"; - std::cerr - << "'https://localhost:8080/addressbook?contentType=application/json&ctx=FAIR.SELECTOR.ALL'\n" - << "'https://localhost:8080/addressbook?contentType=application/json&ctx=FAIR.SELECTOR.ALL'\n" - << "'https://localhost:8080/addressbook/addresses?LongPollingId=Next'\n" - << "'https://localhost:8080/addressbook/addresses?LongPollingId=0'\n" - << "'https://localhost:8080/beverages/wine?LongPollingIdx=Subscription'\n"; + fmt::println(std::cerr, "Open up {} in your web browser", makeExample("addressbook?contentType=text/html&ctx=FAIR.SELECTOR.ALL")); + fmt::println(std::cerr, "Or curl -v -k one of the following:"); + fmt::println(std::cerr, "'{}'", makeExample("addressbook?contentType=application/json&ctx=FAIR.SELECTOR.ALL")); + fmt::println(std::cerr, "'{}'", makeExample("addressbook?contentType=application/json&ctx=FAIR.SELECTOR.ALL")); + fmt::println(std::cerr, "'{}'", makeExample("addressbook/addresses?LongPollingIdx=Next")); + fmt::println(std::cerr, "'{}'", makeExample("addressbook/addresses?LongPollingIdx=Last")); + fmt::println(std::cerr, "'{}'", makeExample("addressbook/addresses?LongPollingIdx=0")); + fmt::println(std::cerr, "'{}'", makeExample("beverages/wine?LongPollingIdx=Next")); + fmt::println(std::cerr, "'{}'", makeExample("loadTest?topic=1&intervalMs=40&payloadSize=4096&nUpdates=100&LongPollingIdx=Next")); // note: inconsistency: brokerName as ctor argument, worker's serviceName as NTTP // note: default roles different from java (has: ADMIN, READ_WRITE, READ_ONLY, ANYONE, NULL) majordomo::Broker primaryBroker("/PrimaryBroker", testSettings()); opencmw::query::registerTypes(SimpleContext(), primaryBroker); - - auto fs = cmrc::assets::get_filesystem(); - - std::variant< - std::monostate, - FileServerRestBackend, - FileServerRestBackend> - rest; - if (auto env = ::getenv("DISABLE_REST_HTTPS"); env != nullptr && std::string_view(env) == "1") { - rest.emplace>(primaryBroker, fs, rootPath); - } else { - rest.emplace>(primaryBroker, fs, rootPath); + opencmw::query::registerTypes(majordomo::load_test::Context(), primaryBroker); + + opencmw::majordomo::rest::Settings rest; + rest.port = port; + rest.handlers = { majordomo::rest::cmrcHandler("/assets/*", "", std::make_shared(cmrc::assets::get_filesystem()), "") }; + if (https) { + rest.certificateFilePath = "./demo_public.crt"; + rest.keyFilePath = "./demo_private.key"; + } + if (const auto bound = primaryBroker.bindRest(rest); !bound) { + fmt::println("Could not bind HTTP/2 REST bridge to port {}: {}", rest.port, bound.error()); + return 1; } - std::jthread restServerThread([&rest] { - std::visit([](T &server) { - if constexpr (not std::is_same_v) { - server.run(); - } - }, - rest); - }); - - const auto brokerRouterAddress = primaryBroker.bind(URI<>("mds://127.0.0.1:12345")); + const auto brokerRouterAddress = primaryBroker.bind(URI<>("mds://127.0.0.1:12345")); if (!brokerRouterAddress) { std::cerr << "Could not bind to broker address" << std::endl; return 1; @@ -82,6 +88,7 @@ int main(int argc, char **argv) { majordomo::Worker<"/addressbook", SimpleContext, AddressRequest, AddressEntry> addressbookWorker(primaryBroker, TestAddressHandler()); majordomo::Worker<"/addressbookBackup", SimpleContext, AddressRequest, AddressEntry> addressbookBackupWorker(primaryBroker, TestAddressHandler()); majordomo::BasicWorker<"/beverages"> beveragesWorker(primaryBroker, TestIntHandler(10)); + majordomo::load_test::Worker<"/loadTest"> loadTestWorker(primaryBroker); // ImageServiceWorker<"/testImage", majordomo::description<"Returns an image">> imageWorker(primaryBroker, std::chrono::seconds(10)); @@ -91,6 +98,7 @@ int main(int argc, char **argv) { RunInThread runAddressbook(addressbookWorker); RunInThread runAddressbookBackup(addressbookBackupWorker); RunInThread runBeverages(beveragesWorker); + RunInThread runLoadTest(loadTestWorker); RunInThread runImage(imageWorker); waitUntilWorkerServiceAvailable(primaryBroker.context, addressbookWorker); diff --git a/concepts/majordomo/assets/mustache/ServicesList.mustache b/concepts/majordomo/assets/mustache/ServicesList.mustache index c0e5be6f..8ac6e980 100644 --- a/concepts/majordomo/assets/mustache/ServicesList.mustache +++ b/concepts/majordomo/assets/mustache/ServicesList.mustache @@ -96,7 +96,7 @@ input.topicInput { border: 0; width: 100px; border-bottom: 1px solid silver; } }; listenButton.onclick = () => { - let post = { method: 'GET', headers: { 'X-OPENCMW-METHOD' : 'POLL' } }; + let post = { method: 'GET', headers: { 'x-opencmw-method' : 'POLL' } }; fetch(href + "/" + topicInput.value, post) .then(response => response.text()) diff --git a/concepts/majordomo/assets/mustache/default.mustache b/concepts/majordomo/assets/mustache/default.mustache index c243a8db..da7a30f1 100644 --- a/concepts/majordomo/assets/mustache/default.mustache +++ b/concepts/majordomo/assets/mustache/default.mustache @@ -174,7 +174,7 @@ } function pollingHandler() { - let get = { method: 'GET', headers: { 'X-OPENCMW-METHOD' : 'POLL' } }; + let get = { method: 'GET', headers: { 'x-opencmw-method' : 'POLL' } }; let queryParams = window.opencmwActiveSubscriptionQueryParams; diff --git a/concepts/majordomo/helpers.hpp b/concepts/majordomo/helpers.hpp index c5c08faf..3a7bea37 100644 --- a/concepts/majordomo/helpers.hpp +++ b/concepts/majordomo/helpers.hpp @@ -12,7 +12,6 @@ // OpenCMW Majordomo #include -#include #include CMRC_DECLARE(testImages); @@ -204,43 +203,6 @@ class ImageServiceWorker : public majordomo::Worker -class FileServerRestBackend : public majordomo::RestBackend { -private: - using super_t = majordomo::RestBackend; - std::filesystem::path _serverRoot; - using super_t::_svr; - using super_t::DEFAULT_REST_SCHEME; - -public: - using super_t::RestBackend; - - FileServerRestBackend(majordomo::Broker &broker, const VirtualFS &vfs, std::filesystem::path serverRoot, opencmw::URI<> restAddress = opencmw::URI<>::factory().scheme(DEFAULT_REST_SCHEME).hostName("0.0.0.0").port(majordomo::DEFAULT_REST_PORT).build()) - : super_t(broker, vfs, restAddress), _serverRoot(std::move(serverRoot)) { - } - - void registerHandlers() override { - _svr.set_mount_point("/", _serverRoot.string()); - - _svr.Post("/stdio.html", [](const httplib::Request &request, httplib::Response &response) { - opencmw::debug::log() << "QtWASM:" << request.body; - response.set_content("", "text/plain"); - }); - - auto cmrcHandler = [this](const httplib::Request &request, httplib::Response &response) { - if (super_t::_vfs.is_file(request.path)) { - auto file = super_t::_vfs.open(request.path); - response.set_content(std::string(file.begin(), file.end()), ""); - } - }; - - _svr.Get("/assets/.*", cmrcHandler); - - // Register default handlers - super_t::registerHandlers(); - } -}; - template concept Shutdownable = requires(T s) { s.run(); diff --git a/docs/RestUriMapping.md b/docs/RestUriMapping.md index 6ee04d0e..3a03abb6 100644 --- a/docs/RestUriMapping.md +++ b/docs/RestUriMapping.md @@ -24,11 +24,11 @@ and `POLL` which maps to `LongPoll`. For clients that don't support a custom request operation, the non-standard HTTP requests can be invoked by defining the -`X-OPENCMW-METHOD` HTTP header. +`x-opencmw-method` HTTP header. If this header is defined, its value is used to override the HTTP request method. -This means that if HTTP method is `GET` and `X-OPENCMW-METHOD=POLL`, +This means that if HTTP method is `GET` and `x-opencmw-method=POLL`, that the REST backend will treat it as `LongPoll` request. Alternatively, one can override the Majordomo method diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index dd5e82cc..10c876be 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -8,6 +8,7 @@ if(EMSCRIPTEN) message("Disabled majordomo and mdp client module because they are not compatible with emscripten builds") else() add_subdirectory(majordomo) + add_subdirectory(nghttp2) add_subdirectory(zmq) endif() diff --git a/src/client/CMakeLists.txt b/src/client/CMakeLists.txt index 3eb165fb..9f6b9367 100644 --- a/src/client/CMakeLists.txt +++ b/src/client/CMakeLists.txt @@ -12,11 +12,12 @@ if(NOT EMSCRIPTEN) target_link_libraries( client INTERFACE pthread - core - disruptor - serialiser - majordomo - zmq) + core + disruptor + serialiser + majordomo + nghttp2 + zmq) else() target_link_libraries( client diff --git a/src/client/include/ClientCommon.hpp b/src/client/include/ClientCommon.hpp index 76d74d6a..4b064aac 100644 --- a/src/client/include/ClientCommon.hpp +++ b/src/client/include/ClientCommon.hpp @@ -43,7 +43,7 @@ constexpr auto find_argument_value_helper(Item &item) { } template -requires std::is_invocable_r_v + requires std::is_invocable_r_v constexpr RequiredType find_argument_value(Func defaultGenerator, Items... args) { auto ret = std::tuple_cat(find_argument_value_helper(args)...); if constexpr (std::tuple_size_v == 0) { @@ -53,8 +53,6 @@ constexpr RequiredType find_argument_value(Func defaultGenerator, Items... args) } } -constexpr const char *ACCEPT_HEADER = "accept"; -constexpr const char *CONTENT_TYPE_HEADER = "content-type"; } // namespace detail class DefaultContentTypeHeader { @@ -62,12 +60,22 @@ class DefaultContentTypeHeader { public: DefaultContentTypeHeader(const MIME::MimeType &type) noexcept - : _mimeType(type){}; + : _mimeType(type) {} DefaultContentTypeHeader(const std::string_view type_str) noexcept - : _mimeType(MIME::getType(type_str)){}; + : _mimeType(MIME::getType(type_str)) {} constexpr operator const MIME::MimeType() const noexcept { return _mimeType; }; }; +class VerifyServerCertificates { + const bool verifyServerCertificates = false; + +public: + VerifyServerCertificates() = default; + VerifyServerCertificates(bool value) noexcept + : verifyServerCertificates(value) {} + constexpr operator bool() const noexcept { return verifyServerCertificates; }; +}; + } // namespace opencmw::client #endif // include guard diff --git a/src/client/include/ClientContext.hpp b/src/client/include/ClientContext.hpp index 2ad9f5a1..158d2fff 100644 --- a/src/client/include/ClientContext.hpp +++ b/src/client/include/ClientContext.hpp @@ -1,7 +1,6 @@ -#ifndef OPENCMW_CPP_DATASOUCREPUBLISHER_HPP -#define OPENCMW_CPP_DATASOUCREPUBLISHER_HPP +#ifndef OPENCMW_CPP_CLIENTCONTEXT_HPP +#define OPENCMW_CPP_CLIENTCONTEXT_HPP -#include #include #include #include @@ -123,4 +122,4 @@ class ClientContext { }; } // namespace opencmw::client -#endif // OPENCMW_CPP_DATASOUCREPUBLISHER_HPP +#endif // OPENCMW_CPP_CLIENTCONTEXT_HPP diff --git a/src/client/include/RestClientEmscripten.hpp b/src/client/include/RestClientEmscripten.hpp index 0f4255f9..e4f6acd9 100644 --- a/src/client/include/RestClientEmscripten.hpp +++ b/src/client/include/RestClientEmscripten.hpp @@ -1,11 +1,11 @@ #ifndef OPENCMW_CPP_RESTCLIENT_EMSCRIPTEN_HPP #define OPENCMW_CPP_RESTCLIENT_EMSCRIPTEN_HPP +#include #include #include #include -#include #include #include @@ -17,14 +17,14 @@ using namespace opencmw; namespace opencmw::client { - namespace detail { +namespace detail { /*** * Get the final URL of a possibly redirected HTTP fetch call. * Uses Javascript to return the the url as a string. */ - static std::string getFinalURL(std::uint32_t id) { - auto finalURLChar = static_cast(EM_ASM_PTR({ +static std::string getFinalURL(std::uint32_t id) { + auto finalURLChar = static_cast(EM_ASM_PTR({ var fetch = Fetch.xhrs.get($0); if (fetch) { var finalURL = fetch.responseURL; @@ -33,192 +33,210 @@ namespace opencmw::client { stringToUTF8(finalURL, stringOnWasmHeap, lengthBytes); return stringOnWasmHeap; } - return 0; - }, id)); - std::string finalURL{finalURLChar, strlen(finalURLChar)}; - EM_ASM({ _free($0) }, finalURLChar); - return finalURL; - } + return 0; }, id)); + std::string finalURL{ finalURLChar, strlen(finalURLChar) }; + EM_ASM({ _free($0) }, finalURLChar); + return finalURL; +} + +struct pointer_equals { + using is_transparent = void; + + template + bool operator()(const Left &left, const Right &right) const { + return std::to_address(left) == std::to_address(right); + } +}; - struct pointer_equals { - using is_transparent = void; +struct pointer_hash { + using is_transparent = void; - template - bool operator()(const Left &left, const Right &right) const { - return std::to_address(left) == std::to_address(right); - } - }; + template + std::size_t operator()(const Pointer &ptr) const { + const auto *raw = std::to_address(ptr); + return std::hash{}(raw); + } +}; - struct pointer_hash { - using is_transparent = void; +auto checkedStringViewSize = [](auto numBytes) { + if (numBytes > std::numeric_limits::max()) { + throw fmt::format("We received more data than we can handle {}", numBytes); + } + return static_cast(numBytes); +}; - template - std::size_t operator()(const Pointer &ptr) const { - const auto *raw = std::to_address(ptr); - return std::hash{}(raw); - } - }; +std::array getPreferredContentTypeHeader(const URI &uri, auto _mimeType) { + auto mimeType = std::string(_mimeType.typeName()); + if (const auto acceptHeader = uri.queryParamMap().find("accept"); acceptHeader != uri.queryParamMap().end() && acceptHeader->second) { + mimeType = acceptHeader->second->c_str(); + } + return { "accept", mimeType, "content-type", mimeType }; +} - auto checkedStringViewSize = [](auto numBytes) { - if (numBytes > std::numeric_limits::max()) { - throw fmt::format("We received more data than we can handle {}", numBytes); - } - return static_cast(numBytes); - }; +struct FetchPayload { + Command command; - std::array getPreferredContentTypeHeader(const URI &uri, auto _mimeType) { - auto mimeType = std::string(_mimeType.typeName()); - if (const auto acceptHeader = uri.queryParamMap().find(ACCEPT_HEADER); acceptHeader != uri.queryParamMap().end() && - acceptHeader->second) { - mimeType = acceptHeader->second->c_str(); - } - return {ACCEPT_HEADER, mimeType, CONTENT_TYPE_HEADER, mimeType}; + explicit FetchPayload(Command &&_command) + : command(std::move(_command)) {} + + FetchPayload(const FetchPayload &other) = delete; + + FetchPayload(FetchPayload &&other) noexcept = default; + + FetchPayload &operator=(const FetchPayload &other) = delete; + + FetchPayload &operator=(FetchPayload &&other) noexcept = default; + + void returnMdpMessage(unsigned short status, std::string_view body, std::string_view errorMsgExt = "") noexcept { + if (!command.callback) { + return; + } + const bool msgOK = status >= 200 && status < 400; + try { + command.callback(mdp::Message{ + .id = 0, + .arrivalTime = std::chrono::system_clock::now(), + .protocolName = command.topic.scheme().value(), + .command = mdp::Command::Final, + .clientRequestID = command.clientRequestID, + .topic = command.topic, + .data = msgOK ? IoBuffer(body.data(), body.size()) : IoBuffer(), + .error = msgOK ? std::string(errorMsgExt) : fmt::format("{} - {}{}{}", status, errorMsgExt, body.empty() ? "" : ":", body), + .rbac = IoBuffer() }); + } catch (const std::exception &e) { + std::cerr + << fmt::format("caught exception '{}' in FetchPayload::returnMdpMessage(cmd={}, {}: {})", e.what(), command.topic, status, + body) + << std::endl; + } catch (...) { + std::cerr + << fmt::format("caught unknown exception in FetchPayload::returnMdpMessage(cmd={}, {}: {})", command.topic, status, body) + << std::endl; } + } - struct FetchPayload { - Command command; + void onsuccess(unsigned short status, std::string_view data) { + returnMdpMessage(status, data); + } - FetchPayload(Command &&_command) - : command(std::move(_command)) {} + void onerror(unsigned short status, std::string_view error, std::string_view data) { + returnMdpMessage(status, data, error); + } +}; - FetchPayload(const FetchPayload &other) = delete; +static std::unordered_set, detail::pointer_hash, detail::pointer_equals> fetchPayloads; - FetchPayload(FetchPayload &&other) noexcept = default; +struct SubscriptionPayload; +static std::unordered_set, detail::pointer_hash, detail::pointer_equals> subscriptionPayloads; - FetchPayload &operator=(const FetchPayload &other) = delete; +struct SubscriptionPayload : FetchPayload { + bool _live = true; + MIME::MimeType _mimeType; + std::size_t _update = 0; - FetchPayload &operator=(FetchPayload &&other) noexcept = default; + static constexpr std::size_t kParallelLongPollingRequests = 3; + std::vector _requestedIndexes; - void returnMdpMessage(unsigned short status, std::string_view body, std::string_view errorMsgExt = "") noexcept { - if (!command.callback) { - return; - } - const bool msgOK = status >= 200 && status < 400; - const auto errorMsg = msgOK ? errorMsgExt : fmt::format("{} - {}{}{}", status, errorMsgExt, body.empty() ? "" : ":", body); - try { - command.callback(mdp::Message{ - .id = 0, - .arrivalTime = std::chrono::system_clock::now(), - .protocolName = command.topic.scheme().value(), - .command = mdp::Command::Final, - .clientRequestID = command.clientRequestID, - .topic = command.topic, - .data = msgOK ? IoBuffer(body.data(), body.size()) : IoBuffer(), - .error = std::string{errorMsg}, - .rbac = IoBuffer()}); - } catch (const std::exception &e) { - std::cerr - << fmt::format("caught exception '{}' in FetchPayload::returnMdpMessage(cmd={}, {}: {})", e.what(), command.topic, status, - body) << std::endl; - } catch (...) { - std::cerr - << fmt::format("caught unknown exception in FetchPayload::returnMdpMessage(cmd={}, {}: {})", command.topic, status, body) - << std::endl; - } - } + SubscriptionPayload(Command &&_command, MIME::MimeType mimeType) + : FetchPayload(std::move(_command)), _mimeType(std::move(mimeType)) {} + + SubscriptionPayload(const SubscriptionPayload &other) = delete; + + SubscriptionPayload(SubscriptionPayload &&other) noexcept = default; - void onsuccess(unsigned short status, std::string_view data) { - returnMdpMessage(status, data); + SubscriptionPayload &operator=(const SubscriptionPayload &other) = delete; + + SubscriptionPayload &operator=(SubscriptionPayload &&other) noexcept = default; + + void sendFollowUpRequestsFor(std::uint64_t longPollingIdx) { + auto it = std::ranges::find(_requestedIndexes, longPollingIdx); + if (it != _requestedIndexes.end()) { + _requestedIndexes.erase(it); + } + for (std::uint64_t i = longPollingIdx + 1; i < longPollingIdx + kParallelLongPollingRequests; ++i) { + if (std::ranges::find(_requestedIndexes, i) == _requestedIndexes.end()) { + _requestedIndexes.push_back(i); + request(std::to_string(i)); } + } + } + void request(std::string longPollingIndex) { + auto uri = opencmw::URI::UriFactory(command.topic).addQueryParameter("LongPollingIdx", longPollingIndex).build(); + auto preferredHeader = detail::getPreferredContentTypeHeader(command.topic, _mimeType); - void onerror(unsigned short status, std::string_view error, std::string_view data) { - returnMdpMessage(status, data, error); + std::array preferredHeaderEmscripten; + std::transform(preferredHeader.cbegin(), preferredHeader.cend(), preferredHeaderEmscripten.begin(), + [](const auto &str) { return str.c_str(); }); + preferredHeaderEmscripten[preferredHeaderEmscripten.size() - 1] = nullptr; + + emscripten_fetch_attr_t attr{}; + + emscripten_fetch_attr_init(&attr); + + strcpy(attr.requestMethod, "GET"); + + attr.userData = this; + static auto getPayloadIt = [](emscripten_fetch_t *fetch) { + auto *rawPayload = fetch->userData; + auto it = detail::subscriptionPayloads.find(rawPayload); + if (it == detail::subscriptionPayloads.end()) { + fmt::print("RestClientEmscripten::payloadError: url: {}, bytes: {}\n", fetch->url, fetch->numBytes); + throw fmt::format("Unknown payload for a resulting subscription"); } + return it; }; - static std::unordered_set, detail::pointer_hash, detail::pointer_equals> fetchPayloads; - - struct SubscriptionPayload; - static std::unordered_set, detail::pointer_hash, detail::pointer_equals> subscriptionPayloads; - - struct SubscriptionPayload : FetchPayload { - bool _live = true; - MIME::MimeType _mimeType; - std::size_t _update = 0; - - SubscriptionPayload(Command &&_command, MIME::MimeType mimeType) - : FetchPayload(std::move(_command)), _mimeType(std::move(mimeType)) {} - - SubscriptionPayload(const SubscriptionPayload &other) = delete; - - SubscriptionPayload(SubscriptionPayload &&other) noexcept = default; - - SubscriptionPayload &operator=(const SubscriptionPayload &other) = delete; - - SubscriptionPayload &operator=(SubscriptionPayload &&other) noexcept = default; - - void requestNext() { - auto uri = opencmw::URI::UriFactory(command.topic).addQueryParameter("LongPollingIdx", - (_update == 0) ? "Next" : fmt::format("{}", - _update)).build(); - //fmt::print("URL 1 >>> {}, thread {}\n", uri.relativeRef(), std::this_thread::get_id()); - auto preferredHeader = detail::getPreferredContentTypeHeader(command.topic, _mimeType); - std::array preferredHeaderEmscripten; - std::transform(preferredHeader.cbegin(), preferredHeader.cend(), preferredHeaderEmscripten.begin(), - [](const auto &str) { return str.c_str(); }); - preferredHeaderEmscripten[preferredHeaderEmscripten.size() - 1] = nullptr; - - emscripten_fetch_attr_t attr{}; - - emscripten_fetch_attr_init(&attr); - - strcpy(attr.requestMethod, "GET"); - - attr.userData = this; - static auto getPayloadIt = [](emscripten_fetch_t *fetch) { - auto *rawPayload = fetch->userData; - auto it = detail::subscriptionPayloads.find(rawPayload); - if (it == detail::subscriptionPayloads.end()) { - fmt::print("RestClientEmscripten::payloadError: url: {}, bytes: {}\n", fetch->url, fetch->numBytes); - throw fmt::format("Unknown payload for a resulting subscription"); - } - return it; - }; - - attr.attributes = EMSCRIPTEN_FETCH_LOAD_TO_MEMORY; - attr.requestHeaders = preferredHeaderEmscripten.data(); - attr.onsuccess = [](emscripten_fetch_t *fetch) { - auto payloadIt = getPayloadIt(fetch); - auto &payload = *payloadIt; - //fmt::print("received update: {}, {}\n", fetch->url, payload->_update); - if (payload->_live) { - std::string finalURL = getFinalURL(fetch->id); - std::string longPollingIdxString = opencmw::URI<>(finalURL).queryParamMap().at("LongPollingIdx").value_or("0"); - char *end = longPollingIdxString.data() + longPollingIdxString.size(); - std::size_t longPollingIdx = strtoull(longPollingIdxString.data(), &end, 10); - if (payload->_update != 0 && longPollingIdx != payload->_update) { - fmt::print("received unexpected update: {}, expected {}\n", longPollingIdx, payload->_update); - } - payload->onsuccess(fetch->status, std::string_view(fetch->data, detail::checkedStringViewSize(fetch->numBytes)), static_cast(longPollingIdx) - static_cast(payload->_update)); - emscripten_fetch_close(fetch); - payload->_update = longPollingIdx + 1; - payload->requestNext(); - } else { - detail::subscriptionPayloads.erase(payloadIt); - } - }; - attr.onerror = [](emscripten_fetch_t *fetch) { - auto payloadIt = getPayloadIt(fetch); - auto &payload = *payloadIt; - payload->onerror(fetch->status, std::string_view(fetch->data, detail::checkedStringViewSize(fetch->numBytes)), fetch->statusText); - emscripten_fetch_close(fetch); - }; - emscripten_fetch(&attr, uri.str().data()); - } + attr.attributes = EMSCRIPTEN_FETCH_LOAD_TO_MEMORY; + attr.requestHeaders = preferredHeaderEmscripten.data(); + attr.onsuccess = [](emscripten_fetch_t *fetch) { + auto payloadIt = getPayloadIt(fetch); + auto &payload = *payloadIt; + std::uint64_t longPollingIdx = 0; + // fmt::print("received update: {}, {}\n", fetch->url, payload->_update); + if (payload->_live) { + std::string finalURL = getFinalURL(fetch->id); + std::string longPollingIdxString = opencmw::URI<>(finalURL).queryParamMap().at("LongPollingIdx").value_or("0"); + + char *end = nullptr; + longPollingIdx = strtoull(longPollingIdxString.data(), &end, 10); + if (end != longPollingIdxString.data() + longPollingIdxString.size()) { + fmt::println(std::cerr, "RestClientEmscripten::payloadError: url: {}, bytes: {}\n", fetch->url, fetch->numBytes); + return; + } - void onsuccess(unsigned short status, std::string_view data, long idxDifference = 0) { - std::string skippedWarning; - if (idxDifference != 0) { - skippedWarning = fmt::format("Warning: skipped {} samples", idxDifference); + if (payload->_update != 0 && longPollingIdx != payload->_update) { + fmt::print("received unexpected update: {}, expected {}\n", longPollingIdx, payload->_update); } - returnMdpMessage(status, data, skippedWarning); - } + payload->onsuccess(fetch->status, std::string_view(fetch->data, detail::checkedStringViewSize(fetch->numBytes)), static_cast(longPollingIdx) - static_cast(payload->_update)); + emscripten_fetch_close(fetch); - void onerror(unsigned short status, std::string_view error, std::string_view data) { - returnMdpMessage(status, data, error); + payload->sendFollowUpRequestsFor(longPollingIdx); + } else { + detail::subscriptionPayloads.erase(payloadIt); } }; - } // namespace detail + attr.onerror = [](emscripten_fetch_t *fetch) { + auto payloadIt = getPayloadIt(fetch); + auto &payload = *payloadIt; + payload->onerror(fetch->status, std::string_view(fetch->data, detail::checkedStringViewSize(fetch->numBytes)), fetch->statusText); + emscripten_fetch_close(fetch); + }; + emscripten_fetch(&attr, uri.str().data()); + } + + void onsuccess(unsigned short status, std::string_view data, long idxDifference = 0) { + std::string skippedWarning; + if (idxDifference != 0) { + skippedWarning = fmt::format("Warning: skipped {} samples", idxDifference); + } + returnMdpMessage(status, data, skippedWarning); + } + + void onerror(unsigned short status, std::string_view error, std::string_view data) { + returnMdpMessage(status, data, error); + } +}; +} // namespace detail class RestClient : public ClientBase { std::string _name; @@ -233,7 +251,7 @@ class RestClient : public ClientBase { * Initialises a basic RestClient * * usage example: - * RestClient client("clientName", DefaultContentTypeHeader(MIME::HTML), MinIoThreads(2), MaxIoThreads(5), ClientCertificates(testCertificate)) + * RestClient client("clientName", DefaultContentTypeHeader(MIME::HTML), ClientCertificates(testCertificate)) * * @tparam Args see argument example above. Order is arbitrary. * @param initArgs @@ -243,9 +261,9 @@ class RestClient : public ClientBase { : _name(detail::find_argument_value([] { return "RestClient"; }, initArgs...)) , _mimeType(detail::find_argument_value([] { return MIME::JSON; }, initArgs...)) { } - ~RestClient() { RestClient::stop(); }; + ~RestClient() { RestClient::stop(); } - void stop() override {}; + void stop() override {} std::vector protocols() noexcept override { return { "http", "https" }; } @@ -308,7 +326,7 @@ class RestClient : public ClientBase { attr.attributes = EMSCRIPTEN_FETCH_LOAD_TO_MEMORY; attr.requestHeaders = preferredHeaderEmscripten.data(); attr.onsuccess = [](emscripten_fetch_t *fetch) { - //fmt::print("RestClientEmscripten: got get/set reply: {}\n", fetch->url); + // fmt::print("RestClientEmscripten: got get/set reply: {}\n", fetch->url); getPayload(fetch)->onsuccess(fetch->status, std::string_view(fetch->data, detail::checkedStringViewSize(fetch->numBytes))); emscripten_fetch_close(fetch); }; @@ -324,13 +342,13 @@ class RestClient : public ClientBase { } void startSubscription(Command &&cmd) { - auto payload = std::make_unique(std::move(cmd), _mimeType); + auto payload = std::make_unique(std::move(cmd), _mimeType); auto rawPayload = payload.get(); detail::subscriptionPayloads.insert(std::move(payload)); fmt::print("starting subscription: {}, existing subscriptions: {}, from main thread: \n", cmd.topic.str(), detail::subscriptionPayloads.size(), emscripten_is_main_runtime_thread()); if (emscripten_is_main_runtime_thread()) { try { - rawPayload->requestNext(); + rawPayload->request("Next"); } catch (std::runtime_error &e) { rawPayload->onerror(500, e.what(), ""); } catch (...) { @@ -340,19 +358,18 @@ class RestClient : public ClientBase { emscripten_async_run_in_main_runtime_thread(EM_FUNC_SIG_IP, +[](void *data) { auto subPayload = reinterpret_cast(data); try { - subPayload->requestNext(); + subPayload->request("Next"); } catch (std::runtime_error &e) { subPayload->onerror(500, e.what(), ""); } catch (...) { subPayload->onerror(500, "failed to set up subscription", ""); } - return 0; - }, rawPayload); + return 0; }, rawPayload); } } void stopSubscription(Command &&cmd) { - auto payloadIt = std::find_if(detail::subscriptionPayloads.begin(), detail::subscriptionPayloads.end(), + auto payloadIt = std::ranges::find_if(detail::subscriptionPayloads, [&](const auto &ptr) { return ptr->command.topic == cmd.topic; }); diff --git a/src/client/include/RestClientNative.hpp b/src/client/include/RestClientNative.hpp index aadcf7b1..11fb26d7 100644 --- a/src/client/include/RestClientNative.hpp +++ b/src/client/include/RestClientNative.hpp @@ -1,411 +1,815 @@ -#ifndef OPENCMW_CPP_RESTCLIENT_NATIVE_HPP -#define OPENCMW_CPP_RESTCLIENT_NATIVE_HPP - -#include -#include -#include - -#include -#include -#include -#include -#include - -#include "RestDefaultClientCertificates.hpp" - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wformat-nonliteral" -#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" -#pragma GCC diagnostic ignored "-Wold-style-cast" -#pragma GCC diagnostic ignored "-Wshadow" -#pragma GCC diagnostic ignored "-Wuninitialized" -#pragma GCC diagnostic ignored "-Wuseless-cast" -#include -#pragma GCC diagnostic pop - +#ifndef OPENCMW_CLIENT_RESTCLIENTNATIVE_HPP +#define OPENCMW_CLIENT_RESTCLIENTNATIVE_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include "ClientCommon.hpp" +#include "ClientContext.hpp" +#include "MdpMessage.hpp" +#include "MIME.hpp" +#include "nghttp2/NgHttp2Utils.hpp" +#include "Topic.hpp" +#ifdef OPENCMW_PROFILE_HTTP +#include "LoadTest.hpp" +#endif namespace opencmw::client { +enum class SubscriptionMode { + Next, + Last, + None +}; -inline constexpr static const char *LONG_POLLING_IDX_TAG = "LongPollingIdx"; +namespace detail { -class MinIoThreads { - const int _minThreads = 1; +using namespace opencmw::nghttp2; +using namespace opencmw::nghttp2::detail; -public: - MinIoThreads() = default; - MinIoThreads(int value) noexcept - : _minThreads(value) {}; - constexpr operator int() const noexcept { return _minThreads; }; +template +struct SharedQueue { + // TODO use a lock-free queue? This is only used client-side, so not that critical + std::deque deque; + std::mutex mutex; + + void push(T v) { + std::lock_guard lock(mutex); + deque.push_back(std::move(v)); + } + + std::optional try_get() { + std::lock_guard lock(mutex); + if (deque.empty()) { + return {}; + } + auto result = std::move(deque.front()); + deque.pop_front(); + return result; + } }; -class MaxIoThreads { - const int _maxThreads = 10'000; +struct RequestResponse { + // request data + client::Command request; + std::string normalizedTopic; + + // response data + std::string responseStatus; + std::string location; // for redirects + std::optional longPollingIdx; + std::string payload; + mdp::Message response; + + void fillResponse() { + response.id = 0; + response.arrivalTime = std::chrono::system_clock::now(); + response.protocolName = request.topic.scheme().value_or(""); + response.rbac.clear(); + + switch (request.command) { + case mdp::Command::Get: + case mdp::Command::Set: + response.command = mdp::Command::Final; + response.clientRequestID = request.clientRequestID; + break; + case mdp::Command::Subscribe: + response.command = mdp::Command::Notify; + break; + default: + break; + } + } -public: - MaxIoThreads() = default; - MaxIoThreads(int value) noexcept - : _maxThreads(value) {}; - constexpr operator int() const noexcept { return _maxThreads; }; + void reportError(std::string error) { + if (!request.callback) { + HTTP_DBG("Client::reportError: {}", error); + return; + } + fillResponse(); + response.topic = request.topic; + response.error = std::move(error); + response.data.clear(); + request.callback(std::move(response)); + } }; -struct ClientCertificates { - std::string _certificates; +constexpr std::size_t kParallelLongPollingRequests = 3; - ClientCertificates() = default; - ClientCertificates(const char *X509_ca_bundle) noexcept - : _certificates(X509_ca_bundle) {}; - ClientCertificates(const std::string &X509_ca_bundle) noexcept - : _certificates(X509_ca_bundle) {}; - constexpr operator std::string() const noexcept { return _certificates; }; +struct Subscription { + client::Command request; + SubscriptionMode mode; + std::optional lastReceivedLongPollingIdx; + std::vector> callbacks; }; -namespace detail { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -inline int readCertificateBundleFromBuffer(X509_STORE &cert_store, const std::string_view &X509_ca_bundle) { - BIO *cbio = BIO_new_mem_buf(X509_ca_bundle.data(), static_cast(X509_ca_bundle.size())); - if (!cbio) { - return -1; - } - STACK_OF(X509_INFO) *inf = PEM_X509_INFO_read_bio(cbio, nullptr, nullptr, nullptr); +struct Endpoint { + std::string scheme; + std::string host; + uint16_t port; - if (!inf) { - BIO_free(cbio); // cleanup - return -1; - } - // iterate over all entries from the pem file, add them to the x509_store one by one - int count = 0; - for (int i = 0; i < sk_X509_INFO_num(inf); i++) { - X509_INFO *itmp = sk_X509_INFO_value(inf, i); - if (itmp->x509) { - X509_STORE_add_cert(&cert_store, itmp->x509); - count++; - } - if (itmp->crl) { - X509_STORE_add_crl(&cert_store, itmp->crl); - count++; + auto operator<=>(const Endpoint &) const = default; +}; + +struct ClientSession { + struct PendingRequest { + client::Command command; + SubscriptionMode mode; + std::string preferredMimeType; + std::optional longPollIdx; + }; + + TcpSocket _socket; + nghttp2_session *_session = nullptr; + WriteBuffer<1024> _writeBuffer; + std::map _subscriptions; + std::map _requestsByStreamId; + + explicit ClientSession(TcpSocket socket_) + : _socket(std::move(socket_)) { + nghttp2_session_callbacks *callbacks; + nghttp2_session_callbacks_new(&callbacks); + + nghttp2_session_callbacks_set_send_callback2(callbacks, [](nghttp2_session *, const uint8_t *data, size_t length, int flags, void *user_data) { + auto client = static_cast(user_data); + HTTP_DBG("Client::send {}", length); + const auto r = client->_socket.write(data, length, flags); + if (r < 0) { + HTTP_DBG("Client::send failed: {}", client->_socket.lastError()); + } + return r; + }); + + nghttp2_session_callbacks_set_on_data_chunk_recv_callback(callbacks, [](nghttp2_session *, uint8_t /*flags*/, int32_t stream_id, const uint8_t *data, size_t len, void *user_data) { + auto client = static_cast(user_data); + client->_requestsByStreamId[stream_id].payload.append(reinterpret_cast(data), len); + return 0; + }); + + nghttp2_session_callbacks_set_on_header_callback(callbacks, [](nghttp2_session *, const nghttp2_frame *frame, const uint8_t *name, size_t namelen, const uint8_t *value, size_t valuelen, uint8_t /*flags*/, void *user_data) { + auto client = static_cast(user_data); + const auto nameView = std::string_view(reinterpret_cast(name), namelen); + const auto valueView = std::string_view(reinterpret_cast(value), valuelen); + HTTP_DBG("Client::Header: id={} {} = {}", frame->hd.stream_id, nameView, valueView); + if (nameView == ":status") { + client->_requestsByStreamId[frame->hd.stream_id].responseStatus = std::string(valueView); + } else if (nameView == "location") { + client->_requestsByStreamId[frame->hd.stream_id].location = std::string(valueView); + } else if (nameView == "x-opencmw-topic") { + try { + client->_requestsByStreamId[frame->hd.stream_id].response.topic = URI<>(std::string(valueView)); + } catch (const std::exception &e) { + HTTP_DBG("Client::Header: Could not parse URI '{}': {}", valueView, e.what()); + return static_cast(NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE); + } + } else if (nameView == "x-opencmw-service-name") { + client->_requestsByStreamId[frame->hd.stream_id].response.serviceName = std::string(valueView); + } else if (nameView == "x-opencmw-long-polling-idx") { + std::uint64_t longPollingIdx; + if (auto ec = std::from_chars(valueView.data(), valueView.data() + valueView.size(), longPollingIdx); ec.ec != std::errc{}) { + HTTP_DBG("Client::Header: Could not parse x-opencmw-long-polling-idx '{}'", valueView); + return static_cast(NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE); + } + client->_requestsByStreamId[frame->hd.stream_id].longPollingIdx = longPollingIdx; +#ifdef OPENCMW_PROFILE_HTTP + } else if (nameView == "x-timestamp") { + fmt::println(std::cerr, "Client::Header: x-timestamp: {} (latency {} ns)", valueView, latency(valueView).count()); +#endif + } + + return 0; + }); + nghttp2_session_callbacks_set_on_frame_recv_callback(callbacks, [](nghttp2_session *, const nghttp2_frame *frame, void *user_data) { + HTTP_DBG("Client::Frame: id={} {} {}", frame->hd.stream_id, frame->hd.type, (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) ? "END_STREAM" : ""); + switch (frame->hd.type) { + case NGHTTP2_HEADERS: + case NGHTTP2_DATA: + if (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) { + auto client = static_cast(user_data); + return client->processResponse(frame->hd.stream_id); + } + break; + } + return 0; + }); + nghttp2_session_callbacks_set_on_stream_close_callback(callbacks, [](nghttp2_session *, int32_t stream_id, uint32_t /*error_code*/, void *user_data) { + auto client = static_cast(user_data); + client->_requestsByStreamId.erase(stream_id); + HTTP_DBG("Client::Stream closed: {}", stream_id); + return 0; + }); + + nghttp2_session_client_new(&_session, callbacks, this); + nghttp2_session_callbacks_del(callbacks); + + nghttp2_settings_entry iv[1] = { + { NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 1000 } + }; + + if (nghttp2_submit_settings(_session, NGHTTP2_FLAG_NONE, iv, 1) != 0) { + HTTP_DBG("Client::ClientSession: nghttp2_submit_settings failed"); } } - sk_X509_INFO_pop_free(inf, X509_INFO_free); - BIO_free(cbio); - return count; -} + ClientSession(const ClientSession &) = delete; + ClientSession &operator=(const ClientSession &) = delete; + ClientSession(ClientSession &&other) noexcept = delete; + ClientSession &operator=(ClientSession &&other) noexcept = delete; -inline X509_STORE *createCertificateStore(const std::string_view &X509_ca_bundle) { - X509_STORE *cert_store = X509_STORE_new(); - if (detail::readCertificateBundleFromBuffer(*cert_store, X509_ca_bundle) <= 0) { - X509_STORE_free(cert_store); - throw std::invalid_argument(fmt::format("failed to read certificate bundle from buffer:\n#---start---\n{}\n#---end---\n", X509_ca_bundle)); - } - return cert_store; -} - -inline X509 *readServerCertificateFromFile(const std::string_view &X509_ca_bundle) { - BIO *certBio = BIO_new(BIO_s_mem()); - BIO_write(certBio, X509_ca_bundle.data(), static_cast(X509_ca_bundle.size())); - X509 *certX509 = PEM_read_bio_X509(certBio, nullptr, nullptr, nullptr); - BIO_free(certBio); - if (certX509) { - return certX509; + ~ClientSession() { + nghttp2_session_del(_session); } - X509_free(certX509); - throw std::invalid_argument(fmt::format("failed to read certificate from buffer:\n#---start---\n{}\n#---end---\n", X509_ca_bundle)); -} - -inline EVP_PKEY *readServerPrivateKeyFromFile(const std::string_view &X509_private_key) { - BIO *certBio = BIO_new(BIO_s_mem()); - BIO_write(certBio, X509_private_key.data(), static_cast(X509_private_key.size())); - EVP_PKEY *privateKeyX509 = PEM_read_bio_PrivateKey(certBio, nullptr, nullptr, nullptr); - BIO_free(certBio); - if (privateKeyX509) { - return privateKeyX509; + + bool isReady() const { + return _socket._state == TcpSocket::Connected; } - EVP_PKEY_free(privateKeyX509); - throw std::invalid_argument(fmt::format("failed to read private key from buffer")); -} -#endif + std::expected continueToMakeReady() { + auto makeError = [](std::string_view msg) { + return std::unexpected(fmt::format("Could not connect to endpoint: {}", msg)); + }; + assert(!isReady()); + if (_socket._state == detail::TcpSocket::Connecting) { + if (auto rc = _socket.connect(); !rc) { + return makeError(rc.error()); + } + } -} // namespace detail + if (_socket._state == detail::TcpSocket::SSLConnectWantsRead || _socket._state == detail::TcpSocket::SSLConnectWantsWrite) { + if (auto rc = _socket.continueHandshake(); !rc) { + return makeError(rc.error()); + } + } -class RestClient : public ClientBase { - static const httplib::Headers EVT_STREAM_HEADERS; - using ThreadPoolType = std::shared_ptr>; - - std::string _name; - MIME::MimeType _mimeType; - std::atomic _run = true; - const int _minIoThreads; - const int _maxIoThreads; - ThreadPoolType _thread_pool; - std::string _caCertificate; - - std::mutex _subscriptionLock; - std::map, httplib::Client> _subscription1; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::map, httplib::SSLClient> _subscription2; -#endif + return {}; + } -public: - static bool CHECK_CERTIFICATES; - - /** - * Initialises a basic RestClient - * - * usage example: - * RestClient client("clientName", DefaultContentTypeHeader(MIME::HTML), MinIoThreads(2), MaxIoThreads(5), ClientCertificates(testCertificate)) - * - * @tparam Args see argument example above. Order is arbitrary. - * @param initArgs - */ - template - explicit(false) RestClient(Args... initArgs) - : _name(detail::find_argument_value([] { return "RestClient"; }, initArgs...)), // - _mimeType(detail::find_argument_value([] { return MIME::JSON; }, initArgs...)) - , _minIoThreads(detail::find_argument_value([] { return MinIoThreads(); }, initArgs...)) - , _maxIoThreads(detail::find_argument_value([] { return MaxIoThreads(); }, initArgs...)) - , _thread_pool(detail::find_argument_value([this] { return std::make_shared>(_name, _minIoThreads, _maxIoThreads); }, initArgs...)) - , _caCertificate(detail::find_argument_value([] { return rest::DefaultCertificate().get(); }, initArgs...)) { + bool wantsToRead() const { + return _socket._state == TcpSocket::Connected ? nghttp2_session_want_read(_session) : (_socket._state == TcpSocket::Connecting || _socket._state == TcpSocket::SSLConnectWantsRead); } - ~RestClient() override { RestClient::stop(); }; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::vector protocols() noexcept override { return { "http", "https" }; } -#else - std::vector protocols() noexcept override { return { "http" }; } + bool wantsToWrite() const { + return _socket._state == TcpSocket::Connected ? _writeBuffer.wantsToWrite(_session) : (_socket._state == TcpSocket::Connecting || _socket._state == TcpSocket::SSLConnectWantsWrite); + } + + void submitRequest(client::Command &&cmd, SubscriptionMode mode, std::string preferredMimeType, std::optional longPollIdx) { + auto topic = cmd.topic; + if (cmd.command == mdp::Command::Set) { + topic = URI<>::UriFactory(topic).addQueryParameter("_bodyOverride", std::string{ cmd.data.asString() }).build(); + } + + std::string longPollIdxParam; + if (longPollIdx) { + longPollIdxParam = std::to_string(*longPollIdx); + } else { + switch (mode) { + case SubscriptionMode::Next: + longPollIdxParam = "Next"; + break; + case SubscriptionMode::Last: + longPollIdxParam = "Last"; + break; + case SubscriptionMode::None: + break; + } + } + if (!longPollIdxParam.empty()) { + topic = URI<>::UriFactory(topic).addQueryParameter("LongPollingIdx", longPollIdxParam).build(); + } + + const auto host = cmd.topic.hostName().value_or(""); + const auto scheme = cmd.topic.scheme().value_or(""); + const auto path = topic.relativeRefNoFragment().value_or("/"); +#ifdef OPENCMW_PROFILE_HTTP + const auto ts = std::to_string(opencmw::load_test::timestamp().count()); +#endif + constexpr uint8_t noCopy = NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE; + auto headers = std::vector{ + nv(u8span(":method"), u8span("GET"), noCopy), // + nv(u8span(":path"), u8span(path)), // + nv(u8span(":scheme"), u8span(scheme)), // + nv(u8span(":authority"), u8span(host)), +#ifdef OPENCMW_PROFILE_HTTP + nv(u8span("x-timestamp"), u8span(ts)) #endif - void stop() noexcept override { stopAllSubscriptions(); }; - [[nodiscard]] std::string name() const noexcept { return _name; } - [[nodiscard]] ThreadPoolType threadPool() const noexcept { return _thread_pool; } - [[nodiscard]] MIME::MimeType defaultMimeType() const noexcept { return _mimeType; } - [[nodiscard]] std::string clientCertificate() const noexcept { return _caCertificate; } + }; + if (!preferredMimeType.empty()) { + headers.push_back(nv(u8span("accept"), u8span(preferredMimeType))); + headers.push_back(nv(u8span("content-type"), u8span(preferredMimeType))); + } + if (cmd.command == mdp::Command::Set) { + headers.push_back(nv(u8span("x-opencmw-method"), u8span("PUT"), noCopy)); + } - void request(Command cmd) override { - switch (cmd.command) { - case mdp::Command::Get: - case mdp::Command::Set: - _thread_pool->execute([this, cmd = std::move(cmd)]() mutable { executeCommand(std::move(cmd)); }); - return; - case mdp::Command::Subscribe: - _thread_pool->execute([this, cmd = std::move(cmd)]() mutable { startSubscription(std::move(cmd)); }); - return; - case mdp::Command::Unsubscribe: // deregister existing subscription URI is key - _thread_pool->execute([this, cmd = std::move(cmd)]() mutable { stopSubscription(cmd); }); + RequestResponse rr; + rr.request = std::move(cmd); + try { + rr.normalizedTopic = mdp::Topic::fromMdpTopic(rr.request.topic).toZmqTopic(); + } catch (...) { + rr.normalizedTopic = rr.request.topic.str(); + } + + const std::int32_t streamId = nghttp2_submit_request2(_session, nullptr, headers.data(), headers.size(), nullptr, nullptr); + if (streamId < 0) { + rr.reportError(fmt::format("Could not submit request: {}", nghttp2_strerror(streamId))); return; - default: - throw std::invalid_argument("command type is undefined"); } + + _requestsByStreamId.emplace(streamId, std::move(rr)); } -private: - httplib::Headers getPreferredContentTypeHeader(const URI &uri) const { - auto mimeType = std::string(_mimeType.typeName()); - if (const auto acceptHeader = uri.queryParamMap().find(detail::ACCEPT_HEADER); acceptHeader != uri.queryParamMap().end() && acceptHeader->second) { - mimeType = acceptHeader->second->c_str(); + int processResponse(std::int32_t streamId) { + auto it = _requestsByStreamId.find(streamId); + assert(it != _requestsByStreamId.end()); + if (it != _requestsByStreamId.end()) { + const auto &request = it->second.request; + if (it->second.responseStatus == "302") { + std::optional> location; + try { + location = URI<>(it->second.location); + } catch (const std::exception &e) { + HTTP_DBG("Client::Header: Could not parse URI '{}': {}", it->second.location, e.what()); + it->second.reportError(fmt::format("Could not parse redirect URI '{}': {}", it->second.location, e.what())); + _requestsByStreamId.erase(it); + return static_cast(NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE); + } + const auto &queryParams = location->queryParamMap(); + auto longPollIdxIt = queryParams.find("LongPollingIdx"); + if (longPollIdxIt == queryParams.end() || !longPollIdxIt->second.has_value()) { + HTTP_DBG("Client::Header: Could not find LongPollingIdx in URI '{}'", it->second.location); + it->second.reportError(fmt::format("Could not find LongPollingIdx in URI '{}'", it->second.location)); + _requestsByStreamId.erase(it); + return static_cast(NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE); + } + std::uint64_t longPollingIdx; + if (auto ec = std::from_chars(longPollIdxIt->second->data(), longPollIdxIt->second->data() + longPollIdxIt->second->size(), longPollingIdx); ec.ec != std::errc{}) { + HTTP_DBG("Client::Header: Could not parse numerical LongPollingIdx from '{}'", longPollIdxIt->second); + it->second.reportError(fmt::format("Could not parse numerical LongPollingIdx from '{}'", longPollIdxIt->second)); + _requestsByStreamId.erase(it); + return static_cast(NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE); + } + // Redirect to long-polling URL + sendLongPollRequests(it->second.normalizedTopic, longPollingIdx, longPollingIdx + kParallelLongPollingRequests - 1); + } else if (it->second.longPollingIdx && it->second.responseStatus == "504") { + // Server timeout on long-poll, resend same request + sendLongPollRequests(it->second.normalizedTopic, it->second.longPollingIdx.value(), it->second.longPollingIdx.value()); + } else { + it->second.fillResponse(); + auto response = std::move(it->second.response); + const auto hasError = !it->second.responseStatus.starts_with("2") && !it->second.responseStatus.starts_with("3"); + if (hasError) { + response.error = std::move(it->second.payload); + } else { + response.data = IoBuffer(it->second.payload.data(), it->second.payload.size()); + } + if (it->second.longPollingIdx) { + // Subscription + handleSubscriptionResponse(it->second.normalizedTopic, it->second.longPollingIdx.value(), std::move(response)); + } else { + // GET/SET + if (request.callback) { + request.callback(std::move(response)); + } + } + } + _requestsByStreamId.erase(it); + } else { + HTTP_DBG("Client::Frame: Could not find request for stream id {}", streamId); + } + return 0; + } + + void reportErrorToAllPendingRequests(std::string_view error) { + for (auto &[streamId, request] : _requestsByStreamId) { + request.reportError(std::string{ error }); } - const httplib::Headers headers = { { detail::ACCEPT_HEADER, mimeType }, { detail::CONTENT_TYPE_HEADER, mimeType } }; - return headers; + _requestsByStreamId.clear(); } - static void returnMdpMessage(Command &cmd, const httplib::Result &result, const std::string &errorMsgExt = "") noexcept { - if (!cmd.callback) { + void handleSubscriptionResponse(std::string zmqTopic, std::uint64_t longPollingIdx, mdp::Message &&response) { + auto subIt = _subscriptions.find(zmqTopic); + if (subIt == _subscriptions.end()) { + HTTP_DBG("Client::handleSubscriptionResponse: Could not find subscription for topic '{}'", zmqTopic); return; } + auto &sub = subIt->second; + sub.lastReceivedLongPollingIdx = longPollingIdx; + auto request = sub.request; - const auto errorMsg = [&]() -> std::optional { - // Result contains a nullptr - if (!result) { - return errorMsgExt.empty() ? "Unknown error, empty result" : errorMsgExt; - } + submitRequest(std::move(request), sub.mode, {}, longPollingIdx + kParallelLongPollingRequests); - // No error - if (result && result->status >= 200 && result->status < 400 && errorMsgExt.empty()) { - return {}; + for (std::size_t i = 0; i < sub.callbacks.size(); ++i) { + if (i < sub.callbacks.size() - 1) { + auto copy = response; + sub.callbacks[i](std::move(copy)); + } else { + sub.callbacks[i](std::move(response)); } + } + } - const auto httpError = httplib::status_message(result->status); - return fmt::format("{} - {}:{}", result->status, httpError, errorMsgExt.empty() ? result->body : errorMsgExt); - }(); + void sendLongPollRequests(std::string zmqTopic, std::uint64_t fromLongPollingIdx, std::uint64_t toLongPollingIdx) { + auto subIt = _subscriptions.find(zmqTopic); + if (subIt == _subscriptions.end()) { + HTTP_DBG("Client::sendLongPollingRequests: Could not find subscription for topic '{}'", zmqTopic); + return; + } + auto &sub = subIt->second; + for (std::uint64_t longPollingIdx = fromLongPollingIdx; longPollingIdx <= toLongPollingIdx; ++longPollingIdx) { + auto request = sub.request; + submitRequest(std::move(request), sub.mode, {}, longPollingIdx); + } + } + void startSubscription(client::Command &&command, SubscriptionMode mode = SubscriptionMode::Next) { + mdp::Topic topic; try { - cmd.callback(mdp::Message{ - .id = 0, - .arrivalTime = std::chrono::system_clock::now(), - .protocolName = cmd.topic.scheme().value(), - .command = mdp::Command::Final, - .clientRequestID = cmd.clientRequestID, - .topic = cmd.topic, - .data = errorMsg ? IoBuffer() : IoBuffer(result->body.data(), result->body.size()), - .error = errorMsg.value_or(""), - .rbac = IoBuffer() }); + topic = mdp::Topic::fromMdpTopic(command.topic); } catch (const std::exception &e) { - std::cerr << fmt::format("caught exception '{}' in RestClient::returnMdpMessage(cmd={}, {}: {})", e.what(), cmd.topic, result->status, result.value().body) << std::endl; - } catch (...) { - std::cerr << fmt::format("caught unknown exception in RestClient::returnMdpMessage(cmd={}, {}: {})", cmd.topic, result->status, result.value().body) << std::endl; + HTTP_DBG("Client::startSubscription: Could not parse topic '{}': {}", command.topic.str(), e.what()); + return; + } + const auto [subIt, inserted] = _subscriptions.try_emplace(topic.toZmqTopic(), Subscription{}); + subIt->second.request = command; + subIt->second.callbacks.emplace_back(command.callback); + if (inserted) { + submitRequest(std::move(command), mode, {}, {}); } } - void executeCommand(Command &&cmd) const { - using namespace std::string_literals; - std::cout << "RestClient::request(" << (cmd.topic.str()) << ")" << std::endl; - auto preferredHeader = getPreferredContentTypeHeader(cmd.topic); + void stopSubscription(client::Command &&command) { + // TODO a single unsubscribe cancels this also in case of multiple subscriptions when the client is shared + // inside an application. Would be great if we could selectively unsubscribe certain callbacks and finally + // stop the subscription when all callbacks are removed. + mdp::Topic topic; + try { + topic = mdp::Topic::fromMdpTopic(command.topic); + } catch (const std::exception &e) { + HTTP_DBG("Client::stopSubscription: Could not parse topic '{}': {}", command.topic.str(), e.what()); + return; + }; + if (auto subIt = _subscriptions.find(topic.toZmqTopic()); subIt != _subscriptions.end()) { + // Cancel all requests for this topic + auto reqIt = _requestsByStreamId.begin(); + while (reqIt != _requestsByStreamId.end()) { + if (reqIt->second.request.topic == command.topic) { + nghttp2_submit_rst_stream(_session, NGHTTP2_FLAG_NONE, reqIt->first, NGHTTP2_CANCEL); + reqIt = _requestsByStreamId.erase(reqIt); + } else { + ++reqIt; + } + } + } + } +}; + +} // namespace detail - auto endpointBuilder = URI<>::factory(cmd.topic); +struct ClientCertificates { + std::string _certificates; - if (cmd.command == mdp::Command::Set) { - preferredHeader.insert(std::make_pair("X-OPENCMW-METHOD"s, "PUT"s)); - endpointBuilder = std::move(endpointBuilder).addQueryParameter("_bodyOverride", std::string(cmd.data.asString())); + ClientCertificates() = default; + ClientCertificates(std::string X509_ca_bundle) noexcept + : _certificates(std::move(X509_ca_bundle)) {} + constexpr operator std::string() const noexcept { return _certificates; }; +}; + +struct RestClient : public ClientBase { + struct SslSettings { + std::string caCertificate; + bool verifyPeers = true; + }; + + std::jthread _worker; + MIME::MimeType _mimeType = opencmw::MIME::JSON; + SslSettings _sslSettings; + std::shared_ptr>> _requestQueue = std::make_shared>>(); + + static std::expected + ensureSession(detail::SSL_CTX_Ptr &ssl_ctx, std::map> &sessions, const SslSettings &sslSettings, URI<> topic) { + if (topic.scheme() != "http" && topic.scheme() != "https") { + return std::unexpected(fmt::format("Unsupported protocol '{}' for endpoint '{}'", topic.scheme().value_or(""), topic.str())); + } + if (topic.hostName().value_or("").empty()) { + return std::unexpected(fmt::format("No host provided for endpoint '{}'", topic.str())); } + const auto port = topic.port().value_or(topic.scheme() == "https" ? 443 : 80); + const auto endpoint = detail::Endpoint{ topic.scheme().value(), topic.hostName().value(), port }; - auto endpoint = endpointBuilder.build(); + auto sessionIt = sessions.find(endpoint); + if (sessionIt != sessions.end()) { + return sessionIt->second.get(); + } - auto callback = [&cmd, &preferredHeader, &endpoint](ClientType &client) { - client.set_follow_location(true); - client.set_read_timeout(cmd.timeout); // default keep-alive value - if (const httplib::Result &result = client.Get(endpoint.relativeRef()->data(), preferredHeader)) { - returnMdpMessage(cmd, result); - } else { - std::stringstream errorStr(fmt::format("\"{}\"", static_cast(result.error()))); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (auto sslResult = client.get_openssl_verify_result(); sslResult) { - errorStr << fmt::format(" - SSL error: '{}'", X509_verify_cert_error_string(sslResult)); + int socketFlags = detail::TcpSocket::None; + if (topic.scheme() == "https") { + if (!ssl_ctx) { + ssl_ctx = detail::SSL_CTX_Ptr(SSL_CTX_new(SSLv23_client_method()), SSL_CTX_free); + if (!ssl_ctx) { + return std::unexpected(fmt::format("Could not create SSL/TLS context: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + SSL_CTX_set_options(ssl_ctx.get(), + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_alpn_protos(ssl_ctx.get(), reinterpret_cast("\x02h2"), 3); + + if (sslSettings.verifyPeers) { + SSL_CTX_set_verify(ssl_ctx.get(), SSL_VERIFY_PEER, nullptr); + + if (!sslSettings.caCertificate.empty()) { + auto maybeStore = detail::createCertificateStore(sslSettings.caCertificate); + if (!maybeStore) { + return std::unexpected(fmt::format("Could not create certificate store: {}", maybeStore.error())); + } + SSL_CTX_set_cert_store(ssl_ctx.get(), maybeStore->release()); + } } -#endif - const std::string errorMsg = fmt::format("GET request failed for: '{}' - {} - CHECK_CERTIFICATES: {}", cmd.topic.str(), errorStr.str(), CHECK_CERTIFICATES); - returnMdpMessage(cmd, result, errorMsg); } - }; - if (cmd.topic.scheme() && equal_with_case_ignore(cmd.topic.scheme().value(), "https")) { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - httplib::SSLClient client(cmd.topic.hostName().value(), cmd.topic.port() ? cmd.topic.port().value() : 443); - // client owns its certificate store and destroys it after use. create a store for each client - client.set_ca_cert_store(detail::createCertificateStore(_caCertificate)); - client.enable_server_certificate_verification(CHECK_CERTIFICATES); - callback(client); -#else - throw std::invalid_argument("https is not supported"); -#endif - } else if (cmd.topic.scheme() && equal_with_case_ignore(cmd.topic.scheme().value(), "http")) { - httplib::Client client(cmd.topic.hostName().value(), cmd.topic.port() ? cmd.topic.port().value() : 80); - callback(client); - return; - } else { - if (cmd.topic.scheme()) { - throw std::invalid_argument(fmt::format("unsupported protocol '{}' for endpoint '{}'", cmd.topic.scheme(), cmd.topic.str())); - } else { - throw std::invalid_argument(fmt::format("no protocol provided for endpoint '{}'", cmd.topic.str())); + auto ssl = detail::create_ssl(ssl_ctx.get()); + if (!ssl) { + return std::unexpected(fmt::format("Failed to create SSL object: {}", ssl.error())); + } + if (sslSettings.verifyPeers) { + socketFlags |= detail::TcpSocket::VerifyPeer; } + auto maybeSocket = detail::TcpSocket::create(std::move(ssl.value()), socket(AF_INET, SOCK_STREAM, 0), socketFlags); + if (!maybeSocket) { + return std::unexpected(fmt::format("Failed to create socket: {}", maybeSocket.error())); + } + auto session = std::make_unique(std::move(maybeSocket.value())); + if (auto rc = session->_socket.prepareConnect(endpoint.host, endpoint.port); !rc) { + return std::unexpected(rc.error()); + } + sessionIt = sessions.emplace(endpoint, std::move(session)).first; + return sessionIt->second.get(); } - } - bool equal_with_case_ignore(const std::string &a, const std::string &b) const { - return std::ranges::equal(a, b, [](const char ca, const char cb) noexcept { return ::tolower(ca) == ::tolower(cb); }); + // HTTP + auto maybeSocket = detail::TcpSocket::create({ nullptr, SSL_free }, socket(AF_INET, SOCK_STREAM, 0), socketFlags); + if (!maybeSocket) { + return std::unexpected(fmt::format("Failed to create socket: {}", maybeSocket.error())); + } + auto session = std::make_unique(std::move(maybeSocket.value())); + if (auto rc = session->_socket.prepareConnect(endpoint.host, endpoint.port); !rc) { + return std::unexpected(rc.error()); + } + sessionIt = sessions.emplace(endpoint, std::move(session)).first; + return sessionIt->second.get(); } - void startSubscription(Command &&cmd) { - std::scoped_lock lock(_subscriptionLock); - if (equal_with_case_ignore(*cmd.topic.scheme(), "http") -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - || equal_with_case_ignore(*cmd.topic.scheme(), "https") -#endif - ) { - auto createNewSubscription = [&](auto &client) { - { - client.set_follow_location(true); - - std::size_t longPollingIdx = 0; - const auto pollHeaders = getPreferredContentTypeHeader(cmd.topic); - client.set_read_timeout(cmd.timeout); // default keep-alive value - while (_run) { - auto endpoint = [&]() { - if (longPollingIdx == 0UZ) { - return URI::factory(cmd.topic).addQueryParameter(LONG_POLLING_IDX_TAG, "Next").build().relativeRef().value(); - } else { - return URI::factory(cmd.topic).addQueryParameter(LONG_POLLING_IDX_TAG, fmt::format("{}", longPollingIdx)).build().relativeRef().value(); - } - }(); - if (const httplib::Result &result = client.Get(endpoint, pollHeaders)) { - returnMdpMessage(cmd, result); - // update long-polling-index - std::string location = result->location.empty() ? endpoint : result->location; - std::string updateIdxString = URI(location).queryParamMap().at(std::string(LONG_POLLING_IDX_TAG)).value_or("0"); - char *end = updateIdxString.data() + updateIdxString.size(); - longPollingIdx = strtoull(updateIdxString.data(), &end, 10) + 1; - } else { // failed or server is down -> wait until retry - if (_run) { - returnMdpMessage(cmd, result, fmt::format("Long-Polling-GET request failed for {}: {}", cmd.topic.str(), static_cast(result.error()))); - } - std::this_thread::sleep_for(cmd.timeout); // time-out until potential retry +public: + template + explicit(false) RestClient(Args... initArgs) + : _mimeType(detail::find_argument_value([] { return MIME::JSON; }, initArgs...)) + , _sslSettings{ + .caCertificate = detail::find_argument_value([] { return ClientCertificates{}; }, initArgs...), + .verifyPeers = detail::find_argument_value([] { return true; }, initArgs...) + } { + _worker = std::jthread([queue = _requestQueue, sslSettings = _sslSettings, mimeType = _mimeType](std::stop_token stopToken) { + auto preferredMimeType = [&mimeType](const URI<> &topic) { + if (const auto contentTypeHeader = topic.queryParamMap().find("contentType"); contentTypeHeader != topic.queryParamMap().end() && contentTypeHeader->second) { + return contentTypeHeader->second.value(); + } + return std::string{ mimeType.typeName() }; + }; + + detail::SSL_CTX_Ptr ssl_ctx{ nullptr, SSL_CTX_free }; + + std::map> sessions; + + auto reportError = [](Command &cmd, std::string error) { + if (!cmd.callback) { + return; + } + mdp::Message msg; + msg.protocolName = cmd.topic.scheme().value_or(""); + msg.arrivalTime = std::chrono::system_clock::now(); + msg.command = mdp::Command::Final; + msg.clientRequestID = cmd.clientRequestID; + msg.topic = cmd.topic; + msg.error = std::move(error); + cmd.callback(msg); + }; + + std::vector pollFds; + + while (!stopToken.stop_requested()) { + while (auto entry = queue->try_get()) { + auto &[cmd, mode] = entry.value(); + switch (cmd.command) { + case mdp::Command::Get: + case mdp::Command::Set: { + auto session = ensureSession(ssl_ctx, sessions, sslSettings, cmd.topic); + auto preferred = preferredMimeType(cmd.topic); + session.value()->submitRequest(std::move(cmd), mode, std::move(preferred), {}); + } break; + case mdp::Command::Subscribe: { + auto session = ensureSession(ssl_ctx, sessions, sslSettings, cmd.topic); + if (!session) { + reportError(cmd, fmt::format("Unsupported endpoint '{}': {}", cmd.topic.str(), session.error())); + continue; } + session.value()->startSubscription(std::move(cmd), mode); + } break; + case mdp::Command::Unsubscribe: { + auto session = ensureSession(ssl_ctx, sessions, sslSettings, cmd.topic); + if (!session) { + reportError(cmd, fmt::format("Unsupported endpoint '{}': {}", cmd.topic.str(), session.error())); + continue; + } + session.value()->stopSubscription(std::move(cmd)); + } break; + case mdp::Command::Final: + case mdp::Command::Partial: + case mdp::Command::Heartbeat: + case mdp::Command::Notify: + case mdp::Command::Invalid: + case mdp::Command::Ready: + case mdp::Command::Disconnect: + assert(false); // unexpected command + break; } } - }; - if (equal_with_case_ignore(*cmd.topic.scheme(), "http")) { - auto it = _subscription1.find(cmd.topic); - if (it == _subscription1.end()) { - _subscription1.emplace(cmd.topic, httplib::Client(cmd.topic.hostName().value(), cmd.topic.port().value())); - createNewSubscription(_subscription1.at(cmd.topic)); + + pollFds.clear(); + pollFds.reserve(sessions.size()); + for (auto &sessionPair : sessions) { + auto &session = sessionPair.second; + struct pollfd pfd = {}; + pfd.fd = session->_socket.fd; + pfd.events = 0; + if (session->wantsToRead()) { + pfd.events |= POLLIN; + } + if (session->wantsToWrite()) { + pfd.events |= POLLOUT; + } + pollFds.push_back(pfd); } - } else { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (auto it = _subscription2.find(cmd.topic); it == _subscription2.end()) { - _subscription2.emplace( - std::piecewise_construct, - std::forward_as_tuple(cmd.topic), - std::forward_as_tuple(cmd.topic.hostName().value(), cmd.topic.port().value())); - auto &client = _subscription2.at(cmd.topic); - client.set_ca_cert_store(detail::createCertificateStore(_caCertificate)); - client.enable_server_certificate_verification(CHECK_CERTIFICATES); - createNewSubscription(_subscription2.at(cmd.topic)); + + if (pollFds.empty()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + continue; + } + + const auto n = ::poll(pollFds.data(), pollFds.size(), 100); + + if (n < 0) { + HTTP_DBG("poll failed: {}", strerror(errno)); + continue; + } + if (n == 0) { + continue; + } + + auto sessionIt = sessions.begin(); + while (sessionIt != sessions.end()) { + auto &session = sessionIt->second; + + auto pollIt = std::ranges::find_if(pollFds, [&](const struct pollfd &pfd) { + return pfd.fd == session->_socket.fd; + }); + + if (pollIt == pollFds.end()) { + ++sessionIt; + continue; + } + + if (pollIt->revents & POLLERR) { + int error = 0; + socklen_t errlen = sizeof(error); + getsockopt(pollIt->fd, SOL_SOCKET, SO_ERROR, &error, &errlen); + session->reportErrorToAllPendingRequests(strerror(error)); + sessionIt = sessions.erase(sessionIt); + continue; + } + + if (((pollIt->revents & POLLIN) || (pollIt->revents & POLLOUT)) && !session->isReady()) { + if (auto r = session->continueToMakeReady(); !r) { + sessionIt->second->reportErrorToAllPendingRequests(r.error()); + sessionIt = sessions.erase(sessionIt); + continue; + } + } + + if (!session->isReady()) { + ++sessionIt; + continue; + } + + if (pollIt->revents & POLLOUT) { + if (!session->_writeBuffer.write(session->_session, session->_socket)) { + HTTP_DBG("Client: Failed to write to peer (fd={}): {}", session->_socket.fd, session->_socket.lastError()); + sessionIt = sessions.erase(sessionIt); + continue; + } + } + + if (pollIt->revents & POLLIN) { + bool mightHaveMore = true; + bool hasError = false; + + while (mightHaveMore && !hasError) { + std::array buffer; + const auto bytes_read = session->_socket.read(buffer.data(), buffer.size()); + if (bytes_read <= 0 && errno != EAGAIN) { + if (bytes_read < 0) { + HTTP_DBG("Client::read failed: {}", session->_socket.lastError()); + } + hasError = true; + continue; + } + + if (bytes_read > 0 && nghttp2_session_mem_recv2(session->_session, buffer.data(), static_cast(bytes_read)) < 0) { + HTTP_DBG("Client: nghttp2_session_mem_recv2 failed"); + hasError = true; + continue; + } + mightHaveMore = bytes_read == static_cast(buffer.size()); + } + if (hasError) { + sessionIt = sessions.erase(sessionIt); + continue; + } + } + + ++sessionIt; } -#else - throw std::invalid_argument("https is not supported"); -#endif } + }); + } - } else { - throw std::invalid_argument(fmt::format("unsupported scheme '{}' for requested subscription '{}'", cmd.topic.scheme(), cmd.topic.str())); + ~RestClient() { + stop(); + } + + [[nodiscard]] + MIME::MimeType defaultMimeType() const { + return _mimeType; + } + + [[nodiscard]] + bool verifySslPeers() const { + return _sslSettings.verifyPeers; + } + + [[nodiscard]] + std::vector protocols() override { + return { "http", "https" }; + } + + void stop() override { + _worker.request_stop(); + if (_worker.joinable()) { + _worker.join(); } } - void stopSubscription(const Command &cmd) { - // stop subscription that matches URI - std::scoped_lock lock(_subscriptionLock); - if (equal_with_case_ignore(*cmd.topic.scheme(), "http")) { - auto it = _subscription1.find(cmd.topic); - if (it != _subscription1.end()) { - it->second.stop(); - _subscription1.erase(it); - return; - } - } else if (equal_with_case_ignore(*cmd.topic.scheme(), "https")) { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - auto it = _subscription2.find(cmd.topic); - if (it != _subscription2.end()) { - it->second.stop(); - _subscription2.erase(it); - return; - } -#else - throw std::runtime_error("https is not supported - enable CPPHTTPLIB_OPENSSL_SUPPORT"); -#endif - } else { - throw std::invalid_argument(fmt::format("unsupported scheme '{}' for requested subscription '{}'", cmd.topic.scheme(), cmd.topic.str())); + void request(Command cmd) override { + switch (cmd.command) { + case mdp::Command::Get: + case mdp::Command::Set: + assert(cmd.callback); + _requestQueue->push({ std::move(cmd), SubscriptionMode::None }); + break; + case mdp::Command::Subscribe: + assert(cmd.callback); + _requestQueue->push({ std::move(cmd), SubscriptionMode::Next }); + break; + case mdp::Command::Unsubscribe: + _requestQueue->push({ std::move(cmd), SubscriptionMode::None }); + break; + default: + assert(false); // unexpected command } } - void stopAllSubscriptions() noexcept { - _run = false; - std::scoped_lock lock(_subscriptionLock); - std::ranges::for_each(_subscription1, [](auto &pair) { pair.second.stop(); }); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::ranges::for_each(_subscription2, [](auto &pair) { pair.second.stop(); }); -#endif + mdp::Message blockingRequest(Command cmd) { + std::promise promise; + auto future = promise.get_future(); + cmd.callback = [&promise](const mdp::Message &msg) { + promise.set_value(std::move(msg)); + }; + request(std::move(cmd)); + return future.get(); } }; -inline bool RestClient::CHECK_CERTIFICATES = true; -inline const httplib::Headers RestClient::EVT_STREAM_HEADERS = { { detail::ACCEPT_HEADER, MIME::EVENT_STREAM.typeName().data() } }; } // namespace opencmw::client -#endif // OPENCMW_CPP_RESTCLIENT_NATIVE_HPP +#endif // OPENCMW_CLIENT_RESTCLIENTNATIVE_HPP diff --git a/src/client/test/CMakeLists.txt b/src/client/test/CMakeLists.txt index c60aa813..ea4c7727 100644 --- a/src/client/test/CMakeLists.txt +++ b/src/client/test/CMakeLists.txt @@ -44,15 +44,15 @@ if(NOT EMSCRIPTEN) # TEST_PREFIX "unittests." REPORTER xml OUTPUT_DIR . OUTPUT_PREFIX "unittests." OUTPUT_SUFFIX .xml) catch_discover_tests(client_tests) - add_executable(rest_client_mock_server_tests catch_main.cpp RestClient_tests.cpp) + add_executable(nghttp2_tests catch_main.cpp nghttp2_tests.cpp) target_link_libraries( - rest_client_mock_server_tests + nghttp2_tests PUBLIC opencmw_project_warnings - opencmw_project_options - test_assets_rest - Catch2::Catch2 - client) - catch_discover_tests(rest_client_mock_server_tests) + opencmw_project_options + test_assets_rest + Catch2::Catch2 + client) + catch_discover_tests(nghttp2_tests) add_executable(clientPublisher_tests catch_main.cpp ClientPublisher_tests.cpp) target_link_libraries( @@ -70,7 +70,8 @@ endif() add_executable(rest_client_only_tests RestClientOnly_tests.cpp) target_link_libraries(rest_client_only_tests PUBLIC opencmw_project_warnings opencmw_project_options client) target_include_directories(rest_client_only_tests PRIVATE ${CMAKE_SOURCE_DIR}) -# This test requires a different kind of invocation as it needs the server running + +# These tests require a different kind of invocation as they need the server running # catch_discover_tests(rest_client_only_tests) if(EMSCRIPTEN) diff --git a/src/client/test/RestClientOnly_tests.cpp b/src/client/test/RestClientOnly_tests.cpp index 95326e44..81aa1ab3 100644 --- a/src/client/test/RestClientOnly_tests.cpp +++ b/src/client/test/RestClientOnly_tests.cpp @@ -1,10 +1,4 @@ -#include -#include -#include - -#include #include -#include #include "concepts/client/helpers.hpp" @@ -13,9 +7,13 @@ using namespace std::chrono_literals; // These are not main-local, as JS doesn't end when // C++ main ends namespace test { +#ifndef __EMSCRIPTEN__ +opencmw::client::RestClient client(opencmw::client::VerifyServerCertificates(false)); +#else opencmw::client::RestClient client; +#endif -std::string schema() { +std::string schema() { if (auto env = ::getenv("DISABLE_REST_HTTPS"); env != nullptr && std::string_view(env) == "1") { return "http"; } else { @@ -66,10 +64,6 @@ auto run = rest_test_runner( } // namespace test int main() { -#ifndef __EMSCRIPTEN__ - opencmw::client::RestClient::CHECK_CERTIFICATES = false; -#endif - using namespace test; #ifndef __EMSCRIPTEN__ diff --git a/src/client/test/RestClient_tests.cpp b/src/client/test/RestClient_tests.cpp deleted file mode 100644 index 74f1559e..00000000 --- a/src/client/test/RestClient_tests.cpp +++ /dev/null @@ -1,468 +0,0 @@ -#include - -#include - -#include - -#include "RestClient.hpp" - -#include -CMRC_DECLARE(assets); - -namespace opencmw::rest_client_test { - -constexpr const char *testCertificate = R"( -R"( -GlobalSign Root CA -================== ------BEGIN CERTIFICATE----- -MIIDdTCCAl2gAwIBAgILBAAAAAABFUtaw5QwDQYJKoZIhvcNAQEFBQAwVzELMAkGA1UEBhMCQkUx -GTAXBgNVBAoTEEdsb2JhbFNpZ24gbnYtc2ExEDAOBgNVBAsTB1Jvb3QgQ0ExGzAZBgNVBAMTEkds -b2JhbFNpZ24gUm9vdCBDQTAeFw05ODA5MDExMjAwMDBaFw0yODAxMjgxMjAwMDBaMFcxCzAJBgNV -BAYTAkJFMRkwFwYDVQQKExBHbG9iYWxTaWduIG52LXNhMRAwDgYDVQQLEwdSb290IENBMRswGQYD -VQQDExJHbG9iYWxTaWduIFJvb3QgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDa -DuaZjc6j40+Kfvvxi4Mla+pIH/EqsLmVEQS98GPR4mdmzxzdzxtIK+6NiY6arymAZavpxy0Sy6sc -THAHoT0KMM0VjU/43dSMUBUc71DuxC73/OlS8pF94G3VNTCOXkNz8kHp1Wrjsok6Vjk4bwY8iGlb -Kk3Fp1S4bInMm/k8yuX9ifUSPJJ4ltbcdG6TRGHRjcdGsnUOhugZitVtbNV4FpWi6cgKOOvyJBNP -c1STE4U6G7weNLWLBYy5d4ux2x8gkasJU26Qzns3dLlwR5EiUWMWea6xrkEmCMgZK9FGqkjWZCrX -gzT/LCrBbBlDSgeF59N89iFo7+ryUp9/k5DPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV -HRMBAf8EBTADAQH/MB0GA1UdDgQWBBRge2YaRQ2XyolQL30EzTSo//z9SzANBgkqhkiG9w0BAQUF -AAOCAQEA1nPnfE920I2/7LqivjTFKDK1fPxsnCwrvQmeU79rXqoRSLblCKOzyj1hTdNGCbM+w6Dj -Y1Ub8rrvrTnhQ7k4o+YviiY776BQVvnGCv04zcQLcFGUl5gE38NflNUVyRRBnMRddWQVDf9VMOyG -j/8N7yy5Y0b2qvzfvGn9LhJIZJrglfCm7ymPAbEVtQwdpf5pLGkkeB6zpxxxYu7KyJesF12KwvhH -hm4qxFYxldBniYUr+WymXUadDKqC5JlR3XC321Y9YeRq4VzW9v493kHMB65jUr9TU/Qr6cf9tveC -X4XSQRjbgbMEHMUfpIBvFSDJ3gyICh3WZlXi/EjJKSZp4A== ------END CERTIFICATE----- -)"; - -class TestServerCertificates { - const cmrc::embedded_filesystem fileSystem = cmrc::assets::get_filesystem(); - const cmrc::file ca_certificate = fileSystem.open("/assets/ca-cert.pem"); - // server-req.pem -> is usually used to request for the CA signature - const cmrc::file server_cert = fileSystem.open("/assets/server-cert.pem"); - const cmrc::file server_key = fileSystem.open("/assets/server-key.pem"); - const cmrc::file client_cert = fileSystem.open("/assets/client-cert.pem"); - const cmrc::file client_key = fileSystem.open("/assets/client-key.pem"); - const cmrc::file pwd = fileSystem.open("/assets/password.txt"); - -public: - const std::string caCertificate = { ca_certificate.begin(), ca_certificate.end() }; - const std::string serverCertificate = { server_cert.begin(), server_cert.end() }; - const std::string serverKey = { server_key.begin(), server_key.end() }; - const std::string clientCertificate = { client_cert.begin(), client_cert.end() }; - const std::string clientKey = { client_key.begin(), client_key.end() }; - const std::string password = { pwd.begin(), pwd.end() }; -}; -inline static const TestServerCertificates testServerCertificates; - -TEST_CASE("Basic Rest Client Constructor and API Tests", "[Client]") { - using namespace opencmw::client; - RestClient client1; - REQUIRE(client1.name() == "RestClient"); - - RestClient client2(std::make_shared>("RestClient", 1, 10'000)); - REQUIRE(client2.name() == "RestClient"); - - RestClient client3("clientName", std::make_shared>("CustomPoolName", 1, 10'000)); - REQUIRE(client3.name() == "clientName"); - REQUIRE(client3.threadPool()->poolName() == "CustomPoolName"); - - RestClient client4("clientName"); - REQUIRE(client4.threadPool()->poolName() == "clientName"); - - RestClient client5("clientName", DefaultContentTypeHeader(MIME::HTML), MinIoThreads(2), MaxIoThreads(5), ClientCertificates(testCertificate)); - REQUIRE(client5.defaultMimeType() == MIME::HTML); - REQUIRE(client5.threadPool()->poolName() == "clientName"); -} - -TEST_CASE("Basic Rest Client Get/Set Test - HTTP", "[Client]") { - using namespace opencmw::client; - RestClient client; - REQUIRE(client.name() == "RestClient"); - - httplib::Server server; - - std::string acceptHeader; - server.Get("/endPoint", [&acceptHeader](const httplib::Request &req, httplib::Response &res) { - fmt::print("server received request on path '{}' body = '{}'\n", req.path, req.body); - if (req.headers.contains("accept")) { - acceptHeader = req.headers.find("accept")->second; - } else { - FAIL("no accept headers found"); - } - - res.set_content("Hello World!", acceptHeader); - }); - client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); }); - while (!server.is_running()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - REQUIRE(server.is_running()); - - std::atomic done(false); - IoBuffer data; - data.put('A'); - data.put('B'); - data.put('C'); - Command command; - command.command = mdp::Command::Get; - command.topic = URI("http://localhost:8080/endPoint"); - command.data = std::move(data); - command.callback = [&done](const mdp::Message & /*rep*/) { - done.store(true, std::memory_order_release); - done.notify_all(); - }; - client.request(command); - - done.wait(false); - REQUIRE(done.load(std::memory_order_acquire) == true); - REQUIRE(acceptHeader == MIME::JSON.typeName()); - server.stop(); -} - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -TEST_CASE("Multiple Rest Client Get/Set Test - HTTPS", "[Client]") { - using namespace opencmw::client; - RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate)); - REQUIRE(RestClient::CHECK_CERTIFICATES); - RestClient::CHECK_CERTIFICATES = true; // 'false' disables certificate check - REQUIRE(client.name() == "TestSSLClient"); - REQUIRE(client.defaultMimeType() == MIME::JSON); - - // HTTP - X509 *cert = opencmw::client::detail::readServerCertificateFromFile(testServerCertificates.serverCertificate); - EVP_PKEY *pkey = opencmw::client::detail::readServerPrivateKeyFromFile(testServerCertificates.serverKey); - if (const X509_STORE *ca_store = opencmw::client::detail::createCertificateStore(testServerCertificates.caCertificate); !cert || !pkey || !ca_store) { - FAIL(fmt::format("Failed to load certificate: {}", ERR_error_string(ERR_get_error(), nullptr))); - } - httplib::SSLServer server(cert, pkey); - - std::string acceptHeader; - server.Get("/endPoint", [&acceptHeader](const httplib::Request &req, httplib::Response &res) { - if (req.headers.contains("accept")) { - acceptHeader = req.headers.find("accept")->second; - } else { - FAIL("no accept headers found"); - } - res.set_content("Hello World!", acceptHeader); - }); - client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); }); - while (!server.is_running()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - REQUIRE(server.is_running()); - - std::array, 4> dones; - dones[0] = false; - dones[1] = false; - dones[2] = false; - dones[3] = false; - std::atomic counter{ 0 }; - auto makeCommand = [&]() { - IoBuffer data; - data.put('A'); - data.put('B'); - data.put('C'); - data.put(0); - - Command command; - command.command = mdp::Command::Get; - command.topic = URI("https://localhost:8080/endPoint"); - command.data = std::move(data); - command.callback = [&dones, &counter](const mdp::Message &/*rep*/) { - std::size_t currentCounter = counter.fetch_add(1, std::memory_order_relaxed); - dones[currentCounter].store(true, std::memory_order_release); - // Assuming you have access to 'done' variable, uncomment the following line - dones[currentCounter].notify_all(); - }; - client.request(command); - }; - for (int i = 0; i < 4; i++) - makeCommand(); - - for (auto &done : dones) { - done.wait(false); - } - REQUIRE(std::ranges::all_of(dones, [](auto &done) { return done.load(std::memory_order_acquire); })); - REQUIRE(acceptHeader == MIME::JSON.typeName()); - server.stop(); -} - -TEST_CASE("Basic Rest Client Get/Set Test - HTTPS", "[Client]") { - using namespace opencmw::client; - RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate)); - REQUIRE(RestClient::CHECK_CERTIFICATES); - RestClient::CHECK_CERTIFICATES = true; // 'false' disables certificate check - REQUIRE(client.name() == "TestSSLClient"); - REQUIRE(client.defaultMimeType() == MIME::JSON); - - // HTTP - X509 *cert = opencmw::client::detail::readServerCertificateFromFile(testServerCertificates.serverCertificate); - EVP_PKEY *pkey = opencmw::client::detail::readServerPrivateKeyFromFile(testServerCertificates.serverKey); - if (const X509_STORE *ca_store = opencmw::client::detail::createCertificateStore(testServerCertificates.caCertificate); !cert || !pkey || !ca_store) { - FAIL(fmt::format("Failed to load certificate: {}", ERR_error_string(ERR_get_error(), nullptr))); - } - httplib::SSLServer server(cert, pkey); - - std::string acceptHeader; - server.Get("/endPoint", [&acceptHeader](const httplib::Request &req, httplib::Response &res) { - if (req.headers.contains("accept")) { - acceptHeader = req.headers.find("accept")->second; - } else { - FAIL("no accept headers found"); - } - res.set_content("Hello World!", acceptHeader); - }); - client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); }); - while (!server.is_running()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - REQUIRE(server.is_running()); - - std::atomic done(false); - IoBuffer data; - data.put('A'); - data.put('B'); - data.put('C'); - data.put(0); - Command command; - command.command = mdp::Command::Get; - command.topic = URI("https://localhost:8080/endPoint"); - command.data = std::move(data); - command.callback = [&done](const mdp::Message & /*rep*/) { - done.store(true, std::memory_order_release); - done.notify_all(); - }; - client.request(command); - - done.wait(false); - REQUIRE(done.load(std::memory_order_acquire) == true); - REQUIRE(acceptHeader == MIME::JSON.typeName()); - server.stop(); -} -#endif - -namespace detail { -class EventDispatcher { - std::mutex _mutex; - std::condition_variable _condition; - std::atomic _id{ 0 }; - std::atomic _cid{ -1 }; - std::string _message; - -public: - void wait_event(httplib::DataSink &sink) { - std::unique_lock lk(_mutex); - int id = _id; - _condition.wait(lk, [&id, this] { return _cid == id; }); - if (sink.is_writable()) { - sink.write(_message.data(), _message.size()); - } - } - - void send_event(const std::string_view &message) { - std::scoped_lock lk(_mutex); - _cid = _id++; - _message = message; - _condition.notify_all(); - } -}; -} // namespace detail - -TEST_CASE("Basic Rest Client Subscribe/Unsubscribe Test", "[Client]") { - using namespace opencmw::client; - - std::atomic updateCounter{ 0 }; - detail::EventDispatcher eventDispatcher; - httplib::Server server; - server.Get("/event", [&eventDispatcher, &updateCounter](const httplib::Request &req, httplib::Response &res) { - auto acceptType = req.headers.find("accept"); - if (acceptType == req.headers.end() || MIME::EVENT_STREAM.typeName() != acceptType->second) { // non-SSE request -> return default response -#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16)) - res.set_content(fmt::format("update counter = {}", updateCounter.load()), MIME::TEXT); -#else - res.set_content(fmt::format("update counter = {}", updateCounter.load()), std::string(MIME::TEXT.typeName())); -#endif - return; - } else { - fmt::print("server received SSE request on path '{}' body = '{}'\n", req.path, req.body); -#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16)) - res.set_chunked_content_provider(MIME::EVENT_STREAM, [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) { -#else - res.set_chunked_content_provider(std::string(MIME::EVENT_STREAM.typeName()), [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) { -#endif - eventDispatcher.wait_event(sink); - return true; - }); - } - }); - server.Get("/endPoint", [](const httplib::Request &req, httplib::Response &res) { - fmt::print("server received request on path '{}' body = '{}'\n", req.path, req.body); - res.set_content("Hello World!", "text/plain"); - }); - - RestClient client; - client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); }); - while (!server.is_running()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - REQUIRE(server.is_running()); - - std::atomic receivedRegular(0); - std::atomic receivedError(0); - IoBuffer data; - data.put('A'); - data.put('B'); - data.put('C'); - data.put(0); - - Command command; - command.command = mdp::Command::Subscribe; - command.topic = URI("http://localhost:8080/event"); - command.data = std::move(data); - command.callback = [&receivedRegular, &receivedError](const mdp::Message &rep) { - fmt::print("SSE client received reply = '{}' - body size: '{}'\n", rep.data.asString(), rep.data.size()); - if (rep.error.size() == 0) { - receivedRegular.fetch_add(1, std::memory_order_relaxed); - } else { - receivedError.fetch_add(1, std::memory_order_relaxed); - } - receivedRegular.notify_all(); - receivedError.notify_all(); - }; - - client.request(command); - - std::cout << "client request launched" << std::endl; - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - eventDispatcher.send_event("test-event meta data"); - std::jthread dispatcher([&updateCounter, &eventDispatcher] { - while (updateCounter < 5) { - std::this_thread::sleep_for(std::chrono::milliseconds(20)); - eventDispatcher.send_event(fmt::format("test-event {}", updateCounter++)); - } - }); - dispatcher.join(); - - while (receivedRegular.load(std::memory_order_relaxed) < 5) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - std::cout << "done waiting" << std::endl; - REQUIRE(receivedRegular.load(std::memory_order_acquire) >= 5); - - command.command = mdp::Command::Unsubscribe; - client.request(command); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - std::cout << "done Unsubscribe" << std::endl; - - client.stop(); - server.stop(); - eventDispatcher.send_event(fmt::format("test-event {}", updateCounter++)); - std::cout << "server stopped" << std::endl; -} - -TEST_CASE("Basic Rest Client Subscribe/Unsubscribe Test HTTPS", "[Client]") { - // HTTP - X509 *cert = opencmw::client::detail::readServerCertificateFromFile(testServerCertificates.serverCertificate); - EVP_PKEY *pkey = opencmw::client::detail::readServerPrivateKeyFromFile(testServerCertificates.serverKey); - if (const X509_STORE *ca_store = opencmw::client::detail::createCertificateStore(testServerCertificates.caCertificate); !cert || !pkey || !ca_store) { - FAIL(fmt::format("Failed to load certificate: {}", ERR_error_string(ERR_get_error(), nullptr))); - } - using namespace opencmw::client; - - std::atomic updateCounter{ 0 }; - detail::EventDispatcher eventDispatcher; - httplib::SSLServer server(cert, pkey); - server.Get("/event", [&eventDispatcher, &updateCounter](const httplib::Request &req, httplib::Response &res) { - DEBUG_LOG("Server received request"); - auto acceptType = req.headers.find("accept"); - if (acceptType == req.headers.end() || MIME::EVENT_STREAM.typeName() != acceptType->second) { // non-SSE request -> return default response -#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16)) - res.set_content(fmt::format("update counter = {}", updateCounter.load()), MIME::TEXT); -#else - res.set_content(fmt::format("update counter = {}", updateCounter.load()), std::string(MIME::TEXT.typeName())); -#endif - return; - } else { - fmt::print("server received SSE request on path '{}' body = '{}'\n", req.path, req.body); -#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16)) - res.set_chunked_content_provider(MIME::EVENT_STREAM, [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) { -#else - res.set_chunked_content_provider(std::string(MIME::EVENT_STREAM.typeName()), [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) { -#endif - eventDispatcher.wait_event(sink); - return true; - }); - } - }); - server.Get("/endPoint", [](const httplib::Request &req, httplib::Response &res) { - fmt::print("server received request on path '{}' body = '{}'\n", req.path, req.body); - res.set_content("Hello World!", "text/plain"); - }); - - RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate)); - - client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); }); - while (!server.is_running()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - REQUIRE(server.is_running()); - REQUIRE(RestClient::CHECK_CERTIFICATES); - RestClient::CHECK_CERTIFICATES = true; // 'false' disables certificate check - REQUIRE(client.name() == "TestSSLClient"); - REQUIRE(client.defaultMimeType() == MIME::JSON); - - std::atomic receivedRegular(0); - std::atomic receivedError(0); - IoBuffer data; - data.put('A'); - data.put('B'); - data.put('C'); - data.put(0); - - Command command; - command.command = mdp::Command::Subscribe; - command.topic = URI("https://localhost:8080/event"); - command.data = std::move(data); - command.callback = [&receivedRegular, &receivedError](const mdp::Message &rep) { - fmt::print("SSE client received reply = '{}' - body size: '{}'\n", rep.data.asString(), rep.data.size()); - if (rep.error.size() == 0) { - receivedRegular.fetch_add(1, std::memory_order_relaxed); - } else { - receivedError.fetch_add(1, std::memory_order_relaxed); - } - receivedRegular.notify_all(); - receivedError.notify_all(); - }; - - client.request(command); - - std::cout << "client request launched" << std::endl; - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - eventDispatcher.send_event("test-event meta data"); - std::jthread dispatcher([&updateCounter, &eventDispatcher] { - while (updateCounter < 5) { - std::this_thread::sleep_for(std::chrono::milliseconds(20)); - eventDispatcher.send_event(fmt::format("test-event {}", updateCounter++)); - } - }); - dispatcher.join(); - - while (receivedRegular.load(std::memory_order_relaxed) < 5) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - std::cout << "done waiting" << std::endl; - REQUIRE(receivedRegular.load(std::memory_order_acquire) >= 5); - - command.command = mdp::Command::Unsubscribe; - client.request(command); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - std::cout << "done Unsubscribe" << std::endl; - - client.stop(); - server.stop(); - eventDispatcher.send_event(fmt::format("test-event {}", updateCounter++)); - std::cout << "server stopped" << std::endl; -} - -} // namespace opencmw::rest_client_test diff --git a/src/client/test/nghttp2_tests.cpp b/src/client/test/nghttp2_tests.cpp new file mode 100644 index 00000000..9a2e341a --- /dev/null +++ b/src/client/test/nghttp2_tests.cpp @@ -0,0 +1,516 @@ +#include +#include + +#include "zmq.h" +#include + +#include +CMRC_DECLARE(assets); + +#include +#include +#include +#include +#include + +constexpr const char *testCertificate = R"( + R"( + GlobalSign Root CA + ================== + -----BEGIN CERTIFICATE----- + MIIDdTCCAl2gAwIBAgILBAAAAAABFUtaw5QwDQYJKoZIhvcNAQEFBQAwVzELMAkGA1UEBhMCQkUx + GTAXBgNVBAoTEEdsb2JhbFNpZ24gbnYtc2ExEDAOBgNVBAsTB1Jvb3QgQ0ExGzAZBgNVBAMTEkds + b2JhbFNpZ24gUm9vdCBDQTAeFw05ODA5MDExMjAwMDBaFw0yODAxMjgxMjAwMDBaMFcxCzAJBgNV + BAYTAkJFMRkwFwYDVQQKExBHbG9iYWxTaWduIG52LXNhMRAwDgYDVQQLEwdSb290IENBMRswGQYD + VQQDExJHbG9iYWxTaWduIFJvb3QgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDa + DuaZjc6j40+Kfvvxi4Mla+pIH/EqsLmVEQS98GPR4mdmzxzdzxtIK+6NiY6arymAZavpxy0Sy6sc + THAHoT0KMM0VjU/43dSMUBUc71DuxC73/OlS8pF94G3VNTCOXkNz8kHp1Wrjsok6Vjk4bwY8iGlb + Kk3Fp1S4bInMm/k8yuX9ifUSPJJ4ltbcdG6TRGHRjcdGsnUOhugZitVtbNV4FpWi6cgKOOvyJBNP + c1STE4U6G7weNLWLBYy5d4ux2x8gkasJU26Qzns3dLlwR5EiUWMWea6xrkEmCMgZK9FGqkjWZCrX + gzT/LCrBbBlDSgeF59N89iFo7+ryUp9/k5DPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV + HRMBAf8EBTADAQH/MB0GA1UdDgQWBBRge2YaRQ2XyolQL30EzTSo//z9SzANBgkqhkiG9w0BAQUF + AAOCAQEA1nPnfE920I2/7LqivjTFKDK1fPxsnCwrvQmeU79rXqoRSLblCKOzyj1hTdNGCbM+w6Dj + Y1Ub8rrvrTnhQ7k4o+YviiY776BQVvnGCv04zcQLcFGUl5gE38NflNUVyRRBnMRddWQVDf9VMOyG + j/8N7yy5Y0b2qvzfvGn9LhJIZJrglfCm7ymPAbEVtQwdpf5pLGkkeB6zpxxxYu7KyJesF12KwvhH + hm4qxFYxldBniYUr+WymXUadDKqC5JlR3XC321Y9YeRq4VzW9v493kHMB65jUr9TU/Qr6cf9tveC + X4XSQRjbgbMEHMUfpIBvFSDJ3gyICh3WZlXi/EjJKSZp4A== + -----END CERTIFICATE----- + )"; + +class TestServerCertificates { + const cmrc::embedded_filesystem fileSystem = cmrc::assets::get_filesystem(); + const cmrc::file ca_certificate = fileSystem.open("/assets/ca-cert.pem"); + // server-req.pem -> is usually used to request for the CA signature + const cmrc::file server_cert = fileSystem.open("/assets/server-cert.pem"); + const cmrc::file server_key = fileSystem.open("/assets/server-key.pem"); + const cmrc::file client_cert = fileSystem.open("/assets/client-cert.pem"); + const cmrc::file client_key = fileSystem.open("/assets/client-key.pem"); + const cmrc::file pwd = fileSystem.open("/assets/password.txt"); + +public: + const std::string caCertificate = { ca_certificate.begin(), ca_certificate.end() }; + const std::string serverCertificate = { server_cert.begin(), server_cert.end() }; + const std::string serverKey = { server_key.begin(), server_key.end() }; + const std::string clientCertificate = { client_cert.begin(), client_cert.end() }; + const std::string clientKey = { client_key.begin(), client_key.end() }; + const std::string password = { pwd.begin(), pwd.end() }; +}; +inline static const TestServerCertificates testServerCertificates; + +using namespace opencmw; +using namespace opencmw::majordomo::detail::nghttp2; +using opencmw::URI; + +constexpr uint16_t kServerPort = 33339; + +void ensureMessageReceived(Http2Server &server, std::stop_token stopToken, std::deque &messages, std::chrono::milliseconds timeout = std::chrono::seconds(5)) { + const auto start = std::chrono::system_clock::now(); + while (!stopToken.stop_requested() && std::chrono::system_clock::now() - start < timeout) { + if (!messages.empty()) { + return; + } + + std::vector pollerItems; + server.populatePollerItems(pollerItems); + + const int rc = zmq_poll(pollerItems.data(), static_cast(pollerItems.size()), 500); + + if (rc == 0) { + continue; + } + + for (const auto &item : pollerItems) { + auto ms = server.processReadWrite(item.fd, item.revents & ZMQ_POLLIN, item.revents & ZMQ_POLLOUT); + messages.insert(messages.end(), ms.begin(), ms.end()); + } + } +}; + +bool waitFor(std::atomic &responseCount, int expected, std::chrono::milliseconds timeout = std::chrono::seconds(5)) { + const auto start = std::chrono::system_clock::now(); + while (responseCount.load() < expected && std::chrono::system_clock::now() - start < timeout) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + if (responseCount.load() < expected) { + FAIL(fmt::format("Expected {} responses, but got only {}\n", expected, responseCount.load())); + } + return responseCount.load() == expected; +} + +struct Stopper { + std::stop_source source; + explicit Stopper(std::stop_source s) + : source(s) {} + ~Stopper() { + source.request_stop(); + } +}; + +static std::string normalize(URI<> uri) { + return opencmw::mdp::Topic::fromMdpTopic(uri).toZmqTopic(); +} + +TEST_CASE("Basic Client Constructor and API Tests", "[http2]") { + using namespace opencmw::client; + + RestClient client1; + REQUIRE(client1.defaultMimeType() == MIME::JSON); + REQUIRE(client1.verifySslPeers() == true); + + auto client2 = RestClient(DefaultContentTypeHeader(MIME::HTML), ClientCertificates(testCertificate)); + REQUIRE(client2.defaultMimeType() == MIME::HTML); + REQUIRE(client2.verifySslPeers() == true); + + RestClient client3(DefaultContentTypeHeader(MIME::BINARY), VerifyServerCertificates(false)); + REQUIRE(client3.defaultMimeType() == MIME::BINARY); + REQUIRE(client3.verifySslPeers() == false); +} + +TEST_CASE("GET HTTP", "[http2]") { + using namespace opencmw::client; + + auto serverThread = std::jthread([](std::stop_token stopToken) { + Http2Server server; + REQUIRE(server.bind(kServerPort)); + + std::deque messages; + ensureMessageReceived(server, stopToken, messages); + REQUIRE(messages.size() >= 1); + const auto req0 = std::move(messages[0]); + messages.pop_front(); + REQUIRE(req0.command == mdp::Command::Get); + REQUIRE(req0.topic.path() == "/sayhello"); + REQUIRE(req0.error.empty()); + REQUIRE(req0.data.empty()); + + Message reply0; + reply0.command = mdp::Command::Final; + reply0.clientRequestID = req0.clientRequestID; + reply0.topic = URI<>("/sayhello"); + reply0.data = opencmw::IoBuffer("Hello, World!"); + server.handleResponse(std::move(reply0)); + + ensureMessageReceived(server, stopToken, messages); + }); + + // Client using plain http + RestClient http; + Stopper stopper(serverThread.get_stop_source()); + std::this_thread::sleep_for(std::chrono::milliseconds(300)); // give the server some time to start listening + + std::atomic responseCount = 0; + + client::Command req0; + req0.command = mdp::Command::Get; + req0.clientRequestID = opencmw::IoBuffer("0"); + req0.topic = URI<>(fmt::format("http://localhost:{}/sayhello?client=http", kServerPort)); + req0.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "Hello, World!"); + REQUIRE(msg.clientRequestID.asString() == "0"); + REQUIRE(msg.topic.path() == "/sayhello"); + responseCount++; + }; + http.request(std::move(req0)); + + // Client that verifies the server's certificate and trusts its CA + RestClient https(ClientCertificates(testServerCertificates.caCertificate)); + client::Command req1; + req1.command = mdp::Command::Get; + req1.topic = URI<>(fmt::format("https://localhost:{}/sayhello?client=https", kServerPort)); + req1.clientRequestID = opencmw::IoBuffer("0"); + req1.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.protocolName == "https"); + REQUIRE(msg.error.contains("Could not connect to endpoint")); + REQUIRE(msg.data.asString() == ""); + REQUIRE(msg.clientRequestID.asString() == "0"); + REQUIRE(msg.topic.path() == "/sayhello"); + responseCount++; + }; + https.request(std::move(req1)); + + REQUIRE(waitFor(responseCount, 2)); +} + +TEST_CASE("HTTPS", "[http2]") { + using namespace opencmw::client; + + auto serverThread = std::jthread([&](std::stop_token stopToken) { + auto server = Http2Server::sslWithBuffers(testServerCertificates.serverCertificate, testServerCertificates.serverKey); + if (!server) { + FAIL(fmt::format("Failed to create server: {}", server.error())); + return; + } + + REQUIRE(server->bind(kServerPort)); + + std::deque messages; + for (int i = 0; i < 2; i++) { + ensureMessageReceived(server.value(), stopToken, messages); + REQUIRE(messages.size() >= 1); + const auto req0 = std::move(messages[0]); + messages.pop_front(); + REQUIRE(req0.command == mdp::Command::Get); + REQUIRE(req0.topic.path() == "/sayhello"); + REQUIRE(req0.error.empty()); + REQUIRE(req0.data.empty()); + + Message reply0; + reply0.command = mdp::Command::Final; + reply0.clientRequestID = req0.clientRequestID; + reply0.topic = URI<>("/sayhello"); + reply0.data = opencmw::IoBuffer("Hello, World!"); + server->handleResponse(std::move(reply0)); + } + + const auto start = std::chrono::system_clock::now(); + while (!stopToken.stop_requested() && std::chrono::system_clock::now() - start < std::chrono::seconds(5)) { + ensureMessageReceived(server.value(), stopToken, messages); + } + }); + + Stopper stopper(serverThread.get_stop_source()); + std::this_thread::sleep_for(std::chrono::milliseconds(300)); // give the server some time to start listening + + std::atomic responseCount = 0; + + // Client that doesn't verify the server's certificate + RestClient doesntCare(VerifyServerCertificates(false)); + client::Command req0; + req0.command = mdp::Command::Get; + req0.topic = URI<>(fmt::format("https://localhost:{}/sayhello?client=doesntCare", kServerPort)); + req0.clientRequestID = opencmw::IoBuffer("0"); + req0.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.protocolName == "https"); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "Hello, World!"); + REQUIRE(msg.clientRequestID.asString() == "0"); + REQUIRE(msg.topic == URI<>("/sayhello")); + responseCount++; + }; + doesntCare.request(std::move(req0)); + + // Client that verifies the server's certificate and trusts its CA + RestClient trusting(ClientCertificates(testServerCertificates.caCertificate)); + client::Command req1; + req1.command = mdp::Command::Get; + req1.topic = URI<>(fmt::format("https://localhost:{}/sayhello?client=trusting", kServerPort)); + req1.clientRequestID = opencmw::IoBuffer("0"); + req1.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.protocolName == "https"); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "Hello, World!"); + REQUIRE(msg.clientRequestID.asString() == "0"); + REQUIRE(msg.topic.path() == "/sayhello"); + responseCount++; + }; + trusting.request(std::move(req1)); + + // Client that verifies the server's certificate but doesn't trust its CA + RestClient notTrusting; + client::Command req2; + req2.command = mdp::Command::Get; + req2.topic = URI<>(fmt::format("https://localhost:{}/sayhello?client=notTrusting", kServerPort)); + req2.clientRequestID = opencmw::IoBuffer("0"); + req2.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.protocolName == "https"); + REQUIRE(msg.error.contains("Could not connect to endpoint")); + REQUIRE(msg.clientRequestID.asString() == "0"); + REQUIRE(msg.topic.path() == "/sayhello"); + responseCount++; + }; + notTrusting.request(std::move(req2)); + + REQUIRE(waitFor(responseCount, 3)); +} + +TEST_CASE("GET/SET", "[http2]") { + auto serverThread = std::jthread([](std::stop_token stopToken) { + Http2Server server; + REQUIRE(server.bind(kServerPort)); + + std::deque messages; + ensureMessageReceived(server, stopToken, messages); + REQUIRE(messages.size() >= 1); + const auto req0 = std::move(messages[0]); + messages.pop_front(); + REQUIRE(req0.command == mdp::Command::Get); + REQUIRE(normalize(req0.topic) == normalize(URI<>("/foo?contentType=application%2Fjson¶m1=1¶m2=foo%2Fbar"))); + REQUIRE(req0.error.empty()); + REQUIRE(req0.data.empty()); + + Message reply0; + reply0.command = mdp::Command::Final; + reply0.clientRequestID = req0.clientRequestID; + reply0.topic = URI<>("/foo?param1=1¶m2=foo%2fbar"); + reply0.data = opencmw::IoBuffer("Hello, World!"); + server.handleResponse(std::move(reply0)); + + ensureMessageReceived(server, stopToken, messages); + REQUIRE(messages.size() >= 1); + const auto req1 = std::move(messages[0]); + messages.pop_front(); + REQUIRE(req1.command == mdp::Command::Get); + REQUIRE(normalize(req1.topic) == normalize(URI<>("/bar?contentType=application%2Fjson¶m1=1¶m2=2"))); + REQUIRE(req1.error.empty()); + REQUIRE(req1.data.empty()); + + Message reply1; + reply1.command = mdp::Command::Final; + reply1.clientRequestID = req1.clientRequestID; + reply1.topic = URI<>("/bar?param1=1¶m2=2"); + reply1.error = "'bar' not found"; + server.handleResponse(std::move(reply1)); + + ensureMessageReceived(server, stopToken, messages); + REQUIRE(messages.size() >= 1); + const auto req2 = std::move(messages[0]); + messages.pop_front(); + REQUIRE(req2.command == mdp::Command::Set); + REQUIRE(normalize(req2.topic) == normalize(URI<>("/setexample?contentType=application%2Fjson"))); + REQUIRE(req2.error.empty()); + REQUIRE(req2.data.asString() == "Some data with\ttabs\nand newlines\x01"); + + Message reply2; + reply2.command = mdp::Command::Final; + reply2.clientRequestID = req2.clientRequestID; + reply2.topic = URI<>("/setexample"); + reply2.data = opencmw::IoBuffer("value set"); + server.handleResponse(std::move(reply2)); + + ensureMessageReceived(server, stopToken, messages); // makes sure responses are sent + }); + + Stopper stopper(serverThread.get_stop_source()); + std::this_thread::sleep_for(std::chrono::milliseconds(300)); // give the server some time to start listening + + std::atomic responseCount = 0; + + client::RestClient client(client::VerifyServerCertificates(false)); + + opencmw::client::Command req0; + req0.command = mdp::Command::Get; + req0.clientRequestID = opencmw::IoBuffer("0"); + req0.topic = URI<>(fmt::format("http://localhost:{}/foo?param1=1¶m2=foo%2fbar", kServerPort)); + req0.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.protocolName == "http"); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "Hello, World!"); + REQUIRE(msg.clientRequestID.asString() == "0"); + REQUIRE(msg.topic == URI<>("/foo?param1=1¶m2=foo%2fbar")); + responseCount++; + }; + client.request(std::move(req0)); + + client::Command req1; + req1.command = mdp::Command::Get; + req1.clientRequestID = opencmw::IoBuffer("1"); + req1.topic = URI<>(fmt::format("http://localhost:{}/bar?param1=1¶m2=2", kServerPort)); + req1.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.protocolName == "http"); + REQUIRE(msg.error == "'bar' not found"); + REQUIRE(msg.data.asString() == ""); + REQUIRE(msg.clientRequestID.asString() == "1"); + REQUIRE(msg.topic == URI<>("/bar?param1=1¶m2=2")); + responseCount++; + }; + client.request(std::move(req1)); + + client::Command req2; + req2.command = mdp::Command::Set; + req2.clientRequestID = opencmw::IoBuffer("2"); + req2.topic = URI<>(fmt::format("http://localhost:{}/setexample", kServerPort)); + req2.data = opencmw::IoBuffer("Some data with\ttabs\nand newlines\x01"); + req2.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.protocolName == "http"); + REQUIRE(msg.clientRequestID.asString() == "2"); + REQUIRE(msg.topic == URI<>("/setexample")); + REQUIRE(msg.data.asString() == "value set"); + responseCount++; + }; + client.request(std::move(req2)); + + waitFor(responseCount, 3); +} + +TEST_CASE("Long polling example", "[http2]") { + constexpr int kFooMessages = 50; + + auto brokerThread = std::jthread([](std::stop_token stopToken) { + Http2Server server; + REQUIRE(server.bind(kServerPort)); + + const auto topic = URI<>("/foo?param1=1¶m2=foo%2Fbar"); + + std::deque messages; + ensureMessageReceived(server, stopToken, messages); + REQUIRE(messages.size() >= 1); + const auto req0 = std::move(messages[0]); + messages.pop_front(); + REQUIRE(req0.command == mdp::Command::Subscribe); + REQUIRE(normalize(req0.topic) == normalize(topic)); + REQUIRE(req0.error.empty()); + REQUIRE(req0.data.empty()); + + for (int i = 0; i < kFooMessages; ++i) { + Message notify; + notify.command = mdp::Command::Notify; + notify.serviceName = "/foo"; + notify.topic = topic; + auto data = std::to_string(i); + notify.data = opencmw::IoBuffer(data.data(), data.size()); + server.handleNotification(opencmw::mdp::Topic::fromMdpTopic(topic), std::move(notify)); + } + + ensureMessageReceived(server, stopToken, messages, std::chrono::seconds(10)); + + // sync with the client to ensure the client unsubscribed before we send more notifications + REQUIRE(messages.size() >= 1); + const auto req1 = std::move(messages[0]); + messages.pop_front(); + REQUIRE(req1.command == mdp::Command::Get); + Message reply1; + reply1.command = mdp::Command::Final; + reply1.clientRequestID = req1.clientRequestID; + reply1.topic = topic; + reply1.data = opencmw::IoBuffer("Hello"); + server.handleResponse(std::move(reply1)); + + // send notifications that should not be received + for (int i = 0; i < kFooMessages; ++i) { + Message notify; + notify.command = mdp::Command::Notify; + notify.serviceName = "/foo"; + notify.topic = topic; + auto data = std::to_string(kFooMessages + i); + notify.data = opencmw::IoBuffer(data.data(), data.size()); + server.handleNotification(opencmw::mdp::Topic::fromMdpTopic(topic), std::move(notify)); + } + + ensureMessageReceived(server, stopToken, messages, std::chrono::seconds(10)); // makes sure messages are processed + }); + + Stopper stopper(brokerThread.get_stop_source()); + std::this_thread::sleep_for(std::chrono::milliseconds(300)); // give the server some time to start listening + + { + std::atomic responseCount = 0; + + client::RestClient client; + opencmw::client::Command sub; + sub.command = mdp::Command::Subscribe; + sub.clientRequestID = opencmw::IoBuffer("0"); + sub.topic = URI<>(fmt::format("http://localhost:{}/foo?param1=1¶m2=foo%2fbar", kServerPort)); + sub.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Notify); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == std::to_string(responseCount)); + REQUIRE(msg.protocolName == "http"); + REQUIRE(normalize(msg.topic) == normalize(URI<>("/foo?param1=1¶m2=foo%2Fbar"))); + responseCount++; + }; + client.request(std::move(sub)); + waitFor(responseCount, kFooMessages); + responseCount = 0; + + // unsubscribe + client::Command unsub; + unsub.command = mdp::Command::Unsubscribe; + unsub.clientRequestID = opencmw::IoBuffer("0"); + unsub.topic = URI<>(fmt::format("http://localhost:{}/foo?param1=1¶m2=foo%2fbar", kServerPort)); + unsub.callback = [](const mdp::Message &) { + FAIL("Unsubscribe should not receive a message"); + }; + client.request(std::move(unsub)); + + // Send a GET to sync with server + client::Command get; + get.command = mdp::Command::Get; + get.clientRequestID = opencmw::IoBuffer("1"); + get.topic = URI<>(fmt::format("http://localhost:{}/foo?param1=1¶m2=foo%2fbar", kServerPort)); + get.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.protocolName == "http"); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "Hello"); + REQUIRE(normalize(msg.topic) == normalize(URI<>("/foo?param1=1¶m2=foo%2Fbar"))); + responseCount++; + }; + client.request(std::move(get)); + + waitFor(responseCount, 1); + responseCount -= 1; + + // wait some time to make sure no more messages are received + std::this_thread::sleep_for(std::chrono::milliseconds(250)); + REQUIRE(responseCount == 0); + } +} diff --git a/src/core/include/LoadTest.hpp b/src/core/include/LoadTest.hpp new file mode 100644 index 00000000..b886193e --- /dev/null +++ b/src/core/include/LoadTest.hpp @@ -0,0 +1,37 @@ +#ifndef OPENCMW_CPP_LOAD_TEST_HPP +#define OPENCMW_CPP_LOAD_TEST_HPP + +#include + +#include + +#include +#include + +namespace opencmw::load_test { + +struct Context { + std::string topic; + std::int64_t intervalMs = 1000; // must be multiple of 10 (enforce?) + std::int64_t payloadSize = 100; + std::int64_t nUpdates = -1; // -1 means infinite + std::int64_t initialDelayMs = 0; + opencmw::MIME::MimeType contentType = opencmw::MIME::BINARY; +}; + +struct Payload { + std::int64_t index; + std::string data; + std::int64_t timestampNs = 0; +}; + +inline std::chrono::nanoseconds timestamp() { + return std::chrono::duration_cast(std::chrono::high_resolution_clock::now().time_since_epoch()); +} + +} // namespace opencmw::load_test + +ENABLE_REFLECTION_FOR(opencmw::load_test::Context, topic, intervalMs, payloadSize, nUpdates, initialDelayMs, contentType) +ENABLE_REFLECTION_FOR(opencmw::load_test::Payload, index, data, timestampNs) + +#endif diff --git a/src/core/include/TimingCtx.hpp b/src/core/include/TimingCtx.hpp index 50bf3642..30c7251f 100644 --- a/src/core/include/TimingCtx.hpp +++ b/src/core/include/TimingCtx.hpp @@ -157,7 +157,7 @@ class TimingCtx { if (WILDCARD != valueString) { int32_t intValue = 0; - if (const auto result = std::from_chars(valueString.begin(), valueString.end(), intValue); result.ec == std::errc::invalid_argument) { + if (const auto result = std::from_chars(valueString.data(), valueString.data() + valueString.size(), intValue); result.ec == std::errc::invalid_argument) { _hash = 0; throw std::invalid_argument(fmt::format("Value: '{}' in '{}' is not a valid integer", valueString, tag)); } diff --git a/src/core/include/Topic.hpp b/src/core/include/Topic.hpp index 35fafac3..a3a7fc27 100644 --- a/src/core/include/Topic.hpp +++ b/src/core/include/Topic.hpp @@ -67,6 +67,7 @@ struct Topic { Params _params; public: + Topic() = default; Topic(const Topic &other) = default; Topic &operator=(const Topic &) = default; Topic(Topic &&) noexcept = default; @@ -167,7 +168,7 @@ struct Topic { , _params(std::move(params)) { if (serviceOrServiceAndQuery.find("?") != std::string::npos) { if (!_params.empty()) { - throw std::invalid_argument(fmt::format("Parameters are not empty ({}), and there are more in the service string ({})\n", _params, serviceOrServiceAndQuery)); + throw std::invalid_argument(fmt::format("Parameters are not empty ({}), and there are more in the service string ({})", _params, serviceOrServiceAndQuery)); } const auto parsed = opencmw::URI(std::string(serviceOrServiceAndQuery)); const auto &queryMap = parsed.queryParamMap(); @@ -175,7 +176,7 @@ struct Topic { } if (!isValidServiceName(_service)) { - throw std::invalid_argument(fmt::format("Invalid service name '{}'\n", _service)); + throw std::invalid_argument(fmt::format("Invalid service name '{}'", _service)); } } }; diff --git a/src/majordomo/CMakeLists.txt b/src/majordomo/CMakeLists.txt index 61e3a8c7..d3ebee60 100644 --- a/src/majordomo/CMakeLists.txt +++ b/src/majordomo/CMakeLists.txt @@ -1,26 +1,28 @@ # setup header only library add_library(majordomo INTERFACE - include/majordomo/Broker.hpp - include/majordomo/MockClient.hpp - include/majordomo/Constants.hpp - include/majordomo/Cryptography.hpp - include/majordomo/Rbac.hpp - include/majordomo/RestBackend.hpp - include/majordomo/Settings.hpp - include/majordomo/SubscriptionMatcher.hpp - include/majordomo/Worker.hpp + include/majordomo/Broker.hpp + include/majordomo/MockClient.hpp + include/majordomo/Constants.hpp + include/majordomo/Cryptography.hpp + include/majordomo/Http2Server.hpp + include/majordomo/Rbac.hpp + include/majordomo/Rest.hpp + include/majordomo/Settings.hpp + include/majordomo/SubscriptionMatcher.hpp + include/majordomo/Worker.hpp ) target_include_directories(majordomo INTERFACE $ $) target_link_libraries(majordomo - INTERFACE - core - serialiser - zmq - httplib::httplib - #OpenSSL::SSL - pthread - sodium - ) + INTERFACE + core + serialiser + zmq + OpenSSL::SSL + OpenSSL::Crypto + nghttp2 + pthread + sodium +) install( TARGETS majordomo diff --git a/src/majordomo/include/majordomo/Broker.hpp b/src/majordomo/include/majordomo/Broker.hpp index cce9b773..6798b540 100644 --- a/src/majordomo/include/majordomo/Broker.hpp +++ b/src/majordomo/include/majordomo/Broker.hpp @@ -15,7 +15,9 @@ #include +#include "Http2Server.hpp" #include "Rbac.hpp" +#include "Rest.hpp" #include "Topic.hpp" #include "URI.hpp" @@ -180,13 +182,16 @@ class Broker { using Timestamp = std::chrono::time_point; struct Client { - const zmq::Socket &socket; - const std::string id; - std::deque requests; - Timestamp expiry; + std::optional> socket; + const std::string id; + std::deque requests; + Timestamp expiry; explicit Client(const zmq::Socket &s, const std::string &id_, Timestamp expiry_) : socket(s), id(std::move(id_)), expiry{ std::move(expiry_) } {} + + explicit Client(const std::string &id_) + : id(std::move(id_)), expiry{} {} }; struct Worker { @@ -260,6 +265,7 @@ class Broker { const std::string brokerName; private: + std::unique_ptr _restServer; Timestamp _heartbeatAt = Clock::now() + settings.heartbeatInterval; SubscriptionMatcher _subscriptionMatcher; std::unordered_map> _subscribedClientsByTopic; // topic -> client IDs @@ -277,11 +283,13 @@ class Broker { std::atomic _shutdownRequested = false; // Sockets collection. The Broker class will be used as the handler - const zmq::Socket _routerSocket; - const zmq::Socket _pubSocket; - const zmq::Socket _subSocket; - const zmq::Socket _dnsSocket; - std::array pollerItems; + const zmq::Socket _routerSocket; + const zmq::Socket _pubSocket; + const zmq::Socket _subSocket; + const zmq::Socket _dnsSocket; + static constexpr std::size_t kNumberOfZmqSockets = 4; + static constexpr std::string_view kRestSourceId = "rest"; // assuming that ZMQ never assigns this ID. Alternatively connect an idle client and use its ID. + std::vector _pollerItems = std::vector(kNumberOfZmqSockets); public: Broker(std::string brokerName_, Settings settings_ = {}) @@ -393,14 +401,14 @@ class Broker { zmq::invoke(zmq_connect, _dnsSocket, INTERNAL_ADDRESS_BROKER.str().data()).assertSuccess(); } - pollerItems[0].socket = _routerSocket.zmq_ptr; - pollerItems[0].events = ZMQ_POLLIN; - pollerItems[1].socket = _pubSocket.zmq_ptr; - pollerItems[1].events = ZMQ_POLLIN; - pollerItems[2].socket = _subSocket.zmq_ptr; - pollerItems[2].events = ZMQ_POLLIN; - pollerItems[3].socket = _dnsSocket.zmq_ptr; - pollerItems[3].events = ZMQ_POLLIN; + _pollerItems[0].socket = _routerSocket.zmq_ptr; + _pollerItems[0].events = ZMQ_POLLIN; + _pollerItems[1].socket = _pubSocket.zmq_ptr; + _pollerItems[1].events = ZMQ_POLLIN; + _pollerItems[2].socket = _subSocket.zmq_ptr; + _pollerItems[2].events = ZMQ_POLLIN; + _pollerItems[3].socket = _dnsSocket.zmq_ptr; + _pollerItems[3].events = ZMQ_POLLIN; } Broker(const Broker &) = delete; @@ -426,6 +434,7 @@ class Broker { */ std::optional> bind(const URI &endpoint, BindOption option = BindOption::DetectFromURI) { assert(!(option == BindOption::DetectFromURI && (endpoint.scheme() == SCHEME_INPROC || endpoint.scheme() == SCHEME_TCP))); + const auto isRouterSocket = option != BindOption::Pub && (option == BindOption::Router || endpoint.scheme() == SCHEME_MDP || endpoint.scheme() == SCHEME_TCP); const auto zmqEndpoint = mdp::toZeroMQEndpoint(endpoint); @@ -445,6 +454,30 @@ class Broker { return adjustedAddressPublic; } + std::expected bindRest(rest::Settings restSettings) { + if (restSettings.certificateFilePath.empty() != restSettings.keyFilePath.empty()) { + return std::unexpected("Provide both certificate and key file paths (for HTTPS) or none (for HTTP)"); + } + if (restSettings.certificateFileBuffer.empty() != restSettings.keyFileBuffer.empty()) { + return std::unexpected("Provide both certificate and key file paths (for HTTPS) or none (for HTTP)"); + } + + std::expected maybeServer; + if (!restSettings.certificateFilePath.empty()) { + maybeServer = detail::nghttp2::Http2Server::sslWithPaths(std::move(restSettings.certificateFilePath), std::move(restSettings.keyFilePath)); + } else if (!restSettings.certificateFileBuffer.empty()) { + maybeServer = detail::nghttp2::Http2Server::sslWithBuffers(std::move(restSettings.certificateFileBuffer), std::move(restSettings.keyFileBuffer)); + } else { + maybeServer = detail::nghttp2::Http2Server::unencrypted(); + } + if (!maybeServer) { + return std::unexpected(maybeServer.error()); + } + _restServer = std::make_unique(std::move(maybeServer.value())); + _restServer->setHandlers(std::move(restSettings.handlers)); + return _restServer->bind(restSettings.port); + } + void run() { sendDnsHeartbeats(true); // initial register of default routes @@ -461,6 +494,17 @@ class Broker { // test interface bool processMessages() { + for (std::size_t i = kNumberOfZmqSockets; i < _pollerItems.size(); ++i) { + const auto read = _pollerItems[i].revents & ZMQ_POLLIN; + const auto write = _pollerItems[i].revents & ZMQ_POLLOUT; + if (read || write) { + auto messages = _restServer->processReadWrite(_pollerItems[i].fd, read, write); + for (auto &message : messages) { + handleRestMessage(std::move(message)); + } + } + } + bool anythingReceived; int loopCount = 0; do { @@ -480,9 +524,13 @@ class Broker { loopCount++; } while (anythingReceived); + _pollerItems.resize(kNumberOfZmqSockets); + if (_restServer) { + _restServer->populatePollerItems(_pollerItems); + } + // N.B. block until data arrived or for at most one heart-beat interval - const auto result = zmq::invoke(zmq_poll, pollerItems.data(), static_cast(pollerItems.size()), settings.heartbeatInterval.count()); - return result.isValid(); + return zmq::invoke(zmq_poll, _pollerItems.data(), static_cast(_pollerItems.size()), settings.heartbeatInterval.count()).isValid(); } void cleanup() { @@ -563,7 +611,53 @@ class Broker { return true; } - bool receiveMessage(const zmq::Socket &socket) { + void handleRestMessage(BrokerMessage &&message) { + message.sourceId = std::string(kRestSourceId); + message.protocolName = mdp::clientProtocol; + switch (message.command) { + case mdp::Command::Get: + case mdp::Command::Set: { + auto [client, inserted] = _clients.try_emplace(message.sourceId, message.sourceId); + client->second.requests.emplace_back(std::move(message)); + } break; + case mdp::Command::Subscribe: { + mdp::Topic subscription; + try { + subscription = mdp::Topic::fromMdpTopic(message.topic); + } catch (...) { + // malformed topic, ignore + return; + } + + subscribe(subscription); + auto [it, inserted] = _subscribedClientsByTopic.try_emplace(subscription); + it->second.emplace(message.sourceId); + } break; + case mdp::Command::Unsubscribe: { + mdp::Topic subscription; + try { + subscription = mdp::Topic::fromMdpTopic(message.topic); + } catch (...) { + // malformed topic, ignore + return; + } + + unsubscribe(subscription); + auto it = _subscribedClientsByTopic.find(subscription); + if (it != _subscribedClientsByTopic.end()) { + it->second.erase(message.sourceId); + if (it->second.empty()) { + _subscribedClientsByTopic.erase(it); + } + } + } break; + default: + break; + } + } + + bool + receiveMessage(const zmq::Socket &socket) { auto maybeMessage = zmq::receive(socket); if (!maybeMessage) { @@ -708,9 +802,15 @@ class Broker { const auto it = _subscribedClientsByTopic.find(topic); if (it != _subscribedClientsByTopic.end()) { for (const auto &clientId : it->second) { - auto clientCopy = message; - clientCopy.sourceId = clientId; - zmq::send(std::move(clientCopy), _routerSocket).assertSuccess(); + auto clientCopy = message; + if (clientId == kRestSourceId) { + if (_restServer) { + _restServer->handleNotification(it->first, std::move(clientCopy)); + } + } else { + clientCopy.sourceId = clientId; + zmq::send(std::move(clientCopy), _routerSocket).assertSuccess(); + } } } } @@ -749,13 +849,19 @@ class Broker { if (client.requests.empty()) continue; - auto clientMessage = std::move(client.requests.back()); - client.requests.pop_back(); + auto clientMessage = std::move(client.requests.front()); + client.requests.pop_front(); if (auto service = bestMatchingService(clientMessage.serviceName)) { if (service->internalHandler) { auto reply = service->internalHandler(std::move(clientMessage)); - zmq::send(std::move(reply), client.socket).assertSuccess(); + if (client.id == kRestSourceId) { + if (_restServer) { + _restServer->handleResponse(std::move(reply)); + } + } else { + zmq::send(std::move(reply), client.socket.value()).assertSuccess(); + } } else { service->putMessage(std::move(clientMessage)); dispatch(*service); @@ -772,7 +878,13 @@ class Broker { reply.error = fmt::format("unknown service (error 501): '{}'", reply.serviceName); reply.rbac = _rbac; - zmq::send(std::move(reply), client.socket).assertSuccess(); + if (client.id == kRestSourceId) { + if (_restServer) { + _restServer->handleResponse(std::move(reply)); + } + } else { + zmq::send(std::move(reply), client.socket.value()).assertSuccess(); + } } } @@ -785,7 +897,7 @@ class Broker { const auto isExpired = [&now](const auto &c) { auto &[senderId, client] = c; - return client.expiry < now; + return senderId != kRestSourceId && client.expiry < now; }; std::erase_if(_clients, isExpired); @@ -875,7 +987,14 @@ class Broker { message.sourceId = message.serviceName; // serviceName=clientSourceID message.serviceName = worker.serviceName; message.protocolName = mdp::clientProtocol; - zmq::send(std::move(message), client->second.socket).assertSuccess(); + if (message.sourceId == kRestSourceId) { + if (_restServer) { + _restServer->handleResponse(std::move(message)); + } + } else { + zmq::send(std::move(message), client->second.socket.value()).assertSuccess(); + } + workerWaiting(worker); } else { disconnectWorker(worker); diff --git a/src/majordomo/include/majordomo/Http2Server.hpp b/src/majordomo/include/majordomo/Http2Server.hpp new file mode 100644 index 00000000..e435992b --- /dev/null +++ b/src/majordomo/include/majordomo/Http2Server.hpp @@ -0,0 +1,927 @@ +#ifndef OPENCMW_MAJORDOMO_HTTP2SERVER_HPP +#define OPENCMW_MAJORDOMO_HTTP2SERVER_HPP + +#include "IoBuffer.hpp" +#include "LoadTest.hpp" +#include "MdpMessage.hpp" +#include "MIME.hpp" +#include "nghttp2/NgHttp2Utils.hpp" +#include "Rest.hpp" +#include "Topic.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include + +namespace opencmw::majordomo::detail::nghttp2 { + +using namespace opencmw::nghttp2; +using namespace opencmw::nghttp2::detail; + +inline int alpn_select_proto_cb(SSL *ssl, const unsigned char **out, + unsigned char *outlen, const unsigned char *in, + unsigned int inlen, void *arg) { + int rv; + (void) ssl; + (void) arg; + + rv = nghttp2_select_alpn(out, outlen, in, inlen); + if (rv != 1) { + return SSL_TLSEXT_ERR_NOACK; + } + + return SSL_TLSEXT_ERR_OK; +} + +inline std::expected create_ssl_ctx(EVP_PKEY *key, X509 *cert) { + auto ssl_ctx = SSL_CTX_Ptr(SSL_CTX_new(TLS_server_method()), SSL_CTX_free); + if (!ssl_ctx) { + return std::unexpected(fmt::format("Could not create SSL/TLS context: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + SSL_CTX_set_options(ssl_ctx.get(), SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + if (SSL_CTX_set1_curves_list(ssl_ctx.get(), "P-256") != 1) { + return std::unexpected(fmt::format("SSL_CTX_set1_curves_list failed: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + + if (SSL_CTX_use_PrivateKey(ssl_ctx.get(), key) <= 0) { + return std::unexpected(fmt::format("Could not configure private key")); + } + if (SSL_CTX_use_certificate(ssl_ctx.get(), cert) != 1) { + return std::unexpected(fmt::format("Could not configure certificate file")); + } + + if (!SSL_CTX_check_private_key(ssl_ctx.get())) { + return std::unexpected("Private key does not match the certificate"); + } + + SSL_CTX_set_alpn_select_cb(ssl_ctx.get(), alpn_select_proto_cb, nullptr); + + return ssl_ctx; +} + +using Message = mdp::BasicMessage; + +enum class RestMethod { + Get, + LongPoll, + Post, + Invalid +}; + +inline RestMethod parseMethod(std::string_view methodString) { + using enum RestMethod; + return methodString == "POLL" ? LongPoll + : methodString == "PUT" ? Post + : methodString == "POST" ? Post + : methodString == "GET" ? Get + : Invalid; +} + +struct Request { + std::vector> rawHeaders; + mdp::Topic topic; + RestMethod method = RestMethod::Invalid; + std::string longPollIndex; + std::string contentType; + std::string accept; + std::string payload; + bool complete = false; + + std::string_view acceptedMime() const { + static constexpr auto acceptableMimeTypes = std::array{ + opencmw::MIME::JSON.typeName(), MIME::HTML.typeName(), MIME::BINARY.typeName() + }; + auto accepted = [](auto format) { + const auto it = std::find(acceptableMimeTypes.begin(), acceptableMimeTypes.end(), format); + return std::make_pair( + it != acceptableMimeTypes.cend(), + it); + }; + + if (!contentType.empty()) { + if (const auto [found, where] = accepted(contentType); found) { + return *where; + } + } + if (auto it = topic.params().find("contentType"); it != topic.params().end()) { + if (const auto [found, where] = accepted(it->second); found) { + return *where; + } + } + + auto isDelimiter = [](char c) { return c == ' ' || c == ','; }; + auto from = accept.cbegin(); + const auto end = accept.cend(); + + while (from != end) { + from = std::find_if_not(from, end, isDelimiter); + auto to = std::find_if(from, end, isDelimiter); + if (from != end) { + std::string_view format(from, to); + if (const auto [found, where] = accepted(format); found) { + return *where; + } + } + + from = to; + } + + return acceptableMimeTypes[0]; + } +}; + +struct ResponseData { + explicit ResponseData(Message &&m) + : message(std::move(m)) + , errorBuffer(message.error.data(), message.error.size()) {} + + explicit ResponseData(rest::Response &&r) + : restResponse(std::move(r)) {} + + Message message; + IoBuffer errorBuffer; + + rest::Response restResponse; +}; + +struct IdGenerator { + std::uint64_t _nextRequestId = 0; + + std::uint64_t generateId() { + return _nextRequestId++; + } +}; + +struct SubscriptionCacheEntry { + constexpr static std::size_t kCapacity = 100; + std::uint64_t firstIndex = 0; + std::deque messages; + + void add(Message &&message) { + if (messages.size() == kCapacity) { + messages.pop_front(); + firstIndex++; + } + messages.push_back(std::move(message)); + } + std::uint64_t lastIndex() const noexcept { + assert(!messages.empty()); + return firstIndex + messages.size() - 1; + } + std::uint64_t nextIndex() const noexcept { + return firstIndex + messages.size(); + } +}; + +struct SharedData { + std::map _subscriptionCache; + std::vector _handlers; + + rest::Handler *findHandler(std::string_view method, std::string_view path) { + auto bestMatch = _handlers.end(); + + for (auto itHandler = _handlers.begin(); itHandler != _handlers.end(); ++itHandler) { + if (itHandler->method != method) { + continue; + } + + std::string_view handlerPath = itHandler->path; + + if (handlerPath == path) { + // exact match, use this handler + return &*itHandler; + } + // if the handler path ends with '*', do a prefix check and use the most specific (longest) one + if (handlerPath.ends_with("*")) { + handlerPath.remove_suffix(1); + if (path.starts_with(handlerPath) && (bestMatch == _handlers.end() || bestMatch->path.size() < itHandler->path.size())) { + bestMatch = itHandler; + } + } + } + + return bestMatch != _handlers.end() ? &*bestMatch : nullptr; + } +}; + +constexpr int kHttpOk = 200; +constexpr int kHttpError = 500; +constexpr int kHttpTimeout = 504; +constexpr int kFileNotFound = 404; + +struct Session { + using PendingRequest = std::tuple; // requestId, streamId + using PendingPoll = std::tuple; // zmqTopic, PollingIndex, streamId + TcpSocket _socket; + nghttp2_session *_session = nullptr; + WriteBuffer<4096> _writeBuffer; + std::map _requestsByStreamId; + std::map _responsesByStreamId; + std::vector _pendingRequests; + std::vector _pendingPolls; + std::shared_ptr _sharedData; + + explicit Session(TcpSocket &&socket, std::shared_ptr sharedData) + : _socket(std::move(socket)), _sharedData(std::move(sharedData)) { + nghttp2_session_callbacks *callbacks; + nghttp2_session_callbacks_new(&callbacks); + nghttp2_session_callbacks_set_on_frame_recv_callback(callbacks, [](nghttp2_session *, const nghttp2_frame *frame, void *user_data) { + auto session = static_cast(user_data); + return session->frame_recv_callback(frame); + }); + nghttp2_session_callbacks_set_on_frame_send_callback(callbacks, [](nghttp2_session *, const nghttp2_frame *frame, void *user_data) { + auto session = static_cast(user_data); + return session->frame_send_callback(frame); + }); + nghttp2_session_callbacks_set_on_frame_not_send_callback(callbacks, [](nghttp2_session *, const nghttp2_frame *frame, int lib_error_code, void *user_data) { + auto session = static_cast(user_data); + return session->frame_not_send_callback(frame, lib_error_code); + }); + nghttp2_session_callbacks_set_on_data_chunk_recv_callback(callbacks, [](nghttp2_session *, uint8_t flags, int32_t stream_id, const uint8_t *data, size_t len, void *user_data) { + auto session = static_cast(user_data); + return session->data_chunk_recv_callback(flags, stream_id, { reinterpret_cast(data), len }); + }); + nghttp2_session_callbacks_set_on_stream_close_callback(callbacks, [](nghttp2_session *, int32_t stream_id, uint32_t error_code, void *user_data) { + auto session = static_cast(user_data); + return session->stream_closed_callback(stream_id, error_code); + }); + nghttp2_session_callbacks_set_on_header_callback2(callbacks, [](nghttp2_session *, const nghttp2_frame *frame, nghttp2_rcbuf *name, + nghttp2_rcbuf *value, uint8_t flags, void *user_data) { + auto session = static_cast(user_data); + return session->header_callback(frame, as_view(name), as_view(value), flags); + }); + nghttp2_session_callbacks_set_on_invalid_frame_recv_callback(callbacks, [](nghttp2_session *, const nghttp2_frame *frame, int lib_error_code, void *user_data) { + auto session = static_cast(user_data); + return session->invalid_frame_recv_callback(frame, lib_error_code); + }); + nghttp2_session_callbacks_set_error_callback2(callbacks, [](nghttp2_session *, int lib_error_code, const char *msg, size_t len, void *user_data) { + auto session = static_cast(user_data); + return session->error_callback(lib_error_code, msg, len); + }); + nghttp2_session_server_new(&_session, callbacks, this); + nghttp2_session_callbacks_del(callbacks); + } + + ~Session() { + nghttp2_session_del(_session); + } + + Session(const Session &) = delete; + Session &operator=(const Session &) = delete; + Session(Session &&other) = delete; + Session &operator=(Session &&other) = delete; + + bool wantsToRead() const { + return _socket._state == TcpSocket::Connected ? nghttp2_session_want_read(_session) : (_socket._state == TcpSocket::SSLAcceptWantsRead); + } + + bool wantsToWrite() const { + return _socket._state == TcpSocket::Connected ? _writeBuffer.wantsToWrite(_session) : (_socket._state == TcpSocket::SSLAcceptWantsWrite); + } + + std::optional processGetSetRequest(std::int32_t streamId, Request &request, IdGenerator &idGenerator) { + Message result; + request.topic.addParam("contentType", request.acceptedMime()); + result.command = request.method == RestMethod::Get ? mdp::Command::Get : mdp::Command::Set; + result.serviceName = request.topic.service(); + result.topic = request.topic.toMdpTopic(); + result.data = IoBuffer(request.payload.data(), request.payload.size()); + auto id = idGenerator.generateId(); + result.clientRequestID = IoBuffer(std::to_string(id).data(), std::to_string(id).size()); + _pendingRequests.emplace_back(id, streamId); + return result; + }; + + static auto ioBufferCallback() { + return [](nghttp2_session *, int32_t /*stream_id*/, uint8_t *buf, size_t length, uint32_t *data_flags, nghttp2_data_source *source, void * /*user_data*/) { + auto ioBuffer = static_cast(source->ptr); + size_t copy_len = std::min(length, ioBuffer->size() - ioBuffer->position()); + std::copy(ioBuffer->data() + ioBuffer->position(), ioBuffer->data() + ioBuffer->position() + copy_len, buf); + ioBuffer->skip(static_cast(copy_len)); + if (ioBuffer->position() == ioBuffer->size()) { + *data_flags |= NGHTTP2_DATA_FLAG_EOF; + } + return static_cast(copy_len); + }; + } + + void sendResponse(std::int32_t streamId, rest::Response response) { + // store message while sending so we don't need to copy the data + auto &msg = _responsesByStreamId.try_emplace(streamId, ResponseData{ std::move(response) }).first->second; + + constexpr auto noCopy = NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE; + const auto statusStr = std::to_string(msg.restResponse.code); + + std::vector headers; + headers.reserve(msg.restResponse.headers.size() + 2); + // :status must go first, otherwise browsers and curl will not accept the response + headers.push_back(nv(u8span(":status"), u8span(statusStr), NGHTTP2_NV_FLAG_NO_COPY_NAME)); + headers.push_back(nv(u8span("access-control-allow-origin"), u8span("*"), noCopy)); + + for (const auto &[name, value] : msg.restResponse.headers) { + headers.push_back(nv(u8span(name), u8span(value), noCopy)); + } + + nghttp2_data_provider2 data_prd; + + if (msg.restResponse.bodyReader) { + data_prd.source.ptr = &msg.restResponse; + data_prd.read_callback = [](nghttp2_session *, int32_t stream_id, uint8_t *buf, size_t length, uint32_t *data_flags, nghttp2_data_source *source, void * /*user_data*/) -> ssize_t { + std::ignore = stream_id; + auto res = static_cast(source->ptr); + const auto r = res->bodyReader(std::span(buf, length)); + if (!r) { + HTTP_DBG("Server: stream_id={} Error reading body: {}", stream_id, r.error()); + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE; + } + const auto &[bytesRead, hasMore] = *r; + if (!hasMore) { + *data_flags |= NGHTTP2_DATA_FLAG_EOF; + } + return static_cast(bytesRead); + }; + } else { + data_prd.source.ptr = &msg.restResponse.body; + data_prd.read_callback = ioBufferCallback(); + } + +#ifdef OPENCMW_DEBUG_HTTP + auto formattedHeaders = headers | std::views::transform([](const auto &header) { + return fmt::format("'{}'='{}'", std::string_view(reinterpret_cast(header.name), header.namelen), std::string_view(reinterpret_cast(header.value), header.valuelen)); + }); + HTTP_DBG("Sending response {} to streamId {}. Headers:\n{}\n Body: {}", msg.restResponse.code, streamId, fmt::join(formattedHeaders, "\n"), msg.restResponse.bodyReader ? "reader" : fmt::format("{} bytes", msg.restResponse.body.size())); +#endif + if (auto rc = nghttp2_submit_response2(_session, streamId, headers.data(), headers.size(), &data_prd); rc != 0) { + HTTP_DBG("Server: nghttp2_submit_response2 for stream ID {} failed: {}", streamId, nghttp2_strerror(rc)); + _responsesByStreamId.erase(streamId); + } + } + + void sendResponse(std::int32_t streamId, int responseCode, Message &&responseMessage, std::vector extraHeaders = {}) { + // store message while sending so we don't need to copy the data + auto &msg = _responsesByStreamId.try_emplace(streamId, ResponseData{ std::move(responseMessage) }).first->second; + IoBuffer *buf = msg.errorBuffer.empty() ? &msg.message.data : &msg.errorBuffer; + + auto codeStr = std::to_string(responseCode); + auto contentLength = std::to_string(buf->size()); + constexpr int noCopy = NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE; + // :status must go first + auto headers = std::vector{ nv(u8span(":status"), u8span(codeStr)), nv(u8span("x-opencmw-topic"), u8span(msg.message.topic.str()), noCopy), + nv(u8span("x-opencmw-service-name"), u8span(msg.message.serviceName), noCopy), nv(u8span("access-control-allow-origin"), u8span("*"), noCopy), nv(u8span("content-length"), u8span(contentLength)) }; + + headers.insert(headers.end(), extraHeaders.begin(), extraHeaders.end()); + + nghttp2_data_provider2 data_prd; + data_prd.source.ptr = buf; + data_prd.read_callback = ioBufferCallback(); + +#ifdef OPENCMW_DEBUG_HTTP + auto formattedHeaders = headers | std::views::transform([](const auto &header) { + return fmt::format("'{}'='{}'", std::string_view(reinterpret_cast(header.name), header.namelen), std::string_view(reinterpret_cast(header.value), header.valuelen)); + }); + HTTP_DBG("Sending response {} to streamId {}. Headers:\n{}", responseCode, streamId, fmt::join(formattedHeaders, "\n")); +#endif + if (auto rc = nghttp2_submit_response2(_session, streamId, headers.data(), headers.size(), &data_prd); rc != 0) { + HTTP_DBG("Server: nghttp2_submit_response2 for stream ID {} failed: {}", streamId, nghttp2_strerror(rc)); + _responsesByStreamId.erase(streamId); + } + } + + void respondToLongPoll(std::int32_t streamId, std::uint64_t index, Message &&msg) { + auto timestamp = std::to_string(opencmw::load_test::timestamp().count()); + sendResponse(streamId, kHttpOk, std::move(msg), { nv(u8span("x-opencmw-long-polling-idx"), u8span(std::to_string(index))), nv(u8span("x-timestamp"), u8span(timestamp)) }); + } + + void respondToLongPollWithError(std::int32_t streamId, std::string_view error, int code, std::uint64_t index) { + Message response = {}; + response.error = std::string(error); + sendResponse(streamId, code, std::move(response), { nv(u8span("x-opencmw-long-polling-idx"), u8span(std::to_string(index))) }); + } + + void respondWithError(std::int32_t streamId, std::string_view error, int code = kHttpError, std::vector extraHeaders = {}) { + Message response = {}; + response.error = std::string(error); + sendResponse(streamId, code, std::move(response), std::move(extraHeaders)); + } + + void respondWithRedirect(std::int32_t streamId, std::string_view location) { + HTTP_DBG("Server::respondWithRedirect: streamId={} location={}", streamId, location); + // :status must go first + const auto headers = std::array{ nv(u8span(":status"), u8span("302"), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE), nv(u8span("location"), u8span(location)) }; + nghttp2_submit_response2(_session, streamId, headers.data(), headers.size(), nullptr); + } + + void respondWithLongPollingRedirect(std::int32_t streamId, const URI<> &topic, std::size_t longPollIdx) { + auto location = URI<>::UriFactory(topic).addQueryParameter("LongPollingIdx", std::to_string(longPollIdx)).build(); + respondWithRedirect(streamId, location.str()); + } + + std::optional processLongPollRequest(std::int32_t streamId, const Request &request) { + std::optional result; + const auto zmqTopic = request.topic.toZmqTopic(); + auto entryIt = _sharedData->_subscriptionCache.find(zmqTopic); + if (entryIt == _sharedData->_subscriptionCache.end()) { + entryIt = _sharedData->_subscriptionCache.try_emplace(zmqTopic, SubscriptionCacheEntry{}).first; + result = Message{}; + result->command = mdp::Command::Subscribe; + result->topic = request.topic.toMdpTopic(); + } + auto &entry = entryIt->second; + if (request.longPollIndex == "Next") { + respondWithLongPollingRedirect(streamId, request.topic.toMdpTopic(), entry.nextIndex()); + return result; + } else if (request.longPollIndex == "Last") { + const std::size_t last = entry.messages.empty() ? entry.nextIndex() : entry.lastIndex(); + respondWithLongPollingRedirect(streamId, request.topic.toMdpTopic(), last); + return result; + } + + std::uint64_t index = 0; + if (auto [ptr, ec] = std::from_chars(request.longPollIndex.data(), request.longPollIndex.data() + request.longPollIndex.size(), index); ec != std::errc()) { + respondWithError(streamId, fmt::format("Malformed LongPollingIdx '{}'", request.longPollIndex)); + return {}; + } + +#ifdef OPENCMW_PROFILE_HTTP + const std::size_t last = entry.messages.empty() ? entry.nextIndex() : entry.lastIndex(); + if (index + 5 < last) { + fmt::println(std::cerr, "Server::LongPoll: index {} < last {} => {}", index, last, last - index); + } +#endif + if (index < entry.firstIndex) { + // index is too old, redirect to the next index + HTTP_DBG("Server::LongPoll: index {} < firstIndex {}", index, entry.firstIndex); + respondWithLongPollingRedirect(streamId, request.topic.toMdpTopic(), entry.nextIndex()); + } else if (entry.messages.empty() || index > entry.lastIndex()) { + // future index, wait for new messages + _pendingPolls.emplace_back(zmqTopic, index, streamId); + } else { + // we have a message for this index, send it + respondToLongPoll(streamId, index - entry.firstIndex, Message(entry.messages[index - entry.firstIndex])); + } + return result; + } + + void processCompletedRequest(std::int32_t streamId) { + auto it = _requestsByStreamId.find(streamId); + assert(it != _requestsByStreamId.end()); + auto &[streamid, request] = *it; + + std::string path; + std::string_view method; + std::string_view xOpencmwMethod; + + for (const auto &[name, value] : request.rawHeaders) { + if (name == ":path") { + path = value; + } else if (name == ":method") { + method = value; + } else if (name == "content-type") { + request.contentType = value; + } else if (name == "accept") { + request.accept = value; + } else if (name == "x-opencmw-method") { + xOpencmwMethod = value; + } + } + + // if we have an externally configured handler for this method/path, use it + if (auto handler = _sharedData->findHandler(method, path); handler) { + rest::Request req; + req.method = method; + req.path = path; + std::swap(req.headers, request.rawHeaders); + auto response = handler->handler(req); + sendResponse(streamId, std::move(response)); + _requestsByStreamId.erase(it); + return; + } + + // redirect "/" request to "/mmi.service" + if (path == "/") { + path = "/mmi.service"; + } + + // Everything else is a service request + try { + auto pathUri = URI<>(path); + auto factory = URI<>::UriFactory(pathUri).setQuery({}); + bool haveSubscriptionContext = false; + for (const auto &[qkey, qvalue] : pathUri.queryParamMap()) { + if (qkey == "LongPollingIdx") { + request.method = RestMethod::LongPoll; + request.longPollIndex = qvalue.value_or(""); + } else if (qkey == "SubscriptionContext") { + request.topic = mdp::Topic::fromMdpTopic(URI<>(qvalue.value_or(""))); + haveSubscriptionContext = true; + } else if (qkey == "_bodyOverride") { + request.payload = qvalue.value_or(""); + } else { + if (qvalue) { + factory = std::move(factory).addQueryParameter(qkey, qvalue.value()); + } else { + factory = std::move(factory).addQueryParameter(qkey); + } + } + } + if (!haveSubscriptionContext) { + request.topic = mdp::Topic::fromMdpTopic(factory.build()); + } + } catch (const std::exception &e) { + HTTP_DBG("Service::Header: Could not parse service URI '{}': {}", path, e.what()); + Message response; + response.error = e.what(); + sendResponse(streamId, kFileNotFound, std::move(response)); + _requestsByStreamId.erase(it); + return; + } + + if (request.method == RestMethod::Invalid && !xOpencmwMethod.empty()) { + request.method = parseMethod(xOpencmwMethod); + } + if (request.method == RestMethod::Invalid) { + request.method = parseMethod(method); + } + + // Set completed for getMessages() to collect + request.complete = true; + } + + std::vector getMessages(IdGenerator &idGenerator) { + const auto completeEnd = std::ranges::partition_point(_requestsByStreamId, [](const auto &pair) { return pair.second.complete; }); + + std::vector result; + result.reserve(static_cast(std::distance(_requestsByStreamId.begin(), completeEnd))); + + for (auto it = _requestsByStreamId.begin(); it != completeEnd; ++it) { + auto &[streamId, request] = *it; + + switch (request.method) { + case RestMethod::Get: + case RestMethod::Post: + if (auto m = processGetSetRequest(streamId, request, idGenerator); m.has_value()) { + result.push_back(std::move(m.value())); + } + break; + case RestMethod::LongPoll: + if (auto m = processLongPollRequest(streamId, request); m.has_value()) { + result.push_back(std::move(m.value())); + } + break; + case RestMethod::Invalid: + respondWithError(it->first, "Invalid REST method", kHttpError); + break; + } + } + + _requestsByStreamId.erase(_requestsByStreamId.begin(), completeEnd); + return result; + } + + int frame_recv_callback(const nghttp2_frame *frame) { + HTTP_DBG("Server::Frame: id={} {} {} {}", frame->hd.stream_id, frame->hd.type, frame->hd.flags, (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) ? "END_STREAM" : ""); + switch (frame->hd.type) { + case NGHTTP2_DATA: + case NGHTTP2_HEADERS: + if (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) { + processCompletedRequest(frame->hd.stream_id); + } + break; + } + return 0; + } + + int frame_send_callback(const nghttp2_frame *frame) { + HTTP_DBG("Server::Frame sent: id={} {} {} {}", frame->hd.stream_id, frame->hd.type, frame->hd.flags, (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) ? "END_STREAM" : ""); + if (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) { + _responsesByStreamId.erase(frame->hd.stream_id); + } + return 0; + } + + int frame_not_send_callback(const nghttp2_frame *frame, int lib_error_code) { + std::ignore = lib_error_code; + HTTP_DBG("Server::Frame not sent: id={} {} {} {}", frame->hd.stream_id, frame->hd.type, frame->hd.flags, (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) ? "END_STREAM" : ""); + if (frame->hd.type == NGHTTP2_DATA) { + _responsesByStreamId.erase(frame->hd.stream_id); + } + return 0; + } + + int data_chunk_recv_callback(uint8_t /*flags*/, int32_t stream_id, std::string_view data) { + HTTP_DBG("Server::Data id={} {} bytes", stream_id, data.size()); + _requestsByStreamId[stream_id].payload += data; + return 0; + } + + int stream_closed_callback(int32_t stream_id, uint32_t error_code) { + std::ignore = error_code; + HTTP_DBG("Server::Stream closed: {} ({})", stream_id, error_code); + const std::size_t erased = _responsesByStreamId.erase(stream_id); + // if this was canceled by the client, remove any pending requests/polls + if (erased > 0) { + std::erase_if(_pendingRequests, [stream_id](const auto &request) { return std::get<1>(request) == stream_id; }); + std::erase_if(_pendingPolls, [stream_id](const auto &poll) { return std::get<2>(poll) == stream_id; }); + } + return 0; + } + + int header_callback(const nghttp2_frame *frame, std::string_view name, std::string_view value, uint8_t /*flags*/) { + HTTP_DBG("Server::Header id={} {} = {}", frame->hd.stream_id, name, value); + const auto [it, inserted] = _requestsByStreamId.try_emplace(frame->hd.stream_id, Request{}); + auto &request = it->second; + request.rawHeaders.emplace_back(name, value); +#ifdef OPENCMW_PROFILE_HTTP + if (name == "x-timestamp") { + fmt::println(std::cerr, "Server::Header: x-timestamp: {} (latency {} ns)", value, opencmw::detail::nghttp2::latency(value).count()); + } +#endif + return 0; + } + + int invalid_frame_recv_callback(const nghttp2_frame *, int lib_error_code) { + std::ignore = lib_error_code; + HTTP_DBG("invalid_frame_recv_callback called error={}", lib_error_code); + return 0; + } + + int error_callback(int lib_error_code, const char *msg, size_t len) { + std::ignore = lib_error_code; + std::ignore = msg; + std::ignore = len; + HTTP_DBG("Server::ERROR: {} ({})", std::string_view(msg, len), lib_error_code); + return 0; + } +}; + +inline std::expected create_server_socket(SSL_CTX *ssl_ctx, std::uint16_t port) { + auto ssl = SSL_Ptr(nullptr, SSL_free); + if (ssl_ctx) { + auto maybeSsl = create_ssl(ssl_ctx); + if (!maybeSsl) { + return std::unexpected(fmt::format("Failed to create SSL object: {}", maybeSsl.error())); + } + ssl = std::move(maybeSsl.value()); + } + + auto serverSocket = TcpSocket::create(std::move(ssl), socket(AF_INET, SOCK_STREAM, 0)); + if (!serverSocket) { + return std::unexpected(serverSocket.error()); + } + + int reuseFlag = 1; + if (setsockopt(serverSocket->fd, SOL_SOCKET, SO_REUSEADDR, &reuseFlag, sizeof(reuseFlag)) < 0) { + return std::unexpected(fmt::format("setsockopt(SO_REUSEADDR) failed: {}", strerror(errno))); + } + + struct sockaddr_in address {}; + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; + address.sin_port = htons(port); + + if (::bind(serverSocket->fd, reinterpret_cast(&address), sizeof(address)) < 0) { + return std::unexpected(fmt::format("Bind failed: {}", strerror(errno))); + } + + if (listen(serverSocket->fd, 32) < 0) { + return std::unexpected(fmt::format("Listen failed: {}", strerror(errno))); + } + + return serverSocket; +} + +struct Http2Server { + TcpSocket _serverSocket; + SSL_CTX_Ptr _ssl_ctx = SSL_CTX_Ptr(nullptr, SSL_CTX_free); + EVP_PKEY_Ptr _key = EVP_PKEY_Ptr(nullptr, EVP_PKEY_free); + X509_Ptr _cert = X509_Ptr(nullptr, X509_free); + std::shared_ptr _sharedData = std::make_shared(); + std::map> _sessions; + IdGenerator _requestIdGenerator; + + Http2Server() = default; + Http2Server(const Http2Server &) = delete; + Http2Server &operator=(const Http2Server &) = delete; + Http2Server(Http2Server &&) = default; + Http2Server &operator=(Http2Server &&) = default; + + Http2Server(SSL_CTX_Ptr ssl_ctx, EVP_PKEY_Ptr key, X509_Ptr cert) + : _ssl_ctx(std::move(ssl_ctx)), _key(std::move(key)), _cert(std::move(cert)) { + if (_ssl_ctx) { + SSL_library_init(); + SSL_load_error_strings(); + OpenSSL_add_all_algorithms(); + } + } + + static std::expected unencrypted() { + return Http2Server(SSL_CTX_Ptr(nullptr, SSL_CTX_free), EVP_PKEY_Ptr(nullptr, EVP_PKEY_free), X509_Ptr(nullptr, X509_free)); + } + + static std::expected sslWithBuffers(std::string_view certBuffer, std::string_view keyBuffer) { + auto maybeCert = nghttp2::readServerCertificateFromBuffer(certBuffer); + if (!maybeCert) { + return std::unexpected(maybeCert.error()); + } + auto maybeKey = nghttp2::readServerPrivateKeyFromBuffer(keyBuffer); + if (!maybeKey) { + return std::unexpected(maybeKey.error()); + } + auto maybeSslCtx = create_ssl_ctx(maybeKey->get(), maybeCert->get()); + if (!maybeSslCtx) { + return std::unexpected(maybeSslCtx.error()); + } + return Http2Server(std::move(maybeSslCtx.value()), std::move(maybeKey.value()), std::move(maybeCert.value())); + } + + static std::expected sslWithPaths(std::filesystem::path certPath, std::filesystem::path keyPath) { + auto maybeCert = nghttp2::readServerCertificateFromFile(certPath); + if (!maybeCert) { + return std::unexpected(maybeCert.error()); + } + auto maybeKey = nghttp2::readServerPrivateKeyFromFile(keyPath); + if (!maybeKey) { + return std::unexpected(maybeKey.error()); + } + auto maybeSslCtx = create_ssl_ctx(maybeKey->get(), maybeCert->get()); + if (!maybeSslCtx) { + return std::unexpected(maybeSslCtx.error()); + } + return Http2Server(std::move(maybeSslCtx.value()), std::move(maybeKey.value()), std::move(maybeCert.value())); + } + + void setHandlers(std::vector handlers) { + _sharedData->_handlers = std::move(handlers); + } + + void handleResponse(Message &&message) { + auto view = message.clientRequestID.asString(); + std::uint64_t id; + const auto ec = std::from_chars(view.begin(), view.end(), id); + if (ec.ec != std::errc{}) { + HTTP_DBG("Failed to parse request ID: '{}'", view); + return; + } + auto matchesId = [id](const auto &pendingRequest) { return std::get<0>(pendingRequest) == id; }; + + auto it = std::ranges::find_if(_sessions, [matchesId](const auto &session) { + return std::ranges::find_if(session.second->_pendingRequests, matchesId) != session.second->_pendingRequests.end(); + }); + + if (it != _sessions.end()) { + auto &session = it->second; + auto pendingIt = std::ranges::find_if(session->_pendingRequests, matchesId); + const auto &[reqId, streamId] = *pendingIt; + const auto code = message.error.empty() ? kHttpOk : kHttpError; + session->sendResponse(streamId, code, std::move(message)); + session->_pendingRequests.erase(pendingIt); + }; + } + + void handleNotification(const mdp::Topic &topic, Message &&message) { + const auto zmqTopic = topic.toZmqTopic(); + auto entryIt = _sharedData->_subscriptionCache.find(zmqTopic); + if (entryIt == _sharedData->_subscriptionCache.end()) { + HTTP_DBG("Server::handleNotification: No subscription for topic '{}'", zmqTopic); + return; + } + auto &entry = entryIt->second; + entry.add(std::move(message)); + for (auto &session : _sessions | std::views::values) { + auto pollIt = session->_pendingPolls.begin(); + while (pollIt != session->_pendingPolls.end()) { + const auto &[pendingZmqTopic, pollIndex, streamId] = *pollIt; + if (pendingZmqTopic == zmqTopic && entry.lastIndex() == pollIndex) { + session->respondToLongPoll(streamId, pollIndex, Message(entry.messages[pollIndex - entry.firstIndex])); + pollIt = session->_pendingPolls.erase(pollIt); + } else { + ++pollIt; + } + } + } + } + + void populatePollerItems(std::vector &items) { + items.push_back(zmq_pollitem_t{ nullptr, _serverSocket.fd, ZMQ_POLLIN, 0 }); + for (const auto &[_, session] : _sessions) { + const auto wantsRead = session->wantsToRead(); + const auto wantsWrite = session->wantsToWrite(); + if (wantsRead || wantsWrite) { + items.emplace_back(nullptr, session->_socket.fd, static_cast((wantsRead ? ZMQ_POLLIN : 0) | (wantsWrite ? ZMQ_POLLOUT : 0)), 0); + } + } + } + + std::vector processReadWrite(int fd, bool read, bool write) { + if (fd == _serverSocket.fd) { + auto maybeSocket = _serverSocket.accept(_ssl_ctx.get(), TcpSocket::None); + if (!maybeSocket) { + HTTP_DBG("Failed to accept client: {}", maybeSocket.error()); + return {}; + } + + auto clientSocket = std::move(maybeSocket.value()); + if (!clientSocket) { + return {}; + } + + auto newFd = clientSocket->fd; + + auto [newSessionIt, inserted] = _sessions.try_emplace(newFd, std::make_unique(std::move(clientSocket.value()), _sharedData)); + assert(inserted); + auto &newSession = newSessionIt->second; + nghttp2_settings_entry iv[1] = { { NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 1000 } }; + if (nghttp2_submit_settings(newSession->_session, NGHTTP2_FLAG_NONE, iv, 1) != 0) { + HTTP_DBG("nghttp2_submit_settings failed"); + } + return {}; + } + + auto sessionIt = _sessions.find(fd); + assert(sessionIt != _sessions.end()); + assert(sessionIt->second->_socket.fd == fd); + auto &session = sessionIt->second; + + if (session->_socket._state != TcpSocket::Connected) { + if (auto r = session->_socket.continueHandshake(); !r) { + HTTP_DBG("Handshake failed: {}", r.error()); + _sessions.erase(sessionIt); + return {}; + } + return {}; + } + + if (write) { + if (!session->_writeBuffer.write(session->_session, session->_socket)) { + HTTP_DBG("Failed to write to peer"); + _sessions.erase(sessionIt); + return {}; + } + } + + if (!read) { + return {}; + } + + bool mightHaveMore = true; + + while (mightHaveMore) { + std::array buffer; + ssize_t bytes_read = session->_socket.read(buffer.data(), sizeof(buffer)); + if (bytes_read <= 0 && errno != EAGAIN) { + if (bytes_read < 0) { + HTTP_DBG("Server::read failed: {} {}", bytes_read, session->_socket.lastError()); + } + _sessions.erase(sessionIt); + return {}; + } + if (nghttp2_session_mem_recv2(session->_session, buffer.data(), static_cast(bytes_read)) < 0) { + HTTP_DBG("Server: nghttp2_session_mem_recv2 failed"); + _sessions.erase(sessionIt); + return {}; + } + mightHaveMore = bytes_read == static_cast(buffer.size()); + } + + return session->getMessages(_requestIdGenerator); + } + + std::expected + bind(std::uint16_t port) { + if (_serverSocket.fd != -1) { + return std::unexpected("Server already bound"); + } + auto socket = create_server_socket(_ssl_ctx.get(), port); + if (!socket) { + return std::unexpected(socket.error()); + } + _serverSocket = std::move(socket.value()); + return {}; + } +}; +} // namespace opencmw::majordomo::detail::nghttp2 + +#endif // OPENCMW_MAJORDOMO_HTTP2SERVER_HPP diff --git a/src/majordomo/include/majordomo/LoadTestWorker.hpp b/src/majordomo/include/majordomo/LoadTestWorker.hpp new file mode 100644 index 00000000..c3c5e309 --- /dev/null +++ b/src/majordomo/include/majordomo/LoadTestWorker.hpp @@ -0,0 +1,80 @@ + +#ifndef OPENCMW_MAJORDOMO_LOADTESTWORKER_H +#define OPENCMW_MAJORDOMO_LOADTESTWORKER_H + +#include "QuerySerialiser.hpp" +#include +#include + +#include + +namespace opencmw::majordomo::load_test { + +using opencmw::load_test::Context; +using opencmw::load_test::Payload; + +template +struct Worker : public opencmw::majordomo::Worker { + using super_t = opencmw::majordomo::Worker; + std::jthread _notifier; + + struct SubscriptionInfo : public Context { + SubscriptionInfo(const Context &ctx) + : Context(ctx) {} + + std::int64_t nextIndex = 0; + std::int64_t initialDelayLeftMs = initialDelayMs; + std::int64_t updatesLeft = nUpdates; + }; + std::unordered_map _subscriptions; + + Worker(const majordomo::Broker<> &broker) + : super_t(broker, {}) { + super_t::setCallback([](opencmw::majordomo::RequestContext & /*rawCtx*/, const Context & /*inCtx*/, const Payload &in, Context & /*outCtx*/, Payload &out) { + out.data = in.data; + out.timestampNs = opencmw::load_test::timestamp().count(); + }); + _notifier = std::jthread([this](std::stop_token stopToken) { + std::int64_t timecounter = 0; + + std::unordered_map subscriptions; + + while (!stopToken.stop_requested()) { + const auto stepMs = 10; + std::this_thread::sleep_for(std::chrono::milliseconds(stepMs)); + timecounter += stepMs; + auto activeSubscriptions = super_t::activeSubscriptions(); + std::erase_if(subscriptions, [&activeSubscriptions](const auto &sub) { return !activeSubscriptions.contains(sub.first); }); + for (const auto &active : activeSubscriptions) { + auto it = subscriptions.find(active); + if (it == subscriptions.end()) { + // TODO protect against unexpected queries + auto ctx = query::deserialise(active.params()); + it = subscriptions.try_emplace(active, ctx).first; + } + auto &sub = it->second; + + sub.initialDelayLeftMs -= stepMs; + + if (sub.initialDelayLeftMs > 0) { + continue; + } + if (sub.updatesLeft == 0) { + continue; + } + if (sub.intervalMs > 0 && timecounter % sub.intervalMs != 0) { + continue; + } + Payload payload{ sub.nextIndex, std::string(static_cast(sub.payloadSize), 'x'), opencmw::load_test::timestamp().count() }; + super_t::notify(sub, std::move(payload)); + sub.nextIndex++; + sub.updatesLeft--; + } + } + }); + } +}; + +} // namespace opencmw::majordomo::load_test + +#endif diff --git a/src/majordomo/include/majordomo/Rest.hpp b/src/majordomo/include/majordomo/Rest.hpp new file mode 100644 index 00000000..d65d8cda --- /dev/null +++ b/src/majordomo/include/majordomo/Rest.hpp @@ -0,0 +1,175 @@ +#ifndef OPENCMW_MAJORDOMO_REST_HPP +#define OPENCMW_MAJORDOMO_REST_HPP + +#include + +#include + +#include +#include +#include +#include +#include +#include + +namespace opencmw::majordomo::rest { + + struct Request { + std::string method; + std::string path; + std::vector> headers; + }; + + struct Response { + struct ReadStatus { + std::size_t bytesWritten; + bool hasMore; + }; + using WriterFunction = std::function(std::span)>; + int code; + std::vector> headers; + WriterFunction bodyReader; + IoBuffer body; + }; + + struct Handler { + std::string method; + std::string path; + std::function handler; + }; + + inline auto mimeTypeFromExtension(std::string_view path) { + if (path.ends_with(".html")) { + return "text/html"; + } + if (path.ends_with(".css")) { + return "text/css"; + } + if (path.ends_with(".js")) { + return "application/javascript"; + } + if (path.ends_with(".png")) { + return "image/png"; + } + if (path.ends_with(".jpg") || path.ends_with(".jpeg")) { + return "image/jpeg"; + } + if (path.ends_with(".wasm")) { + return "application/wasm"; + } + return "text/plain"; + } + + namespace detail { + struct CmrcReadState { + cmrc::file file; + std::size_t pos = 0; + }; + + } // namespace detail + + inline auto cmrcHandler(std::string path, std::string prefix, std::shared_ptr vfs, std::string vprefix) { + return Handler{ + .method = "GET", + .path = path, + .handler = [path, prefix, vfs, vprefix](const Request &request) -> Response { + try { + auto p = std::string_view{ request.path }; + if (p.starts_with(prefix)) { + p.remove_prefix(prefix.size()); + } + while (p.starts_with("/")) { + p.remove_prefix(1); + } + auto state = std::make_shared(vfs->open(vprefix + std::string{ p }), 0); + Response response; + response.code = 200; + response.headers.emplace_back("content-type", mimeTypeFromExtension(p)); + response.headers.emplace_back("content-length", std::to_string(state->file.size())); + response.bodyReader = [state = std::move(state)](std::span buffer) -> std::expected { + const std::size_t n = std::min(buffer.size(), state->file.size() - state->pos); + std::copy(state->file.begin() + state->pos, state->file.begin() + state->pos + n, buffer.begin()); + state->pos += n; + return Response::ReadStatus{ n, state->pos < state->file.size() }; + }; + return response; + } catch (...) { + Response response; + response.code = 404; + response.headers.emplace_back("content-type", "text/plain"); + response.body = IoBuffer("Not found"); + return response; + } + } + }; + } + + inline auto fileSystemHandler(std::string path, std::string prefix, std::filesystem::path root, std::vector> extraHeaders = {}) { + return Handler{ + .method = "GET", + .path = path, + .handler = [root, path, prefix, extraHeaders = std::move(extraHeaders)](const Request &request) -> Response { + auto p = std::string_view{ request.path }; + if (p.starts_with(prefix)) { + p.remove_prefix(prefix.size()); + } + while (p.starts_with("/")) { + p.remove_prefix(1); + } + auto file = root / p; + if (!std::filesystem::exists(file)) { + Response response; + response.code = 404; + response.headers.emplace_back("content-type", "text/plain"); + response.body = IoBuffer("Not found"); + return response; + } + + try { + auto fileStream = std::make_shared(file, std::ios::binary); + + if (!fileStream->is_open()) { + Response response; + response.code = 500; + response.headers.emplace_back("content-type", "text/plain"); + response.body = IoBuffer("Internal Server Error"); + return response; + } + + Response response; + response.code = 200; + response.headers = extraHeaders; + response.headers.emplace_back("content-type", mimeTypeFromExtension(request.path)); + response.headers.emplace_back("content-length", std::to_string(std::filesystem::file_size(file))); + response.bodyReader = [fileStream = std::move(fileStream)](std::span buffer) -> std::expected { + fileStream->read(reinterpret_cast(buffer.data()), static_cast(buffer.size())); + if (fileStream->bad()) { + return std::unexpected(fmt::format("Failed to read file: {}", strerror(errno))); + } + return Response::ReadStatus{ static_cast(fileStream->gcount()), !fileStream->eof() }; + }; + return response; + } catch (...) { + Response response; + response.code = 500; + response.headers.emplace_back("content-type", "text/plain"); + response.body = IoBuffer("Internal Server Error"); + return response; + } + } + }; + } + +struct Settings { + uint16_t port = 8080; + std::filesystem::path certificateFilePath; + std::filesystem::path keyFilePath; + std::string certificateFileBuffer; + std::string keyFileBuffer; + std::string dnsAddress; + std::vector handlers; +}; + +} // namespace opencmw::majordomo::rest + +#endif diff --git a/src/majordomo/include/majordomo/RestBackend.hpp b/src/majordomo/include/majordomo/RestBackend.hpp deleted file mode 100644 index 13b53fcd..00000000 --- a/src/majordomo/include/majordomo/RestBackend.hpp +++ /dev/null @@ -1,903 +0,0 @@ -#ifndef OPENCMW_MAJORDOMO_RESTBACKEND_P_H -#define OPENCMW_MAJORDOMO_RESTBACKEND_P_H - -// STD -#include -#include -#include -#include -#include -#include - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wformat-nonliteral" -#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" -#pragma GCC diagnostic ignored "-Wold-style-cast" -#pragma GCC diagnostic ignored "-Wshadow" -#pragma GCC diagnostic ignored "-Wuninitialized" -#pragma GCC diagnostic ignored "-Wuseless-cast" -#undef CPPHTTPLIB_THREAD_POOL_COUNT -#define CPPHTTPLIB_THREAD_POOL_COUNT 8 -#include -#pragma GCC diagnostic pop - -// Core -#include -#include -#include - -#include -#include -#include -#include - -// Majordomo -#include -#include - -#include -#include -CMRC_DECLARE(assets); - -struct FormData { - std::unordered_map fields; -}; -ENABLE_REFLECTION_FOR(FormData, fields) - -struct Service { - std::string name; - std::string description; - - Service(std::string name_, std::string description_) - : name{ std::move(name_) }, description{ std::move(description_) } {} -}; -ENABLE_REFLECTION_FOR(Service, name, description) - -struct ServicesList { - std::vector services; -}; -ENABLE_REFLECTION_FOR(ServicesList, services) - -namespace opencmw::majordomo { - -using namespace std::chrono_literals; - -constexpr auto HTTP_OK = 200; -constexpr auto HTTP_ERROR = 500; -constexpr auto HTTP_GATEWAY_TIMEOUT = 504; -constexpr auto DEFAULT_REST_PORT = 8080; -constexpr auto UPDATER_POLLING_TIME = 1s; -constexpr auto LONG_POLL_SERVER_TIMEOUT = 30s; -constexpr auto UNUSED_SUBSCRIPTION_EXPIRATION_TIME = 30s; -constexpr std::size_t MAX_CACHED_REPLIES = 32; - -namespace detail { -// Provides a safe alternative to getenv -inline const char *getEnvFilenameOr(const char *field, const char *defaultValue) { - const char *result = ::getenv(field); - if (result == nullptr) { - result = defaultValue; - } - - if (!std::filesystem::exists(result)) { - throw opencmw::startup_error(fmt::format("File {} not found. Current path is {}", result, std::filesystem::current_path().string())); - } - return result; -} -} // namespace detail - -struct HTTPS { - constexpr static std::string_view DEFAULT_REST_SCHEME = "https"; - httplib::SSLServer _svr; - - HTTPS() - : _svr(detail::getEnvFilenameOr("OPENCMW_REST_CERT_FILE", "demo_public.crt"), detail::getEnvFilenameOr("OPENCMW_REST_PRIVATE_KEY_FILE", "demo_private.key")) {} -}; - -struct PLAIN_HTTP { - constexpr static std::string_view DEFAULT_REST_SCHEME = "https"; - httplib::Server _svr; -}; - -namespace detail { -using PollingIndex = std::uint64_t; - -using ReadLock = std::shared_lock; -using WriteLock = std::unique_lock; - -enum class RestMethod { - Get, - Subscribe, - LongPoll, - Post, - Invalid -}; - -inline std::string_view acceptedMimeForRequest(const auto &request) { - static constexpr std::array acceptableMimeTypes = { - MIME::JSON.typeName(), MIME::HTML.typeName(), MIME::BINARY.typeName() - }; - auto accepted = [](auto format) { - const auto it = std::find(acceptableMimeTypes.cbegin(), acceptableMimeTypes.cend(), format); - return std::make_pair( - it != acceptableMimeTypes.cend(), - it); - }; - - if (request.has_header("Content-Type")) { - std::string format = request.get_header_value("Content-Type"); - if (const auto [found, where] = accepted(format); found) { - return *where; - } - } - if (request.has_param("contentType")) { - std::string format = request.get_param_value("contentType"); - if (const auto [found, where] = accepted(format); found) { - return *where; - } - } - - auto isDelimiter = [](char c) { return c == ' ' || c == ','; }; - const auto &acceptHeader = request.get_header_value("Accept"); - auto from = acceptHeader.cbegin(); - const auto end = acceptHeader.cend(); - - while (from != end) { - from = std::find_if_not(from, end, isDelimiter); - auto to = std::find_if(from, end, isDelimiter); - if (from != end) { - std::string_view format(from, to); - if (const auto [found, where] = accepted(format); found) { - return *where; - } - } - - from = to; - } - - return acceptableMimeTypes[0]; -} - -bool respondWithError(auto &response, std::string_view message, int status = HTTP_ERROR) { - response.status = status; - response.set_content(message.data(), MIME::TEXT.typeName().data()); - return true; -}; - -inline bool respondWithServicesList(auto &broker, const httplib::Request &request, httplib::Response &response) { - // Mmi is not a MajordomoWorker, so it doesn't know JSON (TODO) - const auto acceptedFormat = acceptedMimeForRequest(request); - - if (acceptedFormat == MIME::JSON.typeName()) { - std::vector serviceNames; - // TODO: Should this be synchronized? - broker.forEachService([&](std::string_view name, std::string_view) { - serviceNames.emplace_back(name); - }); - response.status = HTTP_OK; - - opencmw::IoBuffer buffer; - // opencmw::serialise(buffer, serviceNames); - IoSerialiser::serialise(buffer, FieldDescriptionShort{}, serviceNames); - response.set_content(buffer.asString().data(), MIME::JSON.typeName().data()); - return true; - - } else if (acceptedFormat == MIME::HTML.typeName()) { - response.set_chunked_content_provider( - MIME::HTML.typeName().data(), - [&broker](std::size_t /*offset*/, httplib::DataSink &sink) { - ServicesList servicesList; - broker.forEachService([&](std::string_view name, std::string_view description) { - servicesList.services.emplace_back(std::string(name), std::string(description)); - }); - - // sort services, move mmi. services to the end - auto serviceLessThan = [](const auto &lhs, const auto &rhs) { - const auto lhsIsMmi = lhs.name.starts_with("/mmi."); - const auto rhsIsMmi = rhs.name.starts_with("/mmi."); - if (lhsIsMmi != rhsIsMmi) { - return rhsIsMmi; - } - return lhs.name < rhs.name; - }; - std::sort(servicesList.services.begin(), servicesList.services.end(), serviceLessThan); - - using namespace std::string_literals; - mustache::serialise(cmrc::assets::get_filesystem(), "ServicesList", sink.os, - std::pair{ "servicesList"s, servicesList }); - sink.done(); - return true; - }); - return true; - - } else { - return respondWithError(response, "Requested an unsupported response type"); - } -} - -struct Connection { - zmq::Socket notificationSubscriptionSocket; - zmq::Socket requestResponseSocket; - std::string subscriptionKey; - - using Timestamp = std::chrono::time_point; - std::atomic lastUsed = std::chrono::system_clock::now(); - -private: - mutable std::shared_mutex _cachedRepliesMutex; - std::deque _cachedReplies; // Ring buffer? - PollingIndex _nextPollingIndex = 0; - std::condition_variable_any _pollingIndexCV; - - std::atomic_int _refCount = 1; - - // Here be dragons! This is not to be used after the connection was involved in any threading code - Connection(Connection &&other) noexcept - : notificationSubscriptionSocket(std::move(other.notificationSubscriptionSocket)) - , requestResponseSocket(std::move(other.requestResponseSocket)) - , subscriptionKey(std::move(other.subscriptionKey)) - , lastUsed(other.lastUsed.load()) - , _cachedReplies(std::move(other._cachedReplies)) - , _nextPollingIndex(other._nextPollingIndex) { - } - -public: - Connection(const zmq::Context &context, std::string _subscriptionKey) - : notificationSubscriptionSocket(context, ZMQ_SUB) - , requestResponseSocket(context, ZMQ_DEALER) - , subscriptionKey(std::move(_subscriptionKey)) {} - - Connection(const Connection &other) = delete; - Connection &operator=(const Connection &) = delete; - Connection &operator=(Connection &&) = delete; - - // Here be dragons! This is not to be used after the connection was involved in any threading code - Connection unsafeMove() && { - return std::move(*this); - } - - const auto &referenceCount() const { return _refCount; } - - // Functions to block connection deletion - void increaseReferenceCount() { ++_refCount; } - void decreaseReferenceCount() { --_refCount; } - - struct KeepAlive { - Connection *_connection; - KeepAlive(Connection *connection) - : _connection(connection) { - _connection->increaseReferenceCount(); - } - - ~KeepAlive() { - _connection->decreaseReferenceCount(); - } - }; - - auto writeLock() { - return WriteLock(_cachedRepliesMutex); - } - - auto readLock() const { - return ReadLock(_cachedRepliesMutex); - } - - bool waitForUpdate(std::chrono::milliseconds timeout) { - // This could also periodically check for the client connection being dropped (e.g. due to client-side timeout) if cpp-httplib had API for that. - auto temporaryLock = writeLock(); - const auto next = _nextPollingIndex; - while (_nextPollingIndex == next) { - if (_pollingIndexCV.wait_for(temporaryLock, timeout) == std::cv_status::timeout) { - return false; - } - } - - return true; - } - - std::size_t cachedRepliesSize(ReadLock & /*lock*/) const { - return _cachedReplies.size(); - } - - std::string cachedReply(ReadLock & /*lock*/, PollingIndex index) const { - const auto firstCachedIndex = _nextPollingIndex - _cachedReplies.size(); - return (index >= firstCachedIndex && index < _nextPollingIndex) ? _cachedReplies[index - firstCachedIndex] : std::string{}; - } - - PollingIndex nextPollingIndex(ReadLock & /*lock*/) const { - return _nextPollingIndex; - } - - void addCachedReply(std::unique_lock & /*lock*/, std::string reply) { - _cachedReplies.push_back(std::move(reply)); - if (_cachedReplies.size() > MAX_CACHED_REPLIES) { - _cachedReplies.erase(_cachedReplies.begin(), _cachedReplies.begin() + long(_cachedReplies.size() - MAX_CACHED_REPLIES)); - } - - _nextPollingIndex++; - _pollingIndexCV.notify_all(); - } -}; - -} // namespace detail - -template -class RestBackend : public Mode { -protected: - Broker &_broker; - const VirtualFS &_vfs; - URI<> _restAddress; - std::atomic _majordomoTimeout = 30000ms; - -private: - std::jthread _mdpConnectionUpdaterThread; - std::shared_mutex _mdpConnectionsMutex; - std::map> _mdpConnectionForService; - -public: - /** - * Timeout used for interaction with majordomo workers, i.e. the time to wait - * for notifications on subscriptions (long-polling) and for responses to Get/Set - * requests. - */ - void setMajordomoTimeout(std::chrono::milliseconds timeout) { - _majordomoTimeout = timeout; - } - - std::chrono::milliseconds majordomoTimeout() const { - return _majordomoTimeout; - } - - using BrokerType = Broker; - // returns a connection with refcount 1. Make sure you lower it to zero at some point - detail::Connection *notificationSubscriptionConnectionFor(const std::string &zmqTopic) { - detail::WriteLock lock(_mdpConnectionsMutex); - // TODO: No need to find + emplace as separate steps - if (auto it = _mdpConnectionForService.find(zmqTopic); it != _mdpConnectionForService.end()) { - auto *connection = it->second.get(); - connection->increaseReferenceCount(); - return connection; - } - - auto [it, inserted] = _mdpConnectionForService.emplace(std::piecewise_construct, - std::forward_as_tuple(zmqTopic), - std::forward_as_tuple(std::make_unique(_broker.context, zmqTopic))); - - if (!inserted) { - assert(inserted); - std::terminate(); - } - - auto *connection = it->second.get(); - - zmq::invoke(zmq_connect, connection->notificationSubscriptionSocket, INTERNAL_ADDRESS_PUBLISHER.str()).template onFailure("Can not connect REST worker to Majordomo broker"); - zmq::invoke(zmq_setsockopt, connection->notificationSubscriptionSocket, ZMQ_SUBSCRIBE, zmqTopic.data(), zmqTopic.size()).assertSuccess(); - - return connection; - } - - // Starts the thread to keep the unused subscriptions alive - void startUpdaterThread() { - _mdpConnectionUpdaterThread = std::jthread([this](const std::stop_token &stopToken) { - thread::setThreadName("RestBackend updater thread"); - - std::vector connections; - std::vector pollItems; - while (!stopToken.stop_requested()) { - std::list keep; - { - // This is a long lock, alternatively, message reading could have separate locks per connection - detail::WriteLock lock(_mdpConnectionsMutex); - - // Expired subscriptions cleanup - std::vector expiredSubscriptions; - for (auto &[subscriptionKey, connection] : _mdpConnectionForService) { - // fmt::print("Reference count is {}\n", connection->referenceCount()); - if (connection->referenceCount() == 0) { - auto connectionLock = connection->writeLock(); - if (connection->referenceCount() != 0) { - continue; - } - if (std::chrono::system_clock::now() - connection->lastUsed.load() > UNUSED_SUBSCRIPTION_EXPIRATION_TIME) { - expiredSubscriptions.push_back(subscriptionKey); - } - } - } - for (const auto &subscriptionKey : expiredSubscriptions) { - _mdpConnectionForService.erase(subscriptionKey); - } - - // setup poller and socket data structures for all connections - const std::size_t connectionCount = _mdpConnectionForService.size(); - connections.resize(connectionCount); - pollItems.resize(connectionCount); - for (std::size_t i = 0UZ; auto &[key, connection] : _mdpConnectionForService) { - connections[i] = connection.get(); - keep.emplace_back(connection.get()); - pollItems[i].events = ZMQ_POLLIN; - pollItems[i].socket = connection->notificationSubscriptionSocket.zmq_ptr; - ++i; - } - } // finished copying local state, keep ensures that connections are kept alive, end of lock on _mdpConnectionsForService - - if (pollItems.empty()) { - std::this_thread::sleep_for(100ms); // prevent spinning on connection cleanup if there are no connections to poll on - continue; - } - - auto pollCount = zmq::invoke(zmq_poll, pollItems.data(), static_cast(pollItems.size()), std::chrono::duration_cast(UPDATER_POLLING_TIME).count()); - if (!pollCount) { - fmt::print("Error while polling for updates from the broker\n"); - std::terminate(); - } - if (pollCount.value() == 0) { - continue; - } - - // Reading messages - for (std::size_t i = 0; i < connections.size(); ++i) { - if (pollItems[i].revents & ZMQ_POLLIN) { - detail::Connection *currentConnection = connections[i]; - std::unique_lock connectionLock = currentConnection->writeLock(); - while (auto responseMessage = zmq::receive(currentConnection->notificationSubscriptionSocket)) { - currentConnection->addCachedReply(connectionLock, std::string(responseMessage->data.asString())); - } - } - } - } - }); - } - - struct RestWorker; - - RestWorker &workerForCurrentThread() { - thread_local static RestWorker worker(*this); - return worker; - } - - using Mode::_svr; - using Mode::DEFAULT_REST_SCHEME; - -public: - explicit RestBackend(Broker &broker, const VirtualFS &vfs, URI<> restAddress = URI<>::factory().scheme(DEFAULT_REST_SCHEME).hostName("0.0.0.0").port(DEFAULT_REST_PORT).build()) - : _broker(broker), _vfs(vfs), _restAddress(restAddress) { - _broker.registerDnsAddress(restAddress); - } - - virtual ~RestBackend() { - _svr.stop(); - // shutdown thread before _connectionForService is destroyed - _mdpConnectionUpdaterThread.request_stop(); - _mdpConnectionUpdaterThread.join(); - } - - auto handleServiceRequest(const httplib::Request &request, httplib::Response &response, const httplib::ContentReader *content_reader_ = nullptr) { - using detail::RestMethod; - - auto convertParams = [](const httplib::Params ¶ms) { - mdp::Topic::Params r; - for (const auto &[key, value] : params) { - if (key == "LongPollingIdx" || key == "SubscriptionContext") { - continue; - } - if (value.empty()) { - r[key] = std::nullopt; - } else { - r[key] = value; - } - } - return r; - }; - - std::optional maybeTopic; - - try { - maybeTopic = mdp::Topic::fromString(request.path, convertParams(request.params)); - } catch (const std::exception &e) { - return detail::respondWithError(response, fmt::format("Error: {}\n", e.what())); - } - auto topic = std::move(*maybeTopic); - - auto restMethod = [&] { - auto methodString = request.has_header("X-OPENCMW-METHOD") ? request.get_header_value("X-OPENCMW-METHOD") : request.method; - // clang-format off - return methodString == "SUB" ? RestMethod::Subscribe : - methodString == "POLL" ? RestMethod::LongPoll : - methodString == "PUT" ? RestMethod::Post : - methodString == "POST" ? RestMethod::Post : - methodString == "GET" ? RestMethod::Get : - RestMethod::Invalid; - // clang-format on - }(); - - for (const auto &[key, value] : request.params) { - if (key == "LongPollingIdx") { - // This parameter is not passed on, it just means we want to use long polling - restMethod = value == "Subscription" ? RestMethod::Subscribe : RestMethod::LongPoll; - } else if (key == "SubscriptionContext") { - topic = mdp::Topic::fromString(value, {}); // params are parsed from value - } - } - - if (restMethod == RestMethod::Invalid) { - return detail::respondWithError(response, "Error: Requested method is not supported\n"); - } - - auto &worker = workerForCurrentThread(); - - switch (restMethod) { - case RestMethod::Get: - case RestMethod::Post: - return worker.respondWithGetSet(request, response, topic, restMethod, content_reader_); - case RestMethod::LongPoll: - return worker.respondWithLongPoll(request, response, topic); - case RestMethod::Subscribe: - return worker.respondWithSubscription(response, topic); - default: - // std::unreachable() is C++23 - assert(false && "We have already checked that restMethod is not Invalid"); - return false; - } - } - - virtual void registerHandlers() { - _svr.Get("/", [this](const httplib::Request &request, httplib::Response &response) { - return detail::respondWithServicesList(_broker, request, response); - }); - - _svr.Options(".*", - [](const httplib::Request & /*req*/, httplib::Response &res) { - res.set_header("Allow", "GET, POST, PUT, OPTIONS"); - res.set_header("Access-Control-Allow-Origin", "*"); - res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, OPTIONS"); - res.set_header("Access-Control-Allow-Headers", "X-OPENCMW-METHOD,Content-Type"); - res.set_header("Access-Control-Max-Age", "86400"); - }); - - static const char *nonEmptyPath = "..*"; - _svr.Get(nonEmptyPath, [this](const httplib::Request &request, httplib::Response &response) { - return handleServiceRequest(request, response, nullptr); - }); - _svr.Post(nonEmptyPath, [this](const httplib::Request &request, httplib::Response &response, const httplib::ContentReader &content_reader) { - return handleServiceRequest(request, response, &content_reader); - }); - _svr.Put(nonEmptyPath, [this](const httplib::Request &request, httplib::Response &response, const httplib::ContentReader &content_reader) { - return handleServiceRequest(request, response, &content_reader); - }); - } - - void run() { - thread::setThreadName("RestBackend thread"); - - startUpdaterThread(); - - registerHandlers(); - - if (!_restAddress.hostName() || !_restAddress.port()) { - throw opencmw::startup_error(fmt::format("REST server URI is not valid {}", _restAddress.str())); - } - - _svr.set_tcp_nodelay(true); - _svr.set_keep_alive_max_count(1000); - _svr.set_read_timeout(1, 0); - _svr.set_write_timeout(1, 0); - _svr.new_task_queue = []() { return new httplib::ThreadPool(/*num_threads=*/32, /*max_queued_requests=*/20); }; - - bool listening = _svr.listen(_restAddress.hostName().value().data(), _restAddress.port().value()); - if (!listening) { - throw opencmw::startup_error(fmt::format("Can not start REST server on {}:{}", _restAddress.hostName().value().data(), _restAddress.port().value())); - } - } - - void requestStop() { - _svr.stop(); - } - - void shutdown() { - requestStop(); - } -}; - -template -struct RestBackend::RestWorker { - RestBackend &restBackend; - - zmq_pollitem_t pollItem{}; - - explicit RestWorker(RestBackend &rest) - : restBackend(rest) { - } - - RestWorker(RestWorker &&other) noexcept = default; - - detail::Connection connect() { - detail::Connection connection(restBackend._broker.context, {}); - pollItem.events = ZMQ_POLLIN; - - zmq::invoke(zmq_connect, connection.notificationSubscriptionSocket, INTERNAL_ADDRESS_PUBLISHER.str()).template onFailure("Can not connect REST worker to Majordomo broker"); - zmq::invoke(zmq_connect, connection.requestResponseSocket, INTERNAL_ADDRESS_BROKER.str()).template onFailure("Can not connect REST worker to Majordomo broker"); - - return std::move(connection).unsafeMove(); - } - - bool respondWithGetSet(const httplib::Request &request, httplib::Response &response, mdp::Topic topic, detail::RestMethod restMethod, const httplib::ContentReader *content_reader_ = nullptr) { - const mdp::Command mdpMessageCommand = restMethod == detail::RestMethod::Post ? mdp::Command::Set : mdp::Command::Get; - - auto uri = URI<>::factory(); - std::string bodyOverride; - std::string contentType; - int contentLength{ 0 }; - for (const auto &[key, value] : request.params) { - if (key == "_bodyOverride") { - bodyOverride = value; - } else if (key == "LongPollingIdx") { - // This parameter is not passed on, it just means we want to use long polling -- already handled - } else { - uri = std::move(uri).addQueryParameter(key, value); - } - } - - for (const auto &[key, value] : request.headers) { - if (httplib::detail::case_ignore::equal(key, "Content-Length")) { - contentLength = std::stoi(value); - } else if (httplib::detail::case_ignore::equal(key, "Content-Type")) { - contentType = value; - } - } - - mdp::Message message; - message.protocolName = mdp::clientProtocol; - message.command = mdpMessageCommand; - - const auto acceptedFormat = detail::acceptedMimeForRequest(request); - topic.addParam("contentType", acceptedFormat); - message.serviceName = std::string(topic.service()); - message.topic = topic.toMdpTopic(); - - if (request.is_multipart_form_data()) { - if (content_reader_ != nullptr) { - const auto &content_reader = *content_reader_; - - FormData formData; - auto it = formData.fields.begin(); - - content_reader( - [&](const httplib::MultipartFormData &file) { - it = formData.fields.emplace(file.name, "").first; - return true; - }, - [&](const char *data, std::size_t data_length) { - // TODO: This should append to content - it->second.append(std::string_view(data, data_length)); - - return true; - }); - - opencmw::IoBuffer buffer; - // opencmw::serialise(buffer, formData); - IoSerialiser::serialise(buffer, FieldDescriptionShort{}, formData.fields); - - std::string requestData(buffer.asString()); - // Json serialiser (rightfully) does not like bool values in string - static auto replacerRegex = std::regex(R"regex("opencmw_unquoted_value[(](.*)[)]")regex"); - requestData = std::regex_replace(requestData, replacerRegex, "$1"); - auto requestDataBegin = std::find(requestData.cbegin(), requestData.cend(), '{'); - - const auto req = std::string_view(requestDataBegin, requestData.cend()); - message.data = IoBuffer(req.data(), req.size()); - } - } else if (!bodyOverride.empty()) { - message.data = IoBuffer(bodyOverride.data(), bodyOverride.size()); - } else if (!request.body.empty()) { - message.data = IoBuffer(request.body.data(), request.body.size()); - } else if (contentType != "" && contentLength > 0 && content_reader_ != nullptr) { - std::string body; - (*content_reader_)([&body](const char *data, size_t datalength) { body = std::string{data, datalength}; return true; }); - message.data = IoBuffer(body.data(), body.size()); - } - - auto connection = connect(); - - if (!zmq::send(std::move(message), connection.requestResponseSocket)) { - return detail::respondWithError(response, "Error: Failed to send a message to the broker\n"); - } - - // blocks waiting for the response - pollItem.socket = connection.requestResponseSocket.zmq_ptr; - auto pollResult = zmq::invoke(zmq_poll, &pollItem, 1, std::chrono::duration_cast(restBackend.majordomoTimeout()).count()); - if (!pollResult || pollResult.value() == 0) { - detail::respondWithError(response, "Error: No response from broker\n", HTTP_GATEWAY_TIMEOUT); - } else if (auto responseMessage = zmq::receive(connection.requestResponseSocket); !responseMessage) { - detail::respondWithError(response, "Error: Empty response from broker\n"); - } else if (!responseMessage->error.empty()) { - detail::respondWithError(response, responseMessage->error); - } else { - response.status = HTTP_OK; - - response.set_header("X-OPENCMW-TOPIC", responseMessage->topic.str().data()); - response.set_header("X-OPENCMW-SERVICE-NAME", responseMessage->serviceName.data()); - response.set_header("Access-Control-Allow-Origin", "*"); - response.set_header("X-TIMESTAMP", fmt::format("{}", std::chrono::duration_cast(std::chrono::high_resolution_clock::now().time_since_epoch()).count())); - const auto data = responseMessage->data.asString(); - - if (request.method != "GET") { - response.set_content(data.data(), data.size(), MIME::TEXT.typeName().data()); - } else { - response.set_content(data.data(), data.size(), acceptedFormat.data()); - } - } - return true; - } - - bool respondWithSubscription(httplib::Response &response, const mdp::Topic &subscription) { - const auto subscriptionKey = subscription.toZmqTopic(); - auto *connection = restBackend.notificationSubscriptionConnectionFor(subscriptionKey); - assert(connection); - const auto majordomoTimeout = restBackend.majordomoTimeout(); - response.set_header("Access-Control-Allow-Origin", "*"); - response.set_header("X-TIMESTAMP", fmt::format("{}", std::chrono::duration_cast(std::chrono::high_resolution_clock::now().time_since_epoch()).count())); - - response.set_chunked_content_provider( - "application/json", - [connection, majordomoTimeout](std::size_t /*offset*/, httplib::DataSink &sink) mutable { - std::cerr << "Chunked reply...\n"; - - if (!connection->waitForUpdate(majordomoTimeout)) { - return false; - } - - auto connectionCacheLock = connection->readLock(); - auto lastIndex = connection->nextPollingIndex(connectionCacheLock) - 1; - const auto &lastReply = connection->cachedReply(connectionCacheLock, lastIndex); - std::cerr << "Chunk: " << lastIndex << "'" << lastReply << "'\n"; - - sink.os << lastReply << "\n\n"; - - return true; - }, - [connection](bool) { - connection->decreaseReferenceCount(); - }); - - return true; - } - - bool respondWithLongPollRedirect(const httplib::Request &request, httplib::Response &response, const mdp::Topic &subscription, detail::PollingIndex redirectLongPollingIdx) { - auto uri = URI<>::factory() - .path(request.path) - .addQueryParameter("LongPollingIdx", std::to_string(redirectLongPollingIdx)); - - // copy over the original query parameters - addParameters(request, uri); - - const auto redirect = uri.toString(); - response.set_redirect(redirect); - return true; - } - - bool respondWithLongPoll(const httplib::Request &request, httplib::Response &response, const mdp::Topic &subscription) { - // TODO: After the URIs are formalized, rethink service and topic - auto uri = URI<>::factory(); - addParameters(request, uri); - - const auto subscriptionKey = subscription.toZmqTopic(); - - const auto longPollingIdxIt = request.params.find("LongPollingIdx"); - if (longPollingIdxIt == request.params.end()) { - return detail::respondWithError(response, "Error: LongPollingIdx parameter not specified"); - } - - const auto &longPollingIdxParam = longPollingIdxIt->second; - - struct CacheInfo { - detail::PollingIndex firstCachedIndex = 0; - detail::PollingIndex nextPollingIndex = 0; - detail::Connection *connection = nullptr; - }; - auto fetchCache = [this, &subscriptionKey] { - std::shared_lock lock(restBackend._mdpConnectionsMutex); - auto &recycledConnectionForService = restBackend._mdpConnectionForService; - if (auto it = recycledConnectionForService.find(subscriptionKey); it != recycledConnectionForService.cend()) { - auto *connectionCache = it->second.get(); - detail::Connection::KeepAlive keep(connectionCache); - connectionCache->lastUsed = std::chrono::system_clock::now(); - auto connectionCacheLock = connectionCache->readLock(); - return CacheInfo{ - .firstCachedIndex = connectionCache->nextPollingIndex(connectionCacheLock) - connectionCache->cachedRepliesSize(connectionCacheLock), - .nextPollingIndex = connectionCache->nextPollingIndex(connectionCacheLock), - .connection = connectionCache - }; - } else { - // We didn't have this before, means 0 is the next index - return CacheInfo{ - .firstCachedIndex = 0, - .nextPollingIndex = 0, - .connection = nullptr - }; - } - }; - - detail::PollingIndex requestedLongPollingIdx = 0; - - // Hoping we already have the requested value in the cache. Holding this caches blocks all cache entries, so no further updates can be received or other connections initiated. - { - const auto cache = fetchCache(); - response.set_header("Access-Control-Allow-Origin", "*"); - - if (longPollingIdxParam == "Next") { - return respondWithLongPollRedirect(request, response, subscription, cache.nextPollingIndex); - } - - if (longPollingIdxParam == "Last") { - if (cache.connection != nullptr) { - return respondWithLongPollRedirect(request, response, subscription, cache.nextPollingIndex - 1); - } else { - return respondWithLongPollRedirect(request, response, subscription, cache.nextPollingIndex); - } - } - - if (longPollingIdxParam == "FirstAvailable") { - return respondWithLongPollRedirect(request, response, subscription, cache.firstCachedIndex); - } - - if (std::from_chars(longPollingIdxParam.data(), longPollingIdxParam.data() + longPollingIdxParam.size(), requestedLongPollingIdx).ec != std::errc{}) { - return detail::respondWithError(response, "Error: Invalid LongPollingIdx value"); - } - - if (requestedLongPollingIdx > cache.nextPollingIndex) { - return detail::respondWithError(response, "Error: LongPollingIdx tries to read the future"); - } - - if (requestedLongPollingIdx < cache.firstCachedIndex || requestedLongPollingIdx + 15 < cache.nextPollingIndex) { - return respondWithLongPollRedirect(request, response, subscription, cache.nextPollingIndex); - } - - if (cache.connection && requestedLongPollingIdx < cache.nextPollingIndex) { - auto connectionCacheLock = cache.connection->readLock(); - // The result is already ready - response.set_content(cache.connection->cachedReply(connectionCacheLock, requestedLongPollingIdx), MIME::JSON.typeName().data()); - return true; - } - } - - // Fallback to creating a connection and waiting - auto *connection = restBackend.notificationSubscriptionConnectionFor(subscriptionKey); - assert(connection); - detail::Connection::KeepAlive keep(connection); - - // Since we use KeepAlive object, the initial refCount can go away - connection->decreaseReferenceCount(); - - if (!connection->waitForUpdate(restBackend.majordomoTimeout())) { - return detail::respondWithError(response, "Timeout waiting for update", HTTP_GATEWAY_TIMEOUT); - } - - const auto newCache = fetchCache(); - - // This time it needs to exist - assert(newCache.connection != nullptr); - - if (requestedLongPollingIdx >= newCache.firstCachedIndex && requestedLongPollingIdx < newCache.nextPollingIndex) { - auto connectionCacheLock = newCache.connection->readLock(); - response.set_content(newCache.connection->cachedReply(connectionCacheLock, requestedLongPollingIdx), MIME::JSON.typeName().data()); - return true; - } else { - return detail::respondWithError(response, "Error: We waited for the new value, but it was not found"); - } - } - -private: - void addParameters(const httplib::Request &request, URI<>::UriFactory &uri) { - for (const auto &[key, value] : request.params) { - if (key == "LongPollingIdx") { - // This parameter is not passed on, it just means we want to use long polling -- already handled - } else { - uri = std::move(uri).addQueryParameter(key, value); - } - } - } -}; - -} // namespace opencmw::majordomo - -#endif // include guard diff --git a/src/majordomo/test/CMakeLists.txt b/src/majordomo/test/CMakeLists.txt index a031859a..d8cfd566 100644 --- a/src/majordomo/test/CMakeLists.txt +++ b/src/majordomo/test/CMakeLists.txt @@ -29,7 +29,6 @@ function(opencmw_add_test_catch2 name sources) catch_discover_tests(${name}) endfunction() - opencmw_add_test_app(majordomo_testapp testapp.cpp) opencmw_add_test_app(majordomo_benchmark majordomo_benchmark.cpp) @@ -40,3 +39,7 @@ opencmw_add_test_catch2(SubscriptionMatch_tests subscriptionmatcher_tests.cpp) opencmw_add_test_catch2(majordomo_tests majordomo_tests.cpp;subscriptionmatcher_tests.cpp) opencmw_add_test_catch2(majordomo_worker_tests majordomoworker_tests.cpp;subscriptionmatcher_tests.cpp) opencmw_add_test_catch2(majordomo_worker_rest_tests majordomoworker_rest_tests.cpp;subscriptionmatcher_tests.cpp) + +if(NOT ENSCRIPTEN) + opencmw_add_test_catch2(majordomo_worker_load_tests majordomo_load_tests.cpp) +endif() diff --git a/src/majordomo/test/majordomo_load_tests.cpp b/src/majordomo/test/majordomo_load_tests.cpp new file mode 100644 index 00000000..7d0a46ae --- /dev/null +++ b/src/majordomo/test/majordomo_load_tests.cpp @@ -0,0 +1,126 @@ +#include "ClientCommon.hpp" +#include "IoBuffer.hpp" +#include "IoSerialiser.hpp" +#include "IoSerialiserYaS.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +using namespace opencmw; + +constexpr std::uint16_t kServerPort = 12355; + +namespace { +template +void waitFor(std::atomic &responseCount, T expected, std::chrono::milliseconds timeout = std::chrono::seconds(5)) { + const auto start = std::chrono::system_clock::now(); + while (responseCount.load() < expected && std::chrono::system_clock::now() - start < timeout) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + const auto received = responseCount.load(); + if (received != expected) { + FAIL(fmt::format("Expected {} responses, but got {}\n", expected, received)); + } +} +} // namespace + +TEST_CASE("Load test", "[majordomo][majordomoworker][load_test][http2]") { + majordomo::Broker broker("/TestBroker", testSettings()); + majordomo::rest::Settings rest; + rest.port = kServerPort; + auto bound = broker.bindRest(rest); + if (!bound) { + FAIL(fmt::format("Failed to bind REST server: {}", bound.error())); + return; + } + + query::registerTypes(opencmw::majordomo::load_test::Context(), broker); + + majordomo::load_test::Worker worker(broker); + + RunInThread brokerRun(broker); + RunInThread workerRun(worker); + REQUIRE(waitUntilWorkerServiceAvailable(broker.context, worker)); + + constexpr auto kNClients = 10UZ; + constexpr auto kSubscriptions = 10UZ; + constexpr bool kSeparateSubscriptions = true; + constexpr auto kNUpdates = 100UZ; + constexpr auto kIntervalMs = 40UZ; + constexpr auto kInitialDelayMs = 50UZ; + constexpr auto kPayloadSize = 4096UZ; + + std::atomic responseCount = 0; + + std::array, kNClients> clients; + for (std::size_t i = 0; i < clients.size(); i++) { + clients[i] = std::make_unique(client::DefaultContentTypeHeader(MIME::BINARY)); + } + + const auto start = std::chrono::system_clock::now(); + std::array latencies; + + std::atomic responseCountOfi0j0 = 0; + std::array latenciesOfi0j0; + + for (std::size_t i = 0; i < clients.size(); i++) { + for (std::size_t j = 0; j < kSubscriptions; j++) { + client::Command cmd; + cmd.command = mdp::Command::Subscribe; + cmd.serviceName = "/loadTest"; + const auto topic = fmt::format("{}:{}", kSeparateSubscriptions ? i : 0, j); + cmd.topic = URI<>(fmt::format("http://localhost:{}/loadTest?topic={}&intervalMs={}&payloadSize={}&nUpdates={}&initialDelayMs={}", kServerPort, topic, kIntervalMs, kPayloadSize, kNUpdates, kInitialDelayMs)); + cmd.callback = [&responseCount, &responseCountOfi0j0, &latencies, &latenciesOfi0j0, kPayloadSize, i, j](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Notify); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.size() > 0); + const auto index = responseCount.fetch_add(1); + + majordomo::load_test::Payload payload; + try { + IoBuffer buffer{ msg.data }; + opencmw::deserialise(buffer, payload); + REQUIRE(payload.data.size() == kPayloadSize); + } catch (const opencmw::ProtocolException &e) { + FAIL(fmt::format("Failed to deserialise payload: {}", e.what())); + return; + } + const auto now = opencmw::load_test::timestamp().count(); + const auto latency = now - payload.timestampNs; + assert(latency >= 0); + if (latency < 0) { + fmt::print("Negative latency: {} ({} - {})\n", latency, now, payload.timestampNs); + } + latencies[index] = static_cast(latency); + if (i == 0 && j == 0) { + const auto idx = responseCountOfi0j0.fetch_add(1); + latenciesOfi0j0[idx] = static_cast(latency / 1000); + } + }; + clients[i]->request(std::move(cmd)); + } + } + + waitFor(responseCount, kNUpdates * kSubscriptions * kNClients, 30s); + + fmt::println("Received {} responses in {}ms (Net production time: {}ms)", responseCount.load(), std::chrono::duration_cast(std::chrono::system_clock::now() - start).count(), kInitialDelayMs + (kNUpdates * kIntervalMs)); + // TODO maybe print more detailed distribution + std::uint64_t sumLatency = 0; + for (const auto &latency : latencies) { + sumLatency += latency; + } + + const auto averageLatency = static_cast(sumLatency) / static_cast(responseCount.load()); + fmt::println("Average latency: {}µs", averageLatency / 1000.0); + fmt::println("{}", fmt::join(latenciesOfi0j0, ", ")); + // TODO compute drift over time +} diff --git a/src/majordomo/test/majordomoworker_rest_tests.cpp b/src/majordomo/test/majordomoworker_rest_tests.cpp index 7a1e9381..f353cfcb 100644 --- a/src/majordomo/test/majordomoworker_rest_tests.cpp +++ b/src/majordomo/test/majordomoworker_rest_tests.cpp @@ -1,5 +1,4 @@ #include -#include #include #include @@ -13,55 +12,13 @@ #include #include -#include #include #include -#include // Concepts and tests use common types #include -std::jthread makeGetRequestResponseCheckerThread(const std::string &address, const std::vector &requiredResponses, const std::vector &requiredStatusCodes = {}, [[maybe_unused]] std::source_location location = std::source_location::current()) { - return std::jthread([=] { - httplib::Client http("localhost", majordomo::DEFAULT_REST_PORT); - http.set_follow_location(true); - http.set_keep_alive(true); -#define requireWithSource(arg) \ - if (!(arg)) opencmw::zmq::debug::withLocation(location) << "<- call got a failed requirement:"; \ - REQUIRE(arg) - for (std::size_t i = 0; i < requiredResponses.size(); ++i) { - const auto response = http.Get(address); - requireWithSource(response); - const auto requiredStatusCode = i < requiredStatusCodes.size() ? requiredStatusCodes[i] : 200; - requireWithSource(response->status == requiredStatusCode); - requireWithSource(response->body.find(requiredResponses[i]) != std::string::npos); - } -#undef requireWithSource - }); -} - -std::jthread makeLongPollingRequestResponseCheckerThread(const std::string &address, const std::vector &requiredResponses, const std::vector &requiredStatusCodes = {}, [[maybe_unused]] std::source_location location = std::source_location::current()) { - return std::jthread([=] { - httplib::Client http("localhost", majordomo::DEFAULT_REST_PORT); - http.set_follow_location(true); - http.set_keep_alive(true); -#define requireWithSource(arg) \ - if (!(arg)) opencmw::zmq::debug::withLocation(location) << "<- call got a failed requirement:"; \ - REQUIRE(arg) - for (std::size_t i = 0; i < requiredResponses.size(); ++i) { - const std::string url = fmt::format("{}{}LongPollingIdx={}", address, address.contains('?') ? "&" : "?", i == 0 ? "Next" : fmt::format("{}", i)); - const auto response = http.Get(url); - if (i == 0) { // check forwarding to the explicit index - REQUIRE(response->location.find("LongPollingIdx=0") != std::string::npos); - } - requireWithSource(response); - const auto requiredStatusCode = i < requiredStatusCodes.size() ? requiredStatusCodes[i] : 200; - requireWithSource(response->status == requiredStatusCode); - requireWithSource(response->body.find(requiredResponses[i]) != std::string::npos); - } -#undef requireWithSource - }); -} +constexpr std::uint16_t kServerPort = 12348; struct ColorContext { bool red = false; @@ -87,6 +44,11 @@ class ColorWorker : public majordomo::Worker explicit ColorWorker(const BrokerType &broker, std::vector notificationContexts) : super_t(broker, {}) { + super_t::setCallback([](majordomo::RequestContext & /*rawCtx*/, const ColorContext &inCtx, const majordomo::Empty &, ColorContext &outCtx, SingleString &out) { + outCtx = inCtx; + out.value = fmt::format("red={}, green={}, blue={}\n", inCtx.red, inCtx.green, inCtx.blue); + FAIL(fmt::format("Unexpected GET/SET request: {}", out.value)); + }); notifyThread = std::jthread([this, contexts = std::move(notificationContexts)]() { int counter = 0; for (const auto &context : contexts) { @@ -133,10 +95,10 @@ struct WaitingContext { ENABLE_REFLECTION_FOR(WaitingContext, timeoutMs, contentType) struct UpdateTime { - long updateTime; + long updateTimeµs; std::vector payload; }; -ENABLE_REFLECTION_FOR(UpdateTime, updateTime, payload) +ENABLE_REFLECTION_FOR(UpdateTime, updateTimeµs, payload) template class WaitingWorker : public majordomo::Worker { @@ -168,11 +130,11 @@ class ClockWorker : public majordomo::Worker 0 && !_shutdownRequested) { std::this_thread::sleep_until(updateTime); fmt::print("publishing update\n"); - UpdateTime update{std::chrono::duration_cast(updateTime.time_since_epoch()).count(), std::views::iota(0, payloadSize) | std::ranges::to() }; + UpdateTime update{ std::chrono::duration_cast(updateTime.time_since_epoch()).count(), std::views::iota(0, payloadSize) | std::ranges::to() }; this->notify(SimpleContext(), update); updateTime += _period; _nUpdates--; @@ -187,12 +149,31 @@ class ClockWorker : public majordomo::Worker +void waitFor(std::atomic &responseCount, T expected, std::chrono::milliseconds timeout = std::chrono::seconds(5)) { + const auto start = std::chrono::system_clock::now(); + while (responseCount.load() < expected && std::chrono::system_clock::now() - start < timeout) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + const auto result = responseCount.load() == expected; + if (!result) { + FAIL(fmt::format("Expected {} responses, but got {}\n", expected, responseCount.load())); + } +} +} // namespace + +TEST_CASE("Simple REST example", "[majordomo][majordomoworker][simple_example][http2]") { // We run both broker and worker inproc - majordomo::Broker broker("/TestBroker", testSettings()); - auto fs = cmrc::assets::get_filesystem(); - FileServerRestBackend rest(broker, fs); - RunInThread restServerRun(rest); + majordomo::Broker broker("/TestBroker", testSettings()); + majordomo::rest::Settings rest; + rest.port = kServerPort; + rest.handlers = { majordomo::rest::cmrcHandler("/assets/*", "", std::make_shared(cmrc::assets::get_filesystem()), "") }; + + if (auto bound = broker.bindRest(rest); !bound) { + FAIL(fmt::format("Failed to bind REST server: {}", bound.error())); + return; + } // For subscription matching, it is necessary that broker knows how to handle the query params "ctx" and "contentType". // ("ctx" needs to use the TimingCtxFilter, and "contentType" compare the mime types (currently simply a string comparison)) @@ -205,59 +186,143 @@ TEST_CASE("Simple MajordomoWorker example showing its usage", "[majordomo][major // Create MajordomoWorker with our domain objects, and our TestHandler. majordomo::Worker<"/addressbook", SimpleContext, AddressRequest, AddressEntry> worker(broker, TestAddressHandler()); - // Run worker and broker in separate threads - RunInThread brokerRun(broker); - RunInThread workerRun(worker); - + RunInThread brokerRun(broker); + RunInThread workerRun(worker); REQUIRE(waitUntilWorkerServiceAvailable(broker.context, worker)); - SECTION("request Address information as JSON and as HTML") { - auto httpThreadJSON = makeGetRequestResponseCheckerThread("/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=application%2Fjavascript", { "Santa Claus" }); - - auto httpThreadHTML = makeGetRequestResponseCheckerThread("/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=text%2Fhtml", { "Elf Road" }); - } - - SECTION("post data") { - httplib::Client postData{ "http://localhost:8080" }; - postData.Post("/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=application%json", "{\"streetNumber\": 1882}", "application/json"); - - auto httpThreadJSON = makeGetRequestResponseCheckerThread("/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=application%json", { "1882" }); - } - - SECTION("post data as multipart") { - std::jthread putRequestThread{ - [] { - // set a value on the server - httplib::Client postRequest("http://localhost:8080"); - postRequest.set_keep_alive(true); - - httplib::MultipartFormDataItems items{ - { "name", "Kalle", "name_file", "text" }, - { "street", "calle", "street_file", "text" }, - // { "streetNumber", "8", "number_file", "number" }, // `error(22) parsing number at buffer position: 41"` , deserialiser finds "8" instead of 8 - { "postalCode", "14005", "postal_code_file", "text" }, - { "city", "ciudad", "city_file", "text" } - // "isCurrent", "true", "is_current_file", "text" }, // does not work because true will be quoted. which is not a valid boolean - //{ "isCurrent", "false", "is_current_file", "boolean" }, // content type boolean might not exist, anyway, content_type is not taken into account anyway - }; - - auto r = postRequest.Put("/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=application%2Fjavascript", items); - - REQUIRE(r); - CAPTURE(r->reason); - CAPTURE(r->body); - REQUIRE(r->status == 200); - - auto httpThreadJSON = makeGetRequestResponseCheckerThread("/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=application%2Fjavascript", { "Kalle" }); - } - }; - } + std::atomic responseCount = 0; + + opencmw::client::RestClient client; + + // Invalid port, assuming there's nobody listening on port 44444 + opencmw::client::Command cannotReach; + cannotReach.command = mdp::Command::Get; + cannotReach.topic = opencmw::URI<>("http://localhost:44444/"); + cannotReach.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == "Connection refused"); + REQUIRE(msg.data.asString() == ""); + responseCount++; + }; + client.request(std::move(cannotReach)); + + auto serviceListCallback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "/addressbook,/mmi.dns,/mmi.echo,/mmi.openapi,/mmi.service"); + responseCount++; + }; + + // path "" returns service list + opencmw::client::Command serviceList; + serviceList.command = mdp::Command::Get; + serviceList.topic = opencmw::URI<>(fmt::format("http://localhost:{}", kServerPort)); + serviceList.callback = serviceListCallback; + client.request(std::move(serviceList)); + + // path "/" returns service list + opencmw::client::Command serviceList2; + serviceList2.command = mdp::Command::Get; + serviceList2.topic = opencmw::URI<>(fmt::format("http://localhost:{}/", kServerPort)); + serviceList2.callback = serviceListCallback; + client.request(std::move(serviceList2)); + + // Asset file + opencmw::client::Command getAsset; + getAsset.command = mdp::Command::Get; + getAsset.topic = opencmw::URI<>(fmt::format("http://localhost:{}/assets/main.css", kServerPort)); + getAsset.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains("body {")); + responseCount++; + }; + client.request(std::move(getAsset)); + + // Asset file that does not exist + opencmw::client::Command getAsset404; + getAsset404.command = mdp::Command::Get; + getAsset404.topic = opencmw::URI<>(fmt::format("http://localhost:{}/assets/does-not-exist", kServerPort)); + getAsset404.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == "Not found"); + REQUIRE(msg.data.asString() == ""); + responseCount++; + }; + client.request(std::move(getAsset404)); + + waitFor(responseCount, 5); + responseCount = 0; + + constexpr int kExpectedResponses = 4; + + // The following requests are all to the same service and thus must be processed in order + + // Get worker data as JSON + opencmw::client::Command getJson; + getJson.command = mdp::Command::Get; + getJson.topic = opencmw::URI<>(fmt::format("http://localhost:{}/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=application%2Fjson", kServerPort)); + getJson.callback = [&responseCount](const auto &msg) { + REQUIRE(responseCount == 0); + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains("\"name\": \"Santa Claus\"")); + REQUIRE(msg.topic == opencmw::URI<>("/addressbook?contentType=application%2Fjson&testFilter=&ctx=FAIR.SELECTOR.ALL")); + responseCount++; + }; + client.request(std::move(getJson)); + + // Get worker data as HTML + opencmw::client::Command getHtml; + getHtml.command = mdp::Command::Get; + getHtml.topic = opencmw::URI<>(fmt::format("http://localhost:{}/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=text%2Fhtml", kServerPort)); + getHtml.callback = [&responseCount](const auto &msg) { + REQUIRE(responseCount == 1); + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains("Elf Road")); + REQUIRE(msg.topic == opencmw::URI<>("/addressbook?contentType=text%2Fhtml&testFilter=&ctx=FAIR.SELECTOR.ALL")); + responseCount++; + }; + client.request(std::move(getHtml)); + + // Set worker data + opencmw::client::Command postJson; + postJson.command = mdp::Command::Set; + postJson.topic = opencmw::URI<>(fmt::format("http://localhost:{}/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=application%2Fjson", kServerPort)); + postJson.data = opencmw::IoBuffer("{\"streetNumber\": 1882}"); + postJson.callback = [&responseCount](const auto &msg) { + REQUIRE(responseCount == 2); + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.topic == opencmw::URI<>("/addressbook?contentType=application%2Fjson&testFilter=&ctx=FAIR.SELECTOR.ALL")); + responseCount++; + }; + client.request(std::move(postJson)); + + // Test that the Set call is correctly applied + opencmw::client::Command getJsonAfterSet; + getJsonAfterSet.command = mdp::Command::Get; + getJsonAfterSet.topic = opencmw::URI<>(fmt::format("http://localhost:{}/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=application%2Fjson", kServerPort)); + getJsonAfterSet.callback = [&responseCount](const auto &msg) { + REQUIRE(responseCount == 3); + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains("\"streetNumber\": 1882")); + REQUIRE(msg.topic == opencmw::URI<>("/addressbook?contentType=application%2Fjson&testFilter=&ctx=FAIR.SELECTOR.ALL")); + responseCount++; + }; + client.request(std::move(getJsonAfterSet)); + + waitFor(responseCount, kExpectedResponses); } + TEST_CASE("Invalid paths", "[majordomo][majordomoworker][rest]") { majordomo::Broker broker("/TestBroker", testSettings()); - auto fs = cmrc::assets::get_filesystem(); - FileServerRestBackend rest(broker, fs); - RunInThread restServerRun(rest); + majordomo::rest::Settings rest; + rest.port = kServerPort; + auto bound = broker.bindRest(rest); + REQUIRE(bound); opencmw::query::registerTypes(PathContext(), broker); @@ -268,16 +333,47 @@ TEST_CASE("Invalid paths", "[majordomo][majordomoworker][rest]") { REQUIRE(waitUntilWorkerServiceAvailable(broker.context, worker)); - auto space = makeGetRequestResponseCheckerThread("/paths/with%20space", { "Invalid service name" }, { 500 }); - auto invalidSubscription = makeGetRequestResponseCheckerThread("/p-a-t-h-s/?LongPollIdx=Next", { "Invalid service name" }, { 500 }); + constexpr int kExpectedResponses = 1; + std::atomic responseCount = 0; + + opencmw::client::RestClient client; + +#if 0 // TODO Currently URI<>() throws on %20 in the path (which is a valid URI but not a valid service name) + opencmw::client::Command getSpace; + getSpace.command = mdp::Command::Get; + getSpace.topic = opencmw::URI<>(fmt::format("http://localhost:{}/paths/with%20space", kServerPort)); + getSpace.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == "Invalid service name"); + REQUIRE(msg.data.empty()); + responseCount++; + }; + client.request(std::move(getSpace)); +#endif + opencmw::client::Command invalidSubscription; + invalidSubscription.command = mdp::Command::Get; + invalidSubscription.topic = opencmw::URI<>(fmt::format("http://localhost:{}//p-a-t-h-s/?LongPollIdx=Next", kServerPort)); + invalidSubscription.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == "Invalid service name '//p-a-t-h-s'"); + REQUIRE(msg.data.empty()); + responseCount++; + }; + + client.request(std::move(invalidSubscription)); + + waitFor(responseCount, kExpectedResponses); } TEST_CASE("Get/Set with subpaths", "[majordomo][majordomoworker][rest]") { - majordomo::Broker broker("/TestBroker", testSettings()); - auto fs = cmrc::assets::get_filesystem(); - FileServerRestBackend rest(broker, fs); - RunInThread restServerRun(rest); - + majordomo::Broker broker("/TestBroker", testSettings()); + majordomo::rest::Settings rest; + rest.port = kServerPort; + auto bound = broker.bindRest(rest); + if (!bound) { + FAIL(fmt::format("Failed to bind REST server: {}", bound.error())); + return; + } opencmw::query::registerTypes(PathContext(), broker); PathWorker<"/paths"> worker(broker); @@ -287,17 +383,59 @@ TEST_CASE("Get/Set with subpaths", "[majordomo][majordomoworker][rest]") { REQUIRE(waitUntilWorkerServiceAvailable(broker.context, worker)); - auto empty = makeGetRequestResponseCheckerThread("/paths", { "path=''" }); - auto one = makeGetRequestResponseCheckerThread("/paths/a", { "path='\\/a'" }); - auto two = makeGetRequestResponseCheckerThread("/paths/a/b", { "path='\\/a\\/b'" }); + constexpr int kExpectedResponses = 3; + std::atomic responseCount = 0; + + opencmw::client::RestClient client; + + opencmw::client::Command getEmpty; + getEmpty.command = mdp::Command::Get; + getEmpty.topic = opencmw::URI<>(fmt::format("http://localhost:{}/paths", kServerPort)); + getEmpty.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains("path=''")); + REQUIRE(msg.topic == opencmw::URI<>("/paths?contentType=application%2Fjson")); + responseCount++; + }; + client.request(std::move(getEmpty)); + + opencmw::client::Command getOne; + getOne.command = mdp::Command::Get; + getOne.topic = opencmw::URI<>(fmt::format("http://localhost:{}/paths/a", kServerPort)); + getOne.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains("path='\\/a'")); + REQUIRE(msg.topic == opencmw::URI<>("/paths/a?contentType=application%2Fjson")); + responseCount++; + }; + client.request(std::move(getOne)); + + opencmw::client::Command getTwo; + getTwo.command = mdp::Command::Get; + getTwo.topic = opencmw::URI<>(fmt::format("http://localhost:{}/paths/a/b", kServerPort)); + getTwo.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains("path='\\/a\\/b'")); + REQUIRE(msg.topic == opencmw::URI<>("/paths/a/b?contentType=application%2Fjson")); + responseCount++; + }; + client.request(std::move(getTwo)); + + waitFor(responseCount, kExpectedResponses); } TEST_CASE("Subscriptions", "[majordomo][majordomoworker][subscription]") { - majordomo::Broker broker("/TestBroker", testSettings()); - auto fs = cmrc::assets::get_filesystem(); - FileServerRestBackend rest(broker, fs); - RunInThread restServerRun(rest); - + majordomo::Broker broker("/TestBroker", testSettings()); + majordomo::rest::Settings rest; + rest.port = kServerPort; + const auto bound = broker.bindRest(rest); + if (!bound) { + FAIL(fmt::format("Failed to bind REST server: {}", bound.error())); + return; + } opencmw::query::registerTypes(ColorContext(), broker); constexpr auto red = ColorContext{ .red = true }; @@ -315,14 +453,93 @@ TEST_CASE("Subscriptions", "[majordomo][majordomoworker][subscription]") { REQUIRE(waitUntilWorkerServiceAvailable(broker.context, worker)); - auto allListener = makeLongPollingRequestResponseCheckerThread("/colors", { "0", "1", "2", "3", "4", "5", "6" }); - auto redListener = makeLongPollingRequestResponseCheckerThread("/colors?red", { "0", "3", "4", "6" }); - auto yellowListener = makeLongPollingRequestResponseCheckerThread("/colors?red&green", { "4", "6" }); - auto whiteListener1 = makeLongPollingRequestResponseCheckerThread("/colors?red&green&blue", { "6" }); - auto whiteListener2 = makeLongPollingRequestResponseCheckerThread("/colors?green&red&blue", { "6" }); - auto whiteListener3 = makeLongPollingRequestResponseCheckerThread("/colors?blue&green&red", { "6" }); - - std::this_thread::sleep_for(50ms); // give time for subscriptions to happen + opencmw::client::RestClient client; + + constexpr auto allExpected = std::array{ "0", "1", "2", "3", "4", "5", "6" }; + constexpr auto redExpected = std::array{ "0", "3", "4", "6" }; + constexpr auto yellowExpected = std::array{ "4", "6" }; + constexpr auto whiteExpected = std::array{ "6" }; + + std::atomic allReceived = 0; + std::atomic redReceived = 0; + std::atomic yellowReceived = 0; + std::atomic whiteReceived1 = 0; + std::atomic whiteReceived2 = 0; + std::atomic whiteReceived3 = 0; + + opencmw::client::Command allSub; + allSub.command = mdp::Command::Subscribe; + allSub.topic = opencmw::URI<>(fmt::format("http://localhost:{}/colors", kServerPort)); + allSub.callback = [&allReceived, &allExpected](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Notify); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains(allExpected[allReceived])); + allReceived++; + }; + client.request(std::move(allSub)); + + opencmw::client::Command redSub; + redSub.command = mdp::Command::Subscribe; + redSub.topic = opencmw::URI<>(fmt::format("http://localhost:{}/colors?red", kServerPort)); + redSub.callback = [&redReceived, &redExpected](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Notify); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains(redExpected[redReceived])); + redReceived++; + }; + client.request(std::move(redSub)); + + opencmw::client::Command yellowSub; + yellowSub.command = mdp::Command::Subscribe; + yellowSub.topic = opencmw::URI<>(fmt::format("http://localhost:{}/colors?red&green", kServerPort)); + yellowSub.callback = [&yellowReceived, &yellowExpected](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Notify); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains(yellowExpected[yellowReceived])); + yellowReceived++; + }; + client.request(std::move(yellowSub)); + + opencmw::client::Command whiteSub1; + whiteSub1.command = mdp::Command::Subscribe; + whiteSub1.topic = opencmw::URI<>(fmt::format("http://localhost:{}/colors?red&green&blue", kServerPort)); + whiteSub1.callback = [&whiteReceived1, &whiteExpected](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Notify); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains(whiteExpected[whiteReceived1])); + whiteReceived1++; + }; + client.request(std::move(whiteSub1)); + + opencmw::client::Command whiteSub2; + whiteSub2.command = mdp::Command::Subscribe; + whiteSub2.topic = opencmw::URI<>(fmt::format("http://localhost:{}/colors?green&red&blue", kServerPort)); + whiteSub2.callback = [&whiteReceived2, &whiteExpected](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Notify); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains(whiteExpected[whiteReceived2])); + whiteReceived2++; + }; + client.request(std::move(whiteSub2)); + + opencmw::client::Command whiteSub3; + whiteSub3.command = mdp::Command::Subscribe; + whiteSub3.topic = opencmw::URI<>(fmt::format("http://localhost:{}/colors?blue&green&red", kServerPort)); + whiteSub3.callback = [&whiteReceived3, &whiteExpected](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Notify); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains(whiteExpected[whiteReceived3])); + whiteReceived3++; + }; + + client.request(std::move(whiteSub3)); + + waitFor(allReceived, allExpected.size()); + waitFor(redReceived, redExpected.size()); + waitFor(yellowReceived, yellowExpected.size()); + waitFor(whiteReceived1, whiteExpected.size()); + waitFor(whiteReceived2, whiteExpected.size()); + waitFor(whiteReceived3, whiteExpected.size()); std::vector subscriptions; for (const auto &subscription : worker.activeSubscriptions()) { @@ -332,40 +549,199 @@ TEST_CASE("Subscriptions", "[majordomo][majordomoworker][subscription]") { REQUIRE(subscriptions == std::vector{ "/colors#", "/colors?blue&green&red#", "/colors?green&red#", "/colors?red#" }); } +TEST_CASE("Handler matching", "[majordomo][rest]") { + majordomo::Broker broker("/TestBroker", testSettings()); + majordomo::rest::Settings rest; + rest.port = kServerPort; + + auto echoHandler = [](std::string method, std::string path) { + return majordomo::rest::Handler{ + .method = method, + .path = path, + .handler = [path, method](const auto &) { + majordomo::rest::Response response; + response.code = 200; + response.body.put(fmt::format("{}|{}", method, path)); + return response; + } + }; + }; + + // Cannot test POST because Command::Set is also using "GET" + rest.handlers = { + echoHandler("GET", "/"), + echoHandler("GET", "/assets/*"), + echoHandler("GET", "/assets/subresource") + }; + + auto bound = broker.bindRest(rest); + if (!bound) { + FAIL(fmt::format("Failed to bind REST server: {}", bound.error())); + return; + } + + RunInThread brokerRun(broker); + + opencmw::client::RestClient client; + + std::atomic responseCount = 0; + + opencmw::client::Command getRoot; + getRoot.command = mdp::Command::Get; + getRoot.topic = opencmw::URI<>(fmt::format("http://localhost:{}/", kServerPort)); + getRoot.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "GET|/"); + responseCount++; + }; + client.request(std::move(getRoot)); + + opencmw::client::Command getAssets; + getAssets.command = mdp::Command::Get; + getAssets.topic = opencmw::URI<>(fmt::format("http://localhost:{}/assets/test.txt", kServerPort)); + getAssets.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "GET|/assets/*"); + responseCount++; + }; + client.request(std::move(getAssets)); + + opencmw::client::Command getSubresource; + getSubresource.command = mdp::Command::Get; + getSubresource.topic = opencmw::URI<>(fmt::format("http://localhost:{}/assets/subresource", kServerPort)); + getSubresource.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "GET|/assets/subresource"); + responseCount++; + }; + client.request(std::move(getSubresource)); + + opencmw::client::Command getSubresource2; + getSubresource2.command = mdp::Command::Get; + getSubresource2.topic = opencmw::URI<>(fmt::format("http://localhost:{}/assets/subresource/extra", kServerPort)); + getSubresource2.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "GET|/assets/*"); + responseCount++; + }; + client.request(std::move(getSubresource2)); + waitFor(responseCount, 4); +} + +TEST_CASE("File system handler", "[majordomo][majordomoworker][rest]") { + majordomo::Broker broker("/TestBroker", testSettings()); + majordomo::rest::Settings rest; + rest.port = kServerPort; + rest.handlers = { majordomo::rest::fileSystemHandler("/files/*", "/files/", std::filesystem::current_path()) }; + auto bound = broker.bindRest(rest); + if (!bound) { + FAIL(fmt::format("Failed to bind REST server: {}", bound.error())); + return; + } + + constexpr auto kExpectedFileContent = R"(-----BEGIN CERTIFICATE----- +MIIFiTCCA3GgAwIBAgIUDBxaxLvthSz4Knvh6R0/zDrxe3QwDQYJKoZIhvcNAQEL +BQAwVDELMAkGA1UEBhMCREUxEDAOBgNVBAgMB1Vua25vd24xEDAOBgNVBAcMB1Vu +a25vd24xITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMTEy +MDUxMDA0MjFaFw0zMTEyMDMxMDA0MjFaMFQxCzAJBgNVBAYTAkRFMRAwDgYDVQQI +DAdVbmtub3duMRAwDgYDVQQHDAdVbmtub3duMSEwHwYDVQQKDBhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDx +1Es3W/5OyMdOPUmCqQuAYV/4xjhP8Fhoi1gjnnlaCrBbBXJl1nW8DdMwVhXF9Yy8 +wPP0SylqkbatiDnUwjviizL6v6DMNQbS+OES5OleCuwbCWAFH3vsDllRZl3LYAdB +6Ec4wNjX5EjE1RgIgT+GEkR13XqyHQi4ELOMEUxxpVcWeBjAFhgiTXvbpnBfBfJo +XPsMoCaWTWhRQksodKM4Mjfn/wxKAfbspNaX5zfPcr/5vNGY2CYiKbsZqwiM5VNq +ml9XoUc5BIWuj4liHerLOqdEj3zpBhn3i+RimGm+N2xDAXOKMlP4w5jyewd5FtLl +vmyLEakiqwSCODkjrP7rbmQ9hRohsF8V5Y4KwaYYEp+pZ4BFCBYdDv+drnVD8MOJ +/7f8LlCE3xN/MvEX2Um1xt5oT4gb3SdSTRZfUEFzkrQK5wodQBUZVhVTg0NRwiWx +hfeILR6/Qd8pciwSXbkT7JBEf1gyghKFDd+GfyDmm/+aIxNAItQ10KUDeLyavP3x +SN9ZR4zITDT9dwzd/yqsKudFB2mqioUv8zXah7lRpQDQmsrwkj4sKOrfS9BxtM8T +tUn56aOeyl7zdWvhI+wJOT07f7l7+c3WW06v9jPXk0IDjhvUdn11uL6aFQiO6cBE +4Jj2z2l4VHsLekcPdXzt6V1IhJuyGWSasOZBwb13tQIDAQABo1MwUTAdBgNVHQ4E +FgQUwA8sghflexohahmoZAFHcpgnJl0wHwYDVR0jBBgwFoAUwA8sghflexohahmo +ZAFHcpgnJl0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAgEAXlVh +oCF+7ywVU1ek2wcPIDayqAUA2JnPGvpjClSqqOfAvst07RspDftxTQ0/BVy7N1Ep +DX9uQP1YiuvZDxc4ySCYUnYiXmf7aeyjSGIwVHhWRx/uQQV880tJ+TIK+JJRLiJg +DLnixdyW/uPY0RkjUzHCADI1zZmrJyZpMAFbqwuoVOtEoMmCJkjJRrZzytRQkTbe +9cUAvHcdjbFADf1yZdTELgNTs9Xolb+aZhXb+DolrQiTpDQj9RSt/raaRyrlssFD +9V0ugW2e9nEEe7PlofDGOYqhaadka9s680xT47s6K8WjPVFwgDVKKojW186JCeAR +8W8AEd1tOmp8BQOY7OLY8hZ9kTnnd8XoLy+l8UH03kfIIPAulGARNCYYo7SJfQU7 +1rQEf28BGi9mL0QhkY0xSSTvuLVbG5DAceUVix6Y+NsTgF+YCWsqXMZX6M/qHjiy +7qf5LeKrLqG1RQsYbi4UzakKfwG2uVH/lETc/j1PMZz4WPPJGeeg+urhHvzvZJ5C +xbuJH3fc9TsTET4FiwQwINSwdY4i/iNCt3kE+6EYNf1G13f3ffCMy3JYxakXjDZB +ido85zpB/BELE69ap2g/pHCNEd9y2ZHCYDvIZaxnJtdBBhWmHqAK00HYEwG5LZPU +I0uPpRG4pT+BvLUZo8rKVpFLHj2nnflS5dXqKPo= +-----END CERTIFICATE----- +)"; + + RunInThread brokerRun(broker); + + std::atomic responseCount = 0; + opencmw::client::RestClient client; + + opencmw::client::Command fileExists; + fileExists.command = mdp::Command::Get; + fileExists.topic = opencmw::URI<>(fmt::format("http://localhost:{}/files/demo_public.crt", kServerPort)); + fileExists.callback = [&responseCount, &kExpectedFileContent](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == kExpectedFileContent); + responseCount++; + }; + client.request(std::move(fileExists)); + + opencmw::client::Command fileDoesNotExist; + fileDoesNotExist.command = mdp::Command::Get; + fileDoesNotExist.topic = opencmw::URI<>(fmt::format("http://localhost:{}/files/does_not_exist", kServerPort)); + fileDoesNotExist.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == "Not found"); + REQUIRE(msg.data.asString() == ""); + responseCount++; + }; + client.request(std::move(fileDoesNotExist)); + + waitFor(responseCount, 2); +} + TEST_CASE("Subscription latencies", "[majordomo][majordomoworker][rest]") { std::atomic nReceived = 0; std::atomic msLatency = 0; { majordomo::Broker broker("/TestBroker", testSettings()); - auto fs = cmrc::assets::get_filesystem(); - majordomo::RestBackend rest(broker, fs, opencmw::URI<>("http://localhost:12346")); - + majordomo::rest::Settings rest; + rest.port = kServerPort; + auto bound = broker.bindRest(rest); + if (!bound) { + FAIL(fmt::format("Failed to bind REST server: {}", bound.error())); + return; + } ClockWorker<"/clock", 2550> worker(broker, 10ms, 70); - RunInThread restServerRun(rest); RunInThread brokerRun(broker); RunInThread workerRun(worker); REQUIRE(waitUntilWorkerServiceAvailable(broker.context, worker)); - rest.setMajordomoTimeout(800ms); // set timeout to unit-test friendly interval + opencmw::client::RestClient client; - opencmw::client::RestClient client{std::string("RestSubLatencyClient")}; - - opencmw::client::Command _command; - _command.command = opencmw::mdp::Command::Subscribe; - _command.topic = opencmw::URI<>("http://localhost:12346/clock"); - _command.callback = [&nReceived, &msLatency](const opencmw::mdp::Message &reply) { - UpdateTime replyData; + opencmw::client::Command command; + command.command = opencmw::mdp::Command::Subscribe; + command.topic = opencmw::URI<>(fmt::format("http://localhost:{}/clock", kServerPort)); + command.callback = [&nReceived, &msLatency](const opencmw::mdp::Message &reply) { + UpdateTime replyData; opencmw::IoBuffer buffer = reply.data; opencmw::deserialise(buffer, replyData); - auto now = std::chrono::system_clock::now(); - auto latency = now.time_since_epoch() - std::chrono::milliseconds(replyData.updateTime); + auto now = std::chrono::high_resolution_clock::now(); + auto latency = now.time_since_epoch() - std::chrono::microseconds(replyData.updateTimeµs); nReceived++; nReceived.notify_all(); - msLatency.fetch_add(static_cast(std::chrono::duration_cast(latency).count())); - fmt::print("Received {}th update with a latency of {} ms.\n", nReceived.load(), std::chrono::duration_cast(latency).count()); + msLatency.fetch_add(static_cast(std::chrono::duration_cast(latency).count())); + fmt::print("Received {}th update with a latency of {} µs.\n", nReceived.load(), std::chrono::duration_cast(latency).count()); }; - client.request(_command); + client.request(std::move(command)); fmt::print("waiting for 40 samples to be received\n"); int n = nReceived; @@ -375,55 +751,7 @@ TEST_CASE("Subscription latencies", "[majordomo][majordomoworker][rest]") { } } - fmt::print("Received {} updates with an average latency of {} ms.\n", nReceived.load(), nReceived > 0 ? static_cast(msLatency)/nReceived : 0.0); + fmt::print("Received {} updates with an average latency of {} µs.\n", nReceived.load(), nReceived > 0 ? static_cast(msLatency) / nReceived : 0.0); REQUIRE(nReceived > 10); - REQUIRE(static_cast(msLatency)/nReceived < 20); -} - -TEST_CASE("Majordomo timeouts", "[majordomo][majordomoworker][rest]") { - majordomo::Broker broker("/TestBroker", testSettings()); - auto fs = cmrc::assets::get_filesystem(); - FileServerRestBackend rest(broker, fs); - RunInThread restServerRun(rest); - - opencmw::query::registerTypes(WaitingContext(), broker); - - WaitingWorker<"/waiter"> worker(broker); - - RunInThread brokerRun(broker); - RunInThread workerRun(worker); - - REQUIRE(waitUntilWorkerServiceAvailable(broker.context, worker)); - - // set timeout to unit-test friendly interval - rest.setMajordomoTimeout(800ms); - - SECTION("Waiting for notification that doesn't happen in time returns 504 message") { - std::vector clientThreads; - for (int i = 0; i < 16; ++i) { - clientThreads.push_back(makeGetRequestResponseCheckerThread("/waiter?LongPollingIdx=Next", { "Timeout" }, { 504 })); - } - } - - SECTION("Waiting for notification that happens in time gives expected response") { - auto client = makeGetRequestResponseCheckerThread("/waiter?LongPollingIdx=Next", { "This is a notification" }); - std::this_thread::sleep_for(400ms); - worker.notify({}, { "This is a notification" }); - } - - SECTION("Response to request takes too long, timeout status is returned") { - httplib::Client postData{ "http://localhost:8080" }; - auto reply = postData.Post("/waiter?contentType=application%2Fjson&timeoutMs=1200", "{\"value\": \"Hello!\"}", "application/json"); - REQUIRE(reply); - REQUIRE(reply->status == 504); - REQUIRE(reply->body.find("No response") != std::string::npos); - } - - SECTION("Response to request arrives in time") { - httplib::Client postData{ "http://localhost:8080" }; - auto reply = postData.Post("/waiter?contentType=application%2Fjson&timeoutMs=0", "{\"value\": \"Hello!\"}", "application/json"); - REQUIRE(reply); - REQUIRE(reply->status == 200); - REQUIRE(reply->body.find("You said: Hello!") != std::string::npos); - } + REQUIRE(static_cast(msLatency) / nReceived < 20000); // unit is µs } diff --git a/src/nghttp2/CMakeLists.txt b/src/nghttp2/CMakeLists.txt new file mode 100644 index 00000000..501c821d --- /dev/null +++ b/src/nghttp2/CMakeLists.txt @@ -0,0 +1,16 @@ +# setup header only library +add_library(nghttp2 INTERFACE + include/nghttp2/NgHttp2Utils.hpp +) + +target_include_directories(nghttp2 INTERFACE $ $) +target_link_libraries(nghttp2 + INTERFACE + nghttp2-static + ) + +install( + TARGETS nghttp2 + EXPORT opencmwTargets + PUBLIC_HEADER DESTINATION include/opencmw +) diff --git a/src/nghttp2/include/nghttp2/NgHttp2Utils.hpp b/src/nghttp2/include/nghttp2/NgHttp2Utils.hpp new file mode 100644 index 00000000..c2bb9364 --- /dev/null +++ b/src/nghttp2/include/nghttp2/NgHttp2Utils.hpp @@ -0,0 +1,468 @@ +#ifndef OPENCMW_NGHTTP2HELPERS_H +#define OPENCMW_NGHTTP2HELPERS_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include + +#include + +#ifdef OPENCMW_DEBUG_HTTP +#include +#define HTTP_DBG(...) fmt::println(std::cerr, __VA_ARGS__); +#else +#define HTTP_DBG(...) +#endif + +namespace opencmw::nghttp2 { +using SSL_CTX_Ptr = std::unique_ptr; +using SSL_Ptr = std::unique_ptr; +using X509_STORE_Ptr = std::unique_ptr; +using X509_Ptr = std::unique_ptr; +using EVP_PKEY_Ptr = std::unique_ptr; + +inline int readCertificateBundleFromBuffer(X509_STORE &cert_store, const std::string_view &X509_ca_bundle) { + BIO *cbio = BIO_new_mem_buf(X509_ca_bundle.data(), static_cast(X509_ca_bundle.size())); + if (!cbio) { + return -1; + } + STACK_OF(X509_INFO) *inf = PEM_X509_INFO_read_bio(cbio, nullptr, nullptr, nullptr); + + if (!inf) { + BIO_free(cbio); // cleanup + return -1; + } + // iterate over all entries from the pem file, add them to the x509_store one by one + int count = 0; + for (int i = 0; i < sk_X509_INFO_num(inf); i++) { + X509_INFO *itmp = sk_X509_INFO_value(inf, i); + if (itmp->x509) { + X509_STORE_add_cert(&cert_store, itmp->x509); + count++; + } + if (itmp->crl) { + X509_STORE_add_crl(&cert_store, itmp->crl); + count++; + } + } + + sk_X509_INFO_pop_free(inf, X509_INFO_free); + BIO_free(cbio); + return count; +} + +inline std::expected createCertificateStore(std::string_view x509_ca_bundle) { + X509_STORE_Ptr cert_store = X509_STORE_Ptr(X509_STORE_new(), X509_STORE_free); + if (readCertificateBundleFromBuffer(*cert_store.get(), x509_ca_bundle) <= 0) { + return std::unexpected(fmt::format("failed to read certificate bundle from buffer:\n#---start---\n{}\n#---end---\n", x509_ca_bundle)); + } + return cert_store; +} + +inline std::expected readServerCertificateFromBuffer(std::string_view X509_ca_bundle) { + BIO *certBio = BIO_new(BIO_s_mem()); + BIO_write(certBio, X509_ca_bundle.data(), static_cast(X509_ca_bundle.size())); + auto certX509 = X509_Ptr(PEM_read_bio_X509(certBio, nullptr, nullptr, nullptr), X509_free); + BIO_free(certBio); + if (!certX509) { + return std::unexpected(fmt::format("failed to read certificate from buffer:\n#---start---\n{}\n#---end---\n", X509_ca_bundle)); + } + return certX509; +} + +inline std::expected readServerCertificateFromFile(std::filesystem::path fpath) { + auto path = fpath.string(); + BIO *certBio + = BIO_new_file(path.data(), "r"); + if (!certBio) { + return std::unexpected(fmt::format("failed to read certificate from file {}: {}", path, ERR_error_string(ERR_get_error(), nullptr))); + } + auto certX509 = X509_Ptr(PEM_read_bio_X509(certBio, nullptr, nullptr, nullptr), X509_free); + BIO_free(certBio); + if (!certX509) { + return std::unexpected(fmt::format("failed to read certificate key from file: {}", path)); + } + return certX509; +} + +inline std::expected readServerPrivateKeyFromBuffer(std::string_view x509_private_key) { + BIO *certBio = BIO_new(BIO_s_mem()); + BIO_write(certBio, x509_private_key.data(), static_cast(x509_private_key.size())); + EVP_PKEY_Ptr privateKeyX509 = EVP_PKEY_Ptr(PEM_read_bio_PrivateKey(certBio, nullptr, nullptr, nullptr), EVP_PKEY_free); + BIO_free(certBio); + if (!privateKeyX509) { + return std::unexpected(fmt::format("failed to read private key from buffer")); + } + return privateKeyX509; +} + +inline std::expected readServerPrivateKeyFromFile(std::filesystem::path fpath) { + auto path = fpath.string(); + BIO *certBio = BIO_new_file(path.data(), "r"); + if (!certBio) { + return std::unexpected(fmt::format("failed to read private key from file {}: {}", path, ERR_error_string(ERR_get_error(), nullptr))); + } + EVP_PKEY_Ptr privateKeyX509 = EVP_PKEY_Ptr(PEM_read_bio_PrivateKey(certBio, nullptr, nullptr, nullptr), EVP_PKEY_free); + BIO_free(certBio); + if (!privateKeyX509) { + return std::unexpected(fmt::format("failed to read private key from file: {}", path)); + } + return privateKeyX509; +} +namespace detail { + +inline std::expected create_ssl(SSL_CTX *ssl_ctx) { + auto ssl = SSL_Ptr(SSL_new(ssl_ctx), SSL_free); + if (!ssl) { + return std::unexpected(fmt::format("Could not create SSL/TLS session object: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + return ssl; +} + +inline std::span u8span(std::string_view view) { + return { const_cast(reinterpret_cast(view.data())), view.size() }; +} + +inline nghttp2_nv nv(const std::span &name, const std::span &value, uint8_t flags = NGHTTP2_NV_FLAG_NO_COPY_NAME) { + return { name.data(), value.data(), name.size(), value.size(), flags }; +} + +#ifdef OPENCMW_PROFILE_HTTP +inline std::chrono::nanoseconds latency(const std::string_view &value) { + const auto now = std::chrono::high_resolution_clock::now(); + const uint64_t ts = std::stoull(std::string(value)); + const auto parsed = std::chrono::time_point(std::chrono::nanoseconds(ts)); + return std::chrono::duration_cast(now - parsed); +} +#endif + +inline std::string_view as_view(nghttp2_rcbuf *rcbuf) { + auto vec = nghttp2_rcbuf_get_buf(rcbuf); + return { reinterpret_cast(vec.base), vec.len }; +} + +// Convenience/RAII wrapper for a non-blocking/no-delay TCP socket that can be used with SSL or without. +struct TcpSocket { + using AddrinfoPtr = std::unique_ptr; + + enum Flags { + None = 0x0, + VerifyPeer = 0x1, + }; + + enum State { + Uninitialized, + Connecting, + SSLConnectWantsRead, + SSLConnectWantsWrite, + SSLAcceptWantsRead, + SSLAcceptWantsWrite, + Connected + }; + + int fd = -1; + int flags = None; + SSL_Ptr _ssl = SSL_Ptr(nullptr, SSL_free); + State _state = Uninitialized; + AddrinfoPtr address = AddrinfoPtr(nullptr, freeaddrinfo); + + TcpSocket() = default; + + TcpSocket(const TcpSocket &) = delete; + TcpSocket &operator=(const TcpSocket &) = delete; + TcpSocket(TcpSocket &&other) noexcept { + close(); + std::swap(fd, other.fd); + std::swap(flags, other.flags); + std::swap(_state, other._state); + _ssl = std::move(other._ssl); + } + TcpSocket &operator=(TcpSocket &&other) noexcept { + if (this != &other) { + close(); + std::swap(fd, other.fd); + std::swap(flags, other.flags); + std::swap(_state, other._state); + _ssl = std::move(other._ssl); + } + return *this; + } + + ~TcpSocket() { + close(); + } + + static std::expected create(SSL_Ptr ssl_, int fd_, int flags_ = VerifyPeer) { + if (fd_ == -1) { + return std::unexpected(fmt::format("Invalid socket file descriptor: {}", fd_)); + } + TcpSocket socket; + socket.fd = fd_; + socket.flags = flags_; + socket._ssl = std::move(ssl_); + int f = fcntl(socket.fd, F_GETFL, 0); + assert(f != -1); + fcntl(socket.fd, F_SETFL, f | O_NONBLOCK); + + int flag = 1; + setsockopt(socket.fd, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag)); + + if (socket._ssl) { + auto bio = BIO_new_socket(socket.fd, BIO_NOCLOSE); + if (!bio) { + return std::unexpected(fmt::format("Failed to create BIO object: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + // SSL will take ownership of the BIO object + SSL_set_bio(socket._ssl.get(), bio, bio); + } + return socket; + } + + std::string lastError() { + if (_ssl) { + return ERR_error_string(ERR_get_error(), nullptr); + } else { + return strerror(errno); + } + } + + std::expected continueHandshake() { + switch (_state) { + case Uninitialized: + return std::unexpected("Socket not initialized"); + case SSLConnectWantsRead: + case SSLConnectWantsWrite: { + int ret = SSL_connect(_ssl.get()); + if (ret == 1) { + _state = Connected; + return _state; + } + if (ret == 0) { + return std::unexpected(fmt::format("SSL handshake failed: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + int err = SSL_get_error(_ssl.get(), ret); + if (err == SSL_ERROR_WANT_READ) { + _state = SSLConnectWantsRead; + return _state; + } + if (err == SSL_ERROR_WANT_WRITE) { + _state = SSLConnectWantsWrite; + return _state; + } + return std::unexpected(fmt::format("SSL handshake failed: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + case SSLAcceptWantsRead: + case SSLAcceptWantsWrite: { + int ret = SSL_accept(_ssl.get()); + if (ret == 1) { + _state = Connected; + return _state; + } + if (ret == 0) { + return std::unexpected(fmt::format("SSL handshake failed: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + int err = SSL_get_error(_ssl.get(), ret); + if (err == SSL_ERROR_WANT_READ) { + _state = SSLAcceptWantsRead; + return _state; + } + if (err == SSL_ERROR_WANT_WRITE) { + _state = SSLAcceptWantsWrite; + return _state; + } + return std::unexpected(fmt::format("SSL handshake failed: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + case Connecting: + case Connected: + return _state; + } + + return _state; + } + + std::expected prepareConnect(std::string_view host, uint16_t port) { + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; // IPv4 + hints.ai_socktype = SOCK_STREAM; // TCP + + struct addrinfo *res; + int status = getaddrinfo(host.data(), nullptr, &hints, &res); + if (status != 0) { + return std::unexpected(fmt::format("Could not resolve address: {}", strerror(status))); + } + address = AddrinfoPtr(res, freeaddrinfo); + reinterpret_cast(address->ai_addr)->sin_port = htons(port); + + if (_ssl && (flags & VerifyPeer) != 0) { + // Use instead of SSL_set_tlsext_host_name() to avoid warning about (void*) cast + if (auto r = SSL_ctrl(_ssl.get(), SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, const_cast(host.data())); r != SSL_TLSEXT_ERR_OK && r != SSL_TLSEXT_ERR_ALERT_WARNING) { + return std::unexpected(fmt::format("Failed to set the TLS SNI hostname: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + SSL_set_verify(_ssl.get(), SSL_VERIFY_PEER, nullptr); + } + + _state = Connecting; + return {}; + } + + std::expected connect() { + assert(address); + if (::connect(fd, address.get()->ai_addr, sizeof(sockaddr)) < 0) { + if (errno == EINPROGRESS) { + return {}; + } + + return std::unexpected(fmt::format("Connect failed: {}", strerror(errno))); + } + + if (!_ssl) { + _state = Connected; + return {}; + } + + _state = SSLConnectWantsWrite; + if (auto r = continueHandshake(); !r) { + return std::unexpected(r.error()); + } + return {}; + } + + std::expected, std::string> accept(SSL_CTX *ssl_ctx, int flags_) { + struct sockaddr_in client_addr; + socklen_t client_len = sizeof(client_addr); + int client_fd = ::accept(fd, static_cast(static_cast(&client_addr)), &client_len); + if (client_fd < 0) { + if (errno == EAGAIN) { + return {}; + } + return std::unexpected(fmt::format("Accept failed: {}", strerror(errno))); + } + + if (!ssl_ctx) { + auto maybeSocket = TcpSocket::create({ nullptr, SSL_free }, client_fd); + if (maybeSocket) { + maybeSocket->_state = Connected; + } + return maybeSocket; + } + + auto ssl = create_ssl(ssl_ctx); + if (!ssl) { + ::close(client_fd); + return std::unexpected(fmt::format("Failed to create SSL object: {}", ssl.error())); + } + + auto clientSocket = TcpSocket::create(std::move(ssl.value()), client_fd, flags_); + if (clientSocket) { + clientSocket->_state = SSLAcceptWantsRead; + } + return clientSocket; + } + + ssize_t read(uint8_t *data, std::size_t length) { + if (_ssl) { + return SSL_read(_ssl.get(), data, static_cast(length)); + } else { + return ::recv(fd, data, length, 0); + } + } + + ssize_t write(const uint8_t *data, std::size_t length, int wflags = 0) { + if (_ssl) { + return SSL_write(_ssl.get(), data, static_cast(length)); + } else { + return ::send(fd, data, length, wflags); + } + } + + void close() { + if (_ssl) { + SSL_shutdown(_ssl.get()); + _ssl.reset(); + } + if (fd != -1) { + ::close(fd); + fd = -1; + } + } +}; + +template +struct WriteBuffer { + std::vector buffer = std::vector(InitialCapacity); + std::size_t size = 0; + + bool hasData() const { + return size > 0; + } + + bool wantsToWrite(nghttp2_session *session) const { + return size > 0 || nghttp2_session_want_write(session); + } + + bool write(nghttp2_session *session, TcpSocket &socket) { + const uint8_t *chunk; + + while (size < Limit && nghttp2_session_want_write(session)) { + nghttp2_ssize len = nghttp2_session_mem_send2(session, &chunk); + if (len < 0) { + break; // out of memory, try again later + } + if (len == 0) { + continue; + } + const std::size_t newSize = size + static_cast(len); + if (newSize > buffer.size()) { + HTTP_DBG("Resizing buffer from {} to {}", buffer.size(), newSize); + buffer.resize(newSize); + } + + std::copy(chunk, chunk + len, buffer.data() + size); + size += static_cast(len); + } + + std::size_t written = 0; + while (size - written > 0) { + const ssize_t n = socket.write(buffer.data() + written, size - written); + if (n == 0 && errno == EAGAIN) { + break; + } + if (n <= 0) { + return false; + } + written += static_cast(n); + } + HTTP_DBG("Write[{}]: Wrote {} bytes", socket.fd, written); + size -= written; + // TODO could be optimized by using a circular buffer + std::move(buffer.data() + written, buffer.data() + written + size, buffer.data()); + return true; + } +}; + +} // namespace detail + +} // namespace opencmw::nghttp2 +#endif diff --git a/src/serialiser/include/IoSerialiserYAML.hpp b/src/serialiser/include/IoSerialiserYAML.hpp index ec56f618..59eed880 100644 --- a/src/serialiser/include/IoSerialiserYAML.hpp +++ b/src/serialiser/include/IoSerialiserYAML.hpp @@ -153,7 +153,7 @@ inline std::string fieldFormatter(ArithmeticType auto const &value, const int nI template constexpr void fieldParser(const std::string_view &data, std::floating_point auto &value) { - if (const auto result = fast_float::from_chars(data.cbegin(), data.cend(), value); result.ec != std::errc()) { + if (const auto result = fast_float::from_chars(data.data(), data.data() + data.size(), value); result.ec != std::errc()) { throw ProtocolException("parsing ArithmeticType (float) from string '{}'", data); } } diff --git a/src/services/CMakeLists.txt b/src/services/CMakeLists.txt index c17721d0..61a4686c 100644 --- a/src/services/CMakeLists.txt +++ b/src/services/CMakeLists.txt @@ -13,7 +13,7 @@ target_include_directories(services INTERFACE $ #include - -#include - namespace opencmw::service::dns { using DnsWorkerType = majordomo::Worker<"/dns", Context, FlatEntryList, FlatEntryList, majordomo::description<"Register and Query Signals">>; @@ -36,57 +35,45 @@ class DnsHandler { } }; -Entry registerSignals(const std::vector &entries, std::string scheme_host_port = "http://localhost:8080") { - IoBuffer outBuffer; - FlatEntryList entrylist{ entries }; - opencmw::serialise(outBuffer, entrylist); - std::string contentType{ MIME::BINARY.typeName() }; - std::string body{ outBuffer.asString() }; - - // send request to register Signal - httplib::Client client{ - scheme_host_port - }; - - auto response = client.Post("dns", body, contentType); +inline Entry registerSignals(const std::vector &entries, std::string scheme_host_port = "http://localhost:8080") { + client::Command cmd; + cmd.command = mdp::Command::Set; + cmd.serviceName = "/dns"; + cmd.topic = URI<>::UriFactory{ URI<>(std::move(scheme_host_port)) }.path("/dns").build(); + opencmw::serialise(cmd.data, FlatEntryList{ entries }); - if (response.error() == httplib::Error::Read) throw std::runtime_error{ "Server did not send an answer" }; - if (response.error() != httplib::Error::Success || response->status == 500) throw std::runtime_error{ response->reason }; + client::RestClient client{ client::DefaultContentTypeHeader(MIME::BINARY) }; + auto reply = client.blockingRequest(std::move(cmd)); - // deserialise response - IoBuffer inBuffer; - inBuffer.put(response->body); + if (!reply.error.empty()) { + throw std::runtime_error{ reply.error }; + } FlatEntryList res; try { - opencmw::deserialise(inBuffer, res); + opencmw::deserialise(reply.data, res); } catch (const ProtocolException &exc) { throw std::runtime_error{ exc.what() }; // rethrowing, because ProtocolException behaves weird } - return res.toEntries().front(); } -std::vector querySignals(const Entry &a = {}, std::string scheme_host_port = "http://localhost:8080") { - // send request to register Service - httplib::Client client{ - scheme_host_port - }; +inline std::vector querySignals(const Entry &a = {}, std::string scheme_host_port = "http://localhost:8080") { + client::Command cmd; + cmd.command = mdp::Command::Get; + cmd.serviceName = "/dns"; + cmd.topic = URI<>::UriFactory{ URI<>(std::move(scheme_host_port)) }.path("/dns").addQueryParameter("service_name", a.service_name).addQueryParameter("signal_name", a.signal_name).addQueryParameter("signal_unit", a.signal_unit).addQueryParameter("signal_rate", std::to_string(a.signal_rate)).addQueryParameter("signal_type", a.signal_type).build(); - auto response = client.Get("dns", httplib::Params{ { "protocol", a.protocol }, { "hostname", a.hostname }, { "port", std::to_string(a.port) }, { "service_name", a.service_name }, { "service_type", a.service_type }, { "signal_name", a.signal_name }, { "signal_unit", a.signal_unit }, { "signal_rate", std::to_string(a.signal_rate) }, { "signal_type", a.signal_type } }, - httplib::Headers{ - { std::string{ "Content-Type" }, std::string{ MIME::BINARY.typeName() } } }); + client::RestClient client{ client::DefaultContentTypeHeader(MIME::BINARY) }; + auto reply = client.blockingRequest(std::move(cmd)); - if (response.error() == httplib::Error::Read) throw std::runtime_error{ "Server did not send an answer" }; - if (response.error() != httplib::Error::Success || response->status == 500) throw std::runtime_error{ response->reason }; - - // deserialise response - IoBuffer inBuffer; - inBuffer.put(response->body); + if (!reply.error.empty()) { + throw std::runtime_error{ reply.error }; + } FlatEntryList res; try { - opencmw::deserialise(inBuffer, res); + opencmw::deserialise(reply.data, res); } catch (const ProtocolException &exc) { throw std::runtime_error{ exc.what() }; // rethrowing, because ProtocolException behaves weird } diff --git a/src/services/include/services/dns_client.hpp b/src/services/include/services/dns_client.hpp index 196e6a38..a9508d2f 100644 --- a/src/services/include/services/dns_client.hpp +++ b/src/services/include/services/dns_client.hpp @@ -1,14 +1,13 @@ #ifndef DNS_CLIENT_HPP #define DNS_CLIENT_HPP -#include "Debug.hpp" +#include "ClientContext.hpp" #include "dns_types.hpp" #include "MdpMessage.hpp" -#include "RestClient.hpp" #include #include #include -#include +#include #include #include #include diff --git a/src/services/test/dns_tests.cpp b/src/services/test/dns_tests.cpp index 60f134c8..7f541b9f 100644 --- a/src/services/test/dns_tests.cpp +++ b/src/services/test/dns_tests.cpp @@ -3,15 +3,18 @@ #include #include -#include #include +#include + #ifndef __EMSCRIPTEN__ #include +#include +#include +#include // Concepts and tests use common types #include #include -#include #endif namespace fs = std::filesystem; @@ -211,12 +214,14 @@ TEST_CASE("data storage - Deleting Entries") { TEST_CASE("run services", "[DNS]") { FileDeleter fd; majordomo::Broker<> broker{ "/Broker", {} }; - std::string rootPath{ "./" }; - auto fs = cmrc::assets::get_filesystem(); - majordomo::RestBackend rest_backend{ broker, fs }; + majordomo::rest::Settings rest; + const auto bound = broker.bindRest(rest); + if (!bound) { + FAIL(fmt::format("Failed to bind REST: {}", bound.error())); + } + REQUIRE(bound); DnsWorkerType dnsWorker{ broker, {} }; - RunInThread restThread(rest_backend); RunInThread brokerThread(broker); RunInThread dnsThread(dnsWorker); } @@ -224,13 +229,12 @@ TEST_CASE("run services", "[DNS]") { TEST_CASE("client", "[DNS]") { FileDeleter fd; majordomo::Broker<> broker{ "/Broker", {} }; - std::string rootPath{ "./" }; - auto fs = cmrc::assets::get_filesystem(); - majordomo::RestBackend rest_backend{ broker, fs }; + majordomo::rest::Settings rest; + const auto bound = broker.bindRest(rest); + REQUIRE(bound); DnsWorkerType dnsWorker{ broker, DnsHandler{} }; broker.bind(URI<>{ "inproc://dns_server" }, majordomo::BindOption::Router); - RunInThread restThread(rest_backend); RunInThread dnsThread(dnsWorker); RunInThread brokerThread(broker); @@ -268,11 +272,11 @@ TEST_CASE("client", "[DNS]") { TEST_CASE("query", "[DNS]") { FileDeleter fd; majordomo::Broker<> broker{ "/Broker", {} }; - auto fs = cmrc::assets::get_filesystem(); - majordomo::RestBackend rest_backend{ broker, fs }; - DnsWorkerType dnsWorker{ broker, DnsHandler{} }; + majordomo::rest::Settings rest; + const auto bound = broker.bindRest(rest); + REQUIRE(bound); - RunInThread restThread(rest_backend); + DnsWorkerType dnsWorker{ broker, DnsHandler{} }; RunInThread brokerThread(broker); RunInThread dnsThread(dnsWorker); @@ -308,11 +312,10 @@ TEST_CASE("query", "[DNS]") { TEST_CASE("client unregister entries", "[DNS]") { FileDeleter fd; majordomo::Broker<> broker{ "/Broker", {} }; - auto fs = cmrc::assets::get_filesystem(); - majordomo::RestBackend rest_backend{ broker, fs }; + majordomo::rest::Settings rest; + const auto bound = broker.bindRest(rest); + REQUIRE(bound); DnsWorkerType dnsWorker{ broker, DnsHandler{} }; - - RunInThread restThread(rest_backend); RunInThread brokerThread(broker); RunInThread dnsThread(dnsWorker); @@ -320,7 +323,7 @@ TEST_CASE("client unregister entries", "[DNS]") { std::vector> clients; clients.emplace_back(std::make_unique(broker.context, 20ms, "dnsTestClient")); - clients.emplace_back(std::make_unique(opencmw::client::DefaultContentTypeHeader(MIME::BINARY))); + clients.emplace_back(std::make_unique(client::DefaultContentTypeHeader(MIME::BINARY))); client::ClientContext clientContext{ std::move(clients) }; DnsClient restClient{ clientContext, URI<>{ "http://localhost:8080/dns" } }; restClient.registerSignals({ entry_a, entry_b, entry_c });