diff --git a/CMakeLists.txt b/CMakeLists.txt index 38e585a7..df912b05 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -108,6 +108,8 @@ else() option(OPENCMW_ENABLE_COVERAGE "Enable Coverage" OFF) endif() option(OPENCMW_ENABLE_CONCEPTS "Enable Concepts Builds" ${opencmw_MASTER_PROJECT}) +option(OPENCMW_DEBUG_HTTP "Enable verbose HTTP output for debugging" OFF) +option(OPENCMW_PROFILE_HTTP "Enable verbose HTTP output for profiling" OFF) # Very basic PCH example option(ENABLE_PCH "Enable Precompiled Headers" OFF) @@ -124,6 +126,14 @@ if(ENABLE_PCH) ) endif() +if(OPENCMW_DEBUG_HTTP) + target_compile_definitions(opencmw_project_options INTERFACE -DOPENCMW_DEBUG_HTTP) +endif() + +if(OPENCMW_PROFILE_HTTP) + target_compile_definitions(opencmw_project_options INTERFACE -DOPENCMW_PROFILE_HTTP) +endif() + if(OPENCMW_ENABLE_TESTING) enable_testing() message("Building Tests.") diff --git a/README.md b/README.md index b315f5f7..2569b6d5 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ where the frame-work takes care of most of the communication, [data-serialisatio and buffering, settings management, Role-Based-Access-Control (RBAC), and other boring but necessary control system integrations while still being open to expert-level modifications, extensions or improvements. -### General Schematic +## General Schematic OpenCMW combines [ZeroMQ](https://zeromq.org/)'s [Majordomo](https://rfc.zeromq.org/spec/7/) with LMAX's [disruptor](https://lmax-exchange.github.io/disruptor/) ([C++ port](https://github.com/Abc-Arbitrage/Disruptor-cpp)) design pattern that both provide a very efficient lock-free mechanisms @@ -38,7 +38,7 @@ for distributing, streaming and processing of data objects. A schematic outline ![OpenCMW architectural schematic](./assets/FAIR_microservice_schematic.svg) -### Glossary +## Glossary _Majordomo Broker_ or _'Broker':_ is the central authority where multiple workers can register their services, allowing clients to perform get, set or subscriptions requests. There can be multiple brokers for subset of services. @@ -67,11 +67,11 @@ _Publisher:_ the [DataSourcePublisher](DataSourceExample.cpp) provides an interf ring-buffer with events from OpenCMW, REST services or other sources. While using disruptor ring-buffers is the preferred and most performing options, the client also supports classic patterns of registering call-back functions or returning `Future` objects. -### OpenCMW Majordomo Protocol +## OpenCMW Majordomo Protocol The OpenCMW Majordomo [protocol](docs/MajordomoProtocol.md) is based on the [ZeroMQ Majordomo protocol](https://rfc.zeromq.org/spec/7/), both extending and slightly modifying it (see [the comparison](docs/Majordomo_protocol_comparison.pdf)). -#### Service Names +### Service Names Service names must always start with `/`. For consistency, this also applies to the built-in MDP broker services like `/mmi.service` (instead of `mmi.service` without leading slash as in ZeroMQ Majordomo). A service name is a non-empty alphanumerical string (also allowing `.`, `_`), that must start with `/` but not end with `/`. It contain additional `/` to denote a hierarchical structure. @@ -84,7 +84,7 @@ Examples: - `/DeviceName/Acquisition/` - invalid (trailing slash) - `/a-service/` - invalid (`-` not allowed) -#### Topics +### Topics The "topic" field (frame 5 in the [OpenCMW MDP protocol](docs/Majordomo_protocol_comparison.pdf)) specifies the topic for subscriptions and GET/SET requests. It contains a URI with the service name as path and optional query parameters to specify further requests parameters and filter criteria. @@ -99,7 +99,7 @@ Note that the whole path is considered the service name, and that there's no add See also the documentation for [mdp::Topic](src/core/include/Topic.hpp). -#### URL to Service/Topic Mapping (mds/mdp and REST) +### URL to Service/Topic Mapping (mds/mdp and REST) With both the MDS/MDP-based ZeroMQ clients as well as the REST interface, a common scheme is used to map from mdp/hds/http(s) URLs used for subscriptions and requests to the OpenCMW service name and topic fields. @@ -120,7 +120,7 @@ Other examples are: - `mds://example.com:8080/DeviceName/Acquisition?signal=test` => service name `/DeviceName/Acquisition`, topic `/DeviceName/Acquisition?signal=test` (subscription via mds). - `mdp://example.com:8080/dashboards/dashboard1?what=header` => service name `/dashboards/dashboard1`, topic `/dashboards/dashboard1?what=header` (Request via mdp). -### Compile-Time-Reflection +## Compile-Time-Reflection The serialisers are based on a [compile-time-reflection](docs/CompileTimeSerialiser.md) that efficiently transform domain-objects to and from the given wire-format (binary, JSON, ...). Compile-time reflection will become part of [C++23](http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p0592r4.html) as described by [David Sankel et al. , “C++ Extensions for Reflection”, ISO/IEC CD TS 23619, N4856](http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/n4856.pdf). @@ -142,7 +142,7 @@ provides also an optional light-weight `constexpr` annotation template wrapper ` that in turn can be used to (re-)generate and document the class definition (e.g. for other programming languages or projects that do not have the primary domain-object definition at hand) or to generate a generic [OpenAPI](https://swagger.io/specification/) definition. More details can be found [here](docs/CompileTimeSerialiser.md). -### Building from source +## Building from source Note that building from source is only required if you want to modify opencmw-cpp itself. @@ -151,7 +151,7 @@ In that case, rather take a look at the project [opencmw-cpp-example](https://gi For concrete build instructions, please check the [build instructions page](docs/BuildInstructions.md). -### Example +## Example For an example on how to implement a simple, first service using opencmw-cpp, please take a look at the project [opencmw-cpp-example](https://github.com/alexxcons/opencmw-cpp-example). @@ -217,7 +217,7 @@ or RESTful (HTTP)-based high-level protocols, or through a simple RESTful web-in [comment]: <> (The basic HTML rendering is based on XXX template engine and can be customised. For more efficient, complex and cross-platform) [comment]: <> (UI designs it is planned to allow embedding of WebAssembly-based ([WASM](https://en.wikipedia.org/wiki/WebAssembly)) applications.) -### Performance +## Performance The end-to-end transmission achieving roughly 10k messages per second for synchronous communications and about 140k messages per second for asynchronous and or publish-subscribe style data acquisition (TCP link via locahost). @@ -241,17 +241,74 @@ Your mileage may vary depending on the specific domain-object, processing logic, but you can check and compare the results for your platform using the [RoundTripAndNotifyEvaluation](RoundTripAndNotifyEvaluation.cpp) and/or [MdpImplementationBenchmark](MdpImplementationBenchmark.cpp) benchmarks. -### Documentation +## Testing HTTP/3 + +The openCMW REST interface supports both HTTP/2 and HTTP/3. When connecting from a browser, the browser typically first connects +via HTTP/2. HTTP/2 responses contain a "alt-svc" header that informs the browser about the availability of HTTP/3. The browser +then should switch to HTTP/3. Note that if HTTP/3 fails for any reason (certificates, server error etc.), the browser might remember +that and not try HTTP/3 again. How to reset this depends on the browser. What I did in Google Chrome + + 1. open an private tab, open the developer console, go to "Network", enable the "Protocol" column + 2. connect to the service; verify in the developer console that HTTP/3 is used (Protocol should switch to`H3` after the first `H2` request) + 3. To reset browser status, close *all* private tabs and open a new one. + +### SSL Certificates + +QUIC/HTTP/3 requires the use of TLS, unencrypted servers are not possible. + +At least in Google Chrome, the TLS stack used for QUIC seems quite separate from the normal settings, and its much stricter than for HTTP/1/2. *If +anything goes wrong here, Google Chrome will silently stick with HTTP/2, or show you an error, if you're lucky*. + +Caveats: + + 1. The certificate must be trusted inside the Chrome Certificate store (e.g. on Mac, trusting it in Keychain might silence the warning for HTTP1/2, +but HTTP/3 will still fail. Add the server's public key under chrome://certificate-manager/. + 2. The hostname in the certificate must match what you connect via in the browser. "Works-everywhere self-signed certificates" I couldn't get to work. +What I did: Create a certificate for hostname `foobar`, edit `/etc/hosts` to resolve `foobar` to the test host's IP address. Enter e.g. `https://foobar:8080` +in the browser, instead of the IP address. + 3. For HTTP/2, you can ignore Chrome's warning with "Proceed anyway" or similar. This does not make the QUIC stack trust the certificate, HTTP/3 will +not be used. + 4. To get the certificate's fingerprint in base64: +``` +openssl x509 -in /path/to/demo_public.crt -noout -pubkey \ + | openssl pkey -pubin -outform DER \ + | openssl dgst -sha256 -binary \ + | openssl enc -base64 +``` + 3. Start Chrome with these parameters to make Chrome trust the certificate +``` +.../Google\ Chrome --enable-quic \ + --origin-to-force-quic-on=:8080 \ + --ignore-certificate-errors-spki-list= \ + --user-data-dir=/tmp/quic-test-profile \ + --no-sandbox +``` + + For verbose logging (netlog can be viewed in the [https://netlog-viewer.appspot.com/#import](netlog viewer)), add +``` + --enable-logging=stderr --v=3 \ + --log-net-log=netlog.json --quic-version=h3 +``` + +### curl + +curl needs to be built with QUIC/HTTP/3 enabled, which is not the case at least on Ubuntu 24.04. I used the docker image `badouralix/curl-http3`: + +``` +docker run --rm badouralix/curl-http3 curl -k -vvvv --http3 https://:8080/loadTest?topic=1&intervalMs=40&payloadSize=4096&nUpdates=100&LongPollingIdx=Next" +``` + +## Documentation .... more to follow. -### Don't like Cpp? +## Don't like Cpp? For prototyping applications or services that do not interact with hardware-based systems, a Java-based [OpenCMW](https://github.com/fair-acc/opencmw-java) twin-project is being developed which follows the same functional style but takes advantage of more concise implementation and C++-based type safety. -### Acknowledgements +## Acknowledgements The implementation heavily relies upon and re-uses time-tried and well-established concepts from [ZeroMQ](https://zeromq.org/) (notably the [Majordomo](https://rfc.zeromq.org/spec/7/) communication pattern, see [Z-Guide](https://zguide.zeromq.org/docs/chapter4/#Service-Oriented-Reliable-Queuing-Majordomo-Pattern) diff --git a/cmake/DependenciesNative.cmake b/cmake/DependenciesNative.cmake index b1696538..e01c9482 100644 --- a/cmake/DependenciesNative.cmake +++ b/cmake/DependenciesNative.cmake @@ -1,32 +1,118 @@ -# Build a static version of openssl to link into +include(ExternalProject) +include(GNUInstallDirs) + +set(OPENSSL_C_FLAGS "-O3 -march=x86-64-v3" CACHE STRING "OpenSSL custom CFLAGS" FORCE) +set(OPENSSL_CXX_FLAGS "-O3 -march=x86-64-v3" CACHE STRING "OpenSSL custom CXXFLAGS" FORCE) set(OPENSSL_INSTALL_DIR "${CMAKE_BINARY_DIR}/openssl-install") -add_library(OpenSSL::Crypto STATIC IMPORTED GLOBAL) -add_library(OpenSSL::SSL STATIC IMPORTED GLOBAL) -add_dependencies(OpenSSL::Crypto PUBLIC openssl-build) -add_dependencies(OpenSSL::SSL PUBLIC openssl-build) -set_target_properties(OpenSSL::Crypto PROPERTIES + +# Build custom OpenSSL with QUIC support +ExternalProject_Add(OpenSslProject + GIT_REPOSITORY https://github.com/openssl/openssl.git + GIT_TAG openssl-3.5.0 # 3.5.0 required for server-side QUIC support + GIT_SHALLOW ON + BUILD_BYPRODUCTS ${OPENSSL_INSTALL_DIR}/lib64/libcrypto.a ${OPENSSL_INSTALL_DIR}/lib64/libssl.a + CONFIGURE_COMMAND COMMAND ./Configure CFLAGS=${OPENSSL_C_FLAGS} CXXFLAGS=${OPENSSL_CXX_FLAGS} no-shared no-tests --prefix=${OPENSSL_INSTALL_DIR} --openssldir=${OPENSSL_INSTALL_DIR} linux-x86_64 + UPDATE_COMMAND "" + BUILD_COMMAND make -j + INSTALL_COMMAND make install_sw # only installs software components (no docs, etc) + BUILD_IN_SOURCE ON +) + +add_library(openssl-crypto-static STATIC IMPORTED GLOBAL) +add_dependencies(openssl-crypto-static OpenSslProject) +set_target_properties(openssl-crypto-static PROPERTIES IMPORTED_LOCATION "${OPENSSL_INSTALL_DIR}/lib64/libcrypto.a" INTERFACE_INCLUDE_DIRECTORIES "${OPENSSL_INSTALL_DIR}/include" ) -set_target_properties(OpenSSL::SSL PROPERTIES + +add_library(openssl-ssl-static STATIC IMPORTED GLOBAL) +add_dependencies(openssl-ssl-static OpenSslProject) +set_target_properties(openssl-ssl-static PROPERTIES IMPORTED_LOCATION "${OPENSSL_INSTALL_DIR}/lib64/libssl.a" INTERFACE_INCLUDE_DIRECTORIES "${OPENSSL_INSTALL_DIR}/include" ) -get_target_property(libcryptoa OpenSSL::Crypto IMPORTED_LOCATION) -get_target_property(libcryptoaloc OpenSSL::Crypto LOCATION) -set(OPENSSL_C_FLAGS "-O3 -march=x86-64-v3" CACHE STRING "OpenSSL custom CFLAGS" FORCE) -set(OPENSSL_CXX_FLAGS "-O3 -march=x86-64-v3" CACHE STRING "OpenSSL custom CXXFLAGS" FORCE) -add_custom_command( - OUTPUT ${OPENSSL_INSTALL_DIR}/lib64/libcrypto.a ${OPENSSL_INSTALL_DIR}/lib64/libssl.a - COMMAND ${FETCHCONTENT_BASE_DIR}/openssl-source-src/Configure CFLAGS=${OPENSSL_C_FLAGS} CXXFLAGS=${OPENSSL_CXX_FLAGS} no-shared no-tests --prefix=${OPENSSL_INSTALL_DIR} --openssldir=${OPENSSL_INSTALL_DIR} linux-x86_64 - COMMAND make -j - COMMAND make install_sw # only installs software components (no docs, etc) - COMMENT "Build openssl as a static library" - WORKING_DIRECTORY ${FETCHCONTENT_BASE_DIR}/openssl-source-build + +option(ENABLE_NGHTTP_DEBUG "Enable verbose nghttp2 debug output" OFF) + +ExternalProject_Add(Nghttp2Project + GIT_REPOSITORY https://github.com/nghttp2/nghttp2 + GIT_TAG v1.65.0 + GIT_SHALLOW ON + BUILD_BYPRODUCTS ${CMAKE_BINARY_DIR}/nghttp2-install/${CMAKE_INSTALL_LIBDIR}/libnghttp2.a + UPDATE_COMMAND "" + CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX:PATH=${CMAKE_BINARY_DIR}/nghttp2-install + -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} + -DENABLE_LIB_ONLY:BOOL=ON + -DENABLE_HTTP3:BOOL=OFF + -DENABLE_DEBUG:BOOL=${ENABLE_NGHTTP_DEBUG} + -DBUILD_STATIC_LIBS:BOOL=ON + -BUILD_SHARED_LIBS:BOOL=OFF + -DENABLE_DOC:BOOL=OFF +) + +add_library(nghttp2-static STATIC IMPORTED GLOBAL) +set_target_properties(nghttp2-static PROPERTIES + IMPORTED_LOCATION "${CMAKE_BINARY_DIR}/nghttp2-install/${CMAKE_INSTALL_LIBDIR}/libnghttp2.a" + INTERFACE_INCLUDE_DIRECTORIES "${CMAKE_BINARY_DIR}/nghttp2-install/include" ) -add_custom_target(openssl-build ALL - DEPENDS ${OPENSSL_INSTALL_DIR}/lib64/libcrypto.a ${OPENSSL_INSTALL_DIR}/lib64/libssl.a +add_dependencies(nghttp2-static Nghttp2Project) + +ExternalProject_Add(Nghttp3Project + GIT_REPOSITORY https://github.com/ngtcp2/nghttp3.git + GIT_TAG v1.10.1 + GIT_SHALLOW ON + BUILD_BYPRODUCTS ${CMAKE_BINARY_DIR}/nghttp3-install/${CMAKE_INSTALL_LIBDIR}/libnghttp3.a + UPDATE_COMMAND "" + CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX:PATH=${CMAKE_BINARY_DIR}/nghttp3-install + -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} + -DENABLE_LIB_ONLY:BOOL=ON + -DENABLE_DEBUG:BOOL=${ENABLE_NGHTTP_DEBUG} + -DBUILD_STATIC_LIBS:BOOL=ON + -DBUILD_SHARED_LIBS:BOOL=OFF + -DENABLE_DOC:BOOL=OFF +) + +ExternalProject_Add(NgTcp2Project + GIT_REPOSITORY https://github.com/ngtcp2/ngtcp2.git + GIT_TAG v1.13.0 + GIT_SHALLOW ON + PREFIX ${CMAKE_BINARY_DIR}/ngtcp2-install + BUILD_BYPRODUCTS ${CMAKE_BINARY_DIR}/ngtcp2-install/${CMAKE_INSTALL_LIBDIR}/libngtcp2.a ${CMAKE_BINARY_DIR}/ngtcp2-install/${CMAKE_INSTALL_LIBDIR}/libngtcp2_crypto_ossl.a + UPDATE_COMMAND "" + CMAKE_ARGS + -DOPENSSL_ROOT_DIR:PATH=${OPENSSL_INSTALL_DIR} + -DENABLE_OPENSSL:BOOL=ON + -DCMAKE_INSTALL_PREFIX:PATH=${CMAKE_BINARY_DIR}/ngtcp2-install + -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} + -DENABLE_LIB_ONLY:BOOL=ON + -DENABLE_DEBUG:BOOL=${ENABLE_NGHTTP_DEBUG} + -DBUILD_STATIC_LIBS:BOOL=ON + -DBUILD_SHARED_LIBS:BOOL=OFF + DEPENDS openssl-crypto-static openssl-ssl-static +) + +add_library(ngtcp2-static STATIC IMPORTED GLOBAL) +set_target_properties(ngtcp2-static PROPERTIES + IMPORTED_LOCATION "${CMAKE_BINARY_DIR}/ngtcp2-install/${CMAKE_INSTALL_LIBDIR}/libngtcp2.a" + INTERFACE_INCLUDE_DIRECTORIES "${CMAKE_BINARY_DIR}/ngtcp2-install/include" +) +add_dependencies(ngtcp2-static NgTcp2Project) + +add_library(ngtcp2-crypto-ossl-static STATIC IMPORTED GLOBAL) +set_target_properties(ngtcp2-crypto-ossl-static PROPERTIES + IMPORTED_LOCATION "${CMAKE_BINARY_DIR}/ngtcp2-install/${CMAKE_INSTALL_LIBDIR}/libngtcp2_crypto_ossl.a" + INTERFACE_INCLUDE_DIRECTORIES "${CMAKE_BINARY_DIR}/ngtcp2-install/include" ) +add_dependencies(ngtcp2-crypto-ossl-static NgTcp2Project) + +add_library(nghttp3-static STATIC IMPORTED GLOBAL) +set_target_properties(nghttp3-static PROPERTIES + IMPORTED_LOCATION "${CMAKE_BINARY_DIR}/nghttp3-install/${CMAKE_INSTALL_LIBDIR}/libnghttp3.a" + INTERFACE_INCLUDE_DIRECTORIES "${CMAKE_BINARY_DIR}/nghttp3-install/include" +) +add_dependencies(nghttp3-static Nghttp3Project) add_library(mustache INTERFACE) target_include_directories(mustache INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/3rd_party/kainjow) @@ -38,6 +124,7 @@ FetchContent_Declare( GIT_TAG v4.3.5 # latest as of 2025-03-27 ) set(ZMQ_BUILD_TESTS OFF CACHE BOOL "Build the tests for ZeroMQ") + # suppress warnings for missing zeromq dependencies by disabling some features set(WITH_TLS OFF CACHE BOOL "TLS support for ZeroMQ WebSockets") set(BUILD_SHARED OFF CACHE BOOL "Build cmake shared library") @@ -56,12 +143,6 @@ FetchContent_Declare( GIT_TAG v1.2.12 # latest v1.2.12 ) -FetchContent_Declare( - openssl-source - GIT_REPOSITORY https://github.com/openssl/openssl.git - GIT_TAG openssl-3.4.1 -) - -FetchContent_MakeAvailable(cpp-httplib zeromq openssl-source) +FetchContent_MakeAvailable(cpp-httplib zeromq) list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/contrib) # replace contrib by extras for catch2 v3.x.x diff --git a/concepts/client/CMakeLists.txt b/concepts/client/CMakeLists.txt index a361193d..33f28624 100644 --- a/concepts/client/CMakeLists.txt +++ b/concepts/client/CMakeLists.txt @@ -1,13 +1,10 @@ -if(NOT EMSCRIPTEN) - add_executable(RestSubscription_example RestSubscription_example.cpp) - target_link_libraries( - RestSubscription_example - PRIVATE core - client - opencmw_project_options - opencmw_project_warnings - assets::rest) -endif() +add_executable(LoadTest_client LoadTest_client.cpp) +target_link_libraries( + LoadTest_client + PRIVATE core + client + opencmw_project_options + opencmw_project_warnings) add_executable(RestSubscription_client RestSubscription_client.cpp) target_link_libraries( diff --git a/concepts/client/LoadTest_client.cpp b/concepts/client/LoadTest_client.cpp new file mode 100644 index 00000000..b74082c4 --- /dev/null +++ b/concepts/client/LoadTest_client.cpp @@ -0,0 +1,68 @@ +#include +#include +#include +#include + +#include +#include +#include + +#include "ClientCommon.hpp" +#include "helpers.hpp" + +using namespace std::chrono_literals; + +std::string schema() { + if (auto env = ::getenv("DISABLE_REST_HTTPS"); env != nullptr && std::string_view(env) == "1") { + return "http"; + } else { + return "https"; + } +} + + +int main() { + constexpr auto kServerPort = 8080; + constexpr auto kNClients = 80UZ; + constexpr auto kNSubscriptions = 10UZ; + constexpr auto kNUpdates = 5000UZ; + constexpr auto kIntervalMs = 40UZ; + constexpr auto kPayloadSize = 4096UZ; + + std::array, kNClients> clients; + for (std::size_t i = 0; i < clients.size(); i++) { + clients[i] = std::make_unique(opencmw::client::DefaultContentTypeHeader(opencmw::MIME::BINARY), opencmw::client::VerifyServerCertificates(false)); + } + std::atomic responseCount = 0; + + const auto start = std::chrono::system_clock::now(); + + for (std::size_t i = 0; i < kNClients; i++) { + for (std::size_t j = 0; j < kNSubscriptions; j++) { + opencmw::client::Command cmd; + cmd.command = opencmw::mdp::Command::Subscribe; + cmd.serviceName = "/loadTest"; + cmd.topic = opencmw::URI<>(std::format("{}://localhost:{}/loadTest?initialDelayMs=1000&topic={}&intervalMs={}&payloadSize={}&nUpdates={}", schema(), kServerPort, /*i,*/ j, kIntervalMs, kPayloadSize, kNUpdates)); + cmd.callback = [&responseCount](const auto &msg) { + responseCount++; + }; + clients[i]->request(std::move(cmd)); + } + } + + constexpr auto expectedResponses = kNClients * kNSubscriptions * kNUpdates; + + std::uint64_t counter = 0; + while (responseCount < expectedResponses) { + counter += 50; + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + if (counter % 20 == 0) { + std::println("Received {} of {} responses", responseCount.load(), expectedResponses); + } + } + + const auto end = std::chrono::system_clock::now(); + const auto elapsed = std::chrono::duration_cast(end - start).count(); + std::println("Elapsed time: {} ms", elapsed); + return 0; +} diff --git a/concepts/client/RestSubscription_client.cpp b/concepts/client/RestSubscription_client.cpp index 4d8e733f..d6c10009 100644 --- a/concepts/client/RestSubscription_client.cpp +++ b/concepts/client/RestSubscription_client.cpp @@ -13,8 +13,11 @@ using namespace std::chrono_literals; // These are not main-local, as JS doesn't end when // C++ main ends namespace test_state { +#ifndef __EMSCRIPTEN__ +opencmw::client::RestClient client(opencmw::client::VerifyServerCertificates(false)); +#else opencmw::client::RestClient client; - +#endif std::string schema() { if (auto env = ::getenv("DISABLE_REST_HTTPS"); env != nullptr && std::string_view(env) == "1") { return "http"; @@ -39,10 +42,6 @@ auto run = rest_test_runner( } // namespace test_state int main() { -#ifndef __EMSCRIPTEN__ - opencmw::client::RestClient::CHECK_CERTIFICATES = false; -#endif - using namespace test_state; #ifndef __EMSCRIPTEN__ diff --git a/concepts/client/RestSubscription_example.cpp b/concepts/client/RestSubscription_example.cpp deleted file mode 100644 index 1e9d2bcf..00000000 --- a/concepts/client/RestSubscription_example.cpp +++ /dev/null @@ -1,127 +0,0 @@ -#include -#include -#include - -#include -#include -#include - -namespace detail { -class EventDispatcher { - std::mutex _mutex; - std::condition_variable _condition; - std::atomic _id{ 0 }; - std::atomic _cid{ -1 }; - std::string _message; - -public: - void wait_event(httplib::DataSink &sink) { - std::unique_lock lk(_mutex); - int id = std::atomic_load_explicit(&_id, std::memory_order_acquire); - _condition.wait(lk, [id, this] { return _cid == id; }); - if (sink.is_writable()) { - sink.write(_message.data(), _message.size()); - } - } - - void send_event(const std::string_view &message) { - std::scoped_lock lk(_mutex); - _cid = _id++; - _message = message; - _condition.notify_all(); - } -}; -} // namespace detail - -int main() { - using namespace std::chrono_literals; - opencmw::client::RestClient client; - - std::atomic updateCounter{ 0 }; - detail::EventDispatcher eventDispatcher; - httplib::Server server; - server.Get("/event", [&eventDispatcher, &updateCounter](const httplib::Request &req, httplib::Response &res) { - auto acceptType = req.headers.find("accept"); - if (acceptType == req.headers.end() || opencmw::MIME::EVENT_STREAM.typeName() != acceptType->second) { // non-SSE request -> return default response -#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16)) - res.set_content(std::format("update counter = {}", updateCounter.load()), opencmw::MIME::TEXT); -#else - res.set_content(std::format("update counter = {}", updateCounter.load()), std::string(opencmw::MIME::TEXT.typeName())); -#endif - return; - } else { - std::print("server received SSE request on path '{}' body = '{}'\n", req.path, req.body); -#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16)) - res.set_chunked_content_provider(opencmw::MIME::EVENT_STREAM, [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) { -#else - res.set_chunked_content_provider(std::string(opencmw::MIME::EVENT_STREAM.typeName()), [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) { -#endif - eventDispatcher.wait_event(sink); - return true; - }); - } - }); - server.Get("/endPoint", [](const httplib::Request &req, httplib::Response &res) { - std::print("server received request on path '{}' body = '{}'\n", req.path, req.body); - res.set_content("Hello World!", "text/plain"); - }); - client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); }); - int timeOut = 0; - while (!server.is_running() && timeOut < 10'000) { - std::this_thread::sleep_for(1ms); - timeOut += 1; - } - assert(server.is_running()); - if (!server.is_running()) { - std::print("couldn't start server\n"); - std::terminate(); - } - - std::atomic received(false); - opencmw::IoBuffer data; - data.put('A'); - data.put('B'); - data.put('C'); - data.put(0); - opencmw::client::Command command; - command.command = opencmw::mdp::Command::Subscribe; - command.topic = opencmw::URI("http://localhost:8080/event"); - command.data = std::move(data); - command.callback = [&received](const opencmw::mdp::Message &rep) { - std::print("SSE client received reply = '{}' - body size: '{}'\n", rep.data.asString(), rep.data.size()); - received.fetch_add(1, std::memory_order_relaxed); - received.notify_all(); - }; - - client.request(command); - - std::cout << "client request launched" << std::endl; - std::this_thread::sleep_for(100ms); - eventDispatcher.send_event("test-event meta data"); - std::jthread([&updateCounter, &eventDispatcher] { - while (updateCounter < 5) { - std::this_thread::sleep_for(500ms); - eventDispatcher.send_event(std::format("test-event {}", updateCounter++)); - } - }).join(); - - while (received.load(std::memory_order_relaxed) < 5) { - std::this_thread::sleep_for(100ms); - } - std::cout << "done waiting" << std::endl; - assert(received.load(std::memory_order_acquire) >= 5); - - command.command = opencmw::mdp::Command::Unsubscribe; - client.request(command); - std::this_thread::sleep_for(100ms); - std::cout << "done Unsubscribe" << std::endl; - client.stop(); - std::cout << "client stopped" << std::endl; - - server.stop(); - eventDispatcher.send_event(std::format("test-event {}", updateCounter++)); - std::cout << "server stopped" << std::endl; - std::this_thread::sleep_for(5s); - - return 0; -} diff --git a/concepts/client/dns_example.cpp b/concepts/client/dns_example.cpp index eb0b1b6c..00ddda64 100644 --- a/concepts/client/dns_example.cpp +++ b/concepts/client/dns_example.cpp @@ -12,8 +12,9 @@ #include #endif // EMSCRIPTEN +#include + #include -#include using namespace std::chrono_literals; using namespace opencmw; @@ -32,12 +33,17 @@ using namespace opencmw::service::dns; void run_dns_server(std::string_view httpAddress, std::string_view mdpAddress) { majordomo::Broker<> broker{ "Broker", {} }; std::string rootPath{ "./" }; - auto fs = cmrc::assets::get_filesystem(); - majordomo::RestBackend rest_backend{ broker, fs, URI<>{ std::string{ httpAddress } } }; + majordomo::rest::Settings rest; + rest.handlers = { majordomo::rest::cmrcHandler("/assets/*", "", std::make_shared(cmrc::assets::get_filesystem()), "") }; + + if (const auto bound = broker.bindRest(rest); !bound) { + std::println(std::cerr, "failed to bind REST: {}", bound.error()); + std::exit(1); + return; + } DnsWorkerType dnsWorker{ broker, DnsHandler{} }; broker.bind(URI<>{ std::string{ mdpAddress } }, majordomo::BindOption::Router); - RunInThread restThread(rest_backend); RunInThread dnsThread(dnsWorker); RunInThread brokerThread(broker); diff --git a/concepts/client/pre.js b/concepts/client/pre.js index f717f711..8e7f7538 100644 --- a/concepts/client/pre.js +++ b/concepts/client/pre.js @@ -47,8 +47,8 @@ if (typeof XMLHttpRequest === 'undefined') { // Set the additional headers res.setHeader('Keep-Alive', 'timeout=5, max=5'); - res.setHeader('X-OPENCMW-SERVICE-NAME', 'dns'); - res.setHeader('X-OPENCMW-TOPIC', '//s?signal_type=&signal_unit=&signal_name=&service_type=&service_name=&port=-1&hostname=&protocol=&signal_rate=nan&contextType=application%2Foctet-stream'); + res.setHeader('x-opencmw-service-name', 'dns'); + res.setHeader('x-opencmw-topic', '//s?signal_type=&signal_unit=&signal_name=&service_type=&service_name=&port=-1&hostname=&protocol=&signal_rate=nan&contextType=application%2Foctet-stream'); // Send the binary response const binaryData = Buffer.from('ffffffff040000005961530001000000602f0da036000000e4010000250000006f70656e636d773a3a736572766963653a3a646e733a3a466c6174456e7472794c69737400ff988e0ac51a000000150000000900000070726f746f636f6c00010000000100000001000000050000006874747000ff335c21ee1a0000002100000009000000686f73746e616d650001000000010000000100000011000000746573742e6578616d706c652e636f6d006881983400160000001000000005000000706f72740001000000010000000100000039050000ffd55573151e000000150000000d000000736572766963655f6e616d6500010000000100000001000000050000007465737400ff846a76151e000000110000000d000000736572766963655f74797065000100000001000000010000000100000000ffc2ad1b281d000000160000000c0000007369676e616c5f6e616d650001000000010000000100000005000000746573743100ffbb0c1f281d000000110000000c0000007369676e616c5f756e69740001000000010000000100000001000000006a17801d281d000000100000000c0000007369676e616c5f726174650001000000010000000100000000007a44ff71c21e281d000000110000000c0000007369676e616c5f74797065000100000001000000010000000100000000fe602f0da03600000000000000250000006f70656e636d773a3a736572766963653a3a646e733a3a466c6174456e7472794c69737400', 'hex'); diff --git a/concepts/majordomo/CMakeLists.txt b/concepts/majordomo/CMakeLists.txt index 0ad4a009..daec7361 100644 --- a/concepts/majordomo/CMakeLists.txt +++ b/concepts/majordomo/CMakeLists.txt @@ -38,6 +38,16 @@ target_link_libraries( assets::rest assets::testImages) +add_executable(MajordomoRest_LoadTestServer MajordomoRest_LoadTestServer.cpp) +target_include_directories(MajordomoRest_LoadTestServer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_libraries( + MajordomoRest_LoadTestServer + PRIVATE majordomo + opencmw_project_options + opencmw_project_warnings + assets::rest + assets::testImages) + if(NOT CMAKE_CXX_COMPILER_ID MATCHES diff --git a/concepts/majordomo/MajordomoRest_LoadTestServer.cpp b/concepts/majordomo/MajordomoRest_LoadTestServer.cpp new file mode 100644 index 00000000..2fd52afb --- /dev/null +++ b/concepts/majordomo/MajordomoRest_LoadTestServer.cpp @@ -0,0 +1,62 @@ +#include +#include +#include +#include + +#include "helpers.hpp" + +namespace majordomo = opencmw::majordomo; + +int main(int argc, char **argv) { + using opencmw::URI; + + std::string rootPath = "./"; + std::uint16_t port = 8080; + bool https = true; + + for (int i = 1; i < argc; i++) { + if (std::string_view(argv[i]) == "--port") { + if (i + 1 < argc) { + port = static_cast(std::stoi(argv[i + 1])); + ++i; + continue; + } + } else if (std::string_view(argv[i]) == "--http") { + https = false; + } else { + rootPath = argv[i]; + } + } + + std::println(std::cerr, "Starting load test server ({}) for {} on port {}", https ? "HTTPS" : "HTTP", rootPath, port); + + opencmw::majordomo::rest::Settings rest; + rest.port = port; + rest.handlers = { majordomo::rest::cmrcHandler("/assets/*", "", std::make_shared(cmrc::assets::get_filesystem()), "") }; + if (https) { + rest.certificateFilePath = "./demo_public.crt"; + rest.keyFilePath = "./demo_private.key"; + } else { + rest.protocols = majordomo::rest::Protocol::Http2; + std::println(std::cerr, "HTTP/3 disabled, requires TLS"); + } + + majordomo::Broker broker("/Broker", testSettings()); + opencmw::query::registerTypes(opencmw::load_test::Context(), broker); + + if (const auto bound = broker.bindRest(rest); !bound) { + std::println("Could not bind HTTP/2 REST bridge to port {}: {}", rest.port, bound.error()); + return 1; + } + + const auto brokerRouterAddress = broker.bind(URI<>("mds://127.0.0.1:12345")); + if (!brokerRouterAddress) { + std::cerr << "Could not bind to broker address" << std::endl; + return 1; + } + + majordomo::load_test::Worker<"/loadTest"> loadTestWorker(broker); + RunInThread runLoadTest(loadTestWorker); + + broker.run(); +} diff --git a/concepts/majordomo/MajordomoRest_example.cpp b/concepts/majordomo/MajordomoRest_example.cpp index 818ae601..a26a66c7 100644 --- a/concepts/majordomo/MajordomoRest_example.cpp +++ b/concepts/majordomo/MajordomoRest_example.cpp @@ -1,13 +1,9 @@ #include #include -#include +#include #include -#include -#include -#include #include -#include #include "helpers.hpp" @@ -18,49 +14,61 @@ int main(int argc, char **argv) { using opencmw::URI; std::string rootPath = "./"; - if (argc > 1) { - rootPath = argv[1]; + std::uint16_t port = 8080; + bool https = true; + + for (int i = 1; i < argc; i++) { + if (std::string_view(argv[i]) == "--port") { + if (i + 1 < argc) { + port = static_cast(std::stoi(argv[i + 1])); + ++i; + continue; + } + } else if (std::string_view(argv[i]) == "--http") { + https = false; + } else { + rootPath = argv[i]; + } } - std::cerr << "Starting server for " << rootPath << "\n"; + const auto scheme = https ? "https" : "http"; + auto makeExample = [scheme, port](std::string_view pathAndQuery) { + return std::format("{}://localhost:{}/{}", scheme, port, pathAndQuery); + }; + + std::println(std::cerr, "Starting {} server for {} on port {}", https ? "HTTPS" : "HTTP", rootPath, port); - std::cerr << "Open up https://localhost:8080/addressbook?contentType=text/html&ctx=FAIR.SELECTOR.ALL in your web browser\n"; - std::cerr << "Or curl -v -k one of the following:\n"; - std::cerr - << "'https://localhost:8080/addressbook?contentType=application/json&ctx=FAIR.SELECTOR.ALL'\n" - << "'https://localhost:8080/addressbook?contentType=application/json&ctx=FAIR.SELECTOR.ALL'\n" - << "'https://localhost:8080/addressbook/addresses?LongPollingId=Next'\n" - << "'https://localhost:8080/addressbook/addresses?LongPollingId=0'\n" - << "'https://localhost:8080/beverages/wine?LongPollingIdx=Subscription'\n"; + std::println(std::cerr, "Open up {} in your web browser", makeExample("addressbook?contentType=text/html&ctx=FAIR.SELECTOR.ALL")); + std::println(std::cerr, "Or curl -v -k one of the following:"); + std::println(std::cerr, "'{}'", makeExample("addressbook?contentType=application/json&ctx=FAIR.SELECTOR.ALL")); + std::println(std::cerr, "'{}'", makeExample("addressbook?contentType=application/json&ctx=FAIR.SELECTOR.ALL")); + std::println(std::cerr, "'{}'", makeExample("addressbook/addresses?LongPollingIdx=Next")); + std::println(std::cerr, "'{}'", makeExample("addressbook/addresses?LongPollingIdx=Last")); + std::println(std::cerr, "'{}'", makeExample("addressbook/addresses?LongPollingIdx=0")); + std::println(std::cerr, "'{}'", makeExample("beverages/wine?LongPollingIdx=Next")); + std::println(std::cerr, "'{}'", makeExample("loadTest?topic=1&intervalMs=40&payloadSize=4096&nUpdates=100&LongPollingIdx=Next")); // note: inconsistency: brokerName as ctor argument, worker's serviceName as NTTP // note: default roles different from java (has: ADMIN, READ_WRITE, READ_ONLY, ANYONE, NULL) majordomo::Broker primaryBroker("/PrimaryBroker", testSettings()); opencmw::query::registerTypes(SimpleContext(), primaryBroker); - auto fs = cmrc::assets::get_filesystem(); - - std::variant< - std::monostate, - FileServerRestBackend, - FileServerRestBackend> - rest; - if (auto env = ::getenv("DISABLE_REST_HTTPS"); env != nullptr && std::string_view(env) == "1") { - rest.emplace>(primaryBroker, fs, rootPath); + opencmw::majordomo::rest::Settings rest; + rest.port = port; + rest.handlers = { majordomo::rest::cmrcHandler("/assets/*", "", std::make_shared(cmrc::assets::get_filesystem()), "") }; + if (https) { + rest.certificateFilePath = "./demo_public.crt"; + rest.keyFilePath = "./demo_private.key"; } else { - rest.emplace>(primaryBroker, fs, rootPath); + rest.protocols = majordomo::rest::Protocol::Http2; + std::println(std::cerr, "HTTP/3 disabled, requires TLS"); + } + if (const auto bound = primaryBroker.bindRest(rest); !bound) { + std::println("Could not bind REST bridge to port {}: {}", rest.port, bound.error()); + return 1; } - std::jthread restServerThread([&rest] { - std::visit([](T &server) { - if constexpr (not std::is_same_v) { - server.run(); - } - }, - rest); - }); - - const auto brokerRouterAddress = primaryBroker.bind(URI<>("mds://127.0.0.1:12345")); + const auto brokerRouterAddress = primaryBroker.bind(URI<>("mds://127.0.0.1:12345")); if (!brokerRouterAddress) { std::cerr << "Could not bind to broker address" << std::endl; return 1; @@ -82,6 +90,7 @@ int main(int argc, char **argv) { majordomo::Worker<"/addressbook", SimpleContext, AddressRequest, AddressEntry> addressbookWorker(primaryBroker, TestAddressHandler()); majordomo::Worker<"/addressbookBackup", SimpleContext, AddressRequest, AddressEntry> addressbookBackupWorker(primaryBroker, TestAddressHandler()); majordomo::BasicWorker<"/beverages"> beveragesWorker(primaryBroker, TestIntHandler(10)); + majordomo::load_test::Worker<"/loadTest"> loadTestWorker(primaryBroker); // ImageServiceWorker<"/testImage", majordomo::description<"Returns an image">> imageWorker(primaryBroker, std::chrono::seconds(10)); @@ -91,6 +100,7 @@ int main(int argc, char **argv) { RunInThread runAddressbook(addressbookWorker); RunInThread runAddressbookBackup(addressbookBackupWorker); RunInThread runBeverages(beveragesWorker); + RunInThread runLoadTest(loadTestWorker); RunInThread runImage(imageWorker); waitUntilWorkerServiceAvailable(primaryBroker.context, addressbookWorker); diff --git a/concepts/majordomo/assets/mustache/ServicesList.mustache b/concepts/majordomo/assets/mustache/ServicesList.mustache index c0e5be6f..502acf8a 100644 --- a/concepts/majordomo/assets/mustache/ServicesList.mustache +++ b/concepts/majordomo/assets/mustache/ServicesList.mustache @@ -96,9 +96,7 @@ input.topicInput { border: 0; width: 100px; border-bottom: 1px solid silver; } }; listenButton.onclick = () => { - let post = { method: 'GET', headers: { 'X-OPENCMW-METHOD' : 'POLL' } }; - - fetch(href + "/" + topicInput.value, post) + fetch(href + "/" + topicInput.value) .then(response => response.text()) .then(data => { notificationLabel.innerHTML = data; diff --git a/concepts/majordomo/assets/mustache/default.mustache b/concepts/majordomo/assets/mustache/default.mustache index c243a8db..b0190fc4 100644 --- a/concepts/majordomo/assets/mustache/default.mustache +++ b/concepts/majordomo/assets/mustache/default.mustache @@ -174,14 +174,13 @@ } function pollingHandler() { - let get = { method: 'GET', headers: { 'X-OPENCMW-METHOD' : 'POLL' } }; - let queryParams = window.opencmwActiveSubscriptionQueryParams; if (!window.opencmwActiveSubscriptionQueryParams) { let formMeta = window.opencmwWebForms.replyContextForm; queryParams = new URLSearchParams(); + queryParams.append("LongPollingIdx", "Next"); for (const [key, fieldMeta] of Object.entries(formMeta)) { const field = document.getElementById(fieldMeta.formId + "_" + fieldMeta.name); @@ -205,14 +204,14 @@ } let url = getPageBaseUrl(); - url.searchParams.append("LongPollingIdx", "Next"); + url.searchParams.set("LongPollingIdx", "Next"); let currentObject = (new URL(document.URL)).pathname; - url.searchParams.append("SubscriptionContext", currentObject + "?" + queryParams.toString()); + url.searchParams.set("SubscriptionContext", currentObject + "?" + queryParams.toString()); let address = url.href; console.log(address); - fetch(address, get) + fetch(address) .then(response => response.text()) .then(data => { var checkbox = document.getElementById("subscriptionPollingCheckbox"); @@ -308,12 +307,10 @@ request.send(formData); } else { - address += '&_bodyOverride=' + JSON.stringify(json); let post = { - method: 'POST', - headers: { 'Accept' : 'application/json', - 'Content-Type': 'application/json', - 'X-OPENCMW-METHOD-DIS' : 'SET' }, + method: 'POST', + headers: { 'Accept' : 'application/json', 'Content-Type': 'application/json' }, + body: JSON.stringify(json), }; fetch(address, post) diff --git a/concepts/majordomo/helpers.hpp b/concepts/majordomo/helpers.hpp index 16c7e499..295175eb 100644 --- a/concepts/majordomo/helpers.hpp +++ b/concepts/majordomo/helpers.hpp @@ -12,7 +12,6 @@ // OpenCMW Majordomo #include -#include #include CMRC_DECLARE(testImages); @@ -204,43 +203,6 @@ class ImageServiceWorker : public majordomo::Worker -class FileServerRestBackend : public majordomo::RestBackend { -private: - using super_t = majordomo::RestBackend; - std::filesystem::path _serverRoot; - using super_t::_svr; - using super_t::DEFAULT_REST_SCHEME; - -public: - using super_t::RestBackend; - - FileServerRestBackend(majordomo::Broker &broker, const VirtualFS &vfs, std::filesystem::path serverRoot, opencmw::URI<> restAddress = opencmw::URI<>::factory().scheme(DEFAULT_REST_SCHEME).hostName("0.0.0.0").port(majordomo::DEFAULT_REST_PORT).build()) - : super_t(broker, vfs, restAddress), _serverRoot(std::move(serverRoot)) { - } - - void registerHandlers() override { - _svr.set_mount_point("/", _serverRoot.string()); - - _svr.Post("/stdio.html", [](const httplib::Request &request, httplib::Response &response) { - opencmw::debug::log() << "QtWASM:" << request.body; - response.set_content("", "text/plain"); - }); - - auto cmrcHandler = [this](const httplib::Request &request, httplib::Response &response) { - if (super_t::_vfs.is_file(request.path)) { - auto file = super_t::_vfs.open(request.path); - response.set_content(std::string(file.begin(), file.end()), ""); - } - }; - - _svr.Get("/assets/.*", cmrcHandler); - - // Register default handlers - super_t::registerHandlers(); - } -}; - template concept Shutdownable = requires(T s) { s.run(); diff --git a/docs/RestUriMapping.md b/docs/RestUriMapping.md index 6ee04d0e..d185ea4e 100644 --- a/docs/RestUriMapping.md +++ b/docs/RestUriMapping.md @@ -15,23 +15,13 @@ There is no topic in this case. # Request type specification The REST backend maps HTTP requests to Majordomo requests. -By default, `PUT` and `POST` HTTP requests are mapped to Majordomo's `Post`, +By default, `PUT` and `POST` HTTP requests are mapped to Majordomo's `Set`, and `GET` HTTP request is mapped to Majordomo's `Get` request. -It also defines two non-standard HTTP requests -- `SUB`, -which maps to Majordomo's `Subscribe` -and `POLL` which maps to `LongPoll`. +Requests with a query parameter `LongPollingIdx` are treated as `LongPoll` request subscribing to the +given topic (consisting of URI path and other query parameters), with the possible values: -For clients that don't support a custom request operation, -the non-standard HTTP requests can be invoked by defining the -`X-OPENCMW-METHOD` HTTP header. - -If this header is defined, its value is used to override -the HTTP request method. -This means that if HTTP method is `GET` and `X-OPENCMW-METHOD=POLL`, -that the REST backend will treat it as `LongPoll` request. - -Alternatively, one can override the Majordomo method -by setting the `LongPollingIdx` query parameter. -The `Subscription` value of this parameter is mapped to Majordomo's `Subscription` method, -while an integer value will activate the `LongPoll` method. + - `Next`: Redirects to the next notification message that arrives after the request has been processed. + - `Last`: Redirects to the most recent notification message that is in the cache when the request is + processed. If there is no such entry yet, it's treated like `Next`, i.e. waits for the next notification. + - a positive integer value, to retrieve a specific cache entry. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index dd5e82cc..bb982178 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -8,6 +8,7 @@ if(EMSCRIPTEN) message("Disabled majordomo and mdp client module because they are not compatible with emscripten builds") else() add_subdirectory(majordomo) + add_subdirectory(rest) add_subdirectory(zmq) endif() diff --git a/src/client/CMakeLists.txt b/src/client/CMakeLists.txt index 3eb165fb..3ffb281a 100644 --- a/src/client/CMakeLists.txt +++ b/src/client/CMakeLists.txt @@ -12,11 +12,12 @@ if(NOT EMSCRIPTEN) target_link_libraries( client INTERFACE pthread - core - disruptor - serialiser - majordomo - zmq) + core + disruptor + serialiser + majordomo + rest + zmq) else() target_link_libraries( client diff --git a/src/client/include/ClientCommon.hpp b/src/client/include/ClientCommon.hpp index 76d74d6a..4b064aac 100644 --- a/src/client/include/ClientCommon.hpp +++ b/src/client/include/ClientCommon.hpp @@ -43,7 +43,7 @@ constexpr auto find_argument_value_helper(Item &item) { } template -requires std::is_invocable_r_v + requires std::is_invocable_r_v constexpr RequiredType find_argument_value(Func defaultGenerator, Items... args) { auto ret = std::tuple_cat(find_argument_value_helper(args)...); if constexpr (std::tuple_size_v == 0) { @@ -53,8 +53,6 @@ constexpr RequiredType find_argument_value(Func defaultGenerator, Items... args) } } -constexpr const char *ACCEPT_HEADER = "accept"; -constexpr const char *CONTENT_TYPE_HEADER = "content-type"; } // namespace detail class DefaultContentTypeHeader { @@ -62,12 +60,22 @@ class DefaultContentTypeHeader { public: DefaultContentTypeHeader(const MIME::MimeType &type) noexcept - : _mimeType(type){}; + : _mimeType(type) {} DefaultContentTypeHeader(const std::string_view type_str) noexcept - : _mimeType(MIME::getType(type_str)){}; + : _mimeType(MIME::getType(type_str)) {} constexpr operator const MIME::MimeType() const noexcept { return _mimeType; }; }; +class VerifyServerCertificates { + const bool verifyServerCertificates = false; + +public: + VerifyServerCertificates() = default; + VerifyServerCertificates(bool value) noexcept + : verifyServerCertificates(value) {} + constexpr operator bool() const noexcept { return verifyServerCertificates; }; +}; + } // namespace opencmw::client #endif // include guard diff --git a/src/client/include/ClientContext.hpp b/src/client/include/ClientContext.hpp index f6072e61..4650f191 100644 --- a/src/client/include/ClientContext.hpp +++ b/src/client/include/ClientContext.hpp @@ -1,7 +1,6 @@ -#ifndef OPENCMW_CPP_DATASOUCREPUBLISHER_HPP -#define OPENCMW_CPP_DATASOUCREPUBLISHER_HPP +#ifndef OPENCMW_CPP_CLIENTCONTEXT_HPP +#define OPENCMW_CPP_CLIENTCONTEXT_HPP -#include #include #include #include @@ -123,4 +122,4 @@ class ClientContext { }; } // namespace opencmw::client -#endif // OPENCMW_CPP_DATASOUCREPUBLISHER_HPP +#endif // OPENCMW_CPP_CLIENTCONTEXT_HPP diff --git a/src/client/include/RestClientEmscripten.hpp b/src/client/include/RestClientEmscripten.hpp index 66cfc244..4dc091b3 100644 --- a/src/client/include/RestClientEmscripten.hpp +++ b/src/client/include/RestClientEmscripten.hpp @@ -1,11 +1,11 @@ #ifndef OPENCMW_CPP_RESTCLIENT_EMSCRIPTEN_HPP #define OPENCMW_CPP_RESTCLIENT_EMSCRIPTEN_HPP +#include #include #include #include -#include #include #include @@ -73,13 +73,13 @@ std::array getPreferredContentTypeHeader(const URI &uri, if (const auto acceptHeader = uri.queryParamMap().find("contentType"); acceptHeader != uri.queryParamMap().end() && acceptHeader->second) { mimeType = acceptHeader->second->c_str(); } - return { ACCEPT_HEADER, mimeType, CONTENT_TYPE_HEADER, mimeType }; + return { "accept", mimeType, "content-type", mimeType }; } struct FetchPayload { Command command; - FetchPayload(Command &&_command) + explicit FetchPayload(Command &&_command) : command(std::move(_command)) {} FetchPayload(const FetchPayload &other) = delete; @@ -94,8 +94,7 @@ struct FetchPayload { if (!command.callback) { return; } - const bool msgOK = status >= 200 && status < 400; - const auto errorMsg = msgOK ? errorMsgExt : std::format("{} - {}{}{}", status, errorMsgExt, body.empty() ? "" : ":", body); + const bool msgOK = status >= 200 && status < 400; try { command.callback(mdp::Message{ .id = 0, @@ -105,7 +104,7 @@ struct FetchPayload { .clientRequestID = command.clientRequestID, .topic = command.topic, .data = msgOK ? IoBuffer(body.data(), body.size()) : IoBuffer(), - .error = std::string{ errorMsg }, + .error = msgOK ? std::string(errorMsgExt) : std::format("{} - {}{}{}", status, errorMsgExt, body.empty() ? "" : ":", body), .rbac = IoBuffer() }); } catch (const std::exception &e) { std::cerr @@ -134,9 +133,12 @@ struct SubscriptionPayload; static std::unordered_set, detail::pointer_hash, detail::pointer_equals> subscriptionPayloads; struct SubscriptionPayload : FetchPayload { - bool _live = true; - MIME::MimeType _mimeType; - std::size_t _update = 0; + bool _live = true; + MIME::MimeType _mimeType; + std::size_t _update = 0; + + static constexpr std::size_t kParallelLongPollingRequests = 3; + std::vector _requestedIndexes; SubscriptionPayload(Command &&_command, MIME::MimeType mimeType) : FetchPayload(std::move(_command)), _mimeType(std::move(mimeType)) {} @@ -149,13 +151,25 @@ struct SubscriptionPayload : FetchPayload { SubscriptionPayload &operator=(SubscriptionPayload &&other) noexcept = default; - void requestNext() { - auto uri = opencmw::URI::UriFactory(command.topic).addQueryParameter("LongPollingIdx", (_update == 0) ? "Next" : std::format("{}", _update)).build(); - // std::print("URL 1 >>> {}, thread {}\n", uri.relativeRef(), std::this_thread::get_id()); + void sendFollowUpRequestsFor(std::uint64_t longPollingIdx) { + auto it = std::ranges::find(_requestedIndexes, longPollingIdx); + if (it != _requestedIndexes.end()) { + _requestedIndexes.erase(it); + } + for (std::uint64_t i = longPollingIdx + 1; i <= longPollingIdx + kParallelLongPollingRequests; ++i) { + if (std::ranges::find(_requestedIndexes, i) == _requestedIndexes.end()) { + _requestedIndexes.push_back(i); + request(std::to_string(i)); + } + } + } + void request(std::string longPollingIndex) { + auto uri = opencmw::URI::UriFactory(command.topic).addQueryParameter("LongPollingIdx", longPollingIndex).build(); auto preferredHeader = detail::getPreferredContentTypeHeader(command.topic, _mimeType); + std::array preferredHeaderEmscripten; std::transform(preferredHeader.cbegin(), preferredHeader.cend(), preferredHeaderEmscripten.begin(), - [](const auto &str) { return str.c_str(); }); + [](const auto &str) { return str.c_str(); }); preferredHeaderEmscripten[preferredHeaderEmscripten.size() - 1] = nullptr; emscripten_fetch_attr_t attr{}; @@ -178,21 +192,28 @@ struct SubscriptionPayload : FetchPayload { attr.attributes = EMSCRIPTEN_FETCH_LOAD_TO_MEMORY; attr.requestHeaders = preferredHeaderEmscripten.data(); attr.onsuccess = [](emscripten_fetch_t *fetch) { - auto payloadIt = getPayloadIt(fetch); - auto &payload = *payloadIt; - // std::print("received update: {}, {}\n", fetch->url, payload->_update); + auto payloadIt = getPayloadIt(fetch); + auto &payload = *payloadIt; + std::uint64_t longPollingIdx = 0; if (payload->_live) { std::string finalURL = getFinalURL(fetch->id); std::string longPollingIdxString = opencmw::URI<>(finalURL).queryParamMap().at("LongPollingIdx").value_or("0"); - char *end = longPollingIdxString.data() + longPollingIdxString.size(); - std::size_t longPollingIdx = strtoull(longPollingIdxString.data(), &end, 10); + + char *end = nullptr; + longPollingIdx = strtoull(longPollingIdxString.data(), &end, 10); + if (end != longPollingIdxString.data() + longPollingIdxString.size()) { + std::println(std::cerr, "RestClientEmscripten::payloadError: url: {}, bytes: {}\n", fetch->url, fetch->numBytes); + return; + } + if (payload->_update != 0 && longPollingIdx != payload->_update) { std::print("received unexpected update: {}, expected {}\n", longPollingIdx, payload->_update); } payload->onsuccess(fetch->status, std::string_view(fetch->data, detail::checkedStringViewSize(fetch->numBytes)), static_cast(longPollingIdx) - static_cast(payload->_update)); emscripten_fetch_close(fetch); - payload->_update = longPollingIdx + 1; - payload->requestNext(); + + payload->_update = longPollingIdx; + payload->sendFollowUpRequestsFor(longPollingIdx); } else { detail::subscriptionPayloads.erase(payloadIt); } @@ -233,7 +254,7 @@ class RestClient : public ClientBase { * Initialises a basic RestClient * * usage example: - * RestClient client("clientName", DefaultContentTypeHeader(MIME::HTML), MinIoThreads(2), MaxIoThreads(5), ClientCertificates(testCertificate)) + * RestClient client("clientName", DefaultContentTypeHeader(MIME::HTML), ClientCertificates(testCertificate)) * * @tparam Args see argument example above. Order is arbitrary. * @param initArgs @@ -243,9 +264,9 @@ class RestClient : public ClientBase { : _name(detail::find_argument_value([] { return "RestClient"; }, initArgs...)) , _mimeType(detail::find_argument_value([] { return MIME::BINARY; }, initArgs...)) { } - ~RestClient() { RestClient::stop(); }; + ~RestClient() { RestClient::stop(); } - void stop() override {}; + void stop() override {} std::vector protocols() noexcept override { return { "http", "https" }; } @@ -282,19 +303,18 @@ class RestClient : public ClientBase { emscripten_fetch_attr_t attr; emscripten_fetch_attr_init(&attr); - std::string body(cmd.data.asString()); + auto payload = std::make_unique(std::move(cmd)); + attr.userData = payload.get(); - if (cmd.command == opencmw::mdp::Command::Set) { + if (payload->command.command == opencmw::mdp::Command::Set) { strcpy(attr.requestMethod, "POST"); - attr.requestData = reinterpret_cast(body.data()); + auto body = payload->command.data.asString(); + attr.requestData = body.data(); attr.requestDataSize = body.size(); } else { strcpy(attr.requestMethod, "GET"); } - auto payload = std::make_unique(std::move(cmd)); - attr.userData = payload.get(); - detail::fetchPayloads.insert(std::move(payload)); static auto getPayload = [](emscripten_fetch_t *fetch) { auto *rawPayload = fetch->userData; auto it = detail::fetchPayloads.find(rawPayload); @@ -308,7 +328,6 @@ class RestClient : public ClientBase { attr.attributes = EMSCRIPTEN_FETCH_LOAD_TO_MEMORY; attr.requestHeaders = preferredHeaderEmscripten.data(); attr.onsuccess = [](emscripten_fetch_t *fetch) { - // std::print("RestClientEmscripten: got get/set reply: {}\n", fetch->url); getPayload(fetch)->onsuccess(fetch->status, std::string_view(fetch->data, detail::checkedStringViewSize(fetch->numBytes))); emscripten_fetch_close(fetch); }; @@ -319,8 +338,8 @@ class RestClient : public ClientBase { // TODO: Pass the payload as POST body: emscripten_fetch(&attr, uri.relativeRef()->data()); - const auto uri = URI<>::factory(cmd.topic).addQueryParameter("_bodyOverride", body).build(); - emscripten_fetch(&attr, uri.str().data()); + emscripten_fetch(&attr, payload->command.topic.str().data()); + detail::fetchPayloads.insert(std::move(payload)); } void startSubscription(Command &&cmd) { @@ -330,7 +349,7 @@ class RestClient : public ClientBase { std::print("starting subscription: {}, existing subscriptions: {}, from main thread: \n", cmd.topic.str(), detail::subscriptionPayloads.size(), emscripten_is_main_runtime_thread()); if (emscripten_is_main_runtime_thread()) { try { - rawPayload->requestNext(); + rawPayload->request("Next"); } catch (std::runtime_error &e) { rawPayload->onerror(500, e.what(), ""); } catch (...) { @@ -340,7 +359,7 @@ class RestClient : public ClientBase { emscripten_async_run_in_main_runtime_thread(EM_FUNC_SIG_IP, +[](void *data) { auto subPayload = reinterpret_cast(data); try { - subPayload->requestNext(); + subPayload->request("Next"); } catch (std::runtime_error &e) { subPayload->onerror(500, e.what(), ""); } catch (...) { @@ -351,7 +370,7 @@ class RestClient : public ClientBase { } void stopSubscription(Command &&cmd) { - auto payloadIt = std::find_if(detail::subscriptionPayloads.begin(), detail::subscriptionPayloads.end(), + auto payloadIt = std::ranges::find_if(detail::subscriptionPayloads, [&](const auto &ptr) { return ptr->command.topic == cmd.topic; }); diff --git a/src/client/include/RestClientNative.hpp b/src/client/include/RestClientNative.hpp index 94d53bf2..76e5c51f 100644 --- a/src/client/include/RestClientNative.hpp +++ b/src/client/include/RestClientNative.hpp @@ -1,411 +1,860 @@ -#ifndef OPENCMW_CPP_RESTCLIENT_NATIVE_HPP -#define OPENCMW_CPP_RESTCLIENT_NATIVE_HPP - -#include -#include -#include - -#include -#include -#include -#include -#include - -#include "RestDefaultClientCertificates.hpp" - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wformat-nonliteral" -#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" -#pragma GCC diagnostic ignored "-Wold-style-cast" -#pragma GCC diagnostic ignored "-Wshadow" -#pragma GCC diagnostic ignored "-Wuninitialized" -#pragma GCC diagnostic ignored "-Wuseless-cast" -#include -#pragma GCC diagnostic pop - +#ifndef OPENCMW_CLIENT_RESTCLIENTNATIVE_HPP +#define OPENCMW_CLIENT_RESTCLIENTNATIVE_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "ClientCommon.hpp" +#include "ClientContext.hpp" +#include "MdpMessage.hpp" +#include "MIME.hpp" +#include "rest/RestUtils.hpp" +#include "Topic.hpp" +#ifdef OPENCMW_PROFILE_HTTP +#include "LoadTest.hpp" +#endif namespace opencmw::client { +enum class SubscriptionMode { + Next, + Last, + None +}; + +namespace detail { -inline constexpr static const char *LONG_POLLING_IDX_TAG = "LongPollingIdx"; +using namespace opencmw::rest::detail; -class MinIoThreads { - const int _minThreads = 1; +template +struct SharedQueue { + // TODO use a lock-free queue? This is only used client-side, so not that critical + std::deque deque; + std::mutex mutex; -public: - MinIoThreads() = default; - MinIoThreads(int value) noexcept - : _minThreads(value) {}; - constexpr operator int() const noexcept { return _minThreads; }; + void push(T v) { + std::lock_guard lock(mutex); + deque.push_back(std::move(v)); + } + + std::optional try_get() { + std::lock_guard lock(mutex); + if (deque.empty()) { + return {}; + } + auto result = std::move(deque.front()); + deque.pop_front(); + return result; + } }; -class MaxIoThreads { - const int _maxThreads = 10'000; +struct RequestResponse { + // request data + client::Command request; + std::unique_ptr body; + std::string normalizedTopic; + + // response data + std::string responseStatus; + std::string location; // for redirects + std::optional longPollingIdx; + std::string payload; + mdp::Message response; + + void fillResponse() { + response.id = 0; + response.arrivalTime = std::chrono::system_clock::now(); + response.protocolName = request.topic.scheme().value_or(""); + response.rbac.clear(); + + switch (request.command) { + case mdp::Command::Get: + case mdp::Command::Set: + response.command = mdp::Command::Final; + response.clientRequestID = request.clientRequestID; + break; + case mdp::Command::Subscribe: + response.command = mdp::Command::Notify; + break; + default: + break; + } + } -public: - MaxIoThreads() = default; - MaxIoThreads(int value) noexcept - : _maxThreads(value) {}; - constexpr operator int() const noexcept { return _maxThreads; }; + void reportError(std::string error) { + if (!request.callback) { + HTTP_DBG("Client::reportError: {}", error); + return; + } + fillResponse(); + response.topic = request.topic; + response.error = std::move(error); + response.data.clear(); + request.callback(std::move(response)); + } }; -struct ClientCertificates { - std::string _certificates; +constexpr std::size_t kParallelLongPollingRequests = 3; - ClientCertificates() = default; - ClientCertificates(const char *X509_ca_bundle) noexcept - : _certificates(X509_ca_bundle) {}; - ClientCertificates(const std::string &X509_ca_bundle) noexcept - : _certificates(X509_ca_bundle) {}; - constexpr operator std::string() const noexcept { return _certificates; }; +struct Subscription { + client::Command request; + SubscriptionMode mode; + std::optional lastReceivedLongPollingIdx; + std::vector> callbacks; }; -namespace detail { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -inline int readCertificateBundleFromBuffer(X509_STORE &cert_store, const std::string_view &X509_ca_bundle) { - BIO *cbio = BIO_new_mem_buf(X509_ca_bundle.data(), static_cast(X509_ca_bundle.size())); - if (!cbio) { - return -1; - } - STACK_OF(X509_INFO) *inf = PEM_X509_INFO_read_bio(cbio, nullptr, nullptr, nullptr); +struct Endpoint { + std::string scheme; + std::string host; + uint16_t port; - if (!inf) { - BIO_free(cbio); // cleanup - return -1; - } - // iterate over all entries from the pem file, add them to the x509_store one by one - int count = 0; - for (int i = 0; i < sk_X509_INFO_num(inf); i++) { - X509_INFO *itmp = sk_X509_INFO_value(inf, i); - if (itmp->x509) { - X509_STORE_add_cert(&cert_store, itmp->x509); - count++; - } - if (itmp->crl) { - X509_STORE_add_crl(&cert_store, itmp->crl); - count++; + auto operator<=>(const Endpoint &) const = default; +}; + +template +struct ClientSessionBase { + struct PendingRequest { + client::Command command; + SubscriptionMode mode; + std::string preferredMimeType; + std::optional longPollIdx; + }; + std::map _subscriptions; + std::map _requestsByStreamId; + + [[nodiscard]] constexpr auto &self() noexcept { return *static_cast(this); } + [[nodiscard]] constexpr const auto &self() const noexcept { return *static_cast(this); } + + bool addHeader(TStreamId streamId, std::string_view nameView, std::string_view valueView) { + HTTP_DBG("Client::Header: id={} {} = {}", streamId, nameView, valueView); + if (nameView == ":status") { + _requestsByStreamId[streamId].responseStatus = std::string(valueView); + } else if (nameView == "location") { + _requestsByStreamId[streamId].location = std::string(valueView); + } else if (nameView == "x-opencmw-topic") { + try { + _requestsByStreamId[streamId].response.topic = URI<>(std::string(valueView)); + } catch (const std::exception &e) { + HTTP_DBG("Client::Header: Could not parse URI '{}': {}", valueView, e.what()); + return false; + } + } else if (nameView == "x-opencmw-service-name") { + _requestsByStreamId[streamId].response.serviceName = std::string(valueView); + } else if (nameView == "x-opencmw-long-polling-idx") { + std::uint64_t longPollingIdx; + if (auto ec = std::from_chars(valueView.data(), valueView.data() + valueView.size(), longPollingIdx); ec.ec != std::errc{}) { + HTTP_DBG("Client::Header: Could not parse x-opencmw-long-polling-idx '{}'", valueView); + return false; + } + _requestsByStreamId[streamId].longPollingIdx = longPollingIdx; +#ifdef OPENCMW_PROFILE_HTTP + } else if (nameView == "x-timestamp") { + std::println(std::cerr, "Client::Header: x-timestamp: {} (latency {} ns)", valueView, latency(valueView).count()); +#endif } + return true; } - sk_X509_INFO_pop_free(inf, X509_INFO_free); - BIO_free(cbio); - return count; -} + void submitRequest(client::Command &&cmd, SubscriptionMode mode, std::string preferredMimeType, std::optional longPollIdx) { + std::string longPollIdxParam; + if (longPollIdx) { + longPollIdxParam = std::to_string(*longPollIdx); + } else { + switch (mode) { + case SubscriptionMode::Next: + longPollIdxParam = "Next"; + break; + case SubscriptionMode::Last: + longPollIdxParam = "Last"; + break; + case SubscriptionMode::None: + break; + } + } -inline X509_STORE *createCertificateStore(const std::string_view &X509_ca_bundle) { - X509_STORE *cert_store = X509_STORE_new(); - if (detail::readCertificateBundleFromBuffer(*cert_store, X509_ca_bundle) <= 0) { - X509_STORE_free(cert_store); - throw std::invalid_argument(std::format("failed to read certificate bundle from buffer:\n#---start---\n{}\n#---end---\n", X509_ca_bundle)); - } - return cert_store; -} - -inline X509 *readServerCertificateFromFile(const std::string_view &X509_ca_bundle) { - BIO *certBio = BIO_new(BIO_s_mem()); - BIO_write(certBio, X509_ca_bundle.data(), static_cast(X509_ca_bundle.size())); - X509 *certX509 = PEM_read_bio_X509(certBio, nullptr, nullptr, nullptr); - BIO_free(certBio); - if (certX509) { - return certX509; - } - X509_free(certX509); - throw std::invalid_argument(std::format("failed to read certificate from buffer:\n#---start---\n{}\n#---end---\n", X509_ca_bundle)); -} - -inline EVP_PKEY *readServerPrivateKeyFromFile(const std::string_view &X509_private_key) { - BIO *certBio = BIO_new(BIO_s_mem()); - BIO_write(certBio, X509_private_key.data(), static_cast(X509_private_key.size())); - EVP_PKEY *privateKeyX509 = PEM_read_bio_PrivateKey(certBio, nullptr, nullptr, nullptr); - BIO_free(certBio); - if (privateKeyX509) { - return privateKeyX509; - } - EVP_PKEY_free(privateKeyX509); - throw std::invalid_argument(std::format("failed to read private key from buffer")); -} + auto topic = cmd.topic; + if (!longPollIdxParam.empty()) { + topic = URI<>::UriFactory(topic).addQueryParameter("LongPollingIdx", longPollIdxParam).build(); + } + const auto host = cmd.topic.hostName().value_or(""); + const auto scheme = cmd.topic.scheme().value_or(""); + const auto path = topic.relativeRefNoFragment().value_or("/"); +#ifdef OPENCMW_PROFILE_HTTP + const auto ts = std::to_string(opencmw::load_test::timestamp().count()); #endif + constexpr uint8_t noCopy = NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE; -} // namespace detail + const auto method = (cmd.command == mdp::Command::Set) ? u8span("POST") : u8span("GET"); -class RestClient : public ClientBase { - static const httplib::Headers EVT_STREAM_HEADERS; - using ThreadPoolType = std::shared_ptr>; - - std::string _name; - MIME::MimeType _mimeType; - std::atomic _run = true; - const int _minIoThreads; - const int _maxIoThreads; - ThreadPoolType _thread_pool; - std::string _caCertificate; - - std::mutex _subscriptionLock; - std::map, httplib::Client> _subscription1; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::map, httplib::SSLClient> _subscription2; + auto headers = std::vector{ + nv(u8span(":method"), method, noCopy), // + nv(u8span(":path"), u8span(path)), // + nv(u8span(":scheme"), u8span(scheme)), // + nv(u8span(":authority"), u8span(host)), +#ifdef OPENCMW_PROFILE_HTTP + nv(u8span("x-timestamp"), u8span(ts)) #endif + }; + if (!preferredMimeType.empty()) { + headers.push_back(nv(u8span("accept"), u8span(preferredMimeType))); + headers.push_back(nv(u8span("content-type"), u8span(preferredMimeType))); + } -public: - static bool CHECK_CERTIFICATES; - - /** - * Initialises a basic RestClient - * - * usage example: - * RestClient client("clientName", DefaultContentTypeHeader(MIME::HTML), MinIoThreads(2), MaxIoThreads(5), ClientCertificates(testCertificate)) - * - * @tparam Args see argument example above. Order is arbitrary. - * @param initArgs - */ - template - explicit(false) RestClient(Args... initArgs) - : _name(detail::find_argument_value([] { return "RestClient"; }, initArgs...)), // - _mimeType(detail::find_argument_value([] { return MIME::JSON; }, initArgs...)) - , _minIoThreads(detail::find_argument_value([] { return MinIoThreads(); }, initArgs...)) - , _maxIoThreads(detail::find_argument_value([] { return MaxIoThreads(); }, initArgs...)) - , _thread_pool(detail::find_argument_value([this] { return std::make_shared>(_name, _minIoThreads, _maxIoThreads); }, initArgs...)) - , _caCertificate(detail::find_argument_value([] { return rest::DefaultCertificate().get(); }, initArgs...)) { - } - ~RestClient() override { RestClient::stop(); }; + RequestResponse rr; + rr.request = std::move(cmd); + try { + rr.normalizedTopic = mdp::Topic::fromMdpTopic(rr.request.topic).toZmqTopic(); + } catch (...) { + rr.normalizedTopic = rr.request.topic.str(); + } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::vector protocols() noexcept override { return { "http", "https" }; } -#else - std::vector protocols() noexcept override { return { "http" }; } -#endif - void stop() noexcept override { stopAllSubscriptions(); }; - [[nodiscard]] std::string name() const noexcept { return _name; } - [[nodiscard]] ThreadPoolType threadPool() const noexcept { return _thread_pool; } - [[nodiscard]] MIME::MimeType defaultMimeType() const noexcept { return _mimeType; } - [[nodiscard]] std::string clientCertificate() const noexcept { return _caCertificate; } + if (!rr.request.data.empty()) { + // we need a pointer that survives rr being moved + rr.body = std::make_unique(std::move(rr.request.data)); + } - void request(Command cmd) override { - switch (cmd.command) { - case mdp::Command::Get: - case mdp::Command::Set: - _thread_pool->execute([this, cmd = std::move(cmd)]() mutable { executeCommand(std::move(cmd)); }); - return; - case mdp::Command::Subscribe: - _thread_pool->execute([this, cmd = std::move(cmd)]() mutable { startSubscription(std::move(cmd)); }); - return; - case mdp::Command::Unsubscribe: // deregister existing subscription URI is key - _thread_pool->execute([this, cmd = std::move(cmd)]() mutable { stopSubscription(cmd); }); + const TStreamId streamId = self().submitRequestImpl(headers, rr.body.get()); + if (streamId < 0) { + rr.reportError(std::format("Could not submit request: {}", nghttp2_strerror(streamId))); return; - default: - throw std::invalid_argument("command type is undefined"); } + + _requestsByStreamId.emplace(streamId, std::move(rr)); + } + + int processResponse(TStreamId streamId) { + auto it = _requestsByStreamId.find(streamId); + assert(it != _requestsByStreamId.end()); + if (it != _requestsByStreamId.end()) { + const auto &request = it->second.request; + if (it->second.responseStatus == "302") { + std::optional> location; + try { + location = URI<>(it->second.location); + } catch (const std::exception &e) { + HTTP_DBG("Client::Header: Could not parse URI '{}': {}", it->second.location, e.what()); + it->second.reportError(std::format("Could not parse redirect URI '{}': {}", it->second.location, e.what())); + _requestsByStreamId.erase(it); + return static_cast(NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE); + } + const auto &queryParams = location->queryParamMap(); + auto longPollIdxIt = queryParams.find("LongPollingIdx"); + if (longPollIdxIt == queryParams.end() || !longPollIdxIt->second.has_value()) { + HTTP_DBG("Client::Header: Could not find LongPollingIdx in URI '{}'", it->second.location); + it->second.reportError(std::format("Could not find LongPollingIdx in URI '{}'", it->second.location)); + _requestsByStreamId.erase(it); + return static_cast(NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE); + } + std::uint64_t longPollingIdx; + if (auto ec = std::from_chars(longPollIdxIt->second->data(), longPollIdxIt->second->data() + longPollIdxIt->second->size(), longPollingIdx); ec.ec != std::errc{}) { + HTTP_DBG("Client::Header: Could not parse numerical LongPollingIdx from '{}'", longPollIdxIt->second); + it->second.reportError(std::format("Could not parse numerical LongPollingIdx from '{}'", longPollIdxIt->second)); + _requestsByStreamId.erase(it); + return static_cast(NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE); + } + // Redirect to long-polling URL + sendLongPollRequests(it->second.normalizedTopic, longPollingIdx, longPollingIdx + kParallelLongPollingRequests - 1); + } else if (it->second.longPollingIdx && it->second.responseStatus == "504") { + // Server timeout on long-poll, resend same request + sendLongPollRequests(it->second.normalizedTopic, it->second.longPollingIdx.value(), it->second.longPollingIdx.value()); + } else { + it->second.fillResponse(); + auto response = std::move(it->second.response); + const auto hasError = !it->second.responseStatus.starts_with("2") && !it->second.responseStatus.starts_with("3"); + if (hasError) { + response.error = std::move(it->second.payload); + } else { + response.data = IoBuffer(it->second.payload.data(), it->second.payload.size()); + } + if (it->second.longPollingIdx) { + // Subscription + handleSubscriptionResponse(it->second.normalizedTopic, it->second.longPollingIdx.value(), std::move(response)); + } else { + // GET/SET + if (request.callback) { + request.callback(std::move(response)); + } + } + } + _requestsByStreamId.erase(it); + } else { + HTTP_DBG("Client::Frame: Could not find request for stream id {}", streamId); + } + return 0; } -private: - httplib::Headers getPreferredContentTypeHeader(const URI &uri) const { - auto mimeType = std::string(_mimeType.typeName()); - if (const auto acceptHeader = uri.queryParamMap().find(detail::ACCEPT_HEADER); acceptHeader != uri.queryParamMap().end() && acceptHeader->second) { - mimeType = acceptHeader->second->c_str(); + void reportErrorToAllPendingRequests(std::string_view error) { + for (auto &[streamId, request] : _requestsByStreamId) { + request.reportError(std::string{ error }); } - const httplib::Headers headers = { { detail::ACCEPT_HEADER, mimeType }, { detail::CONTENT_TYPE_HEADER, mimeType } }; - return headers; + _requestsByStreamId.clear(); } - static void returnMdpMessage(Command &cmd, const httplib::Result &result, const std::string &errorMsgExt = "") noexcept { - if (!cmd.callback) { + void handleSubscriptionResponse(std::string zmqTopic, std::uint64_t longPollingIdx, mdp::Message &&response) { + auto subIt = _subscriptions.find(zmqTopic); + if (subIt == _subscriptions.end()) { + HTTP_DBG("Client::handleSubscriptionResponse: Could not find subscription for topic '{}'", zmqTopic); return; } + auto &sub = subIt->second; + sub.lastReceivedLongPollingIdx = longPollingIdx; + auto request = sub.request; - const auto errorMsg = [&]() -> std::optional { - // Result contains a nullptr - if (!result) { - return errorMsgExt.empty() ? "Unknown error, empty result" : errorMsgExt; - } + submitRequest(std::move(request), sub.mode, {}, longPollingIdx + kParallelLongPollingRequests); - // No error - if (result && result->status >= 200 && result->status < 400 && errorMsgExt.empty()) { - return {}; + for (std::size_t i = 0; i < sub.callbacks.size(); ++i) { + if (i < sub.callbacks.size() - 1) { + auto copy = response; + sub.callbacks[i](std::move(copy)); + } else { + sub.callbacks[i](std::move(response)); } + } + } - const auto httpError = httplib::status_message(result->status); - return std::format("{} - {}:{}", result->status, httpError, errorMsgExt.empty() ? result->body : errorMsgExt); - }(); + void sendLongPollRequests(std::string zmqTopic, std::uint64_t fromLongPollingIdx, std::uint64_t toLongPollingIdx) { + auto subIt = _subscriptions.find(zmqTopic); + if (subIt == _subscriptions.end()) { + HTTP_DBG("Client::sendLongPollingRequests: Could not find subscription for topic '{}'", zmqTopic); + return; + } + auto &sub = subIt->second; + for (std::uint64_t longPollingIdx = fromLongPollingIdx; longPollingIdx <= toLongPollingIdx; ++longPollingIdx) { + auto request = sub.request; + submitRequest(std::move(request), sub.mode, {}, longPollingIdx); + } + } + void startSubscription(client::Command &&command, SubscriptionMode mode = SubscriptionMode::Next) { + mdp::Topic topic; try { - cmd.callback(mdp::Message{ - .id = 0, - .arrivalTime = std::chrono::system_clock::now(), - .protocolName = cmd.topic.scheme().value(), - .command = mdp::Command::Final, - .clientRequestID = cmd.clientRequestID, - .topic = cmd.topic, - .data = errorMsg ? IoBuffer() : IoBuffer(result->body.data(), result->body.size()), - .error = errorMsg.value_or(""), - .rbac = IoBuffer() }); + topic = mdp::Topic::fromMdpTopic(command.topic); } catch (const std::exception &e) { - std::cerr << std::format("caught exception '{}' in RestClient::returnMdpMessage(cmd={}, {}: {})", e.what(), cmd.topic, result->status, result.value().body) << std::endl; - } catch (...) { - std::cerr << std::format("caught unknown exception in RestClient::returnMdpMessage(cmd={}, {}: {})", cmd.topic, result->status, result.value().body) << std::endl; + HTTP_DBG("Client::startSubscription: Could not parse topic '{}': {}", command.topic.str(), e.what()); + return; + } + const auto [subIt, inserted] = _subscriptions.try_emplace(topic.toZmqTopic(), Subscription{}); + subIt->second.request = command; + subIt->second.callbacks.emplace_back(command.callback); + if (inserted) { + submitRequest(std::move(command), mode, {}, {}); } } - void executeCommand(Command &&cmd) const { - using namespace std::string_literals; - std::cout << "RestClient::request(" << (cmd.topic.str()) << ")" << std::endl; - auto preferredHeader = getPreferredContentTypeHeader(cmd.topic); + void stopSubscription(client::Command &&command) { + // TODO a single unsubscribe cancels this also in case of multiple subscriptions when the client is shared + // inside an application. Would be great if we could selectively unsubscribe certain callbacks and finally + // stop the subscription when all callbacks are removed. + mdp::Topic topic; + try { + topic = mdp::Topic::fromMdpTopic(command.topic); + } catch (const std::exception &e) { + HTTP_DBG("Client::stopSubscription: Could not parse topic '{}': {}", command.topic.str(), e.what()); + return; + }; + if (auto subIt = _subscriptions.find(topic.toZmqTopic()); subIt != _subscriptions.end()) { + // Cancel all requests for this topic + auto reqIt = _requestsByStreamId.begin(); + while (reqIt != _requestsByStreamId.end()) { + if (reqIt->second.request.topic == command.topic) { + self().cancelStream(reqIt->first); + reqIt = _requestsByStreamId.erase(reqIt); + } else { + ++reqIt; + } + } + } + } +}; - auto endpointBuilder = URI<>::factory(cmd.topic); +struct Http2ClientSession : public ClientSessionBase { + TcpSocket _socket; + nghttp2_session *_session = nullptr; + WriteBuffer<1024> _writeBuffer; + + explicit Http2ClientSession(TcpSocket socket_) + : _socket(std::move(socket_)) { + nghttp2_session_callbacks *callbacks; + nghttp2_session_callbacks_new(&callbacks); + + nghttp2_session_callbacks_set_send_callback2(callbacks, [](nghttp2_session *, const uint8_t *data, size_t length, int flags, void *user_data) { + auto client = static_cast(user_data); + HTTP_DBG("Client::send {}", length); + const auto r = client->_socket.write(data, length, flags); + if (r < 0) { + HTTP_DBG("Client::send failed: {}", client->_socket.lastError()); + } + return r; + }); + + nghttp2_session_callbacks_set_on_data_chunk_recv_callback(callbacks, [](nghttp2_session *, uint8_t /*flags*/, int32_t stream_id, const uint8_t *data, size_t len, void *user_data) { + auto client = static_cast(user_data); + client->_requestsByStreamId[stream_id].payload.append(reinterpret_cast(data), len); + return 0; + }); + + nghttp2_session_callbacks_set_on_header_callback(callbacks, [](nghttp2_session *, const nghttp2_frame *frame, const uint8_t *name, size_t namelen, const uint8_t *value, size_t valuelen, uint8_t /*flags*/, void *user_data) { + auto client = static_cast(user_data); + const auto nameView = std::string_view(reinterpret_cast(name), namelen); + const auto valueView = std::string_view(reinterpret_cast(value), valuelen); + if (!client->addHeader(frame->hd.stream_id, nameView, valueView)) { + return static_cast(NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE); + } + return 0; + }); + nghttp2_session_callbacks_set_on_frame_recv_callback(callbacks, [](nghttp2_session *, const nghttp2_frame *frame, void *user_data) { + HTTP_DBG("Client::Frame: id={} {} {}", frame->hd.stream_id, frame->hd.type, (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) ? "END_STREAM" : ""); + switch (frame->hd.type) { + case NGHTTP2_HEADERS: + case NGHTTP2_DATA: + if (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) { + auto client = static_cast(user_data); + return client->processResponse(frame->hd.stream_id); + } + break; + } + return 0; + }); + nghttp2_session_callbacks_set_on_stream_close_callback(callbacks, [](nghttp2_session *, int32_t stream_id, uint32_t /*error_code*/, void *user_data) { + auto client = static_cast(user_data); + client->_requestsByStreamId.erase(stream_id); + HTTP_DBG("Client::Stream closed: {}", stream_id); + return 0; + }); + + nghttp2_session_client_new(&_session, callbacks, this); + nghttp2_session_callbacks_del(callbacks); + + nghttp2_settings_entry iv[1] = { + { NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 1000 } + }; - if (cmd.command == mdp::Command::Set) { - preferredHeader.insert(std::make_pair("X-OPENCMW-METHOD"s, "PUT"s)); - endpointBuilder = std::move(endpointBuilder).addQueryParameter("_bodyOverride", std::string(cmd.data.asString())); + if (nghttp2_submit_settings(_session, NGHTTP2_FLAG_NONE, iv, 1) != 0) { + HTTP_DBG("Client::ClientSession: nghttp2_submit_settings failed"); } + } - auto endpoint = endpointBuilder.build(); + Http2ClientSession(const Http2ClientSession &) = delete; + Http2ClientSession &operator=(const Http2ClientSession &) = delete; + Http2ClientSession(Http2ClientSession &&other) noexcept = delete; + Http2ClientSession &operator=(Http2ClientSession &&other) noexcept = delete; - auto callback = [&cmd, &preferredHeader, &endpoint](ClientType &client) { - client.set_follow_location(true); - client.set_read_timeout(cmd.timeout); // default keep-alive value - if (const httplib::Result &result = client.Get(endpoint.relativeRef()->data(), preferredHeader)) { - returnMdpMessage(cmd, result); - } else { - std::stringstream errorStr(std::format("\"{}\"", static_cast(result.error()))); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (auto sslResult = client.get_openssl_verify_result(); sslResult) { - errorStr << std::format(" - SSL error: '{}'", X509_verify_cert_error_string(sslResult)); - } -#endif - const std::string errorMsg = std::format("GET request failed for: '{}' - {} - CHECK_CERTIFICATES: {}", cmd.topic.str(), errorStr.str(), CHECK_CERTIFICATES); - returnMdpMessage(cmd, result, errorMsg); - } + ~Http2ClientSession() { + nghttp2_session_del(_session); + } + + bool isReady() const { + return _socket._state == TcpSocket::Connected; + } + + std::expected continueToMakeReady() { + auto makeError = [](std::string_view msg) { + return std::unexpected(std::format("Could not connect to endpoint: {}", msg)); }; + assert(!isReady()); + if (_socket._state == detail::TcpSocket::Connecting) { + if (auto rc = _socket.connect(); !rc) { + return makeError(rc.error()); + } + } - if (cmd.topic.scheme() && equal_with_case_ignore(cmd.topic.scheme().value(), "https")) { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - httplib::SSLClient client(cmd.topic.hostName().value(), cmd.topic.port() ? cmd.topic.port().value() : 443); - // client owns its certificate store and destroys it after use. create a store for each client - client.set_ca_cert_store(detail::createCertificateStore(_caCertificate)); - client.enable_server_certificate_verification(CHECK_CERTIFICATES); - callback(client); -#else - throw std::invalid_argument("https is not supported"); -#endif - } else if (cmd.topic.scheme() && equal_with_case_ignore(cmd.topic.scheme().value(), "http")) { - httplib::Client client(cmd.topic.hostName().value(), cmd.topic.port() ? cmd.topic.port().value() : 80); - callback(client); - return; - } else { - if (cmd.topic.scheme()) { - throw std::invalid_argument(std::format("unsupported protocol '{}' for endpoint '{}'", cmd.topic.scheme(), cmd.topic.str())); - } else { - throw std::invalid_argument(std::format("no protocol provided for endpoint '{}'", cmd.topic.str())); + if (_socket._state == detail::TcpSocket::SSLConnectWantsRead || _socket._state == detail::TcpSocket::SSLConnectWantsWrite) { + if (auto rc = _socket.continueHandshake(); !rc) { + return makeError(rc.error()); } } + + return {}; } - bool equal_with_case_ignore(const std::string &a, const std::string &b) const { - return std::ranges::equal(a, b, [](const char ca, const char cb) noexcept { return ::tolower(ca) == ::tolower(cb); }); + bool wantsToRead() const { + return _socket._state == TcpSocket::Connected ? nghttp2_session_want_read(_session) : (_socket._state == TcpSocket::Connecting || _socket._state == TcpSocket::SSLConnectWantsRead); } - void startSubscription(Command &&cmd) { - std::scoped_lock lock(_subscriptionLock); - if (equal_with_case_ignore(*cmd.topic.scheme(), "http") -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - || equal_with_case_ignore(*cmd.topic.scheme(), "https") -#endif - ) { - auto createNewSubscription = [&](auto &client) { - { - client.set_follow_location(true); - - std::size_t longPollingIdx = 0; - const auto pollHeaders = getPreferredContentTypeHeader(cmd.topic); - client.set_read_timeout(cmd.timeout); // default keep-alive value - while (_run) { - auto endpoint = [&]() { - if (longPollingIdx == 0UZ) { - return URI::factory(cmd.topic).addQueryParameter(LONG_POLLING_IDX_TAG, "Next").build().relativeRef().value(); - } else { - return URI::factory(cmd.topic).addQueryParameter(LONG_POLLING_IDX_TAG, std::format("{}", longPollingIdx)).build().relativeRef().value(); - } - }(); - if (const httplib::Result &result = client.Get(endpoint, pollHeaders)) { - returnMdpMessage(cmd, result); - // update long-polling-index - std::string location = result->location.empty() ? endpoint : result->location; - std::string updateIdxString = URI(location).queryParamMap().at(std::string(LONG_POLLING_IDX_TAG)).value_or("0"); - char *end = updateIdxString.data() + updateIdxString.size(); - longPollingIdx = strtoull(updateIdxString.data(), &end, 10) + 1; - } else { // failed or server is down -> wait until retry - if (_run) { - returnMdpMessage(cmd, result, std::format("Long-Polling-GET request failed for {}: {}", cmd.topic.str(), static_cast(result.error()))); - } - std::this_thread::sleep_for(cmd.timeout); // time-out until potential retry + bool wantsToWrite() const { + return _socket._state == TcpSocket::Connected ? _writeBuffer.wantsToWrite(_session) : (_socket._state == TcpSocket::Connecting || _socket._state == TcpSocket::SSLConnectWantsWrite); + } + + int32_t submitRequestImpl(const std::vector &headers, IoBuffer *body) { + nghttp2_data_provider2 data_prd; + data_prd.read_callback = nullptr; + + if (body && !body->empty()) { + data_prd.source.ptr = body; + data_prd.read_callback = [](nghttp2_session *, int32_t /*stream_id*/, uint8_t *buf, size_t length, uint32_t *data_flags, nghttp2_data_source *source, void * /*user_data*/) { + auto ioBuffer = static_cast(source->ptr); + const std::size_t copy_len = std::min(length, ioBuffer->size() - ioBuffer->position()); + std::copy(ioBuffer->data() + ioBuffer->position(), ioBuffer->data() + ioBuffer->position() + copy_len, buf); + ioBuffer->skip(static_cast(copy_len)); + if (ioBuffer->position() == ioBuffer->size()) { + *data_flags |= NGHTTP2_DATA_FLAG_EOF; + } + return static_cast(copy_len); + }; + } + + auto streamId = nghttp2_submit_request2(_session, nullptr, headers.data(), headers.size(), &data_prd, nullptr); + if (streamId < 0) { + HTTP_DBG("Client::submitRequest: nghttp2_submit_request failed: {}", nghttp2_strerror(streamId)); + } + return streamId; + } + + void cancelStream(int32_t streamId) { + HTTP_DBG("Client::cancelStream: id={}", streamId); + if (nghttp2_submit_rst_stream(_session, NGHTTP2_FLAG_NONE, streamId, NGHTTP2_CANCEL) != 0) { + HTTP_DBG("Client::cancelStream: nghttp2_submit_rst_stream failed"); + } + } +}; + +} // namespace detail + +struct ClientCertificates { + std::string _certificates; + + ClientCertificates() = default; + ClientCertificates(std::string X509_ca_bundle) noexcept + : _certificates(std::move(X509_ca_bundle)) {} + constexpr operator std::string() const noexcept { return _certificates; }; +}; + +struct RestClient : public ClientBase { + struct SslSettings { + std::string caCertificate; + bool verifyPeers = true; + }; + bool _forceHttp2 = false; // Force HTTP/2 + std::jthread _worker; + MIME::MimeType _mimeType = opencmw::MIME::JSON; + SslSettings _sslSettings; + std::shared_ptr>> _requestQueue = std::make_shared>>(); + + static std::expected + ensureSession(detail::SSL_CTX_Ptr &ssl_ctx, std::map> &sessions, const SslSettings &sslSettings, URI<> topic) { + if (topic.scheme() != "http" && topic.scheme() != "https") { + return std::unexpected(std::format("Unsupported protocol '{}' for endpoint '{}'", topic.scheme().value_or(""), topic.str())); + } + if (topic.hostName().value_or("").empty()) { + return std::unexpected(std::format("No host provided for endpoint '{}'", topic.str())); + } + const auto port = topic.port().value_or(topic.scheme() == "https" ? 443 : 80); + const auto endpoint = detail::Endpoint{ topic.scheme().value(), topic.hostName().value(), port }; + + auto sessionIt = sessions.find(endpoint); + if (sessionIt != sessions.end()) { + return sessionIt->second.get(); + } + + int socketFlags = detail::TcpSocket::None; + if (topic.scheme() == "https") { + if (!ssl_ctx) { + ssl_ctx = detail::SSL_CTX_Ptr(SSL_CTX_new(SSLv23_client_method()), SSL_CTX_free); + if (!ssl_ctx) { + return std::unexpected(std::format("Could not create SSL/TLS context: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + SSL_CTX_set_options(ssl_ctx.get(), + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_alpn_protos(ssl_ctx.get(), reinterpret_cast("\x02h2"), 3); + + if (sslSettings.verifyPeers) { + SSL_CTX_set_verify(ssl_ctx.get(), SSL_VERIFY_PEER, nullptr); + + if (!sslSettings.caCertificate.empty()) { + auto maybeStore = detail::createCertificateStore(sslSettings.caCertificate); + if (!maybeStore) { + return std::unexpected(std::format("Could not create certificate store: {}", maybeStore.error())); } + SSL_CTX_set_cert_store(ssl_ctx.get(), maybeStore->release()); } } + } + + auto ssl = detail::create_ssl(ssl_ctx.get()); + if (!ssl) { + return std::unexpected(std::format("Failed to create SSL object: {}", ssl.error())); + } + if (sslSettings.verifyPeers) { + socketFlags |= detail::TcpSocket::VerifyPeer; + } + auto maybeSocket = detail::TcpSocket::create(std::move(ssl.value()), socket(AF_INET, SOCK_STREAM, 0), socketFlags); + if (!maybeSocket) { + return std::unexpected(std::format("Failed to create socket: {}", maybeSocket.error())); + } + auto session = std::make_unique(std::move(maybeSocket.value())); + if (auto rc = session->_socket.prepareConnect(endpoint.host, endpoint.port); !rc) { + return std::unexpected(rc.error()); + } + sessionIt = sessions.emplace(endpoint, std::move(session)).first; + return sessionIt->second.get(); + } + + // HTTP + auto maybeSocket = detail::TcpSocket::create({ nullptr, SSL_free }, socket(AF_INET, SOCK_STREAM, 0), socketFlags); + if (!maybeSocket) { + return std::unexpected(std::format("Failed to create socket: {}", maybeSocket.error())); + } + auto session = std::make_unique(std::move(maybeSocket.value())); + if (auto rc = session->_socket.prepareConnect(endpoint.host, endpoint.port); !rc) { + return std::unexpected(rc.error()); + } + sessionIt = sessions.emplace(endpoint, std::move(session)).first; + return sessionIt->second.get(); + } + +public: + template + explicit(false) RestClient(Args... initArgs) + : _mimeType(detail::find_argument_value([] { return MIME::JSON; }, initArgs...)) + , _sslSettings{ + .caCertificate = detail::find_argument_value([] { return ClientCertificates{}; }, initArgs...), + .verifyPeers = detail::find_argument_value([] { return true; }, initArgs...) + } { + _worker = std::jthread([queue = _requestQueue, sslSettings = _sslSettings, mimeType = _mimeType](std::stop_token stopToken) { + auto preferredMimeType = [&mimeType](const URI<> &topic) { + if (const auto contentTypeHeader = topic.queryParamMap().find("contentType"); contentTypeHeader != topic.queryParamMap().end() && contentTypeHeader->second) { + return contentTypeHeader->second.value(); + } + return std::string{ mimeType.typeName() }; }; - if (equal_with_case_ignore(*cmd.topic.scheme(), "http")) { - auto it = _subscription1.find(cmd.topic); - if (it == _subscription1.end()) { - _subscription1.emplace(cmd.topic, httplib::Client(cmd.topic.hostName().value(), cmd.topic.port().value())); - createNewSubscription(_subscription1.at(cmd.topic)); + + detail::SSL_CTX_Ptr ssl_ctx{ nullptr, SSL_CTX_free }; + + std::map> sessions; + + auto reportError = [](Command &cmd, std::string error) { + if (!cmd.callback) { + return; } - } else { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (auto it = _subscription2.find(cmd.topic); it == _subscription2.end()) { - _subscription2.emplace( - std::piecewise_construct, - std::forward_as_tuple(cmd.topic), - std::forward_as_tuple(cmd.topic.hostName().value(), cmd.topic.port().value())); - auto &client = _subscription2.at(cmd.topic); - client.set_ca_cert_store(detail::createCertificateStore(_caCertificate)); - client.enable_server_certificate_verification(CHECK_CERTIFICATES); - createNewSubscription(_subscription2.at(cmd.topic)); + mdp::Message msg; + msg.protocolName = cmd.topic.scheme().value_or(""); + msg.arrivalTime = std::chrono::system_clock::now(); + msg.command = mdp::Command::Final; + msg.clientRequestID = cmd.clientRequestID; + msg.topic = cmd.topic; + msg.error = std::move(error); + cmd.callback(msg); + }; + + std::vector pollFds; + + while (!stopToken.stop_requested()) { + while (auto entry = queue->try_get()) { + auto &[cmd, mode] = entry.value(); + switch (cmd.command) { + case mdp::Command::Get: + case mdp::Command::Set: { + auto session = ensureSession(ssl_ctx, sessions, sslSettings, cmd.topic); + auto preferred = preferredMimeType(cmd.topic); + session.value()->submitRequest(std::move(cmd), mode, std::move(preferred), {}); + } break; + case mdp::Command::Subscribe: { + auto session = ensureSession(ssl_ctx, sessions, sslSettings, cmd.topic); + if (!session) { + reportError(cmd, std::format("Unsupported endpoint '{}': {}", cmd.topic.str(), session.error())); + continue; + } + session.value()->startSubscription(std::move(cmd), mode); + } break; + case mdp::Command::Unsubscribe: { + auto session = ensureSession(ssl_ctx, sessions, sslSettings, cmd.topic); + if (!session) { + reportError(cmd, std::format("Unsupported endpoint '{}': {}", cmd.topic.str(), session.error())); + continue; + } + session.value()->stopSubscription(std::move(cmd)); + } break; + case mdp::Command::Final: + case mdp::Command::Partial: + case mdp::Command::Heartbeat: + case mdp::Command::Notify: + case mdp::Command::Invalid: + case mdp::Command::Ready: + case mdp::Command::Disconnect: + assert(false); // unexpected command + break; + } + } + + pollFds.clear(); + pollFds.reserve(sessions.size()); + for (auto &sessionPair : sessions) { + auto &session = sessionPair.second; + struct pollfd pfd = {}; + pfd.fd = session->_socket.fd; + pfd.events = 0; + if (session->wantsToRead()) { + pfd.events |= POLLIN; + } + if (session->wantsToWrite()) { + pfd.events |= POLLOUT; + } + pollFds.push_back(pfd); + } + + if (pollFds.empty()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + continue; + } + + const auto n = ::poll(pollFds.data(), pollFds.size(), 100); + + if (n < 0) { + HTTP_DBG("poll failed: {}", strerror(errno)); + continue; + } + if (n == 0) { + continue; + } + + auto sessionIt = sessions.begin(); + while (sessionIt != sessions.end()) { + auto &session = sessionIt->second; + + auto pollIt = std::ranges::find_if(pollFds, [&](const struct pollfd &pfd) { + return pfd.fd == session->_socket.fd; + }); + + if (pollIt == pollFds.end()) { + ++sessionIt; + continue; + } + + if (pollIt->revents & POLLERR) { + int error = 0; + socklen_t errlen = sizeof(error); + getsockopt(pollIt->fd, SOL_SOCKET, SO_ERROR, &error, &errlen); + session->reportErrorToAllPendingRequests(strerror(error)); + sessionIt = sessions.erase(sessionIt); + continue; + } + + if (((pollIt->revents & POLLIN) || (pollIt->revents & POLLOUT)) && !session->isReady()) { + if (auto r = session->continueToMakeReady(); !r) { + sessionIt->second->reportErrorToAllPendingRequests(r.error()); + sessionIt = sessions.erase(sessionIt); + continue; + } + } + + if (!session->isReady()) { + ++sessionIt; + continue; + } + + if (pollIt->revents & POLLOUT) { + if (!session->_writeBuffer.write(session->_session, session->_socket)) { + HTTP_DBG("Client: Failed to write to peer (fd={}): {}", session->_socket.fd, session->_socket.lastError()); + sessionIt = sessions.erase(sessionIt); + continue; + } + } + + if (pollIt->revents & POLLIN) { + bool mightHaveMore = true; + bool hasError = false; + + while (mightHaveMore && !hasError) { + std::array buffer; + const auto bytes_read = session->_socket.read(buffer.data(), buffer.size()); + if (bytes_read <= 0 && errno != EAGAIN) { + if (bytes_read < 0) { + HTTP_DBG("Client::read failed: {}", session->_socket.lastError()); + } + hasError = true; + continue; + } + + if (bytes_read > 0 && nghttp2_session_mem_recv2(session->_session, buffer.data(), static_cast(bytes_read)) < 0) { + HTTP_DBG("Client: nghttp2_session_mem_recv2 failed"); + hasError = true; + continue; + } + mightHaveMore = bytes_read == static_cast(buffer.size()); + } + if (hasError) { + sessionIt = sessions.erase(sessionIt); + continue; + } + } + + ++sessionIt; } -#else - throw std::invalid_argument("https is not supported"); -#endif } + }); + } - } else { - throw std::invalid_argument(std::format("unsupported scheme '{}' for requested subscription '{}'", cmd.topic.scheme(), cmd.topic.str())); + ~RestClient() { + stop(); + } + + [[nodiscard]] + MIME::MimeType defaultMimeType() const { + return _mimeType; + } + + [[nodiscard]] + bool verifySslPeers() const { + return _sslSettings.verifyPeers; + } + + [[nodiscard]] + std::vector protocols() override { + return { "http", "https" }; + } + + void stop() override { + _worker.request_stop(); + if (_worker.joinable()) { + _worker.join(); } } - void stopSubscription(const Command &cmd) { - // stop subscription that matches URI - std::scoped_lock lock(_subscriptionLock); - if (equal_with_case_ignore(*cmd.topic.scheme(), "http")) { - auto it = _subscription1.find(cmd.topic); - if (it != _subscription1.end()) { - it->second.stop(); - _subscription1.erase(it); - return; - } - } else if (equal_with_case_ignore(*cmd.topic.scheme(), "https")) { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - auto it = _subscription2.find(cmd.topic); - if (it != _subscription2.end()) { - it->second.stop(); - _subscription2.erase(it); - return; - } -#else - throw std::runtime_error("https is not supported - enable CPPHTTPLIB_OPENSSL_SUPPORT"); -#endif - } else { - throw std::invalid_argument(std::format("unsupported scheme '{}' for requested subscription '{}'", cmd.topic.scheme(), cmd.topic.str())); + void request(Command cmd) override { + switch (cmd.command) { + case mdp::Command::Get: + case mdp::Command::Set: + assert(cmd.callback); + _requestQueue->push({ std::move(cmd), SubscriptionMode::None }); + break; + case mdp::Command::Subscribe: + assert(cmd.callback); + _requestQueue->push({ std::move(cmd), SubscriptionMode::Next }); + break; + case mdp::Command::Unsubscribe: + _requestQueue->push({ std::move(cmd), SubscriptionMode::None }); + break; + default: + assert(false); // unexpected command } } - void stopAllSubscriptions() noexcept { - _run = false; - std::scoped_lock lock(_subscriptionLock); - std::ranges::for_each(_subscription1, [](auto &pair) { pair.second.stop(); }); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::ranges::for_each(_subscription2, [](auto &pair) { pair.second.stop(); }); -#endif + mdp::Message blockingRequest(Command cmd) { + std::promise promise; + auto future = promise.get_future(); + cmd.callback = [&promise](const mdp::Message &msg) { + promise.set_value(std::move(msg)); + }; + request(std::move(cmd)); + return future.get(); } }; -inline bool RestClient::CHECK_CERTIFICATES = true; -inline const httplib::Headers RestClient::EVT_STREAM_HEADERS = { { detail::ACCEPT_HEADER, MIME::EVENT_STREAM.typeName().data() } }; } // namespace opencmw::client -#endif // OPENCMW_CPP_RESTCLIENT_NATIVE_HPP +#endif // OPENCMW_CLIENT_RESTCLIENTNATIVE_HPP diff --git a/src/client/include/RestDefaultClientCertificates.hpp b/src/client/include/RestDefaultClientCertificates.hpp index 61783625..d5a87134 100644 --- a/src/client/include/RestDefaultClientCertificates.hpp +++ b/src/client/include/RestDefaultClientCertificates.hpp @@ -2,6 +2,7 @@ #define OPENCMW_CPP_RESTDEFAULTCLIENTCERTIFICATE_HPP #include +#include #include namespace opencmw::client::rest { diff --git a/src/client/test/CMakeLists.txt b/src/client/test/CMakeLists.txt index c60aa813..ea4c7727 100644 --- a/src/client/test/CMakeLists.txt +++ b/src/client/test/CMakeLists.txt @@ -44,15 +44,15 @@ if(NOT EMSCRIPTEN) # TEST_PREFIX "unittests." REPORTER xml OUTPUT_DIR . OUTPUT_PREFIX "unittests." OUTPUT_SUFFIX .xml) catch_discover_tests(client_tests) - add_executable(rest_client_mock_server_tests catch_main.cpp RestClient_tests.cpp) + add_executable(nghttp2_tests catch_main.cpp nghttp2_tests.cpp) target_link_libraries( - rest_client_mock_server_tests + nghttp2_tests PUBLIC opencmw_project_warnings - opencmw_project_options - test_assets_rest - Catch2::Catch2 - client) - catch_discover_tests(rest_client_mock_server_tests) + opencmw_project_options + test_assets_rest + Catch2::Catch2 + client) + catch_discover_tests(nghttp2_tests) add_executable(clientPublisher_tests catch_main.cpp ClientPublisher_tests.cpp) target_link_libraries( @@ -70,7 +70,8 @@ endif() add_executable(rest_client_only_tests RestClientOnly_tests.cpp) target_link_libraries(rest_client_only_tests PUBLIC opencmw_project_warnings opencmw_project_options client) target_include_directories(rest_client_only_tests PRIVATE ${CMAKE_SOURCE_DIR}) -# This test requires a different kind of invocation as it needs the server running + +# These tests require a different kind of invocation as they need the server running # catch_discover_tests(rest_client_only_tests) if(EMSCRIPTEN) diff --git a/src/client/test/RestClientOnly_tests.cpp b/src/client/test/RestClientOnly_tests.cpp index c007f45c..2efb58e7 100644 --- a/src/client/test/RestClientOnly_tests.cpp +++ b/src/client/test/RestClientOnly_tests.cpp @@ -1,10 +1,4 @@ -#include -#include -#include - -#include #include -#include #include "concepts/client/helpers.hpp" @@ -13,9 +7,13 @@ using namespace std::chrono_literals; // These are not main-local, as JS doesn't end when // C++ main ends namespace test { +#ifndef __EMSCRIPTEN__ +opencmw::client::RestClient client(opencmw::client::VerifyServerCertificates(false)); +#else opencmw::client::RestClient client; +#endif -std::string schema() { +std::string schema() { if (auto env = ::getenv("DISABLE_REST_HTTPS"); env != nullptr && std::string_view(env) == "1") { return "http"; } else { @@ -66,10 +64,6 @@ auto run = rest_test_runner( } // namespace test int main() { -#ifndef __EMSCRIPTEN__ - opencmw::client::RestClient::CHECK_CERTIFICATES = false; -#endif - using namespace test; #ifndef __EMSCRIPTEN__ diff --git a/src/client/test/RestClient_tests.cpp b/src/client/test/RestClient_tests.cpp deleted file mode 100644 index 329fd4f4..00000000 --- a/src/client/test/RestClient_tests.cpp +++ /dev/null @@ -1,468 +0,0 @@ -#include - -#include - -#include - -#include "RestClient.hpp" - -#include -CMRC_DECLARE(assets); - -namespace opencmw::rest_client_test { - -constexpr const char *testCertificate = R"( -R"( -GlobalSign Root CA -================== ------BEGIN CERTIFICATE----- -MIIDdTCCAl2gAwIBAgILBAAAAAABFUtaw5QwDQYJKoZIhvcNAQEFBQAwVzELMAkGA1UEBhMCQkUx -GTAXBgNVBAoTEEdsb2JhbFNpZ24gbnYtc2ExEDAOBgNVBAsTB1Jvb3QgQ0ExGzAZBgNVBAMTEkds -b2JhbFNpZ24gUm9vdCBDQTAeFw05ODA5MDExMjAwMDBaFw0yODAxMjgxMjAwMDBaMFcxCzAJBgNV -BAYTAkJFMRkwFwYDVQQKExBHbG9iYWxTaWduIG52LXNhMRAwDgYDVQQLEwdSb290IENBMRswGQYD -VQQDExJHbG9iYWxTaWduIFJvb3QgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDa -DuaZjc6j40+Kfvvxi4Mla+pIH/EqsLmVEQS98GPR4mdmzxzdzxtIK+6NiY6arymAZavpxy0Sy6sc -THAHoT0KMM0VjU/43dSMUBUc71DuxC73/OlS8pF94G3VNTCOXkNz8kHp1Wrjsok6Vjk4bwY8iGlb -Kk3Fp1S4bInMm/k8yuX9ifUSPJJ4ltbcdG6TRGHRjcdGsnUOhugZitVtbNV4FpWi6cgKOOvyJBNP -c1STE4U6G7weNLWLBYy5d4ux2x8gkasJU26Qzns3dLlwR5EiUWMWea6xrkEmCMgZK9FGqkjWZCrX -gzT/LCrBbBlDSgeF59N89iFo7+ryUp9/k5DPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV -HRMBAf8EBTADAQH/MB0GA1UdDgQWBBRge2YaRQ2XyolQL30EzTSo//z9SzANBgkqhkiG9w0BAQUF -AAOCAQEA1nPnfE920I2/7LqivjTFKDK1fPxsnCwrvQmeU79rXqoRSLblCKOzyj1hTdNGCbM+w6Dj -Y1Ub8rrvrTnhQ7k4o+YviiY776BQVvnGCv04zcQLcFGUl5gE38NflNUVyRRBnMRddWQVDf9VMOyG -j/8N7yy5Y0b2qvzfvGn9LhJIZJrglfCm7ymPAbEVtQwdpf5pLGkkeB6zpxxxYu7KyJesF12KwvhH -hm4qxFYxldBniYUr+WymXUadDKqC5JlR3XC321Y9YeRq4VzW9v493kHMB65jUr9TU/Qr6cf9tveC -X4XSQRjbgbMEHMUfpIBvFSDJ3gyICh3WZlXi/EjJKSZp4A== ------END CERTIFICATE----- -)"; - -class TestServerCertificates { - const cmrc::embedded_filesystem fileSystem = cmrc::assets::get_filesystem(); - const cmrc::file ca_certificate = fileSystem.open("/assets/ca-cert.pem"); - // server-req.pem -> is usually used to request for the CA signature - const cmrc::file server_cert = fileSystem.open("/assets/server-cert.pem"); - const cmrc::file server_key = fileSystem.open("/assets/server-key.pem"); - const cmrc::file client_cert = fileSystem.open("/assets/client-cert.pem"); - const cmrc::file client_key = fileSystem.open("/assets/client-key.pem"); - const cmrc::file pwd = fileSystem.open("/assets/password.txt"); - -public: - const std::string caCertificate = { ca_certificate.begin(), ca_certificate.end() }; - const std::string serverCertificate = { server_cert.begin(), server_cert.end() }; - const std::string serverKey = { server_key.begin(), server_key.end() }; - const std::string clientCertificate = { client_cert.begin(), client_cert.end() }; - const std::string clientKey = { client_key.begin(), client_key.end() }; - const std::string password = { pwd.begin(), pwd.end() }; -}; -inline static const TestServerCertificates testServerCertificates; - -TEST_CASE("Basic Rest Client Constructor and API Tests", "[Client]") { - using namespace opencmw::client; - RestClient client1; - REQUIRE(client1.name() == "RestClient"); - - RestClient client2(std::make_shared>("RestClient", 1, 10'000)); - REQUIRE(client2.name() == "RestClient"); - - RestClient client3("clientName", std::make_shared>("CustomPoolName", 1, 10'000)); - REQUIRE(client3.name() == "clientName"); - REQUIRE(client3.threadPool()->poolName() == "CustomPoolName"); - - RestClient client4("clientName"); - REQUIRE(client4.threadPool()->poolName() == "clientName"); - - RestClient client5("clientName", DefaultContentTypeHeader(MIME::HTML), MinIoThreads(2), MaxIoThreads(5), ClientCertificates(testCertificate)); - REQUIRE(client5.defaultMimeType() == MIME::HTML); - REQUIRE(client5.threadPool()->poolName() == "clientName"); -} - -TEST_CASE("Basic Rest Client Get/Set Test - HTTP", "[Client]") { - using namespace opencmw::client; - RestClient client; - REQUIRE(client.name() == "RestClient"); - - httplib::Server server; - - std::string acceptHeader; - server.Get("/endPoint", [&acceptHeader](const httplib::Request &req, httplib::Response &res) { - std::print("server received request on path '{}' body = '{}'\n", req.path, req.body); - if (req.headers.contains("accept")) { - acceptHeader = req.headers.find("accept")->second; - } else { - FAIL("no accept headers found"); - } - - res.set_content("Hello World!", acceptHeader); - }); - client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); }); - while (!server.is_running()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - REQUIRE(server.is_running()); - - std::atomic done(false); - IoBuffer data; - data.put('A'); - data.put('B'); - data.put('C'); - Command command; - command.command = mdp::Command::Get; - command.topic = URI("http://localhost:8080/endPoint"); - command.data = std::move(data); - command.callback = [&done](const mdp::Message & /*rep*/) { - done.store(true, std::memory_order_release); - done.notify_all(); - }; - client.request(command); - - done.wait(false); - REQUIRE(done.load(std::memory_order_acquire) == true); - REQUIRE(acceptHeader == MIME::JSON.typeName()); - server.stop(); -} - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -TEST_CASE("Multiple Rest Client Get/Set Test - HTTPS", "[Client]") { - using namespace opencmw::client; - RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate)); - REQUIRE(RestClient::CHECK_CERTIFICATES); - RestClient::CHECK_CERTIFICATES = true; // 'false' disables certificate check - REQUIRE(client.name() == "TestSSLClient"); - REQUIRE(client.defaultMimeType() == MIME::JSON); - - // HTTP - X509 *cert = opencmw::client::detail::readServerCertificateFromFile(testServerCertificates.serverCertificate); - EVP_PKEY *pkey = opencmw::client::detail::readServerPrivateKeyFromFile(testServerCertificates.serverKey); - if (const X509_STORE *ca_store = opencmw::client::detail::createCertificateStore(testServerCertificates.caCertificate); !cert || !pkey || !ca_store) { - FAIL(std::format("Failed to load certificate: {}", ERR_error_string(ERR_get_error(), nullptr))); - } - httplib::SSLServer server(cert, pkey); - - std::string acceptHeader; - server.Get("/endPoint", [&acceptHeader](const httplib::Request &req, httplib::Response &res) { - if (req.headers.contains("accept")) { - acceptHeader = req.headers.find("accept")->second; - } else { - FAIL("no accept headers found"); - } - res.set_content("Hello World!", acceptHeader); - }); - client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); }); - while (!server.is_running()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - REQUIRE(server.is_running()); - - std::array, 4> dones; - dones[0] = false; - dones[1] = false; - dones[2] = false; - dones[3] = false; - std::atomic counter{ 0 }; - auto makeCommand = [&]() { - IoBuffer data; - data.put('A'); - data.put('B'); - data.put('C'); - data.put(0); - - Command command; - command.command = mdp::Command::Get; - command.topic = URI("https://localhost:8080/endPoint"); - command.data = std::move(data); - command.callback = [&dones, &counter](const mdp::Message &/*rep*/) { - std::size_t currentCounter = counter.fetch_add(1, std::memory_order_relaxed); - dones[currentCounter].store(true, std::memory_order_release); - // Assuming you have access to 'done' variable, uncomment the following line - dones[currentCounter].notify_all(); - }; - client.request(command); - }; - for (int i = 0; i < 4; i++) - makeCommand(); - - for (auto &done : dones) { - done.wait(false); - } - REQUIRE(std::ranges::all_of(dones, [](auto &done) { return done.load(std::memory_order_acquire); })); - REQUIRE(acceptHeader == MIME::JSON.typeName()); - server.stop(); -} - -TEST_CASE("Basic Rest Client Get/Set Test - HTTPS", "[Client]") { - using namespace opencmw::client; - RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate)); - REQUIRE(RestClient::CHECK_CERTIFICATES); - RestClient::CHECK_CERTIFICATES = true; // 'false' disables certificate check - REQUIRE(client.name() == "TestSSLClient"); - REQUIRE(client.defaultMimeType() == MIME::JSON); - - // HTTP - X509 *cert = opencmw::client::detail::readServerCertificateFromFile(testServerCertificates.serverCertificate); - EVP_PKEY *pkey = opencmw::client::detail::readServerPrivateKeyFromFile(testServerCertificates.serverKey); - if (const X509_STORE *ca_store = opencmw::client::detail::createCertificateStore(testServerCertificates.caCertificate); !cert || !pkey || !ca_store) { - FAIL(std::format("Failed to load certificate: {}", ERR_error_string(ERR_get_error(), nullptr))); - } - httplib::SSLServer server(cert, pkey); - - std::string acceptHeader; - server.Get("/endPoint", [&acceptHeader](const httplib::Request &req, httplib::Response &res) { - if (req.headers.contains("accept")) { - acceptHeader = req.headers.find("accept")->second; - } else { - FAIL("no accept headers found"); - } - res.set_content("Hello World!", acceptHeader); - }); - client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); }); - while (!server.is_running()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - REQUIRE(server.is_running()); - - std::atomic done(false); - IoBuffer data; - data.put('A'); - data.put('B'); - data.put('C'); - data.put(0); - Command command; - command.command = mdp::Command::Get; - command.topic = URI("https://localhost:8080/endPoint"); - command.data = std::move(data); - command.callback = [&done](const mdp::Message & /*rep*/) { - done.store(true, std::memory_order_release); - done.notify_all(); - }; - client.request(command); - - done.wait(false); - REQUIRE(done.load(std::memory_order_acquire) == true); - REQUIRE(acceptHeader == MIME::JSON.typeName()); - server.stop(); -} -#endif - -namespace detail { -class EventDispatcher { - std::mutex _mutex; - std::condition_variable _condition; - std::atomic _id{ 0 }; - std::atomic _cid{ -1 }; - std::string _message; - -public: - void wait_event(httplib::DataSink &sink) { - std::unique_lock lk(_mutex); - int id = _id; - _condition.wait(lk, [&id, this] { return _cid == id; }); - if (sink.is_writable()) { - sink.write(_message.data(), _message.size()); - } - } - - void send_event(const std::string_view &message) { - std::scoped_lock lk(_mutex); - _cid = _id++; - _message = message; - _condition.notify_all(); - } -}; -} // namespace detail - -TEST_CASE("Basic Rest Client Subscribe/Unsubscribe Test", "[Client]") { - using namespace opencmw::client; - - std::atomic updateCounter{ 0 }; - detail::EventDispatcher eventDispatcher; - httplib::Server server; - server.Get("/event", [&eventDispatcher, &updateCounter](const httplib::Request &req, httplib::Response &res) { - auto acceptType = req.headers.find("accept"); - if (acceptType == req.headers.end() || MIME::EVENT_STREAM.typeName() != acceptType->second) { // non-SSE request -> return default response -#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16)) - res.set_content(std::format("update counter = {}", updateCounter.load()), MIME::TEXT); -#else - res.set_content(std::format("update counter = {}", updateCounter.load()), std::string(MIME::TEXT.typeName())); -#endif - return; - } else { - std::print("server received SSE request on path '{}' body = '{}'\n", req.path, req.body); -#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16)) - res.set_chunked_content_provider(MIME::EVENT_STREAM, [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) { -#else - res.set_chunked_content_provider(std::string(MIME::EVENT_STREAM.typeName()), [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) { -#endif - eventDispatcher.wait_event(sink); - return true; - }); - } - }); - server.Get("/endPoint", [](const httplib::Request &req, httplib::Response &res) { - std::print("server received request on path '{}' body = '{}'\n", req.path, req.body); - res.set_content("Hello World!", "text/plain"); - }); - - RestClient client; - client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); }); - while (!server.is_running()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - REQUIRE(server.is_running()); - - std::atomic receivedRegular(0); - std::atomic receivedError(0); - IoBuffer data; - data.put('A'); - data.put('B'); - data.put('C'); - data.put(0); - - Command command; - command.command = mdp::Command::Subscribe; - command.topic = URI("http://localhost:8080/event"); - command.data = std::move(data); - command.callback = [&receivedRegular, &receivedError](const mdp::Message &rep) { - std::print("SSE client received reply = '{}' - body size: '{}'\n", rep.data.asString(), rep.data.size()); - if (rep.error.size() == 0) { - receivedRegular.fetch_add(1, std::memory_order_relaxed); - } else { - receivedError.fetch_add(1, std::memory_order_relaxed); - } - receivedRegular.notify_all(); - receivedError.notify_all(); - }; - - client.request(command); - - std::cout << "client request launched" << std::endl; - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - eventDispatcher.send_event("test-event meta data"); - std::jthread dispatcher([&updateCounter, &eventDispatcher] { - while (updateCounter < 5) { - std::this_thread::sleep_for(std::chrono::milliseconds(20)); - eventDispatcher.send_event(std::format("test-event {}", updateCounter++)); - } - }); - dispatcher.join(); - - while (receivedRegular.load(std::memory_order_relaxed) < 5) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - std::cout << "done waiting" << std::endl; - REQUIRE(receivedRegular.load(std::memory_order_acquire) >= 5); - - command.command = mdp::Command::Unsubscribe; - client.request(command); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - std::cout << "done Unsubscribe" << std::endl; - - client.stop(); - server.stop(); - eventDispatcher.send_event(std::format("test-event {}", updateCounter++)); - std::cout << "server stopped" << std::endl; -} - -TEST_CASE("Basic Rest Client Subscribe/Unsubscribe Test HTTPS", "[Client]") { - // HTTP - X509 *cert = opencmw::client::detail::readServerCertificateFromFile(testServerCertificates.serverCertificate); - EVP_PKEY *pkey = opencmw::client::detail::readServerPrivateKeyFromFile(testServerCertificates.serverKey); - if (const X509_STORE *ca_store = opencmw::client::detail::createCertificateStore(testServerCertificates.caCertificate); !cert || !pkey || !ca_store) { - FAIL(std::format("Failed to load certificate: {}", ERR_error_string(ERR_get_error(), nullptr))); - } - using namespace opencmw::client; - - std::atomic updateCounter{ 0 }; - detail::EventDispatcher eventDispatcher; - httplib::SSLServer server(cert, pkey); - server.Get("/event", [&eventDispatcher, &updateCounter](const httplib::Request &req, httplib::Response &res) { - DEBUG_LOG("Server received request"); - auto acceptType = req.headers.find("accept"); - if (acceptType == req.headers.end() || MIME::EVENT_STREAM.typeName() != acceptType->second) { // non-SSE request -> return default response -#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16)) - res.set_content(std::format("update counter = {}", updateCounter.load()), MIME::TEXT); -#else - res.set_content(std::format("update counter = {}", updateCounter.load()), std::string(MIME::TEXT.typeName())); -#endif - return; - } else { - std::print("server received SSE request on path '{}' body = '{}'\n", req.path, req.body); -#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16)) - res.set_chunked_content_provider(MIME::EVENT_STREAM, [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) { -#else - res.set_chunked_content_provider(std::string(MIME::EVENT_STREAM.typeName()), [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) { -#endif - eventDispatcher.wait_event(sink); - return true; - }); - } - }); - server.Get("/endPoint", [](const httplib::Request &req, httplib::Response &res) { - std::print("server received request on path '{}' body = '{}'\n", req.path, req.body); - res.set_content("Hello World!", "text/plain"); - }); - - RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate)); - - client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); }); - while (!server.is_running()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - REQUIRE(server.is_running()); - REQUIRE(RestClient::CHECK_CERTIFICATES); - RestClient::CHECK_CERTIFICATES = true; // 'false' disables certificate check - REQUIRE(client.name() == "TestSSLClient"); - REQUIRE(client.defaultMimeType() == MIME::JSON); - - std::atomic receivedRegular(0); - std::atomic receivedError(0); - IoBuffer data; - data.put('A'); - data.put('B'); - data.put('C'); - data.put(0); - - Command command; - command.command = mdp::Command::Subscribe; - command.topic = URI("https://localhost:8080/event"); - command.data = std::move(data); - command.callback = [&receivedRegular, &receivedError](const mdp::Message &rep) { - std::print("SSE client received reply = '{}' - body size: '{}'\n", rep.data.asString(), rep.data.size()); - if (rep.error.size() == 0) { - receivedRegular.fetch_add(1, std::memory_order_relaxed); - } else { - receivedError.fetch_add(1, std::memory_order_relaxed); - } - receivedRegular.notify_all(); - receivedError.notify_all(); - }; - - client.request(command); - - std::cout << "client request launched" << std::endl; - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - eventDispatcher.send_event("test-event meta data"); - std::jthread dispatcher([&updateCounter, &eventDispatcher] { - while (updateCounter < 5) { - std::this_thread::sleep_for(std::chrono::milliseconds(20)); - eventDispatcher.send_event(std::format("test-event {}", updateCounter++)); - } - }); - dispatcher.join(); - - while (receivedRegular.load(std::memory_order_relaxed) < 5) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - std::cout << "done waiting" << std::endl; - REQUIRE(receivedRegular.load(std::memory_order_acquire) >= 5); - - command.command = mdp::Command::Unsubscribe; - client.request(command); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - std::cout << "done Unsubscribe" << std::endl; - - client.stop(); - server.stop(); - eventDispatcher.send_event(std::format("test-event {}", updateCounter++)); - std::cout << "server stopped" << std::endl; -} - -} // namespace opencmw::rest_client_test diff --git a/src/client/test/nghttp2_tests.cpp b/src/client/test/nghttp2_tests.cpp new file mode 100644 index 00000000..6cfc2b97 --- /dev/null +++ b/src/client/test/nghttp2_tests.cpp @@ -0,0 +1,520 @@ +#include +#include + +#include "zmq.h" +#include + +#include +CMRC_DECLARE(assets); + +#include + +#include +#include +#include + +constexpr const char *testCertificate = R"( + R"( + GlobalSign Root CA + ================== + -----BEGIN CERTIFICATE----- + MIIDdTCCAl2gAwIBAgILBAAAAAABFUtaw5QwDQYJKoZIhvcNAQEFBQAwVzELMAkGA1UEBhMCQkUx + GTAXBgNVBAoTEEdsb2JhbFNpZ24gbnYtc2ExEDAOBgNVBAsTB1Jvb3QgQ0ExGzAZBgNVBAMTEkds + b2JhbFNpZ24gUm9vdCBDQTAeFw05ODA5MDExMjAwMDBaFw0yODAxMjgxMjAwMDBaMFcxCzAJBgNV + BAYTAkJFMRkwFwYDVQQKExBHbG9iYWxTaWduIG52LXNhMRAwDgYDVQQLEwdSb290IENBMRswGQYD + VQQDExJHbG9iYWxTaWduIFJvb3QgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDa + DuaZjc6j40+Kfvvxi4Mla+pIH/EqsLmVEQS98GPR4mdmzxzdzxtIK+6NiY6arymAZavpxy0Sy6sc + THAHoT0KMM0VjU/43dSMUBUc71DuxC73/OlS8pF94G3VNTCOXkNz8kHp1Wrjsok6Vjk4bwY8iGlb + Kk3Fp1S4bInMm/k8yuX9ifUSPJJ4ltbcdG6TRGHRjcdGsnUOhugZitVtbNV4FpWi6cgKOOvyJBNP + c1STE4U6G7weNLWLBYy5d4ux2x8gkasJU26Qzns3dLlwR5EiUWMWea6xrkEmCMgZK9FGqkjWZCrX + gzT/LCrBbBlDSgeF59N89iFo7+ryUp9/k5DPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV + HRMBAf8EBTADAQH/MB0GA1UdDgQWBBRge2YaRQ2XyolQL30EzTSo//z9SzANBgkqhkiG9w0BAQUF + AAOCAQEA1nPnfE920I2/7LqivjTFKDK1fPxsnCwrvQmeU79rXqoRSLblCKOzyj1hTdNGCbM+w6Dj + Y1Ub8rrvrTnhQ7k4o+YviiY776BQVvnGCv04zcQLcFGUl5gE38NflNUVyRRBnMRddWQVDf9VMOyG + j/8N7yy5Y0b2qvzfvGn9LhJIZJrglfCm7ymPAbEVtQwdpf5pLGkkeB6zpxxxYu7KyJesF12KwvhH + hm4qxFYxldBniYUr+WymXUadDKqC5JlR3XC321Y9YeRq4VzW9v493kHMB65jUr9TU/Qr6cf9tveC + X4XSQRjbgbMEHMUfpIBvFSDJ3gyICh3WZlXi/EjJKSZp4A== + -----END CERTIFICATE----- + )"; + +class TestServerCertificates { + const cmrc::embedded_filesystem fileSystem = cmrc::assets::get_filesystem(); + const cmrc::file ca_certificate = fileSystem.open("/assets/ca-cert.pem"); + // server-req.pem -> is usually used to request for the CA signature + const cmrc::file server_cert = fileSystem.open("/assets/server-cert.pem"); + const cmrc::file server_key = fileSystem.open("/assets/server-key.pem"); + const cmrc::file client_cert = fileSystem.open("/assets/client-cert.pem"); + const cmrc::file client_key = fileSystem.open("/assets/client-key.pem"); + const cmrc::file pwd = fileSystem.open("/assets/password.txt"); + +public: + const std::string caCertificate = { ca_certificate.begin(), ca_certificate.end() }; + const std::string serverCertificate = { server_cert.begin(), server_cert.end() }; + const std::string serverKey = { server_key.begin(), server_key.end() }; + const std::string clientCertificate = { client_cert.begin(), client_cert.end() }; + const std::string clientKey = { client_key.begin(), client_key.end() }; + const std::string password = { pwd.begin(), pwd.end() }; +}; +inline static const TestServerCertificates testServerCertificates; + +using namespace opencmw; +using namespace opencmw::majordomo::detail::rest; +using opencmw::URI; + +constexpr uint16_t kServerPort = 33339; + +void ensureMessageReceived(RestServer &server, std::stop_token stopToken, std::deque &messages, std::chrono::milliseconds timeout = std::chrono::seconds(5)) { + const auto start = std::chrono::system_clock::now(); + while (!stopToken.stop_requested() && std::chrono::system_clock::now() - start < timeout) { + if (!messages.empty()) { + return; + } + + std::vector pollerItems; + server.populatePollerItems(pollerItems); + + const int rc = zmq_poll(pollerItems.data(), static_cast(pollerItems.size()), 500); + + if (rc == 0) { + continue; + } + + for (const auto &item : pollerItems) { + auto ms = server.processReadWrite(item.fd, item.revents & ZMQ_POLLIN, item.revents & ZMQ_POLLOUT); + messages.insert(messages.end(), ms.begin(), ms.end()); + } + } +}; + +bool waitFor(std::atomic &responseCount, int expected, std::chrono::milliseconds timeout = std::chrono::seconds(5)) { + const auto start = std::chrono::system_clock::now(); + while (responseCount.load() < expected && std::chrono::system_clock::now() - start < timeout) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + if (responseCount.load() < expected) { + FAIL(std::format("Expected {} responses, but got only {}\n", expected, responseCount.load())); + } + return responseCount.load() == expected; +} + +struct Stopper { + std::stop_source source; + explicit Stopper(std::stop_source s) + : source(s) {} + ~Stopper() { + source.request_stop(); + } +}; + +static std::string normalize(URI<> uri) { + return opencmw::mdp::Topic::fromMdpTopic(uri).toZmqTopic(); +} + +TEST_CASE("Basic Client Constructor and API Tests", "[http2]") { + using namespace opencmw::client; + + RestClient client1; + REQUIRE(client1.defaultMimeType() == MIME::JSON); + REQUIRE(client1.verifySslPeers() == true); + + auto client2 = RestClient(DefaultContentTypeHeader(MIME::HTML), ClientCertificates(testCertificate)); + REQUIRE(client2.defaultMimeType() == MIME::HTML); + REQUIRE(client2.verifySslPeers() == true); + + RestClient client3(DefaultContentTypeHeader(MIME::BINARY), VerifyServerCertificates(false)); + REQUIRE(client3.defaultMimeType() == MIME::BINARY); + REQUIRE(client3.verifySslPeers() == false); +} + +TEST_CASE("GET HTTP", "[http2]") { + using namespace opencmw::client; + + auto serverThread = std::jthread([](std::stop_token stopToken) { + RestServer server; + REQUIRE(server.bind(kServerPort, majordomo::rest::Http2)); + + std::deque messages; + ensureMessageReceived(server, stopToken, messages); + REQUIRE(messages.size() >= 1); + const auto req0 = std::move(messages[0]); + messages.pop_front(); + REQUIRE(req0.command == mdp::Command::Get); + REQUIRE(req0.topic.path() == "/sayhello"); + REQUIRE(req0.error.empty()); + REQUIRE(req0.data.empty()); + + Message reply0; + reply0.command = mdp::Command::Final; + reply0.clientRequestID = req0.clientRequestID; + reply0.topic = URI<>("/sayhello"); + reply0.data = opencmw::IoBuffer("Hello, World!"); + server.handleResponse(std::move(reply0)); + + ensureMessageReceived(server, stopToken, messages); + }); + + // Client using plain http + RestClient http; + Stopper stopper(serverThread.get_stop_source()); + std::this_thread::sleep_for(std::chrono::milliseconds(300)); // give the server some time to start listening + + std::atomic responseCount = 0; + + client::Command req0; + req0.command = mdp::Command::Get; + req0.clientRequestID = opencmw::IoBuffer("0"); + req0.topic = URI<>(std::format("http://localhost:{}/sayhello?client=http", kServerPort)); + req0.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "Hello, World!"); + REQUIRE(msg.clientRequestID.asString() == "0"); + REQUIRE(msg.topic.path() == "/sayhello"); + responseCount++; + }; + http.request(std::move(req0)); + + // Client that verifies the server's certificate and trusts its CA + RestClient https(ClientCertificates(testServerCertificates.caCertificate)); + client::Command req1; + req1.command = mdp::Command::Get; + req1.topic = URI<>(std::format("https://localhost:{}/sayhello?client=https", kServerPort)); + req1.clientRequestID = opencmw::IoBuffer("0"); + req1.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.protocolName == "https"); + REQUIRE(msg.error.contains("Could not connect to endpoint")); + REQUIRE(msg.data.asString() == ""); + REQUIRE(msg.clientRequestID.asString() == "0"); + REQUIRE(msg.topic.path() == "/sayhello"); + responseCount++; + }; + https.request(std::move(req1)); + + REQUIRE(waitFor(responseCount, 2)); +} + +TEST_CASE("HTTPS", "[http2]") { + using namespace opencmw::client; + + auto serverThread = std::jthread([&](std::stop_token stopToken) { + auto server = RestServer::sslWithBuffers(testServerCertificates.serverCertificate, testServerCertificates.serverKey); + if (!server) { + FAIL(std::format("Failed to create server: {}", server.error())); + return; + } + + auto bound = server->bind(kServerPort, majordomo::rest::Http2); + if (!bound) { + FAIL(std::format("Failed to bind server: {}", bound.error())); + return; + } + + std::deque messages; + for (int i = 0; i < 2; i++) { + ensureMessageReceived(server.value(), stopToken, messages); + REQUIRE(messages.size() >= 1); + const auto req0 = std::move(messages[0]); + messages.pop_front(); + REQUIRE(req0.command == mdp::Command::Get); + REQUIRE(req0.topic.path() == "/sayhello"); + REQUIRE(req0.error.empty()); + REQUIRE(req0.data.empty()); + + Message reply0; + reply0.command = mdp::Command::Final; + reply0.clientRequestID = req0.clientRequestID; + reply0.topic = URI<>("/sayhello"); + reply0.data = opencmw::IoBuffer("Hello, World!"); + server->handleResponse(std::move(reply0)); + } + + const auto start = std::chrono::system_clock::now(); + while (!stopToken.stop_requested() && std::chrono::system_clock::now() - start < std::chrono::seconds(5)) { + ensureMessageReceived(server.value(), stopToken, messages); + } + }); + + Stopper stopper(serverThread.get_stop_source()); + std::this_thread::sleep_for(std::chrono::milliseconds(300)); // give the server some time to start listening + + std::atomic responseCount = 0; + + // Client that doesn't verify the server's certificate + RestClient doesntCare(VerifyServerCertificates(false)); + client::Command req0; + req0.command = mdp::Command::Get; + req0.topic = URI<>(std::format("https://localhost:{}/sayhello?client=doesntCare", kServerPort)); + req0.clientRequestID = opencmw::IoBuffer("0"); + req0.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.protocolName == "https"); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "Hello, World!"); + REQUIRE(msg.clientRequestID.asString() == "0"); + REQUIRE(msg.topic == URI<>("/sayhello")); + responseCount++; + }; + doesntCare.request(std::move(req0)); + + // Client that verifies the server's certificate and trusts its CA + RestClient trusting(ClientCertificates(testServerCertificates.caCertificate)); + client::Command req1; + req1.command = mdp::Command::Get; + req1.topic = URI<>(std::format("https://localhost:{}/sayhello?client=trusting", kServerPort)); + req1.clientRequestID = opencmw::IoBuffer("0"); + req1.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.protocolName == "https"); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "Hello, World!"); + REQUIRE(msg.clientRequestID.asString() == "0"); + REQUIRE(msg.topic.path() == "/sayhello"); + responseCount++; + }; + trusting.request(std::move(req1)); + + // Client that verifies the server's certificate but doesn't trust its CA + RestClient notTrusting; + client::Command req2; + req2.command = mdp::Command::Get; + req2.topic = URI<>(std::format("https://localhost:{}/sayhello?client=notTrusting", kServerPort)); + req2.clientRequestID = opencmw::IoBuffer("0"); + req2.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.protocolName == "https"); + REQUIRE(msg.error.contains("Could not connect to endpoint")); + REQUIRE(msg.clientRequestID.asString() == "0"); + REQUIRE(msg.topic.path() == "/sayhello"); + responseCount++; + }; + notTrusting.request(std::move(req2)); + + REQUIRE(waitFor(responseCount, 3)); +} + +TEST_CASE("GET/SET", "[http2]") { + auto serverThread = std::jthread([](std::stop_token stopToken) { + RestServer server; + REQUIRE(server.bind(kServerPort, majordomo::rest::Http2)); + + std::deque messages; + ensureMessageReceived(server, stopToken, messages); + REQUIRE(messages.size() >= 1); + const auto req0 = std::move(messages[0]); + messages.pop_front(); + REQUIRE(req0.command == mdp::Command::Get); + REQUIRE(normalize(req0.topic) == normalize(URI<>("/foo?contentType=application%2Fjson¶m1=1¶m2=foo%2Fbar"))); + REQUIRE(req0.error.empty()); + REQUIRE(req0.data.empty()); + + Message reply0; + reply0.command = mdp::Command::Final; + reply0.clientRequestID = req0.clientRequestID; + reply0.topic = URI<>("/foo?param1=1¶m2=foo%2fbar"); + reply0.data = opencmw::IoBuffer("Hello, World!"); + server.handleResponse(std::move(reply0)); + + ensureMessageReceived(server, stopToken, messages); + REQUIRE(messages.size() >= 1); + const auto req1 = std::move(messages[0]); + messages.pop_front(); + REQUIRE(req1.command == mdp::Command::Get); + REQUIRE(normalize(req1.topic) == normalize(URI<>("/bar?contentType=application%2Fjson¶m1=1¶m2=2"))); + REQUIRE(req1.error.empty()); + REQUIRE(req1.data.empty()); + + Message reply1; + reply1.command = mdp::Command::Final; + reply1.clientRequestID = req1.clientRequestID; + reply1.topic = URI<>("/bar?param1=1¶m2=2"); + reply1.error = "'bar' not found"; + server.handleResponse(std::move(reply1)); + + ensureMessageReceived(server, stopToken, messages); + REQUIRE(messages.size() >= 1); + const auto req2 = std::move(messages[0]); + messages.pop_front(); + REQUIRE(req2.command == mdp::Command::Set); + REQUIRE(normalize(req2.topic) == normalize(URI<>("/setexample?contentType=application%2Fjson"))); + REQUIRE(req2.error.empty()); + REQUIRE(req2.data.asString() == "Some data with\ttabs\nand newlines\x01"); + + Message reply2; + reply2.command = mdp::Command::Final; + reply2.clientRequestID = req2.clientRequestID; + reply2.topic = URI<>("/setexample"); + reply2.data = opencmw::IoBuffer("value set"); + server.handleResponse(std::move(reply2)); + + ensureMessageReceived(server, stopToken, messages); // makes sure responses are sent + }); + + Stopper stopper(serverThread.get_stop_source()); + std::this_thread::sleep_for(std::chrono::milliseconds(300)); // give the server some time to start listening + + std::atomic responseCount = 0; + + client::RestClient client(client::VerifyServerCertificates(false)); + + opencmw::client::Command req0; + req0.command = mdp::Command::Get; + req0.clientRequestID = opencmw::IoBuffer("0"); + req0.topic = URI<>(std::format("http://localhost:{}/foo?param1=1¶m2=foo%2fbar", kServerPort)); + req0.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.protocolName == "http"); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "Hello, World!"); + REQUIRE(msg.clientRequestID.asString() == "0"); + REQUIRE(msg.topic == URI<>("/foo?param1=1¶m2=foo%2fbar")); + responseCount++; + }; + client.request(std::move(req0)); + + client::Command req1; + req1.command = mdp::Command::Get; + req1.clientRequestID = opencmw::IoBuffer("1"); + req1.topic = URI<>(std::format("http://localhost:{}/bar?param1=1¶m2=2", kServerPort)); + req1.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.protocolName == "http"); + REQUIRE(msg.error == "'bar' not found"); + REQUIRE(msg.data.asString() == ""); + REQUIRE(msg.clientRequestID.asString() == "1"); + REQUIRE(msg.topic == URI<>("/bar?param1=1¶m2=2")); + responseCount++; + }; + client.request(std::move(req1)); + + client::Command req2; + req2.command = mdp::Command::Set; + req2.clientRequestID = opencmw::IoBuffer("2"); + req2.topic = URI<>(std::format("http://localhost:{}/setexample", kServerPort)); + req2.data = opencmw::IoBuffer("Some data with\ttabs\nand newlines\x01"); + req2.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.protocolName == "http"); + REQUIRE(msg.clientRequestID.asString() == "2"); + REQUIRE(msg.topic == URI<>("/setexample")); + REQUIRE(msg.data.asString() == "value set"); + responseCount++; + }; + client.request(std::move(req2)); + + waitFor(responseCount, 3); +} + +TEST_CASE("Long polling example", "[http2]") { + constexpr int kFooMessages = 50; + + auto brokerThread = std::jthread([](std::stop_token stopToken) { + RestServer server; + REQUIRE(server.bind(kServerPort, majordomo::rest::Http2)); + + const auto topic = URI<>("/foo?param1=1¶m2=foo%2Fbar"); + + std::deque messages; + ensureMessageReceived(server, stopToken, messages); + REQUIRE(messages.size() >= 1); + const auto req0 = std::move(messages[0]); + messages.pop_front(); + REQUIRE(req0.command == mdp::Command::Subscribe); + REQUIRE(normalize(req0.topic) == normalize(topic)); + REQUIRE(req0.error.empty()); + REQUIRE(req0.data.empty()); + + for (int i = 0; i < kFooMessages; ++i) { + Message notify; + notify.command = mdp::Command::Notify; + notify.serviceName = "/foo"; + notify.topic = topic; + auto data = std::to_string(i); + notify.data = opencmw::IoBuffer(data.data(), data.size()); + server.handleNotification(opencmw::mdp::Topic::fromMdpTopic(topic), std::move(notify)); + } + + ensureMessageReceived(server, stopToken, messages, std::chrono::seconds(10)); + + // sync with the client to ensure the client unsubscribed before we send more notifications + REQUIRE(messages.size() >= 1); + const auto req1 = std::move(messages[0]); + messages.pop_front(); + REQUIRE(req1.command == mdp::Command::Get); + Message reply1; + reply1.command = mdp::Command::Final; + reply1.clientRequestID = req1.clientRequestID; + reply1.topic = topic; + reply1.data = opencmw::IoBuffer("Hello"); + server.handleResponse(std::move(reply1)); + + // send notifications that should not be received + for (int i = 0; i < kFooMessages; ++i) { + Message notify; + notify.command = mdp::Command::Notify; + notify.serviceName = "/foo"; + notify.topic = topic; + auto data = std::to_string(kFooMessages + i); + notify.data = opencmw::IoBuffer(data.data(), data.size()); + server.handleNotification(opencmw::mdp::Topic::fromMdpTopic(topic), std::move(notify)); + } + + ensureMessageReceived(server, stopToken, messages, std::chrono::seconds(10)); // makes sure messages are processed + }); + + Stopper stopper(brokerThread.get_stop_source()); + std::this_thread::sleep_for(std::chrono::milliseconds(300)); // give the server some time to start listening + + { + std::atomic responseCount = 0; + + client::RestClient client; + opencmw::client::Command sub; + sub.command = mdp::Command::Subscribe; + sub.clientRequestID = opencmw::IoBuffer("0"); + sub.topic = URI<>(std::format("http://localhost:{}/foo?param1=1¶m2=foo%2fbar", kServerPort)); + sub.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Notify); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == std::to_string(responseCount)); + REQUIRE(msg.protocolName == "http"); + REQUIRE(normalize(msg.topic) == normalize(URI<>("/foo?param1=1¶m2=foo%2Fbar"))); + responseCount++; + }; + client.request(std::move(sub)); + waitFor(responseCount, kFooMessages); + responseCount = 0; + + // unsubscribe + client::Command unsub; + unsub.command = mdp::Command::Unsubscribe; + unsub.clientRequestID = opencmw::IoBuffer("0"); + unsub.topic = URI<>(std::format("http://localhost:{}/foo?param1=1¶m2=foo%2fbar", kServerPort)); + unsub.callback = [](const mdp::Message &) { + FAIL("Unsubscribe should not receive a message"); + }; + client.request(std::move(unsub)); + + // Send a GET to sync with server + client::Command get; + get.command = mdp::Command::Get; + get.clientRequestID = opencmw::IoBuffer("1"); + get.topic = URI<>(std::format("http://localhost:{}/foo?param1=1¶m2=foo%2fbar", kServerPort)); + get.callback = [&responseCount](const mdp::Message &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.protocolName == "http"); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "Hello"); + REQUIRE(normalize(msg.topic) == normalize(URI<>("/foo?param1=1¶m2=foo%2Fbar"))); + responseCount++; + }; + client.request(std::move(get)); + + waitFor(responseCount, 1); + responseCount -= 1; + + // wait some time to make sure no more messages are received + std::this_thread::sleep_for(std::chrono::milliseconds(250)); + REQUIRE(responseCount == 0); + } +} diff --git a/src/client/test/pre.js b/src/client/test/pre.js index 370808c9..8b633e3b 100644 --- a/src/client/test/pre.js +++ b/src/client/test/pre.js @@ -38,8 +38,8 @@ if (typeof XMLHttpRequest === 'undefined') { // Set the additional headers //res.setHeader('Access-Control-Allow-Origin', '127.0.0.1,localhost'); res.setHeader('Keep-Alive', 'timeout=5, max=5'); - res.setHeader('X-OPENCMW-SERVICE-NAME', 'dns'); - res.setHeader('X-OPENCMW-TOPIC', '//s?signal_type=&signal_unit=&signal_name=&service_type=&service_name=&port=-1&hostname=&protocol=&signal_rate=nan&contextType=application%2Foctet-stream'); + res.setHeader('x-opencmw-service-name', 'dns'); + res.setHeader('x-opencmw-topic', '//s?signal_type=&signal_unit=&signal_name=&service_type=&service_name=&port=-1&hostname=&protocol=&signal_rate=nan&contextType=application%2Foctet-stream'); // Send the binary response const binaryData = Buffer.from('ffffffff040000005961530001000000602f0da036000000e4010000250000006f70656e636d773a3a736572766963653a3a646e733a3a466c6174456e7472794c69737400ff988e0ac51a000000150000000900000070726f746f636f6c00010000000100000001000000050000006874747000ff335c21ee1a0000002100000009000000686f73746e616d650001000000010000000100000011000000746573742e6578616d706c652e636f6d006881983400160000001000000005000000706f72740001000000010000000100000039050000ffd55573151e000000150000000d000000736572766963655f6e616d6500010000000100000001000000050000007465737400ff846a76151e000000110000000d000000736572766963655f74797065000100000001000000010000000100000000ffc2ad1b281d000000160000000c0000007369676e616c5f6e616d650001000000010000000100000005000000746573743100ffbb0c1f281d000000110000000c0000007369676e616c5f756e69740001000000010000000100000001000000006a17801d281d000000100000000c0000007369676e616c5f726174650001000000010000000100000000007a44ff71c21e281d000000110000000c0000007369676e616c5f74797065000100000001000000010000000100000000fe602f0da03600000000000000250000006f70656e636d773a3a736572766963653a3a646e733a3a466c6174456e7472794c69737400', 'hex'); diff --git a/src/core/include/LoadTest.hpp b/src/core/include/LoadTest.hpp new file mode 100644 index 00000000..b886193e --- /dev/null +++ b/src/core/include/LoadTest.hpp @@ -0,0 +1,37 @@ +#ifndef OPENCMW_CPP_LOAD_TEST_HPP +#define OPENCMW_CPP_LOAD_TEST_HPP + +#include + +#include + +#include +#include + +namespace opencmw::load_test { + +struct Context { + std::string topic; + std::int64_t intervalMs = 1000; // must be multiple of 10 (enforce?) + std::int64_t payloadSize = 100; + std::int64_t nUpdates = -1; // -1 means infinite + std::int64_t initialDelayMs = 0; + opencmw::MIME::MimeType contentType = opencmw::MIME::BINARY; +}; + +struct Payload { + std::int64_t index; + std::string data; + std::int64_t timestampNs = 0; +}; + +inline std::chrono::nanoseconds timestamp() { + return std::chrono::duration_cast(std::chrono::high_resolution_clock::now().time_since_epoch()); +} + +} // namespace opencmw::load_test + +ENABLE_REFLECTION_FOR(opencmw::load_test::Context, topic, intervalMs, payloadSize, nUpdates, initialDelayMs, contentType) +ENABLE_REFLECTION_FOR(opencmw::load_test::Payload, index, data, timestampNs) + +#endif diff --git a/src/core/include/Topic.hpp b/src/core/include/Topic.hpp index 7ce1392b..6a851cb0 100644 --- a/src/core/include/Topic.hpp +++ b/src/core/include/Topic.hpp @@ -68,6 +68,7 @@ struct Topic { Params _params; public: + Topic() = default; Topic(const Topic &other) = default; Topic &operator=(const Topic &) = default; Topic(Topic &&) noexcept = default; @@ -168,7 +169,7 @@ struct Topic { , _params(std::move(params)) { if (serviceOrServiceAndQuery.find("?") != std::string::npos) { if (!_params.empty()) { - throw std::invalid_argument(std::format("Parameters are not empty ({}), and there are more in the service string ({})\n", _params, serviceOrServiceAndQuery)); + throw std::invalid_argument(std::format("Parameters are not empty ({}), and there are more in the service string ({})", _params, serviceOrServiceAndQuery)); } const auto parsed = opencmw::URI(std::string(serviceOrServiceAndQuery)); const auto &queryMap = parsed.queryParamMap(); @@ -176,7 +177,7 @@ struct Topic { } if (!isValidServiceName(_service)) { - throw std::invalid_argument(std::format("Invalid service name '{}'\n", _service)); + throw std::invalid_argument(std::format("Invalid service name '{}'", _service)); } } }; diff --git a/src/majordomo/CMakeLists.txt b/src/majordomo/CMakeLists.txt index 61e3a8c7..912ca06e 100644 --- a/src/majordomo/CMakeLists.txt +++ b/src/majordomo/CMakeLists.txt @@ -1,31 +1,35 @@ -# setup header only library add_library(majordomo INTERFACE - include/majordomo/Broker.hpp - include/majordomo/MockClient.hpp - include/majordomo/Constants.hpp - include/majordomo/Cryptography.hpp - include/majordomo/Rbac.hpp - include/majordomo/RestBackend.hpp - include/majordomo/Settings.hpp - include/majordomo/SubscriptionMatcher.hpp - include/majordomo/Worker.hpp + include/majordomo/Broker.hpp + include/majordomo/MockClient.hpp + include/majordomo/Constants.hpp + include/majordomo/Cryptography.hpp + include/majordomo/NgTcp2Util.hpp + include/majordomo/Rbac.hpp + include/majordomo/Rest.hpp + include/majordomo/RestServer.hpp + include/majordomo/Settings.hpp + include/majordomo/SubscriptionMatcher.hpp + include/majordomo/Worker.hpp + include/majordomo/TlsServerSession_Ossl.hpp + include/majordomo/TlsSessionBase_Ossl.hpp ) + target_include_directories(majordomo INTERFACE $ $) + target_link_libraries(majordomo - INTERFACE - core - serialiser - zmq - httplib::httplib - #OpenSSL::SSL - pthread - sodium - ) + INTERFACE + core + serialiser + zmq + pthread + rest + sodium +) install( - TARGETS majordomo - EXPORT opencmwTargets - PUBLIC_HEADER DESTINATION include/opencmw + TARGETS majordomo + EXPORT opencmwTargets + PUBLIC_HEADER DESTINATION include/opencmw ) # setup tests diff --git a/src/majordomo/include/majordomo/Broker.hpp b/src/majordomo/include/majordomo/Broker.hpp index d7b309c9..0664713a 100644 --- a/src/majordomo/include/majordomo/Broker.hpp +++ b/src/majordomo/include/majordomo/Broker.hpp @@ -16,6 +16,8 @@ #include #include "Rbac.hpp" +#include "Rest.hpp" +#include "RestServer.hpp" #include "Topic.hpp" #include "URI.hpp" @@ -181,13 +183,16 @@ class Broker { using Timestamp = std::chrono::time_point; struct Client { - const zmq::Socket &socket; - const std::string id; - std::deque requests; - Timestamp expiry; + std::optional> socket; + const std::string id; + std::deque requests; + Timestamp expiry; explicit Client(const zmq::Socket &s, const std::string &id_, Timestamp expiry_) : socket(s), id(std::move(id_)), expiry{ std::move(expiry_) } {} + + explicit Client(const std::string &id_) + : id(std::move(id_)), expiry{} {} }; struct Worker { @@ -261,6 +266,7 @@ class Broker { const std::string brokerName; private: + std::unique_ptr _restServer; Timestamp _heartbeatAt = Clock::now() + settings.heartbeatInterval; SubscriptionMatcher _subscriptionMatcher; std::unordered_map> _subscribedClientsByTopic; // topic -> client IDs @@ -278,11 +284,13 @@ class Broker { std::atomic _shutdownRequested = false; // Sockets collection. The Broker class will be used as the handler - const zmq::Socket _routerSocket; - const zmq::Socket _pubSocket; - const zmq::Socket _subSocket; - const zmq::Socket _dnsSocket; - std::array pollerItems; + const zmq::Socket _routerSocket; + const zmq::Socket _pubSocket; + const zmq::Socket _subSocket; + const zmq::Socket _dnsSocket; + static constexpr std::size_t kNumberOfZmqSockets = 4; + static constexpr std::string_view kRestSourceId = "rest"; // assuming that ZMQ never assigns this ID. Alternatively connect an idle client and use its ID. + std::vector _pollerItems = std::vector(kNumberOfZmqSockets); public: Broker(std::string brokerName_, Settings settings_ = {}) @@ -394,14 +402,14 @@ class Broker { zmq::invoke(zmq_connect, _dnsSocket, INTERNAL_ADDRESS_BROKER.str().data()).assertSuccess(); } - pollerItems[0].socket = _routerSocket.zmq_ptr; - pollerItems[0].events = ZMQ_POLLIN; - pollerItems[1].socket = _pubSocket.zmq_ptr; - pollerItems[1].events = ZMQ_POLLIN; - pollerItems[2].socket = _subSocket.zmq_ptr; - pollerItems[2].events = ZMQ_POLLIN; - pollerItems[3].socket = _dnsSocket.zmq_ptr; - pollerItems[3].events = ZMQ_POLLIN; + _pollerItems[0].socket = _routerSocket.zmq_ptr; + _pollerItems[0].events = ZMQ_POLLIN; + _pollerItems[1].socket = _pubSocket.zmq_ptr; + _pollerItems[1].events = ZMQ_POLLIN; + _pollerItems[2].socket = _subSocket.zmq_ptr; + _pollerItems[2].events = ZMQ_POLLIN; + _pollerItems[3].socket = _dnsSocket.zmq_ptr; + _pollerItems[3].events = ZMQ_POLLIN; } Broker(const Broker &) = delete; @@ -427,6 +435,7 @@ class Broker { */ std::optional> bind(const URI &endpoint, BindOption option = BindOption::DetectFromURI) { assert(!(option == BindOption::DetectFromURI && (endpoint.scheme() == SCHEME_INPROC || endpoint.scheme() == SCHEME_TCP))); + const auto isRouterSocket = option != BindOption::Pub && (option == BindOption::Router || endpoint.scheme() == SCHEME_MDP || endpoint.scheme() == SCHEME_TCP); const auto zmqEndpoint = mdp::toZeroMQEndpoint(endpoint); @@ -446,6 +455,30 @@ class Broker { return adjustedAddressPublic; } + std::expected bindRest(rest::Settings restSettings) { + if (restSettings.certificateFilePath.empty() != restSettings.keyFilePath.empty()) { + return std::unexpected("Provide both certificate and key file paths (for HTTPS) or none (for HTTP)"); + } + if (restSettings.certificateFileBuffer.empty() != restSettings.keyFileBuffer.empty()) { + return std::unexpected("Provide both certificate and key file paths (for HTTPS) or none (for HTTP)"); + } + + std::expected maybeServer; + if (!restSettings.certificateFilePath.empty()) { + maybeServer = detail::rest::RestServer::sslWithPaths(std::move(restSettings.certificateFilePath), std::move(restSettings.keyFilePath)); + } else if (!restSettings.certificateFileBuffer.empty()) { + maybeServer = detail::rest::RestServer::sslWithBuffers(std::move(restSettings.certificateFileBuffer), std::move(restSettings.keyFileBuffer)); + } else { + maybeServer = detail::rest::RestServer::unencrypted(); + } + if (!maybeServer) { + return std::unexpected(maybeServer.error()); + } + _restServer = std::make_unique(std::move(maybeServer.value())); + _restServer->setHandlers(std::move(restSettings.handlers)); + return _restServer->bind(restSettings.port, restSettings.protocols); + } + void run() { sendDnsHeartbeats(true); // initial register of default routes @@ -462,6 +495,17 @@ class Broker { // test interface bool processMessages() { + for (std::size_t i = kNumberOfZmqSockets; i < _pollerItems.size(); ++i) { + const auto read = _pollerItems[i].revents & ZMQ_POLLIN; + const auto write = _pollerItems[i].revents & ZMQ_POLLOUT; + if (read || write) { + auto messages = _restServer->processReadWrite(_pollerItems[i].fd, read, write); + for (auto &message : messages) { + handleRestMessage(std::move(message)); + } + } + } + bool anythingReceived; int loopCount = 0; do { @@ -481,9 +525,13 @@ class Broker { loopCount++; } while (anythingReceived); + _pollerItems.resize(kNumberOfZmqSockets); + if (_restServer) { + _restServer->populatePollerItems(_pollerItems); + } + // N.B. block until data arrived or for at most one heart-beat interval - const auto result = zmq::invoke(zmq_poll, pollerItems.data(), static_cast(pollerItems.size()), settings.heartbeatInterval.count()); - return result.isValid(); + return zmq::invoke(zmq_poll, _pollerItems.data(), static_cast(_pollerItems.size()), settings.heartbeatInterval.count()).isValid(); } void cleanup() { @@ -564,7 +612,53 @@ class Broker { return true; } - bool receiveMessage(const zmq::Socket &socket) { + void handleRestMessage(BrokerMessage &&message) { + message.sourceId = std::string(kRestSourceId); + message.protocolName = mdp::clientProtocol; + switch (message.command) { + case mdp::Command::Get: + case mdp::Command::Set: { + auto [client, inserted] = _clients.try_emplace(message.sourceId, message.sourceId); + client->second.requests.emplace_back(std::move(message)); + } break; + case mdp::Command::Subscribe: { + mdp::Topic subscription; + try { + subscription = mdp::Topic::fromMdpTopic(message.topic); + } catch (...) { + // malformed topic, ignore + return; + } + + subscribe(subscription); + auto [it, inserted] = _subscribedClientsByTopic.try_emplace(subscription); + it->second.emplace(message.sourceId); + } break; + case mdp::Command::Unsubscribe: { + mdp::Topic subscription; + try { + subscription = mdp::Topic::fromMdpTopic(message.topic); + } catch (...) { + // malformed topic, ignore + return; + } + + unsubscribe(subscription); + auto it = _subscribedClientsByTopic.find(subscription); + if (it != _subscribedClientsByTopic.end()) { + it->second.erase(message.sourceId); + if (it->second.empty()) { + _subscribedClientsByTopic.erase(it); + } + } + } break; + default: + break; + } + } + + bool + receiveMessage(const zmq::Socket &socket) { auto maybeMessage = zmq::receive(socket); if (!maybeMessage) { @@ -709,9 +803,15 @@ class Broker { const auto it = _subscribedClientsByTopic.find(topic); if (it != _subscribedClientsByTopic.end()) { for (const auto &clientId : it->second) { - auto clientCopy = message; - clientCopy.sourceId = clientId; - zmq::send(std::move(clientCopy), _routerSocket).assertSuccess(); + auto clientCopy = message; + if (clientId == kRestSourceId) { + if (_restServer) { + _restServer->handleNotification(it->first, std::move(clientCopy)); + } + } else { + clientCopy.sourceId = clientId; + zmq::send(std::move(clientCopy), _routerSocket).assertSuccess(); + } } } } @@ -750,13 +850,19 @@ class Broker { if (client.requests.empty()) continue; - auto clientMessage = std::move(client.requests.back()); - client.requests.pop_back(); + auto clientMessage = std::move(client.requests.front()); + client.requests.pop_front(); if (auto service = bestMatchingService(clientMessage.serviceName)) { if (service->internalHandler) { auto reply = service->internalHandler(std::move(clientMessage)); - zmq::send(std::move(reply), client.socket).assertSuccess(); + if (client.id == kRestSourceId) { + if (_restServer) { + _restServer->handleResponse(std::move(reply)); + } + } else { + zmq::send(std::move(reply), client.socket.value()).assertSuccess(); + } } else { service->putMessage(std::move(clientMessage)); dispatch(*service); @@ -773,7 +879,13 @@ class Broker { reply.error = std::format("unknown service (error 501): '{}'", reply.serviceName); reply.rbac = _rbac; - zmq::send(std::move(reply), client.socket).assertSuccess(); + if (client.id == kRestSourceId) { + if (_restServer) { + _restServer->handleResponse(std::move(reply)); + } + } else { + zmq::send(std::move(reply), client.socket.value()).assertSuccess(); + } } } @@ -786,7 +898,7 @@ class Broker { const auto isExpired = [&now](const auto &c) { auto &[senderId, client] = c; - return client.expiry < now; + return senderId != kRestSourceId && client.expiry < now; }; std::erase_if(_clients, isExpired); @@ -876,7 +988,14 @@ class Broker { message.sourceId = message.serviceName; // serviceName=clientSourceID message.serviceName = worker.serviceName; message.protocolName = mdp::clientProtocol; - zmq::send(std::move(message), client->second.socket).assertSuccess(); + if (message.sourceId == kRestSourceId) { + if (_restServer) { + _restServer->handleResponse(std::move(message)); + } + } else { + zmq::send(std::move(message), client->second.socket.value()).assertSuccess(); + } + workerWaiting(worker); } else { disconnectWorker(worker); diff --git a/src/majordomo/include/majordomo/LoadTestWorker.hpp b/src/majordomo/include/majordomo/LoadTestWorker.hpp new file mode 100644 index 00000000..0241954c --- /dev/null +++ b/src/majordomo/include/majordomo/LoadTestWorker.hpp @@ -0,0 +1,83 @@ + +#ifndef OPENCMW_MAJORDOMO_LOADTESTWORKER_H +#define OPENCMW_MAJORDOMO_LOADTESTWORKER_H + +#include "QuerySerialiser.hpp" +#include +#include + +#include + +namespace opencmw::majordomo::load_test { + +using opencmw::load_test::Context; +using opencmw::load_test::Payload; + +template +struct Worker : public opencmw::majordomo::Worker { + using super_t = opencmw::majordomo::Worker; + std::jthread _notifier; + + struct SubscriptionInfo : public Context { + SubscriptionInfo(const Context &ctx) + : Context(ctx) {} + + std::int64_t nextIndex = 0; + std::int64_t initialDelayLeftMs = initialDelayMs; + std::int64_t updatesLeft = nUpdates; + }; + std::unordered_map _subscriptions; + + template + explicit Worker(BrokerType &broker) + : super_t(broker, {}) { + opencmw::query::registerTypes(Context(), broker); + + super_t::setCallback([](opencmw::majordomo::RequestContext & /*rawCtx*/, const Context & /*inCtx*/, const Payload &in, Context & /*outCtx*/, Payload &out) { + out.data = in.data; + out.timestampNs = opencmw::load_test::timestamp().count(); + }); + _notifier = std::jthread([this](std::stop_token stopToken) { + std::int64_t timecounter = 0; + + std::unordered_map subscriptions; + + while (!stopToken.stop_requested()) { + const auto stepMs = 10; + std::this_thread::sleep_for(std::chrono::milliseconds(stepMs)); + timecounter += stepMs; + auto activeSubscriptions = super_t::activeSubscriptions(); + std::erase_if(subscriptions, [&activeSubscriptions](const auto &sub) { return !activeSubscriptions.contains(sub.first); }); + for (const auto &active : activeSubscriptions) { + auto it = subscriptions.find(active); + if (it == subscriptions.end()) { + // TODO protect against unexpected queries + auto ctx = query::deserialise(active.params()); + it = subscriptions.try_emplace(active, ctx).first; + } + auto &sub = it->second; + + sub.initialDelayLeftMs -= stepMs; + + if (sub.initialDelayLeftMs > 0) { + continue; + } + if (sub.updatesLeft == 0) { + continue; + } + if (sub.intervalMs > 0 && timecounter % sub.intervalMs != 0) { + continue; + } + Payload payload{ sub.nextIndex, std::string(static_cast(sub.payloadSize), 'x'), opencmw::load_test::timestamp().count() }; + super_t::notify(sub, std::move(payload)); + sub.nextIndex++; + sub.updatesLeft--; + } + } + }); + } +}; + +} // namespace opencmw::majordomo::load_test + +#endif diff --git a/src/majordomo/include/majordomo/NgTcp2Util.hpp b/src/majordomo/include/majordomo/NgTcp2Util.hpp new file mode 100644 index 00000000..7f62180e --- /dev/null +++ b/src/majordomo/include/majordomo/NgTcp2Util.hpp @@ -0,0 +1,994 @@ +/* + * ngtcp2 + * + * Copyright (c) 2017 ngtcp2 contributors + * Copyright (c) 2012 nghttp2 contributors + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#ifndef OPENCMW_NGTCP2UTIL_H +#define OPENCMW_NGTCP2UTIL_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include +#include +#include + +// This is code collected from the ngtcp2 examples (network.h, shared.h, util.h, etc.) +namespace opencmw::majordomo::detail::rest { + +inline auto CRYPTO_load_u64_le(std::span in) { + uint64_t v; + + memcpy(&v, in.data(), sizeof(v)); + + if constexpr (std::endian::native == std::endian::big) { + return std::byteswap(v); + } + + return v; +} + +constexpr void siphash_round(uint64_t v[4]) { + v[0] += v[1]; + v[2] += v[3]; + v[1] = std::rotl(v[1], 13); + v[3] = std::rotl(v[3], 16); + v[1] ^= v[0]; + v[3] ^= v[2]; + v[0] = std::rotl(v[0], 32); + v[2] += v[1]; + v[0] += v[3]; + v[1] = std::rotl(v[1], 17); + v[3] = std::rotl(v[3], 21); + v[1] ^= v[2]; + v[3] ^= v[0]; + v[2] = std::rotl(v[2], 32); +} + +// SipHash is a fast, secure PRF that is often used for hash tables. + +// siphash24 implements SipHash-2-4. See +// https://131002.net/siphash/siphash.pdf +inline uint64_t siphash24(std::span key, std::span input) { + const auto orig_input_len = input.size(); + uint64_t v[]{ + key[0] ^ UINT64_C(0x736f6d6570736575), + key[1] ^ UINT64_C(0x646f72616e646f6d), + key[0] ^ UINT64_C(0x6c7967656e657261), + key[1] ^ UINT64_C(0x7465646279746573), + }; + + while (input.size() >= sizeof(uint64_t)) { + auto m = CRYPTO_load_u64_le(input.first()); + v[3] ^= m; + siphash_round(v); + siphash_round(v); + v[0] ^= m; + + input = input.subspan(sizeof(uint64_t)); + } + + std::array last_block{}; + std::ranges::copy(input, std::begin(last_block)); + last_block.back() = orig_input_len & 0xff; + + auto last_block_word = CRYPTO_load_u64_le(last_block); + v[3] ^= last_block_word; + siphash_round(v); + siphash_round(v); + v[0] ^= last_block_word; + + v[2] ^= 0xff; + siphash_round(v); + siphash_round(v); + siphash_round(v); + siphash_round(v); + + return v[0] ^ v[1] ^ v[2] ^ v[3]; +} + +// Define here to be usable in tests. +template +T byteswap(T v) { + auto c = std::bit_cast>(v); + std::ranges::reverse(c); + return std::bit_cast(c); +} + +// inspired by , but our +// template can take functions returning other than void. +template +struct Defer { + Defer(F &&f, T &&...t) + : f(std::bind(std::forward(f), std::forward(t)...)) {} + Defer(Defer &&o) noexcept + : f(std::move(o.f)) {} + ~Defer() { f(); } + + using ResultType = std::invoke_result_t; + std::function f; +}; + +template +Defer defer(F &&f, T &&...t) { + return Defer(std::forward(f), std::forward(t)...); +} + +template +constexpr size_t array_size(T (&)[N]) { + return N; +} + +template +constexpr size_t str_size(T (&)[N]) { + return N - 1; +} + +template +[[nodiscard]] std::span +as_writable_uint8_span(std::span s) noexcept { + return std::span < uint8_t, + N == std::dynamic_extent + ? std::dynamic_extent + : N * sizeof(T) > { reinterpret_cast(s.data()), s.size_bytes() }; +} + +enum network_error { + NETWORK_ERR_OK = 0, + NETWORK_ERR_FATAL = -10, + NETWORK_ERR_SEND_BLOCKED = -11, + NETWORK_ERR_CLOSE_WAIT = -12, + NETWORK_ERR_RETRY = -13, + NETWORK_ERR_DROP_CONN = -14, +}; + +union in_addr_union { + in_addr in; + in6_addr in6; +}; + +union sockaddr_union { + sockaddr_storage storage; + sockaddr sa; + sockaddr_in6 in6; + sockaddr_in in; +}; + +struct Address { + socklen_t len; + union sockaddr_union su; + uint32_t ifindex; +}; + +enum class AppProtocol { + H3, + HQ, +}; + +constexpr uint8_t HQ_ALPN[] = "\xahq-interop"; +constexpr uint8_t HQ_ALPN_V1[] = "\xahq-interop"; + +constexpr uint8_t H3_ALPN[] = "\x2h3"; +constexpr uint8_t H3_ALPN_V1[] = "\x2h3"; + +// msghdr_get_ecn gets ECN bits from |msg|. |family| is the address +// family from which packet is received. +inline unsigned int msghdr_get_ecn(msghdr *msg, int family) { + switch (family) { + case AF_INET: + for (auto cmsg = CMSG_FIRSTHDR(msg); cmsg; cmsg = CMSG_NXTHDR(msg, cmsg)) { + if (cmsg->cmsg_level == IPPROTO_IP && +#ifdef __APPLE__ + cmsg->cmsg_type == IP_RECVTOS +#else // !defined(__APPLE__) + cmsg->cmsg_type == IP_TOS +#endif // !defined(__APPLE__) + && cmsg->cmsg_len) { + return *reinterpret_cast(CMSG_DATA(cmsg)) & IPTOS_ECN_MASK; + } + } + break; + case AF_INET6: + for (auto cmsg = CMSG_FIRSTHDR(msg); cmsg; cmsg = CMSG_NXTHDR(msg, cmsg)) { + if (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_TCLASS && cmsg->cmsg_len) { + unsigned int tos; + + memcpy(&tos, CMSG_DATA(cmsg), sizeof(int)); + + return tos & IPTOS_ECN_MASK; + } + } + break; + } + + return 0; +} + +// fd_set_recv_ecn sets socket option to |fd| so that it can receive +// ECN bits. +inline void fd_set_recv_ecn(int fd, int family) { + unsigned int tos = 1; + switch (family) { + case AF_INET: + if (setsockopt(fd, IPPROTO_IP, IP_RECVTOS, &tos, + static_cast(sizeof(tos))) + == -1) { + std::cerr << "setsockopt: " << strerror(errno) << std::endl; + } + break; + case AF_INET6: + if (setsockopt(fd, IPPROTO_IPV6, IPV6_RECVTCLASS, &tos, + static_cast(sizeof(tos))) + == -1) { + std::cerr << "setsockopt: " << strerror(errno) << std::endl; + } + break; + } +} + +// fd_set_ip_mtu_discover sets IP(V6)_MTU_DISCOVER socket option to |fd|. +inline void fd_set_ip_mtu_discover(int fd, int family) { +#if defined(IP_MTU_DISCOVER) && defined(IPV6_MTU_DISCOVER) + int val; + + switch (family) { + case AF_INET: + val = IP_PMTUDISC_DO; + if (setsockopt(fd, IPPROTO_IP, IP_MTU_DISCOVER, &val, + static_cast(sizeof(val))) + == -1) { + std::cerr << "setsockopt: IP_MTU_DISCOVER: " << strerror(errno) << std::endl; + } + break; + case AF_INET6: + val = IPV6_PMTUDISC_DO; + if (setsockopt(fd, IPPROTO_IPV6, IPV6_MTU_DISCOVER, &val, + static_cast(sizeof(val))) + == -1) { + std::cerr << "setsockopt: IPV6_MTU_DISCOVER: " << strerror(errno) << std::endl; + } + break; + } +#endif // defined(IP_MTU_DISCOVER) && defined(IPV6_MTU_DISCOVER) +} + +// fd_set_ip_dontfrag sets IP(V6)_DONTFRAG socket option to |fd|. +inline void fd_set_ip_dontfrag(int fd, int family) { +#if defined(IP_DONTFRAG) && defined(IPV6_DONTFRAG) + int val = 1; + + switch (family) { + case AF_INET: + if (setsockopt(fd, IPPROTO_IP, IP_DONTFRAG, &val, + static_cast(sizeof(val))) + == -1) { + std::cerr << "setsockopt: IP_DONTFRAG: " << strerror(errno) << std::endl; + } + break; + case AF_INET6: + if (setsockopt(fd, IPPROTO_IPV6, IPV6_DONTFRAG, &val, + static_cast(sizeof(val))) + == -1) { + std::cerr << "setsockopt: IPV6_DONTFRAG: " << strerror(errno) + << std::endl; + } + break; + } +#else + std::ignore = fd; + std::ignore = family; +#endif // defined(IP_DONTFRAG) && defined(IPV6_DONTFRAG) +} + +inline std::optional
msghdr_get_local_addr(msghdr *msg, int family) { + switch (family) { + case AF_INET: + for (auto cmsg = CMSG_FIRSTHDR(msg); cmsg; cmsg = CMSG_NXTHDR(msg, cmsg)) { + if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_PKTINFO) { + in_pktinfo pktinfo; + memcpy(&pktinfo, CMSG_DATA(cmsg), sizeof(pktinfo)); + Address res; + memset(&res, 0, sizeof(res)); + res.len = sizeof(res.su.in); + res.ifindex = static_cast(pktinfo.ipi_ifindex); + auto &sa = res.su.in; + sa.sin_family = AF_INET; + sa.sin_addr = pktinfo.ipi_addr; + return res; + } + } + return {}; + case AF_INET6: + for (auto cmsg = CMSG_FIRSTHDR(msg); cmsg; cmsg = CMSG_NXTHDR(msg, cmsg)) { + if (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_PKTINFO) { + in6_pktinfo pktinfo; + memcpy(&pktinfo, CMSG_DATA(cmsg), sizeof(pktinfo)); + Address res; + memset(&res, 0, sizeof(res)); + res.len = sizeof(res.su.in6); + res.ifindex = static_cast(pktinfo.ipi6_ifindex); + auto &sa = res.su.in6; + sa.sin6_family = AF_INET6; + sa.sin6_addr = pktinfo.ipi6_addr; + return res; + } + } + return {}; + } + return {}; +} + +// msghdr_get_udp_gro returns UDP_GRO value from |msg|. If UDP_GRO is +// not found, or UDP_GRO is not supported, this function returns 0. +inline size_t msghdr_get_udp_gro(msghdr *msg) { + int gso_size = 0; + +#ifdef UDP_GRO + for (auto cmsg = CMSG_FIRSTHDR(msg); cmsg; cmsg = CMSG_NXTHDR(msg, cmsg)) { + if (cmsg->cmsg_level == SOL_UDP && cmsg->cmsg_type == UDP_GRO) { + memcpy(&gso_size, CMSG_DATA(cmsg), sizeof(gso_size)); + + break; + } + } +#else + std::ignore = msg; +#endif // defined(UDP_GRO) + + return static_cast(gso_size); +} + +inline void set_port(Address &dst, Address &src) { + switch (dst.su.storage.ss_family) { + case AF_INET: + assert(AF_INET == src.su.storage.ss_family); + dst.su.in.sin_port = src.su.in.sin_port; + return; + case AF_INET6: + assert(AF_INET6 == src.su.storage.ss_family); + dst.su.in6.sin6_port = src.su.in6.sin6_port; + return; + default: + assert(0); + } +} + +inline nghttp3_nv make_nv(const std::string_view &name, + const std::string_view &value, uint8_t flags) { + return nghttp3_nv{ + reinterpret_cast(const_cast(std::data(name))), + reinterpret_cast(const_cast(std::data(value))), + name.size(), + value.size(), + flags, + }; +} + +inline nghttp3_nv make_nv_cc(const std::string_view &name, + const std::string_view &value) { + return make_nv(name, value, NGHTTP3_NV_FLAG_NONE); +} + +inline nghttp3_nv make_nv_nc(const std::string_view &name, + const std::string_view &value) { + return make_nv(name, value, NGHTTP3_NV_FLAG_NO_COPY_NAME); +} + +inline nghttp3_nv make_nv_nn(const std::string_view &name, + const std::string_view &value) { + return make_nv(name, value, + NGHTTP3_NV_FLAG_NO_COPY_NAME | NGHTTP3_NV_FLAG_NO_COPY_VALUE); +} + +inline ngtcp2_tstamp timestamp() { + auto c = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + return static_cast(c); +} + +inline char lowcase(char c) { + constexpr static unsigned char tbl[] = { + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 'a', + 'b', + 'c', + 'd', + 'e', + 'f', + 'g', + 'h', + 'i', + 'j', + 'k', + 'l', + 'm', + 'n', + 'o', + 'p', + 'q', + 'r', + 's', + 't', + 'u', + 'v', + 'w', + 'x', + 'y', + 'z', + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 124, + 125, + 126, + 127, + 128, + 129, + 130, + 131, + 132, + 133, + 134, + 135, + 136, + 137, + 138, + 139, + 140, + 141, + 142, + 143, + 144, + 145, + 146, + 147, + 148, + 149, + 150, + 151, + 152, + 153, + 154, + 155, + 156, + 157, + 158, + 159, + 160, + 161, + 162, + 163, + 164, + 165, + 166, + 167, + 168, + 169, + 170, + 171, + 172, + 173, + 174, + 175, + 176, + 177, + 178, + 179, + 180, + 181, + 182, + 183, + 184, + 185, + 186, + 187, + 188, + 189, + 190, + 191, + 192, + 193, + 194, + 195, + 196, + 197, + 198, + 199, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + 211, + 212, + 213, + 214, + 215, + 216, + 217, + 218, + 219, + 220, + 221, + 222, + 223, + 224, + 225, + 226, + 227, + 228, + 229, + 230, + 231, + 232, + 233, + 234, + 235, + 236, + 237, + 238, + 239, + 240, + 241, + 242, + 243, + 244, + 245, + 246, + 247, + 248, + 249, + 250, + 251, + 252, + 253, + 254, + 255, + }; + return static_cast(tbl[static_cast(c)]); +} + +struct CaseCmp { + bool operator()(char lhs, char rhs) const { + return lowcase(lhs) == lowcase(rhs); + } +}; + +template +bool istarts_with(InputIterator1 first1, InputIterator1 last1, + InputIterator2 first2, InputIterator2 last2) { + if (last1 - first1 < last2 - first2) { + return false; + } + return std::equal(first2, last2, first1, CaseCmp()); +} + +template +bool istarts_with(const S &a, const T &b) { + return istarts_with(a.begin(), a.end(), b.begin(), b.end()); +} + +// make_cid_key returns the key for |cid|. +std::string_view make_cid_key(const ngtcp2_cid *cid); +inline ngtcp2_cid make_cid_key(std::span cid) { + assert(cid.size() <= NGTCP2_MAX_CIDLEN); + + ngtcp2_cid res; + + std::ranges::copy(cid, std::begin(res.data)); + res.datalen = cid.size(); + + return res; +} + +// straddr stringifies |sa| of length |salen| in a format "[IP]:PORT". +inline std::string straddr(const sockaddr *sa, socklen_t salen) { + std::array host; + std::array port; + + auto rv = getnameinfo(sa, salen, host.data(), host.size(), port.data(), + port.size(), NI_NUMERICHOST | NI_NUMERICSERV); + if (rv != 0) { + std::cerr << "getnameinfo: " << gai_strerror(rv) << std::endl; + return ""; + } + std::string res = "["; + res.append(host.data(), strlen(host.data())); + res += "]:"; + res.append(port.data(), strlen(port.data())); + return res; +} + +constexpr char B64_CHARS[] = { + 'A', + 'B', + 'C', + 'D', + 'E', + 'F', + 'G', + 'H', + 'I', + 'J', + 'K', + 'L', + 'M', + 'N', + 'O', + 'P', + 'Q', + 'R', + 'S', + 'T', + 'U', + 'V', + 'W', + 'X', + 'Y', + 'Z', + 'a', + 'b', + 'c', + 'd', + 'e', + 'f', + 'g', + 'h', + 'i', + 'j', + 'k', + 'l', + 'm', + 'n', + 'o', + 'p', + 'q', + 'r', + 's', + 't', + 'u', + 'v', + 'w', + 'x', + 'y', + 'z', + '0', + '1', + '2', + '3', + '4', + '5', + '6', + '7', + '8', + '9', + '+', + '/', +}; + +template +std::string b64encode(InputIt first, InputIt last) { + std::string res; + size_t len = last - first; + if (len == 0) { + return res; + } + size_t r = len % 3; + res.resize((len + 2) / 3 * 4); + auto j = last - r; + auto p = std::begin(res); + while (first != j) { + uint32_t n = static_cast(*first++) << 16; + n += static_cast(*first++) << 8; + n += static_cast(*first++); + *p++ = B64_CHARS[n >> 18]; + *p++ = B64_CHARS[(n >> 12) & 0x3fu]; + *p++ = B64_CHARS[(n >> 6) & 0x3fu]; + *p++ = B64_CHARS[n & 0x3fu]; + } + + if (r == 2) { + uint32_t n = static_cast(*first++) << 16; + n += static_cast(*first++) << 8; + *p++ = B64_CHARS[n >> 18]; + *p++ = B64_CHARS[(n >> 12) & 0x3fu]; + *p++ = B64_CHARS[(n >> 6) & 0x3fu]; + *p++ = '='; + } else if (r == 1) { + uint32_t n = static_cast(*first++) << 16; + *p++ = B64_CHARS[n >> 18]; + *p++ = B64_CHARS[(n >> 12) & 0x3fu]; + *p++ = '='; + *p++ = '='; + } + return res; +} + +// format_uint converts |n| into string. +template +std::string format_uint(T n) { + if (n == 0) { + return "0"; + } + size_t nlen = 0; + for (auto t = n; t; t /= 10, ++nlen); + std::string res(nlen, '\0'); + for (; n; n /= 10) { + res[--nlen] = (n % 10) + '0'; + } + return res; +} + +// format_uint_iec converts |n| into string with the IEC unit (either +// "G", "M", or "K"). It chooses the largest unit which does not drop +// precision. +template +std::string format_uint_iec(T n) { + if (n >= (1 << 30) && (n & ((1 << 30) - 1)) == 0) { + return format_uint(n / (1 << 30)) + 'G'; + } + if (n >= (1 << 20) && (n & ((1 << 20) - 1)) == 0) { + return format_uint(n / (1 << 20)) + 'M'; + } + if (n >= (1 << 10) && (n & ((1 << 10) - 1)) == 0) { + return format_uint(n / (1 << 10)) + 'K'; + } + return format_uint(n); +} + +// generate_secure_random generates a cryptographically secure pseudo +// random data of |data|. +inline int generate_secure_random(std::span data) { + if (RAND_bytes(data.data(), static_cast(data.size())) != 1) { + return -1; + } + + return 0; +} + +// generate_secret generates secret and writes it to |secret|. +// Currently, |secret| must be 32 bytes long. +inline int generate_secret(std::span secret) { + std::array rand; + + if (generate_secure_random(rand) != 0) { + return -1; + } + + auto ctx = EVP_MD_CTX_new(); + if (ctx == nullptr) { + return -1; + } + + auto ctx_deleter = defer(EVP_MD_CTX_free, ctx); + + unsigned int mdlen = static_cast(secret.size()); + if (!EVP_DigestInit_ex(ctx, EVP_sha256(), nullptr) || !EVP_DigestUpdate(ctx, rand.data(), rand.size()) || !EVP_DigestFinal_ex(ctx, secret.data(), &mdlen)) { + return -1; + } + + return 0; +} + +constexpr bool is_digit(const char c) { return '0' <= c && c <= '9'; } + +constexpr bool is_hex_digit(const char c) { + return is_digit(c) || ('A' <= c && c <= 'F') || ('a' <= c && c <= 'f'); +} + +// Returns integer corresponding to hex notation |c|. If +// is_hex_digit(c) is false, it returns 256. +constexpr uint32_t hex_to_uint(char c) { + if (c <= '9') { + return static_cast(c - '0'); + } + if (c <= 'Z') { + return static_cast(c - 'A' + 10); + } + if (c <= 'z') { + return static_cast(c - 'a' + 10); + } + return 256; +} + +template +std::string percent_decode(InputIt first, InputIt last) { + std::string result; + result.resize(last - first); + auto p = std::begin(result); + for (; first != last; ++first) { + if (*first != '%') { + *p++ = *first; + continue; + } + + if (first + 1 != last && first + 2 != last && is_hex_digit(*(first + 1)) && is_hex_digit(*(first + 2))) { + *p++ = (hex_to_uint(*(first + 1)) << 4) + hex_to_uint(*(first + 2)); + first += 2; + continue; + } + + *p++ = *first; + } + result.resize(static_cast(p - std::begin(result))); + return result; +} + +inline int create_nonblock_socket(int domain, int type, int protocol) { +#ifdef SOCK_NONBLOCK + auto fd = socket(domain, type | SOCK_NONBLOCK, protocol); + if (fd == -1) { + return -1; + } +#else // !defined(SOCK_NONBLOCK) + auto fd = socket(domain, type, protocol); + if (fd == -1) { + return -1; + } + + make_socket_nonblocking(fd); +#endif // !defined(SOCK_NONBLOCK) + + return fd; +} + +} // namespace opencmw::majordomo::detail::rest + +namespace std { +template<> +struct hash { + hash() { + auto rv = opencmw::majordomo::detail::rest::generate_secure_random(opencmw::majordomo::detail::rest::as_writable_uint8_span(std::span(key))); + std::ignore = rv; + assert(rv == 0 && "Failed to generate hash key"); + } + + std::size_t operator()(const ngtcp2_cid &cid) const noexcept { + return static_cast(opencmw::majordomo::detail::rest::siphash24(key, { cid.data, cid.datalen })); + } + + std::array key; +}; +} // namespace std + +inline bool operator==(const ngtcp2_cid &lhs, const ngtcp2_cid &rhs) { + return ngtcp2_cid_eq(&lhs, &rhs); +} + +#endif // UTIL_H diff --git a/src/majordomo/include/majordomo/Rest.hpp b/src/majordomo/include/majordomo/Rest.hpp new file mode 100644 index 00000000..a97165bd --- /dev/null +++ b/src/majordomo/include/majordomo/Rest.hpp @@ -0,0 +1,272 @@ +#ifndef OPENCMW_MAJORDOMO_REST_HPP +#define OPENCMW_MAJORDOMO_REST_HPP + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace opencmw::majordomo::rest { + +struct Request { + std::string method; + std::string path; + std::vector> headers; +}; + +struct Response { + struct ReadStatus { + std::size_t bytesWritten; + bool hasMore; + }; + using WriterFunction = std::function(std::span)>; + int code; + std::vector> headers; + // To provide a body, set one of the following: + WriterFunction bodyReader; + IoBuffer body; //< owned data + // Data for the body not owned by the view. Handler must ensure lifetime + // (Note that the IoBuffer API allows non-owning cases, but they fail when the buffer is moved) + std::span bodyView; +}; + +struct Handler { + std::string method; + std::string path; + std::function handler; +}; + +inline auto mimeTypeFromExtension(std::string_view path) { + if (path.ends_with(".html")) { + return "text/html"; + } + if (path.ends_with(".css")) { + return "text/css"; + } + if (path.ends_with(".js")) { + return "application/javascript"; + } + if (path.ends_with(".png")) { + return "image/png"; + } + if (path.ends_with(".jpg") || path.ends_with(".jpeg")) { + return "image/jpeg"; + } + if (path.ends_with(".wasm")) { + return "application/wasm"; + } + return "text/plain"; +} + +inline auto cmrcHandler(std::string path, std::string prefix, std::shared_ptr vfs, std::string vprefix) { + return Handler{ + .method = "GET", + .path = path, + .handler = [path, prefix, vfs, vprefix](const Request &request) -> Response { + try { + auto p = std::string_view{ request.path }; + if (p.starts_with(prefix)) { + p.remove_prefix(prefix.size()); + } + while (p.starts_with("/")) { + p.remove_prefix(1); + } + auto file = vfs->open(vprefix + std::string{ p }); + Response response; + response.code = 200; + response.headers.emplace_back("content-type", mimeTypeFromExtension(p)); + response.headers.emplace_back("content-length", std::to_string(file.size())); + response.body = IoBuffer(file.begin(), file.size()); + return response; + } catch (...) { + Response response; + response.code = 404; + response.headers.emplace_back("content-type", "text/plain"); + response.body = IoBuffer("Not found"); + return response; + } + } + }; +} + +namespace detail { +struct MappedFile { + int _fd = -1; + size_t _size = 0; + std::uint8_t *_data = nullptr; + + static std::expected map(const std::filesystem::path &path) { + MappedFile f; + f._fd = open(path.c_str(), O_RDONLY); + if (f._fd == -1) { + return std::unexpected(strerror(errno)); + } + f._size = std::filesystem::file_size(path); + f._data = static_cast(mmap(nullptr, f._size, PROT_READ, MAP_PRIVATE, f._fd, 0)); + if (f._data == MAP_FAILED) { + return std::unexpected(strerror(errno)); + } + return f; + } + + MappedFile() = default; + + ~MappedFile() { + if (_data && _data != MAP_FAILED) { + munmap(_data, _size); + } + if (_fd != -1) { + ::close(_fd); + } + } + + MappedFile(const MappedFile &) = delete; + MappedFile &operator=(const MappedFile &) = delete; + MappedFile(MappedFile &&other) noexcept { + if (_fd != -1) { + munmap(_data, _size); + ::close(_fd); + } + _fd = std::exchange(other._fd, -1); + _size = std::exchange(other._size, 0); + _data = std::exchange(other._data, nullptr); + } + + MappedFile &operator=(MappedFile &&other) noexcept { + if (this != &other) { + if (_fd != -1) { + munmap(_data, _size); + ::close(_fd); + } + _fd = std::exchange(other._fd, -1); + _size = std::exchange(other._size, 0); + _data = std::exchange(other._data, nullptr); + } + return *this; + } + + std::span view() const { + return { _data, _size }; + } + + size_t size() const { + return _size; + } +}; + +} // namespace detail + +inline auto fileSystemHandler(std::string path, std::string prefix, std::filesystem::path root, std::vector> extraHeaders = {}, std::size_t mmapThreshold = 1024 * 1024 * 100) { + return Handler{ + .method = "GET", + .path = path, + .handler = [mappedFiles = std::make_shared>(), root, path, prefix, extraHeaders = std::move(extraHeaders), mmapThreshold](const Request &request) -> Response { + auto p = std::string_view{ request.path }; + if (p.starts_with(prefix)) { + p.remove_prefix(prefix.size()); + } + while (p.starts_with("/")) { + p.remove_prefix(1); + } + auto file = root / p; + if (!std::filesystem::exists(file)) { + Response response; + response.code = 404; + response.headers.emplace_back("content-type", "text/plain"); + response.body = IoBuffer("Not found"); + return response; + } + + // mmap if file is large enough + if (std::filesystem::file_size(file) >= mmapThreshold) { + auto it = mappedFiles->find(file); + + if (it == mappedFiles->end()) { + auto mapped = detail::MappedFile::map(file); + if (!mapped) { + auto &error = mapped.error(); + Response response; + response.code = 500; + response.headers.emplace_back("content-type", "text/plain"); + response.headers.emplace_back("content-length", std::to_string(error.size())); + response.body = IoBuffer(error.data(), error.size()); + return response; + } + it = mappedFiles->emplace(file, std::move(mapped.value())).first; + } + + Response response; + response.code = 200; + response.headers.emplace_back("content-type", mimeTypeFromExtension(file.string())); + auto view = it->second.view(); + response.headers.emplace_back("content-length", std::to_string(view.size())); + response.bodyView = view; + return response; + } + + // otherwise use ifstream + try { + auto fileStream = std::make_shared(file, std::ios::binary); + + if (!fileStream->is_open()) { + Response response; + response.code = 500; + response.headers.emplace_back("content-type", "text/plain"); + response.body = IoBuffer("Internal Server Error"); + return response; + } + + Response response; + response.code = 200; + response.headers = extraHeaders; + response.headers.emplace_back("content-type", mimeTypeFromExtension(request.path)); + response.headers.emplace_back("content-length", std::to_string(std::filesystem::file_size(file))); + response.bodyReader = [fileStream = std::move(fileStream)](std::span buffer) -> std::expected { + fileStream->read(reinterpret_cast(buffer.data()), static_cast(buffer.size())); + if (fileStream->bad()) { + return std::unexpected(std::format("Failed to read file: {}", strerror(errno))); + } + return Response::ReadStatus{ static_cast(fileStream->gcount()), !fileStream->eof() }; + }; + return response; + } catch (...) { + Response response; + response.code = 500; + response.headers.emplace_back("content-type", "text/plain"); + response.body = IoBuffer("Internal Server Error"); + return response; + } + } + }; +} + +enum Protocol { + Http2 = 0x1, + Http3 = 0x2, +}; + +struct Settings { + uint16_t port = 8080; + std::filesystem::path certificateFilePath; + std::filesystem::path keyFilePath; + std::string certificateFileBuffer; + std::string keyFileBuffer; + std::string dnsAddress; + std::vector handlers; + int protocols = Protocol::Http2 | Protocol::Http3; +}; + +} // namespace opencmw::majordomo::rest + +#endif diff --git a/src/majordomo/include/majordomo/RestBackend.hpp b/src/majordomo/include/majordomo/RestBackend.hpp deleted file mode 100644 index ee599d87..00000000 --- a/src/majordomo/include/majordomo/RestBackend.hpp +++ /dev/null @@ -1,903 +0,0 @@ -#ifndef OPENCMW_MAJORDOMO_RESTBACKEND_P_H -#define OPENCMW_MAJORDOMO_RESTBACKEND_P_H - -// STD -#include -#include -#include -#include -#include -#include - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wformat-nonliteral" -#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" -#pragma GCC diagnostic ignored "-Wold-style-cast" -#pragma GCC diagnostic ignored "-Wshadow" -#pragma GCC diagnostic ignored "-Wuninitialized" -#pragma GCC diagnostic ignored "-Wuseless-cast" -#undef CPPHTTPLIB_THREAD_POOL_COUNT -#define CPPHTTPLIB_THREAD_POOL_COUNT 8 -#include -#pragma GCC diagnostic pop - -// Core -#include -#include -#include - -#include -#include -#include -#include - -// Majordomo -#include -#include - -#include -#include -CMRC_DECLARE(assets); - -struct FormData { - std::unordered_map fields; -}; -ENABLE_REFLECTION_FOR(FormData, fields) - -struct Service { - std::string name; - std::string description; - - Service(std::string name_, std::string description_) - : name{ std::move(name_) }, description{ std::move(description_) } {} -}; -ENABLE_REFLECTION_FOR(Service, name, description) - -struct ServicesList { - std::vector services; -}; -ENABLE_REFLECTION_FOR(ServicesList, services) - -namespace opencmw::majordomo { - -using namespace std::chrono_literals; - -constexpr auto HTTP_OK = 200; -constexpr auto HTTP_ERROR = 500; -constexpr auto HTTP_GATEWAY_TIMEOUT = 504; -constexpr auto DEFAULT_REST_PORT = 8080; -constexpr auto UPDATER_POLLING_TIME = 1s; -constexpr auto LONG_POLL_SERVER_TIMEOUT = 30s; -constexpr auto UNUSED_SUBSCRIPTION_EXPIRATION_TIME = 30s; -constexpr std::size_t MAX_CACHED_REPLIES = 32; - -namespace detail { -// Provides a safe alternative to getenv -inline const char *getEnvFilenameOr(const char *field, const char *defaultValue) { - const char *result = ::getenv(field); - if (result == nullptr) { - result = defaultValue; - } - - if (!std::filesystem::exists(result)) { - throw opencmw::startup_error(std::format("File {} not found. Current path is {}", result, std::filesystem::current_path().string())); - } - return result; -} -} // namespace detail - -struct HTTPS { - constexpr static std::string_view DEFAULT_REST_SCHEME = "https"; - httplib::SSLServer _svr; - - HTTPS() - : _svr(detail::getEnvFilenameOr("OPENCMW_REST_CERT_FILE", "demo_public.crt"), detail::getEnvFilenameOr("OPENCMW_REST_PRIVATE_KEY_FILE", "demo_private.key")) {} -}; - -struct PLAIN_HTTP { - constexpr static std::string_view DEFAULT_REST_SCHEME = "https"; - httplib::Server _svr; -}; - -namespace detail { -using PollingIndex = std::uint64_t; - -using ReadLock = std::shared_lock; -using WriteLock = std::unique_lock; - -enum class RestMethod { - Get, - Subscribe, - LongPoll, - Post, - Invalid -}; - -inline std::string_view acceptedMimeForRequest(const auto &request) { - static constexpr std::array acceptableMimeTypes = { - MIME::JSON.typeName(), MIME::HTML.typeName(), MIME::BINARY.typeName() - }; - auto accepted = [](auto format) { - const auto it = std::find(acceptableMimeTypes.cbegin(), acceptableMimeTypes.cend(), format); - return std::make_pair( - it != acceptableMimeTypes.cend(), - it); - }; - - if (request.has_header("Content-Type")) { - std::string format = request.get_header_value("Content-Type"); - if (const auto [found, where] = accepted(format); found) { - return *where; - } - } - if (request.has_param("contentType")) { - std::string format = request.get_param_value("contentType"); - if (const auto [found, where] = accepted(format); found) { - return *where; - } - } - - auto isDelimiter = [](char c) { return c == ' ' || c == ','; }; - const auto &acceptHeader = request.get_header_value("Accept"); - auto from = acceptHeader.cbegin(); - const auto end = acceptHeader.cend(); - - while (from != end) { - from = std::find_if_not(from, end, isDelimiter); - auto to = std::find_if(from, end, isDelimiter); - if (from != end) { - std::string_view format(from, to); - if (const auto [found, where] = accepted(format); found) { - return *where; - } - } - - from = to; - } - - return acceptableMimeTypes[0]; -} - -bool respondWithError(auto &response, std::string_view message, int status = HTTP_ERROR) { - response.status = status; - response.set_content(message.data(), MIME::TEXT.typeName().data()); - return true; -}; - -inline bool respondWithServicesList(auto &broker, const httplib::Request &request, httplib::Response &response) { - // Mmi is not a MajordomoWorker, so it doesn't know JSON (TODO) - const auto acceptedFormat = acceptedMimeForRequest(request); - - if (acceptedFormat == MIME::JSON.typeName()) { - std::vector serviceNames; - // TODO: Should this be synchronized? - broker.forEachService([&](std::string_view name, std::string_view) { - serviceNames.emplace_back(name); - }); - response.status = HTTP_OK; - - opencmw::IoBuffer buffer; - // opencmw::serialise(buffer, serviceNames); - IoSerialiser::serialise(buffer, FieldDescriptionShort{}, serviceNames); - response.set_content(buffer.asString().data(), MIME::JSON.typeName().data()); - return true; - - } else if (acceptedFormat == MIME::HTML.typeName()) { - response.set_chunked_content_provider( - MIME::HTML.typeName().data(), - [&broker](std::size_t /*offset*/, httplib::DataSink &sink) { - ServicesList servicesList; - broker.forEachService([&](std::string_view name, std::string_view description) { - servicesList.services.emplace_back(std::string(name), std::string(description)); - }); - - // sort services, move mmi. services to the end - auto serviceLessThan = [](const auto &lhs, const auto &rhs) { - const auto lhsIsMmi = lhs.name.starts_with("/mmi."); - const auto rhsIsMmi = rhs.name.starts_with("/mmi."); - if (lhsIsMmi != rhsIsMmi) { - return rhsIsMmi; - } - return lhs.name < rhs.name; - }; - std::sort(servicesList.services.begin(), servicesList.services.end(), serviceLessThan); - - using namespace std::string_literals; - mustache::serialise(cmrc::assets::get_filesystem(), "ServicesList", sink.os, - std::pair{ "servicesList"s, servicesList }); - sink.done(); - return true; - }); - return true; - - } else { - return respondWithError(response, "Requested an unsupported response type"); - } -} - -struct Connection { - zmq::Socket notificationSubscriptionSocket; - zmq::Socket requestResponseSocket; - std::string subscriptionKey; - - using Timestamp = std::chrono::time_point; - std::atomic lastUsed = std::chrono::system_clock::now(); - -private: - mutable std::shared_mutex _cachedRepliesMutex; - std::deque _cachedReplies; // Ring buffer? - PollingIndex _nextPollingIndex = 0; - std::condition_variable_any _pollingIndexCV; - - std::atomic_int _refCount = 1; - - // Here be dragons! This is not to be used after the connection was involved in any threading code - Connection(Connection &&other) noexcept - : notificationSubscriptionSocket(std::move(other.notificationSubscriptionSocket)) - , requestResponseSocket(std::move(other.requestResponseSocket)) - , subscriptionKey(std::move(other.subscriptionKey)) - , lastUsed(other.lastUsed.load()) - , _cachedReplies(std::move(other._cachedReplies)) - , _nextPollingIndex(other._nextPollingIndex) { - } - -public: - Connection(const zmq::Context &context, std::string _subscriptionKey) - : notificationSubscriptionSocket(context, ZMQ_SUB) - , requestResponseSocket(context, ZMQ_DEALER) - , subscriptionKey(std::move(_subscriptionKey)) {} - - Connection(const Connection &other) = delete; - Connection &operator=(const Connection &) = delete; - Connection &operator=(Connection &&) = delete; - - // Here be dragons! This is not to be used after the connection was involved in any threading code - Connection unsafeMove() && { - return std::move(*this); - } - - const auto &referenceCount() const { return _refCount; } - - // Functions to block connection deletion - void increaseReferenceCount() { ++_refCount; } - void decreaseReferenceCount() { --_refCount; } - - struct KeepAlive { - Connection *_connection; - KeepAlive(Connection *connection) - : _connection(connection) { - _connection->increaseReferenceCount(); - } - - ~KeepAlive() { - _connection->decreaseReferenceCount(); - } - }; - - auto writeLock() { - return WriteLock(_cachedRepliesMutex); - } - - auto readLock() const { - return ReadLock(_cachedRepliesMutex); - } - - bool waitForUpdate(std::chrono::milliseconds timeout) { - // This could also periodically check for the client connection being dropped (e.g. due to client-side timeout) if cpp-httplib had API for that. - auto temporaryLock = writeLock(); - const auto next = _nextPollingIndex; - while (_nextPollingIndex == next) { - if (_pollingIndexCV.wait_for(temporaryLock, timeout) == std::cv_status::timeout) { - return false; - } - } - - return true; - } - - std::size_t cachedRepliesSize(ReadLock & /*lock*/) const { - return _cachedReplies.size(); - } - - std::string cachedReply(ReadLock & /*lock*/, PollingIndex index) const { - const auto firstCachedIndex = _nextPollingIndex - _cachedReplies.size(); - return (index >= firstCachedIndex && index < _nextPollingIndex) ? _cachedReplies[index - firstCachedIndex] : std::string{}; - } - - PollingIndex nextPollingIndex(ReadLock & /*lock*/) const { - return _nextPollingIndex; - } - - void addCachedReply(std::unique_lock & /*lock*/, std::string reply) { - _cachedReplies.push_back(std::move(reply)); - if (_cachedReplies.size() > MAX_CACHED_REPLIES) { - _cachedReplies.erase(_cachedReplies.begin(), _cachedReplies.begin() + long(_cachedReplies.size() - MAX_CACHED_REPLIES)); - } - - _nextPollingIndex++; - _pollingIndexCV.notify_all(); - } -}; - -} // namespace detail - -template -class RestBackend : public Mode { -protected: - Broker &_broker; - const VirtualFS &_vfs; - URI<> _restAddress; - std::atomic _majordomoTimeout = 30000ms; - -private: - std::jthread _mdpConnectionUpdaterThread; - std::shared_mutex _mdpConnectionsMutex; - std::map> _mdpConnectionForService; - -public: - /** - * Timeout used for interaction with majordomo workers, i.e. the time to wait - * for notifications on subscriptions (long-polling) and for responses to Get/Set - * requests. - */ - void setMajordomoTimeout(std::chrono::milliseconds timeout) { - _majordomoTimeout = timeout; - } - - std::chrono::milliseconds majordomoTimeout() const { - return _majordomoTimeout; - } - - using BrokerType = Broker; - // returns a connection with refcount 1. Make sure you lower it to zero at some point - detail::Connection *notificationSubscriptionConnectionFor(const std::string &zmqTopic) { - detail::WriteLock lock(_mdpConnectionsMutex); - // TODO: No need to find + emplace as separate steps - if (auto it = _mdpConnectionForService.find(zmqTopic); it != _mdpConnectionForService.end()) { - auto *connection = it->second.get(); - connection->increaseReferenceCount(); - return connection; - } - - auto [it, inserted] = _mdpConnectionForService.emplace(std::piecewise_construct, - std::forward_as_tuple(zmqTopic), - std::forward_as_tuple(std::make_unique(_broker.context, zmqTopic))); - - if (!inserted) { - assert(inserted); - std::terminate(); - } - - auto *connection = it->second.get(); - - zmq::invoke(zmq_connect, connection->notificationSubscriptionSocket, INTERNAL_ADDRESS_PUBLISHER.str()).template onFailure("Can not connect REST worker to Majordomo broker"); - zmq::invoke(zmq_setsockopt, connection->notificationSubscriptionSocket, ZMQ_SUBSCRIBE, zmqTopic.data(), zmqTopic.size()).assertSuccess(); - - return connection; - } - - // Starts the thread to keep the unused subscriptions alive - void startUpdaterThread() { - _mdpConnectionUpdaterThread = std::jthread([this](const std::stop_token &stopToken) { - thread::setThreadName("RestBackend updater thread"); - - std::vector connections; - std::vector pollItems; - while (!stopToken.stop_requested()) { - std::list keep; - { - // This is a long lock, alternatively, message reading could have separate locks per connection - detail::WriteLock lock(_mdpConnectionsMutex); - - // Expired subscriptions cleanup - std::vector expiredSubscriptions; - for (auto &[subscriptionKey, connection] : _mdpConnectionForService) { - // std::print("Reference count is {}\n", connection->referenceCount()); - if (connection->referenceCount() == 0) { - auto connectionLock = connection->writeLock(); - if (connection->referenceCount() != 0) { - continue; - } - if (std::chrono::system_clock::now() - connection->lastUsed.load() > UNUSED_SUBSCRIPTION_EXPIRATION_TIME) { - expiredSubscriptions.push_back(subscriptionKey); - } - } - } - for (const auto &subscriptionKey : expiredSubscriptions) { - _mdpConnectionForService.erase(subscriptionKey); - } - - // setup poller and socket data structures for all connections - const std::size_t connectionCount = _mdpConnectionForService.size(); - connections.resize(connectionCount); - pollItems.resize(connectionCount); - for (std::size_t i = 0UZ; auto &[key, connection] : _mdpConnectionForService) { - connections[i] = connection.get(); - keep.emplace_back(connection.get()); - pollItems[i].events = ZMQ_POLLIN; - pollItems[i].socket = connection->notificationSubscriptionSocket.zmq_ptr; - ++i; - } - } // finished copying local state, keep ensures that connections are kept alive, end of lock on _mdpConnectionsForService - - if (pollItems.empty()) { - std::this_thread::sleep_for(100ms); // prevent spinning on connection cleanup if there are no connections to poll on - continue; - } - - auto pollCount = zmq::invoke(zmq_poll, pollItems.data(), static_cast(pollItems.size()), std::chrono::duration_cast(UPDATER_POLLING_TIME).count()); - if (!pollCount) { - std::print("Error while polling for updates from the broker\n"); - std::terminate(); - } - if (pollCount.value() == 0) { - continue; - } - - // Reading messages - for (std::size_t i = 0; i < connections.size(); ++i) { - if (pollItems[i].revents & ZMQ_POLLIN) { - detail::Connection *currentConnection = connections[i]; - std::unique_lock connectionLock = currentConnection->writeLock(); - while (auto responseMessage = zmq::receive(currentConnection->notificationSubscriptionSocket)) { - currentConnection->addCachedReply(connectionLock, std::string(responseMessage->data.asString())); - } - } - } - } - }); - } - - struct RestWorker; - - RestWorker &workerForCurrentThread() { - thread_local static RestWorker worker(*this); - return worker; - } - - using Mode::_svr; - using Mode::DEFAULT_REST_SCHEME; - -public: - explicit RestBackend(Broker &broker, const VirtualFS &vfs, URI<> restAddress = URI<>::factory().scheme(DEFAULT_REST_SCHEME).hostName("0.0.0.0").port(DEFAULT_REST_PORT).build()) - : _broker(broker), _vfs(vfs), _restAddress(restAddress) { - _broker.registerDnsAddress(restAddress); - } - - virtual ~RestBackend() { - _svr.stop(); - // shutdown thread before _connectionForService is destroyed - _mdpConnectionUpdaterThread.request_stop(); - _mdpConnectionUpdaterThread.join(); - } - - auto handleServiceRequest(const httplib::Request &request, httplib::Response &response, const httplib::ContentReader *content_reader_ = nullptr) { - using detail::RestMethod; - - auto convertParams = [](const httplib::Params ¶ms) { - mdp::Topic::Params r; - for (const auto &[key, value] : params) { - if (key == "LongPollingIdx" || key == "SubscriptionContext") { - continue; - } - if (value.empty()) { - r[key] = std::nullopt; - } else { - r[key] = value; - } - } - return r; - }; - - std::optional maybeTopic; - - try { - maybeTopic = mdp::Topic::fromString(request.path, convertParams(request.params)); - } catch (const std::exception &e) { - return detail::respondWithError(response, std::format("Error: {}\n", e.what())); - } - auto topic = std::move(*maybeTopic); - - auto restMethod = [&] { - auto methodString = request.has_header("X-OPENCMW-METHOD") ? request.get_header_value("X-OPENCMW-METHOD") : request.method; - // clang-format off - return methodString == "SUB" ? RestMethod::Subscribe : - methodString == "POLL" ? RestMethod::LongPoll : - methodString == "PUT" ? RestMethod::Post : - methodString == "POST" ? RestMethod::Post : - methodString == "GET" ? RestMethod::Get : - RestMethod::Invalid; - // clang-format on - }(); - - for (const auto &[key, value] : request.params) { - if (key == "LongPollingIdx") { - // This parameter is not passed on, it just means we want to use long polling - restMethod = value == "Subscription" ? RestMethod::Subscribe : RestMethod::LongPoll; - } else if (key == "SubscriptionContext") { - topic = mdp::Topic::fromString(value, {}); // params are parsed from value - } - } - - if (restMethod == RestMethod::Invalid) { - return detail::respondWithError(response, "Error: Requested method is not supported\n"); - } - - auto &worker = workerForCurrentThread(); - - switch (restMethod) { - case RestMethod::Get: - case RestMethod::Post: - return worker.respondWithGetSet(request, response, topic, restMethod, content_reader_); - case RestMethod::LongPoll: - return worker.respondWithLongPoll(request, response, topic); - case RestMethod::Subscribe: - return worker.respondWithSubscription(response, topic); - default: - // std::unreachable() is C++23 - assert(false && "We have already checked that restMethod is not Invalid"); - return false; - } - } - - virtual void registerHandlers() { - _svr.Get("/", [this](const httplib::Request &request, httplib::Response &response) { - return detail::respondWithServicesList(_broker, request, response); - }); - - _svr.Options(".*", - [](const httplib::Request & /*req*/, httplib::Response &res) { - res.set_header("Allow", "GET, POST, PUT, OPTIONS"); - res.set_header("Access-Control-Allow-Origin", "*"); - res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, OPTIONS"); - res.set_header("Access-Control-Allow-Headers", "X-OPENCMW-METHOD,Content-Type"); - res.set_header("Access-Control-Max-Age", "86400"); - }); - - static const char *nonEmptyPath = "..*"; - _svr.Get(nonEmptyPath, [this](const httplib::Request &request, httplib::Response &response) { - return handleServiceRequest(request, response, nullptr); - }); - _svr.Post(nonEmptyPath, [this](const httplib::Request &request, httplib::Response &response, const httplib::ContentReader &content_reader) { - return handleServiceRequest(request, response, &content_reader); - }); - _svr.Put(nonEmptyPath, [this](const httplib::Request &request, httplib::Response &response, const httplib::ContentReader &content_reader) { - return handleServiceRequest(request, response, &content_reader); - }); - } - - void run() { - thread::setThreadName("RestBackend thread"); - - startUpdaterThread(); - - registerHandlers(); - - if (!_restAddress.hostName() || !_restAddress.port()) { - throw opencmw::startup_error(std::format("REST server URI is not valid {}", _restAddress.str())); - } - - _svr.set_tcp_nodelay(true); - _svr.set_keep_alive_max_count(1000); - _svr.set_read_timeout(1, 0); - _svr.set_write_timeout(1, 0); - _svr.new_task_queue = []() { return new httplib::ThreadPool(/*num_threads=*/32, /*max_queued_requests=*/20); }; - - bool listening = _svr.listen(_restAddress.hostName().value().data(), _restAddress.port().value()); - if (!listening) { - throw opencmw::startup_error(std::format("Can not start REST server on {}:{}", _restAddress.hostName().value().data(), _restAddress.port().value())); - } - } - - void requestStop() { - _svr.stop(); - } - - void shutdown() { - requestStop(); - } -}; - -template -struct RestBackend::RestWorker { - RestBackend &restBackend; - - zmq_pollitem_t pollItem{}; - - explicit RestWorker(RestBackend &rest) - : restBackend(rest) { - } - - RestWorker(RestWorker &&other) noexcept = default; - - detail::Connection connect() { - detail::Connection connection(restBackend._broker.context, {}); - pollItem.events = ZMQ_POLLIN; - - zmq::invoke(zmq_connect, connection.notificationSubscriptionSocket, INTERNAL_ADDRESS_PUBLISHER.str()).template onFailure("Can not connect REST worker to Majordomo broker"); - zmq::invoke(zmq_connect, connection.requestResponseSocket, INTERNAL_ADDRESS_BROKER.str()).template onFailure("Can not connect REST worker to Majordomo broker"); - - return std::move(connection).unsafeMove(); - } - - bool respondWithGetSet(const httplib::Request &request, httplib::Response &response, mdp::Topic topic, detail::RestMethod restMethod, const httplib::ContentReader *content_reader_ = nullptr) { - const mdp::Command mdpMessageCommand = restMethod == detail::RestMethod::Post ? mdp::Command::Set : mdp::Command::Get; - - auto uri = URI<>::factory(); - std::string bodyOverride; - std::string contentType; - int contentLength{ 0 }; - for (const auto &[key, value] : request.params) { - if (key == "_bodyOverride") { - bodyOverride = value; - } else if (key == "LongPollingIdx") { - // This parameter is not passed on, it just means we want to use long polling -- already handled - } else { - uri = std::move(uri).addQueryParameter(key, value); - } - } - - for (const auto &[key, value] : request.headers) { - if (httplib::detail::case_ignore::equal(key, "Content-Length")) { - contentLength = std::stoi(value); - } else if (httplib::detail::case_ignore::equal(key, "Content-Type")) { - contentType = value; - } - } - - mdp::Message message; - message.protocolName = mdp::clientProtocol; - message.command = mdpMessageCommand; - - const auto acceptedFormat = detail::acceptedMimeForRequest(request); - topic.addParam("contentType", acceptedFormat); - message.serviceName = std::string(topic.service()); - message.topic = topic.toMdpTopic(); - - if (request.is_multipart_form_data()) { - if (content_reader_ != nullptr) { - const auto &content_reader = *content_reader_; - - FormData formData; - auto it = formData.fields.begin(); - - content_reader( - [&](const httplib::MultipartFormData &file) { - it = formData.fields.emplace(file.name, "").first; - return true; - }, - [&](const char *data, std::size_t data_length) { - // TODO: This should append to content - it->second.append(std::string_view(data, data_length)); - - return true; - }); - - opencmw::IoBuffer buffer; - // opencmw::serialise(buffer, formData); - IoSerialiser::serialise(buffer, FieldDescriptionShort{}, formData.fields); - - std::string requestData(buffer.asString()); - // Json serialiser (rightfully) does not like bool values in string - static auto replacerRegex = std::regex(R"regex("opencmw_unquoted_value[(](.*)[)]")regex"); - requestData = std::regex_replace(requestData, replacerRegex, "$1"); - auto requestDataBegin = std::find(requestData.cbegin(), requestData.cend(), '{'); - - const auto req = std::string_view(requestDataBegin, requestData.cend()); - message.data = IoBuffer(req.data(), req.size()); - } - } else if (!bodyOverride.empty()) { - message.data = IoBuffer(bodyOverride.data(), bodyOverride.size()); - } else if (!request.body.empty()) { - message.data = IoBuffer(request.body.data(), request.body.size()); - } else if (contentType != "" && contentLength > 0 && content_reader_ != nullptr) { - std::string body; - (*content_reader_)([&body](const char *data, size_t datalength) { body = std::string{data, datalength}; return true; }); - message.data = IoBuffer(body.data(), body.size()); - } - - auto connection = connect(); - - if (!zmq::send(std::move(message), connection.requestResponseSocket)) { - return detail::respondWithError(response, "Error: Failed to send a message to the broker\n"); - } - - // blocks waiting for the response - pollItem.socket = connection.requestResponseSocket.zmq_ptr; - auto pollResult = zmq::invoke(zmq_poll, &pollItem, 1, std::chrono::duration_cast(restBackend.majordomoTimeout()).count()); - if (!pollResult || pollResult.value() == 0) { - detail::respondWithError(response, "Error: No response from broker\n", HTTP_GATEWAY_TIMEOUT); - } else if (auto responseMessage = zmq::receive(connection.requestResponseSocket); !responseMessage) { - detail::respondWithError(response, "Error: Empty response from broker\n"); - } else if (!responseMessage->error.empty()) { - detail::respondWithError(response, responseMessage->error); - } else { - response.status = HTTP_OK; - - response.set_header("X-OPENCMW-TOPIC", responseMessage->topic.str().data()); - response.set_header("X-OPENCMW-SERVICE-NAME", responseMessage->serviceName.data()); - response.set_header("Access-Control-Allow-Origin", "*"); - response.set_header("X-TIMESTAMP", std::format("{}", std::chrono::duration_cast(std::chrono::high_resolution_clock::now().time_since_epoch()).count())); - const auto data = responseMessage->data.asString(); - - if (request.method != "GET") { - response.set_content(data.data(), data.size(), MIME::TEXT.typeName().data()); - } else { - response.set_content(data.data(), data.size(), acceptedFormat.data()); - } - } - return true; - } - - bool respondWithSubscription(httplib::Response &response, const mdp::Topic &subscription) { - const auto subscriptionKey = subscription.toZmqTopic(); - auto *connection = restBackend.notificationSubscriptionConnectionFor(subscriptionKey); - assert(connection); - const auto majordomoTimeout = restBackend.majordomoTimeout(); - response.set_header("Access-Control-Allow-Origin", "*"); - response.set_header("X-TIMESTAMP", std::format("{}", std::chrono::duration_cast(std::chrono::high_resolution_clock::now().time_since_epoch()).count())); - - response.set_chunked_content_provider( - "application/json", - [connection, majordomoTimeout](std::size_t /*offset*/, httplib::DataSink &sink) mutable { - std::cerr << "Chunked reply...\n"; - - if (!connection->waitForUpdate(majordomoTimeout)) { - return false; - } - - auto connectionCacheLock = connection->readLock(); - auto lastIndex = connection->nextPollingIndex(connectionCacheLock) - 1; - const auto &lastReply = connection->cachedReply(connectionCacheLock, lastIndex); - std::cerr << "Chunk: " << lastIndex << "'" << lastReply << "'\n"; - - sink.os << lastReply << "\n\n"; - - return true; - }, - [connection](bool) { - connection->decreaseReferenceCount(); - }); - - return true; - } - - bool respondWithLongPollRedirect(const httplib::Request &request, httplib::Response &response, const mdp::Topic &subscription, detail::PollingIndex redirectLongPollingIdx) { - auto uri = URI<>::factory() - .path(request.path) - .addQueryParameter("LongPollingIdx", std::to_string(redirectLongPollingIdx)); - - // copy over the original query parameters - addParameters(request, uri); - - const auto redirect = uri.toString(); - response.set_redirect(redirect); - return true; - } - - bool respondWithLongPoll(const httplib::Request &request, httplib::Response &response, const mdp::Topic &subscription) { - // TODO: After the URIs are formalized, rethink service and topic - auto uri = URI<>::factory(); - addParameters(request, uri); - - const auto subscriptionKey = subscription.toZmqTopic(); - - const auto longPollingIdxIt = request.params.find("LongPollingIdx"); - if (longPollingIdxIt == request.params.end()) { - return detail::respondWithError(response, "Error: LongPollingIdx parameter not specified"); - } - - const auto &longPollingIdxParam = longPollingIdxIt->second; - - struct CacheInfo { - detail::PollingIndex firstCachedIndex = 0; - detail::PollingIndex nextPollingIndex = 0; - detail::Connection *connection = nullptr; - }; - auto fetchCache = [this, &subscriptionKey] { - std::shared_lock lock(restBackend._mdpConnectionsMutex); - auto &recycledConnectionForService = restBackend._mdpConnectionForService; - if (auto it = recycledConnectionForService.find(subscriptionKey); it != recycledConnectionForService.cend()) { - auto *connectionCache = it->second.get(); - detail::Connection::KeepAlive keep(connectionCache); - connectionCache->lastUsed = std::chrono::system_clock::now(); - auto connectionCacheLock = connectionCache->readLock(); - return CacheInfo{ - .firstCachedIndex = connectionCache->nextPollingIndex(connectionCacheLock) - connectionCache->cachedRepliesSize(connectionCacheLock), - .nextPollingIndex = connectionCache->nextPollingIndex(connectionCacheLock), - .connection = connectionCache - }; - } else { - // We didn't have this before, means 0 is the next index - return CacheInfo{ - .firstCachedIndex = 0, - .nextPollingIndex = 0, - .connection = nullptr - }; - } - }; - - detail::PollingIndex requestedLongPollingIdx = 0; - - // Hoping we already have the requested value in the cache. Holding this caches blocks all cache entries, so no further updates can be received or other connections initiated. - { - const auto cache = fetchCache(); - response.set_header("Access-Control-Allow-Origin", "*"); - - if (longPollingIdxParam == "Next") { - return respondWithLongPollRedirect(request, response, subscription, cache.nextPollingIndex); - } - - if (longPollingIdxParam == "Last") { - if (cache.connection != nullptr) { - return respondWithLongPollRedirect(request, response, subscription, cache.nextPollingIndex - 1); - } else { - return respondWithLongPollRedirect(request, response, subscription, cache.nextPollingIndex); - } - } - - if (longPollingIdxParam == "FirstAvailable") { - return respondWithLongPollRedirect(request, response, subscription, cache.firstCachedIndex); - } - - if (std::from_chars(longPollingIdxParam.data(), longPollingIdxParam.data() + longPollingIdxParam.size(), requestedLongPollingIdx).ec != std::errc{}) { - return detail::respondWithError(response, "Error: Invalid LongPollingIdx value"); - } - - if (requestedLongPollingIdx > cache.nextPollingIndex) { - return detail::respondWithError(response, "Error: LongPollingIdx tries to read the future"); - } - - if (requestedLongPollingIdx < cache.firstCachedIndex || requestedLongPollingIdx + 15 < cache.nextPollingIndex) { - return respondWithLongPollRedirect(request, response, subscription, cache.nextPollingIndex); - } - - if (cache.connection && requestedLongPollingIdx < cache.nextPollingIndex) { - auto connectionCacheLock = cache.connection->readLock(); - // The result is already ready - response.set_content(cache.connection->cachedReply(connectionCacheLock, requestedLongPollingIdx), MIME::BINARY.typeName().data()); - return true; - } - } - - // Fallback to creating a connection and waiting - auto *connection = restBackend.notificationSubscriptionConnectionFor(subscriptionKey); - assert(connection); - detail::Connection::KeepAlive keep(connection); - - // Since we use KeepAlive object, the initial refCount can go away - connection->decreaseReferenceCount(); - - if (!connection->waitForUpdate(restBackend.majordomoTimeout())) { - return detail::respondWithError(response, "Timeout waiting for update", HTTP_GATEWAY_TIMEOUT); - } - - const auto newCache = fetchCache(); - - // This time it needs to exist - assert(newCache.connection != nullptr); - - if (requestedLongPollingIdx >= newCache.firstCachedIndex && requestedLongPollingIdx < newCache.nextPollingIndex) { - auto connectionCacheLock = newCache.connection->readLock(); - response.set_content(newCache.connection->cachedReply(connectionCacheLock, requestedLongPollingIdx), MIME::BINARY.typeName().data()); - return true; - } else { - return detail::respondWithError(response, "Error: We waited for the new value, but it was not found"); - } - } - -private: - void addParameters(const httplib::Request &request, URI<>::UriFactory &uri) { - for (const auto &[key, value] : request.params) { - if (key == "LongPollingIdx") { - // This parameter is not passed on, it just means we want to use long polling -- already handled - } else { - uri = std::move(uri).addQueryParameter(key, value); - } - } - } -}; - -} // namespace opencmw::majordomo - -#endif // include guard diff --git a/src/majordomo/include/majordomo/RestServer.hpp b/src/majordomo/include/majordomo/RestServer.hpp new file mode 100644 index 00000000..3830bafd --- /dev/null +++ b/src/majordomo/include/majordomo/RestServer.hpp @@ -0,0 +1,3083 @@ +#ifndef OPENCMW_MAJORDOMO_RESTSERVER_HPP +#define OPENCMW_MAJORDOMO_RESTSERVER_HPP + +#include "IoBuffer.hpp" +#include "LoadTest.hpp" +#include "MdpMessage.hpp" +#include "MIME.hpp" +#include "NgTcp2Util.hpp" +#include "Rest.hpp" +#include "rest/RestUtils.hpp" +#include "TlsServerSession_Ossl.hpp" +#include "Topic.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace opencmw::majordomo::detail::rest { + +#ifdef OPENCMW_DEBUG_HTTP +inline void print_debug2(const char *format, va_list args) { + vfprintf(stderr, format, args); +} + +inline void print_debug(void * /*user_data*/, const char *fmt, ...) { + va_list ap; + std::array buf; + + va_start(ap, fmt); + auto n = vsnprintf(buf.data(), buf.size(), fmt, ap); + va_end(ap); + + if (static_cast(n) >= buf.size()) { + n = buf.size() - 1; + } + + buf[n++] = '\n'; + + while (write(fileno(stderr), buf.data(), n) == -1 && errno == EINTR); +} +#endif + +constexpr size_t NGTCP2_SV_SCIDLEN = 18; + +constexpr size_t max_preferred_versionslen = 4; + +constexpr size_t NGTCP2_STATELESS_RESET_BURST = 100; + +// Endpoint is a local endpoint. +struct Endpoint { + Address addr; + int fd; +}; + +struct Buffer { + Buffer(const std::uint8_t *data, std::size_t datalen) + : buf{ data, data + datalen }, begin(buf.data()), tail(begin + datalen) {} + explicit Buffer(std::size_t datalen) + : buf(datalen), begin(buf.data()), tail(begin) {} + + std::size_t size() const { return static_cast(tail - begin); } + std::size_t left() const { return static_cast(buf.data() + buf.size() - tail); } + std::uint8_t *wpos() { return tail; } + std::span data() const { return { begin, size() }; } + void push(std::size_t len) { tail += len; } + void reset() { tail = begin; } + + std::vector buf; + // begin points to the beginning of the buffer. This might point to + // buf.data() if a buffer space is allocated by this object. It is + // also allowed to point to the external shared buffer. + std::uint8_t *begin; + // tail points to the position of the buffer where write should + // occur. + std::uint8_t *tail; +}; + +inline std::expected create_sock(Address &local_addr, std::string_view addr, std::string_view port, int family) { + addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_flags = AI_PASSIVE; + hints.ai_family = family; + hints.ai_socktype = SOCK_DGRAM; + + addrinfo *res, *rp; + int val = 1; + + auto paddr = addr == "*" ? nullptr : addr.data(); + + if (auto rv = getaddrinfo(paddr, port.data(), &hints, &res); rv != 0) { + return std::unexpected(std::format("getaddrinfo: {}", gai_strerror(rv))); + } + + auto res_d = defer(freeaddrinfo, res); + + int fd = -1; + + for (rp = res; rp; rp = rp->ai_next) { + fd = create_nonblock_socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + if (fd == -1) { + continue; + } + + if (rp->ai_family == AF_INET6) { + if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &val, static_cast(sizeof(val))) == -1) { + ::close(fd); + continue; + } + + if (setsockopt(fd, IPPROTO_IPV6, IPV6_RECVPKTINFO, &val, static_cast(sizeof(val))) == -1) { + ::close(fd); + continue; + } + } else if (setsockopt(fd, IPPROTO_IP, IP_PKTINFO, &val, static_cast(sizeof(val))) == -1) { + ::close(fd); + continue; + } + + if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &val, static_cast(sizeof(val))) == -1) { + ::close(fd); + continue; + } + + fd_set_recv_ecn(fd, rp->ai_family); + fd_set_ip_mtu_discover(fd, rp->ai_family); + fd_set_ip_dontfrag(fd, family); + + if (bind(fd, rp->ai_addr, rp->ai_addrlen) != -1) { + break; + } + + ::close(fd); + } + + if (!rp) { + return std::unexpected(std::format("Could not bind to address {}:{}", addr, port)); + } + + socklen_t len = sizeof(local_addr.su.storage); + if (getsockname(fd, &local_addr.su.sa, &len) == -1) { + ::close(fd); + return std::unexpected(std::format("getsockname: {}", strerror(errno))); + } + local_addr.len = len; + local_addr.ifindex = 0; + + return fd; +} + +inline uint32_t generate_reserved_version(const sockaddr *sa, socklen_t salen, + uint32_t version) { + uint32_t h = 0x811C9DC5u; + const uint8_t *p = reinterpret_cast(sa); + const uint8_t *ep = p + salen; + for (; p != ep; ++p) { + h ^= *p; + h *= 0x01000193u; + } + version = htonl(version); + p = reinterpret_cast(&version); + ep = p + sizeof(version); + for (; p != ep; ++p) { + h ^= *p; + h *= 0x01000193u; + } + h &= 0xf0f0f0f0u; + h |= 0x0a0a0a0au; + return h; +} + +using namespace opencmw::rest::detail; + +inline int alpn_select_proto_cb(SSL * /*ssl*/, const unsigned char **out, unsigned char *outlen, const unsigned char *in, unsigned int inlen, void * /*arg*/) { + static const unsigned char supported_alpns[] = { 2, 'h', '2', 2, 'h', '3' }; + + HTTP_DBG("Client-supported ALPN protocols:"); + for (unsigned int i = 0; i < inlen;) { + unsigned char proto_len = in[i]; + std::string_view proto(reinterpret_cast(&in[i + 1]), proto_len); + HTTP_DBG(" - {}", proto); + i += 1 + proto_len; + } + + int ret = SSL_select_next_proto(const_cast(out), outlen, supported_alpns, sizeof(supported_alpns), in, inlen); + + HTTP_DBG("alpn_select_proto_cb: Selected ALPN: {}", ret == OPENSSL_NPN_NEGOTIATED ? std::string_view(reinterpret_cast(*out), *outlen) : "none"); + if (ret != OPENSSL_NPN_NEGOTIATED) { + return SSL_TLSEXT_ERR_NOACK; + } + + return SSL_TLSEXT_ERR_OK; +} + +inline std::expected create_ssl_ctx(EVP_PKEY *key, X509 *cert) { + auto ssl_ctx = SSL_CTX_Ptr(SSL_CTX_new(TLS_server_method()), SSL_CTX_free); + if (!ssl_ctx) { + return std::unexpected(std::format("Could not create SSL/TLS context: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + SSL_CTX_set_options(ssl_ctx.get(), SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + if (SSL_CTX_set1_curves_list(ssl_ctx.get(), "P-256") != 1) { + return std::unexpected(std::format("SSL_CTX_set1_curves_list failed: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + + if (SSL_CTX_use_PrivateKey(ssl_ctx.get(), key) <= 0) { + return std::unexpected(std::format("Could not configure private key")); + } + if (SSL_CTX_use_certificate(ssl_ctx.get(), cert) != 1) { + return std::unexpected(std::format("Could not configure certificate file")); + } + + if (!SSL_CTX_check_private_key(ssl_ctx.get())) { + return std::unexpected("Private key does not match the certificate"); + } + + SSL_CTX_set_alpn_select_cb(ssl_ctx.get(), alpn_select_proto_cb, nullptr); + + return ssl_ctx; +} + +using Message = mdp::BasicMessage; + +enum class RestMethod { + Get, + LongPoll, + Post, + Invalid +}; + +inline RestMethod parseMethod(std::string_view methodString) { + using enum RestMethod; + return methodString == "PUT" ? Post + : methodString == "POST" ? Post + : methodString == "GET" ? Get + : Invalid; +} + +struct Request { + std::vector> rawHeaders; + mdp::Topic topic; + RestMethod method = RestMethod::Invalid; + std::string longPollIndex; + std::string contentType; + std::string accept; + std::string payload; + bool complete = false; + + std::string_view acceptedMime() const { + static constexpr auto acceptableMimeTypes = std::array{ + opencmw::MIME::JSON.typeName(), MIME::HTML.typeName(), MIME::BINARY.typeName() + }; + auto accepted = [](auto format) { + const auto it = std::find(acceptableMimeTypes.begin(), acceptableMimeTypes.end(), format); + return std::make_pair( + it != acceptableMimeTypes.cend(), + it); + }; + + if (!contentType.empty()) { + if (const auto [found, where] = accepted(contentType); found) { + return *where; + } + } + if (auto it = topic.params().find("contentType"); it != topic.params().end()) { + if (const auto [found, where] = accepted(it->second); found) { + return *where; + } + } + + auto isDelimiter = [](char c) { return c == ' ' || c == ','; }; + auto from = accept.cbegin(); + const auto end = accept.cend(); + + while (from != end) { + from = std::find_if_not(from, end, isDelimiter); + auto to = std::find_if(from, end, isDelimiter); + if (from != end) { + std::string_view format(from, to); + if (const auto [found, where] = accepted(format); found) { + return *where; + } + } + + from = to; + } + + return acceptableMimeTypes[0]; + } +}; + +struct BodyReaderChunk { + std::array buffer; + std::size_t unacked = 0; + std::size_t offset = 0; + + std::span data() { + return { buffer.data() + offset, buffer.size() - offset }; + } + + void bytesAdded(std::size_t len) { + unacked += len; + offset += len; + } + + void clear() { + unacked = 0; + offset = 0; + } + + bool full() const { + return offset == buffer.size(); + } + + BodyReaderChunk() = default; + BodyReaderChunk &operator=(BodyReaderChunk &&other) noexcept = delete; + BodyReaderChunk(BodyReaderChunk &&other) noexcept = delete; + BodyReaderChunk(const BodyReaderChunk &) = delete; + BodyReaderChunk &operator=(const BodyReaderChunk &) = delete; +}; + +struct IdGenerator { + std::uint64_t _nextRequestId = 0; + + std::uint64_t generateId() { + return _nextRequestId++; + } +}; + +struct SubscriptionCacheEntry { + constexpr static std::size_t kCapacity = 100; + std::uint64_t firstIndex = 0; + std::deque messages; + + void add(Message &&message) { + if (messages.size() == kCapacity) { + messages.pop_front(); + firstIndex++; + } + messages.push_back(std::move(message)); + } + std::uint64_t lastIndex() const noexcept { + assert(!messages.empty()); + return firstIndex + messages.size() - 1; + } + std::uint64_t nextIndex() const noexcept { + return firstIndex + messages.size(); + } +}; + +struct SharedData { + std::string _altSvcHeaderValue; + nghttp2_nv _altSvcHeader; + std::map _subscriptionCache; + std::vector _handlers; + std::deque> _bodyReaderChunks; + + const std::array _static_secret = [] { + std::array secret{}; + generate_secret(secret); + return secret; + }(); + + majordomo::rest::Handler *findHandler(std::string_view method, std::string_view path) { + auto bestMatch = _handlers.end(); + + for (auto itHandler = _handlers.begin(); itHandler != _handlers.end(); ++itHandler) { + if (itHandler->method != method) { + continue; + } + + std::string_view handlerPath = itHandler->path; + + if (handlerPath == path) { + // exact match, use this handler + return &*itHandler; + } + // if the handler path ends with '*', do a prefix check and use the most specific (longest) one + if (handlerPath.ends_with("*")) { + handlerPath.remove_suffix(1); + if (path.starts_with(handlerPath) && (bestMatch == _handlers.end() || bestMatch->path.size() < itHandler->path.size())) { + bestMatch = itHandler; + } + } + } + + return bestMatch != _handlers.end() ? &*bestMatch : nullptr; + } + + std::unique_ptr acquireChunk() { + if (!_bodyReaderChunks.empty()) { + auto chunk = std::move(_bodyReaderChunks.back()); + _bodyReaderChunks.pop_back(); + return chunk; + } + return std::make_unique(); + } + + void releaseChunk(std::unique_ptr chunk) { + chunk->clear(); + _bodyReaderChunks.push_back(std::move(chunk)); + } +}; + +struct ResponseData { + explicit ResponseData(std::shared_ptr sharedData_, Message &&m) + : sharedData(std::move(sharedData_)) + , message(std::move(m)) + , errorBuffer(message.error.data(), message.error.size()) {} + + explicit ResponseData(std::shared_ptr sharedData_, majordomo::rest::Response &&r) + : sharedData(std::move(sharedData_)) + , restResponse(std::move(r)) {} + + std::shared_ptr sharedData; + Message message; + IoBuffer errorBuffer; + + majordomo::rest::Response restResponse; + // Http3-specific (TODO: try to unify with Http2) + // Used when streaming data from restResponse.bodyReader + std::deque> bodyReaderChunks; + // Used when streaming data from an IoBuffer + IoBuffer *bodyBuffer = nullptr; +}; + +constexpr int kHttpOk = 200; +constexpr int kHttpError = 500; +constexpr int kFileNotFound = 404; + +template +struct SessionBase { + using PendingRequest = std::tuple; // requestId, streamId + using PendingPoll = std::tuple; // zmqTopic, PollingIndex, streamId + std::map _requestsByStreamId; + std::map _responsesByStreamId; + std::vector _pendingRequests; + std::vector _pendingPolls; + std::shared_ptr _sharedData; + + explicit SessionBase(std::shared_ptr sharedData) + : _sharedData(std::move(sharedData)) { + } + + SessionBase(const SessionBase &) = delete; + SessionBase &operator=(const SessionBase &) = delete; + SessionBase(SessionBase &&other) = delete; + SessionBase &operator=(SessionBase &&other) = delete; + + [[nodiscard]] constexpr auto &self() noexcept { return *static_cast(this); } + [[nodiscard]] constexpr const auto &self() const noexcept { return *static_cast(this); } + + std::optional processGetSetRequest(TStreamId streamId, Request &request, IdGenerator &idGenerator) { + Message result; + request.topic.addParam("contentType", request.acceptedMime()); + result.command = request.method == RestMethod::Get ? mdp::Command::Get : mdp::Command::Set; + result.serviceName = request.topic.service(); + result.topic = request.topic.toMdpTopic(); + result.data = IoBuffer(request.payload.data(), request.payload.size()); + auto id = idGenerator.generateId(); + result.clientRequestID = IoBuffer(std::to_string(id).data(), std::to_string(id).size()); + _pendingRequests.emplace_back(id, streamId); + return result; + }; + + void addData(TStreamId streamId, std::string_view data) { + _requestsByStreamId[streamId].payload += data; + } + + void addHeader(TStreamId streamId, std::string_view name, std::string_view value) { + const auto [it, inserted] = _requestsByStreamId.try_emplace(streamId, Request{}); + auto &request = it->second; + request.rawHeaders.emplace_back(name, value); +#ifdef OPENCMW_PROFILE_HTTP + if (name == "x-timestamp") { + std::println(std::cerr, "x-timestamp: {} (latency {} ns)", value, opencmw::rest::detail::latency(value).count()); + } +#endif + } + + void respondToLongPoll(TStreamId streamId, std::uint64_t index, Message &&msg) { + auto timestamp = std::to_string(opencmw::load_test::timestamp().count()); + self().sendResponse(streamId, kHttpOk, std::move(msg), { nv(u8span("x-opencmw-long-polling-idx"), u8span(std::to_string(index))), nv(u8span("x-timestamp"), u8span(timestamp)) }); + } + + void respondToLongPollWithError(TStreamId streamId, std::string_view error, int code, std::uint64_t index) { + Message response = {}; + response.error = std::string(error); + self().sendResponse(streamId, code, std::move(response), { nv(u8span("x-opencmw-long-polling-idx"), u8span(std::to_string(index))) }); + } + + void respondWithError(TStreamId streamId, std::string_view error, int code = kHttpError, std::vector extraHeaders = {}) { + Message response = {}; + response.error = std::string(error); + self().sendResponse(streamId, code, std::move(response), std::move(extraHeaders)); + } + + void respondWithLongPollingRedirect(TStreamId streamId, const URI<> &topic, std::size_t longPollIdx) { + auto location = URI<>::UriFactory(topic).addQueryParameter("LongPollingIdx", std::to_string(longPollIdx)).build(); + self().respondWithRedirect(streamId, location.str()); + } + + std::optional processLongPollRequest(TStreamId streamId, const Request &request) { + std::optional result; + const auto zmqTopic = request.topic.toZmqTopic(); + auto entryIt = _sharedData->_subscriptionCache.find(zmqTopic); + if (entryIt == _sharedData->_subscriptionCache.end()) { + entryIt = _sharedData->_subscriptionCache.try_emplace(zmqTopic, SubscriptionCacheEntry{}).first; + result = Message{}; + result->command = mdp::Command::Subscribe; + result->topic = request.topic.toMdpTopic(); + } + auto &entry = entryIt->second; + if (request.longPollIndex == "Next") { + respondWithLongPollingRedirect(streamId, request.topic.toMdpTopic(), entry.nextIndex()); + return result; + } else if (request.longPollIndex == "Last") { + const std::size_t last = entry.messages.empty() ? entry.nextIndex() : entry.lastIndex(); + respondWithLongPollingRedirect(streamId, request.topic.toMdpTopic(), last); + return result; + } + + std::uint64_t index = 0; + if (auto [ptr, ec] = std::from_chars(request.longPollIndex.data(), request.longPollIndex.data() + request.longPollIndex.size(), index); ec != std::errc()) { + respondWithError(streamId, std::format("Malformed LongPollingIdx '{}'", request.longPollIndex)); + return {}; + } + +#ifdef OPENCMW_PROFILE_HTTP + if (index % 100 == 0) { + const std::size_t last = entry.messages.empty() ? entry.nextIndex() : entry.lastIndex(); + std::println(std::cerr, "{}: {}; delta: {}", zmqTopic, index, index <= last ? (last - index) : 0); + } +#endif + + if (index < entry.firstIndex) { + // index is too old, redirect to the next index + HTTP_DBG("Server::LongPoll: index {} < firstIndex {}", index, entry.firstIndex); + respondWithLongPollingRedirect(streamId, request.topic.toMdpTopic(), entry.nextIndex()); + } else if (entry.messages.empty() || index > entry.lastIndex()) { + // future index, wait for new messages + _pendingPolls.emplace_back(zmqTopic, index, streamId); + } else { + // we have a message for this index, send it + respondToLongPoll(streamId, index, Message(entry.messages[index - entry.firstIndex])); + } + return result; + } + + void processCompletedRequest(TStreamId streamId) { + auto it = _requestsByStreamId.find(streamId); + assert(it != _requestsByStreamId.end()); + auto &[streamid, request] = *it; + + std::string path; + std::string_view method; + + for (const auto &[name, value] : request.rawHeaders) { + if (name == ":path") { + path = value; + } else if (name == ":method") { + method = value; + } else if (name == "content-type") { + request.contentType = value; + } else if (name == "accept") { + request.accept = value; + } + } + + // if we have an externally configured handler for this method/path, use it + if (auto handler = _sharedData->findHandler(method, path); handler) { + majordomo::rest::Request req; + req.method = method; + req.path = path; + std::swap(req.headers, request.rawHeaders); + auto response = handler->handler(req); + self().sendResponse(streamId, std::move(response)); + _requestsByStreamId.erase(it); + return; + } + + // redirect "/" request to "/mmi.service" + if (path == "/") { + path = "/mmi.service"; + } + + // Everything else is a service request + try { + auto pathUri = URI<>(path); + auto factory = URI<>::UriFactory(pathUri).setQuery({}); + bool haveSubscriptionContext = false; + for (const auto &[qkey, qvalue] : pathUri.queryParamMap()) { + if (qkey == "LongPollingIdx") { + request.method = RestMethod::LongPoll; + request.longPollIndex = qvalue.value_or(""); + } else if (qkey == "SubscriptionContext") { + request.topic = mdp::Topic::fromMdpTopic(URI<>(qvalue.value_or(""))); + haveSubscriptionContext = true; + } else { + if (qvalue) { + factory = std::move(factory).addQueryParameter(qkey, qvalue.value()); + } else { + factory = std::move(factory).addQueryParameter(qkey); + } + } + } + if (!haveSubscriptionContext) { + request.topic = mdp::Topic::fromMdpTopic(factory.build()); + } + } catch (const std::exception &e) { + HTTP_DBG("Service::Header: Could not parse service URI '{}': {}", path, e.what()); + Message response; + response.error = e.what(); + self().sendResponse(streamId, kFileNotFound, std::move(response)); + _requestsByStreamId.erase(it); + return; + } + + if (request.method == RestMethod::Invalid) { + request.method = parseMethod(method); + } + + // Set completed for getMessages() to collect + request.complete = true; + } + + std::vector getMessages(IdGenerator &idGenerator) { + const auto completeEnd = std::ranges::partition_point(_requestsByStreamId, [](const auto &pair) { return pair.second.complete; }); + + std::vector result; + result.reserve(static_cast(std::distance(_requestsByStreamId.begin(), completeEnd))); + + for (auto it = _requestsByStreamId.begin(); it != completeEnd; ++it) { + auto &[streamId, request] = *it; + + switch (request.method) { + case RestMethod::Get: + case RestMethod::Post: + if (auto m = processGetSetRequest(streamId, request, idGenerator); m.has_value()) { + result.push_back(std::move(m.value())); + } + break; + case RestMethod::LongPoll: + if (auto m = processLongPollRequest(streamId, request); m.has_value()) { + result.push_back(std::move(m.value())); + } + break; + case RestMethod::Invalid: + respondWithError(it->first, "Invalid REST method", kHttpError); + break; + } + } + + _requestsByStreamId.erase(_requestsByStreamId.begin(), completeEnd); + return result; + } + + void handleNotification(std::string_view zmqTopic, std::uint64_t index, const Message &msg) { + auto pollIt = _pendingPolls.begin(); + while (pollIt != _pendingPolls.end()) { + const auto &[pendingZmqTopic, pollIndex, streamId] = *pollIt; + if (pendingZmqTopic == zmqTopic && index == pollIndex) { + respondToLongPoll(streamId, pollIndex, Message(msg)); + pollIt = _pendingPolls.erase(pollIt); + } else { + ++pollIt; + } + } + } +}; + +struct Http2Session : public SessionBase { + TcpSocket _socket; + WriteBuffer<4096> _writeBuffer; + nghttp2_session *_session = nullptr; + + explicit Http2Session(TcpSocket &&socket, std::shared_ptr sharedData) + : SessionBase(std::move(sharedData)), _socket{ std::move(socket) } { + nghttp2_session_callbacks *callbacks; + nghttp2_session_callbacks_new(&callbacks); + nghttp2_session_callbacks_set_on_frame_recv_callback(callbacks, [](nghttp2_session *, const nghttp2_frame *frame, void *user_data) { + auto session = static_cast(user_data); + return session->frame_recv_callback(frame); + }); + nghttp2_session_callbacks_set_on_frame_send_callback(callbacks, [](nghttp2_session *, const nghttp2_frame *frame, void *user_data) { + auto session = static_cast(user_data); + return session->frame_send_callback(frame); + }); + nghttp2_session_callbacks_set_on_frame_not_send_callback(callbacks, [](nghttp2_session *, const nghttp2_frame *frame, int lib_error_code, void *user_data) { + auto session = static_cast(user_data); + return session->frame_not_send_callback(frame, lib_error_code); + }); + nghttp2_session_callbacks_set_on_data_chunk_recv_callback(callbacks, [](nghttp2_session *, uint8_t flags, int32_t stream_id, const uint8_t *data, size_t len, void *user_data) { + auto session = static_cast(user_data); + return session->data_chunk_recv_callback(flags, stream_id, { reinterpret_cast(data), len }); + }); + nghttp2_session_callbacks_set_on_stream_close_callback(callbacks, [](nghttp2_session *, int32_t stream_id, uint32_t error_code, void *user_data) { + auto session = static_cast(user_data); + return session->stream_closed_callback(stream_id, error_code); + }); + nghttp2_session_callbacks_set_on_header_callback2(callbacks, [](nghttp2_session *, const nghttp2_frame *frame, nghttp2_rcbuf *name, + nghttp2_rcbuf *value, uint8_t flags, void *user_data) { + auto session = static_cast(user_data); + return session->header_callback(frame, as_view(name), as_view(value), flags); + }); + nghttp2_session_callbacks_set_on_invalid_frame_recv_callback(callbacks, [](nghttp2_session *, const nghttp2_frame *frame, int lib_error_code, void *user_data) { + auto session = static_cast(user_data); + return session->invalid_frame_recv_callback(frame, lib_error_code); + }); + nghttp2_session_callbacks_set_error_callback2(callbacks, [](nghttp2_session *, int lib_error_code, const char *msg, size_t len, void *user_data) { + auto session = static_cast(user_data); + return session->error_callback(lib_error_code, msg, len); + }); + nghttp2_session_server_new(&_session, callbacks, this); + nghttp2_session_callbacks_del(callbacks); + } + + ~Http2Session() { + nghttp2_session_del(_session); + } + + Http2Session(const Http2Session &) = delete; + Http2Session &operator=(const Http2Session &) = delete; + Http2Session(Http2Session &&other) = delete; + Http2Session &operator=(Http2Session &&other) = delete; + + bool wantsToRead() const { + return _socket._state == TcpSocket::Connected ? nghttp2_session_want_read(_session) : (_socket._state == TcpSocket::SSLAcceptWantsRead); + } + + bool wantsToWrite() const { + return _socket._state == TcpSocket::Connected ? _writeBuffer.wantsToWrite(_session) : (_socket._state == TcpSocket::SSLAcceptWantsWrite); + } + + static auto ioBufferCallback() { + return [](nghttp2_session *, int32_t /*stream_id*/, uint8_t *buf, size_t length, uint32_t *data_flags, nghttp2_data_source *source, void * /*user_data*/) { + auto ioBuffer = static_cast(source->ptr); + const std::size_t copy_len = std::min(length, ioBuffer->size() - ioBuffer->position()); + std::copy(ioBuffer->data() + ioBuffer->position(), ioBuffer->data() + ioBuffer->position() + copy_len, buf); + ioBuffer->skip(static_cast(copy_len)); + if (ioBuffer->position() == ioBuffer->size()) { + *data_flags |= NGHTTP2_DATA_FLAG_EOF; + } + return static_cast(copy_len); + }; + } + + static auto viewCallback() { + return [](nghttp2_session *, int32_t /*stream_id*/, uint8_t *buf, size_t length, uint32_t *data_flags, nghttp2_data_source *source, void * /*user_data*/) { + auto view = static_cast *>(source->ptr); + const std::size_t copy_len = std::min(length, view->size()); + std::copy(view->data(), view->data() + copy_len, buf); + *view = view->subspan(copy_len); + if (view->empty()) { + *data_flags |= NGHTTP2_DATA_FLAG_EOF; + } + return static_cast(copy_len); + }; + } + + void sendResponse(std::int32_t streamId, majordomo::rest::Response response) { + // store message while sending so we don't need to copy the data + auto &msg = _responsesByStreamId.try_emplace(streamId, ResponseData{ _sharedData, std::move(response) }).first->second; + + constexpr auto noCopy = NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE; + const auto statusStr = std::to_string(msg.restResponse.code); + + std::vector headers; + headers.reserve(msg.restResponse.headers.size() + 2); + // :status must go first, otherwise browsers and curl will not accept the response + headers.push_back(nv(u8span(":status"), u8span(statusStr), NGHTTP2_NV_FLAG_NO_COPY_NAME)); + headers.push_back(nv(u8span("access-control-allow-origin"), u8span("*"), noCopy)); + if (!_sharedData->_altSvcHeaderValue.empty()) { + headers.push_back(_sharedData->_altSvcHeader); + } + + for (const auto &[name, value] : msg.restResponse.headers) { + headers.push_back(nv(u8span(name), u8span(value), noCopy)); + } + + nghttp2_data_provider2 data_prd; + data_prd.read_callback = nullptr; + + if (msg.restResponse.bodyReader) { + data_prd.source.ptr = &msg.restResponse; + data_prd.read_callback = [](nghttp2_session *, int32_t stream_id, uint8_t *buf, size_t length, uint32_t *data_flags, nghttp2_data_source *source, void * /*user_data*/) -> ssize_t { + std::ignore = stream_id; + auto res = static_cast(source->ptr); + const auto r = res->bodyReader(std::span(buf, length)); + if (!r) { + HTTP_DBG("Server: stream_id={} Error reading body: {}", stream_id, r.error()); + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE; + } + const auto &[bytesRead, hasMore] = *r; + if (!hasMore) { + *data_flags |= NGHTTP2_DATA_FLAG_EOF; + } + return static_cast(bytesRead); + }; + } else if (!msg.restResponse.bodyView.empty()) { + data_prd.source.ptr = &msg.restResponse.bodyView; + data_prd.read_callback = viewCallback(); + } else if (!msg.restResponse.body.empty()) { + data_prd.source.ptr = &msg.restResponse.body; + data_prd.read_callback = ioBufferCallback(); + } + +#ifdef OPENCMW_DEBUG_HTTP + auto formattedHeaders = headers | std::views::transform([](const auto &header) { + return std::format("'{}'='{}'", std::string_view(reinterpret_cast(header.name), header.namelen), std::string_view(reinterpret_cast(header.value), header.valuelen)); + }); + HTTP_DBG("Server::H2: Sending response {} to streamId {}. Headers:\n{}\n Body: {}", msg.restResponse.code, streamId, opencmw::join(formattedHeaders, "\n"), msg.restResponse.bodyReader ? "reader" : std::format("{} bytes", msg.restResponse.body.size())); +#endif + auto prd = data_prd.read_callback ? &data_prd : nullptr; + if (auto rc = nghttp2_submit_response2(_session, streamId, headers.data(), headers.size(), prd); rc != 0) { + HTTP_DBG("Server::H2: nghttp2_submit_response2 for stream ID {} failed: {}", streamId, nghttp2_strerror(rc)); + _responsesByStreamId.erase(streamId); + } + } + + void sendResponse(std::int32_t streamId, int responseCode, Message &&responseMessage, std::vector extraHeaders = {}) { + // store message while sending so we don't need to copy the data + auto &msg = _responsesByStreamId.try_emplace(streamId, ResponseData{ _sharedData, std::move(responseMessage) }).first->second; + IoBuffer *buf = msg.errorBuffer.empty() ? &msg.message.data : &msg.errorBuffer; + + auto codeStr = std::to_string(responseCode); + auto contentLength = std::to_string(buf->size()); + constexpr int noCopy = NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE; + // :status must go first + auto headers = std::vector{ nv(u8span(":status"), u8span(codeStr)), nv(u8span("x-opencmw-topic"), u8span(msg.message.topic.str()), noCopy), + nv(u8span("x-opencmw-service-name"), u8span(msg.message.serviceName), noCopy), nv(u8span("access-control-allow-origin"), u8span("*"), noCopy), nv(u8span("content-length"), u8span(contentLength)) }; + + if (!_sharedData->_altSvcHeaderValue.empty()) { + headers.push_back(_sharedData->_altSvcHeader); + } + headers.insert(headers.end(), extraHeaders.begin(), extraHeaders.end()); + + nghttp2_data_provider2 data_prd; + data_prd.source.ptr = buf; + data_prd.read_callback = ioBufferCallback(); + +#ifdef OPENCMW_DEBUG_HTTP + auto formattedHeaders = headers | std::views::transform([](const auto &header) { + return std::format("'{}'='{}'", std::string_view(reinterpret_cast(header.name), header.namelen), std::string_view(reinterpret_cast(header.value), header.valuelen)); + }); + HTTP_DBG("Server::H2: Sending response {} to streamId {}. Headers:\n{}", responseCode, streamId, opencmw::join(formattedHeaders, "\n")); +#endif + if (auto rc = nghttp2_submit_response2(_session, streamId, headers.data(), headers.size(), &data_prd); rc != 0) { + HTTP_DBG("Server::H2: nghttp2_submit_response2 for stream ID {} failed: {}", streamId, nghttp2_strerror(rc)); + _responsesByStreamId.erase(streamId); + } + } + + void respondWithRedirect(std::int32_t streamId, std::string_view location) { + HTTP_DBG("Server::respondWithRedirect: streamId={} location={}", streamId, location); + // :status must go first + constexpr auto noCopy = NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE; + auto headers = std::vector{ nv(u8span(":status"), u8span("302"), noCopy), nv(u8span("location"), u8span(location)) }; + if (!_sharedData->_altSvcHeaderValue.empty()) { + headers.push_back(_sharedData->_altSvcHeader); + } + nghttp2_submit_response2(_session, streamId, headers.data(), headers.size(), nullptr); + } + + void respondWithLongPollingRedirect(std::int32_t streamId, const URI<> &topic, std::size_t longPollIdx) { + auto location = URI<>::UriFactory(topic).addQueryParameter("LongPollingIdx", std::to_string(longPollIdx)).build(); + respondWithRedirect(streamId, location.str()); + } + + int frame_recv_callback(const nghttp2_frame *frame) { + HTTP_DBG("Server::Frame: id={} {} {} {}", frame->hd.stream_id, frame->hd.type, frame->hd.flags, (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) ? "END_STREAM" : ""); + switch (frame->hd.type) { + case NGHTTP2_DATA: + case NGHTTP2_HEADERS: + if (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) { + processCompletedRequest(frame->hd.stream_id); + } + break; + } + return 0; + } + + int frame_send_callback(const nghttp2_frame *frame) { + if (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) { + HTTP_DBG("Server::Frame sent: id={} {} {} END_STREAM", frame->hd.stream_id, frame->hd.type, frame->hd.flags); + _responsesByStreamId.erase(frame->hd.stream_id); + } + return 0; + } + + int frame_not_send_callback(const nghttp2_frame *frame, int lib_error_code) { + std::ignore = lib_error_code; + HTTP_DBG("Server::Frame not sent: id={} {} {} {}", frame->hd.stream_id, frame->hd.type, frame->hd.flags, (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) ? "END_STREAM" : ""); + if (frame->hd.type == NGHTTP2_DATA) { + _responsesByStreamId.erase(frame->hd.stream_id); + } + return 0; + } + + int data_chunk_recv_callback(uint8_t /*flags*/, int32_t stream_id, std::string_view data) { + HTTP_DBG("Server::Data id={} {} bytes", stream_id, data.size()); + addData(stream_id, data); + return 0; + } + + int stream_closed_callback(int32_t stream_id, uint32_t error_code) { + std::ignore = error_code; + HTTP_DBG("Server::Stream closed: {} ({})", stream_id, error_code); + const std::size_t erased = _responsesByStreamId.erase(stream_id); + // if this was canceled by the client, remove any pending requests/polls + if (erased > 0) { + std::erase_if(_pendingRequests, [stream_id](const auto &request) { return std::get<1>(request) == stream_id; }); + std::erase_if(_pendingPolls, [stream_id](const auto &poll) { return std::get<2>(poll) == stream_id; }); + } + return 0; + } + + int header_callback(const nghttp2_frame *frame, std::string_view name, std::string_view value, uint8_t /*flags*/) { + HTTP_DBG("Server::Header id={} {} = {}", frame->hd.stream_id, name, value); + addHeader(frame->hd.stream_id, name, value); + return 0; + } + + int invalid_frame_recv_callback(const nghttp2_frame *, int lib_error_code) { + std::ignore = lib_error_code; + HTTP_DBG("invalid_frame_recv_callback called error={}", lib_error_code); + return 0; + } + + int error_callback(int lib_error_code, const char *msg, size_t len) { + std::ignore = lib_error_code; + std::ignore = msg; + std::ignore = len; + HTTP_DBG("Server::ERROR: {} ({})", std::string_view(msg, len), lib_error_code); + return 0; + } +}; + +template +struct Http3Session : public SessionBase, std::int64_t>, public std::enable_shared_from_this> { + enum class State { + Initial, + Closing, + Draining, + }; + State _state = State::Initial; + TServer *server_; + ngtcp2_conn *_conn = nullptr; + ngtcp2_crypto_conn_ref _conn_ref; + nghttp3_conn *_httpconn = nullptr; + ngtcp2_cid _scid; + TLSServerSession _tlsSession; + ngtcp2_ccerr _last_error; + // conn_closebuf_ contains a packet which contains CONNECTION_CLOSE. + // This packet is repeatedly sent as a response to the incoming + // packet in draining period. + std::unique_ptr _conn_closebuf; + bool _no_gso = false; + struct { + size_t bytes_recv; + size_t bytes_sent; + size_t num_pkts_recv; + size_t next_pkts_recv = 1; + } _close_wait; + + struct { + bool send_blocked = false; + size_t num_blocked = 0; + size_t num_blocked_sent = 0; + // blocked field is effective only when send_blocked is true. + struct { + Endpoint *endpoint = nullptr; + Address local_addr; + Address remote_addr; + unsigned int ecn = 0; + std::span data; + size_t gso_size = 0; + } blocked[2]; + std::unique_ptr data = std::unique_ptr(new uint8_t[64 * 1024]); + } _tx; + + static ngtcp2_conn *get_conn(ngtcp2_crypto_conn_ref *conn_ref) { + auto session = static_cast(conn_ref->user_data); + return session->_conn; + } + + explicit Http3Session(TServer *server_, std::shared_ptr sharedData) + : SessionBase(std::move(sharedData)), server_(server_), _conn_ref({ .get_conn = get_conn, .user_data = this }) { + } + + ~Http3Session() { + nghttp3_conn_del(_httpconn); + } + + Http3Session(const Http3Session &) + = delete; + Http3Session &operator=(const Http3Session &) = delete; + Http3Session(Http3Session &&other) = delete; + Http3Session &operator=(Http3Session &&other) = delete; + + static auto viewCallback() { + return [](nghttp3_conn *, int64_t stream_id, nghttp3_vec *vec, std::size_t /*veccnt*/, std::uint32_t *pflags, void */*conn_user_data*/, void *stream_user_data) -> nghttp3_ssize { + std::ignore = stream_id; + auto res = static_cast(stream_user_data); + auto &bodyView = res->restResponse.bodyView; + vec[0].base = const_cast(bodyView.data()); + vec[0].len = bodyView.size(); + HTTP_DBG("Server::H3::viewCallback: stream_id={} sz={} vec[0].len={}", stream_id, bodyView.size(), vec[0].len); + *pflags = NGHTTP3_DATA_FLAG_EOF; + return 1; + }; + } + + static auto ioBufferCallback() { + return [](nghttp3_conn *, int64_t stream_id, nghttp3_vec *vec, std::size_t /*veccnt*/, std::uint32_t *pflags, void * /*conn_user_data*/, void *stream_user_data) -> nghttp3_ssize { + std::ignore = stream_id; + auto res = static_cast(stream_user_data); + vec[0].base = res->bodyBuffer->data(); + vec[0].len = res->bodyBuffer->size(); + HTTP_DBG("Server::H3::ioBufferCallback: stream_id={} sz={} vec[0].len={}", stream_id, res->bodyBuffer->size(), vec[0].len); + *pflags = NGHTTP3_DATA_FLAG_EOF; + return 1; + }; + } + + void sendResponse(std::int64_t streamId, majordomo::rest::Response response) { + // store message while sending so we don't need to copy the data + auto &msg = this->_responsesByStreamId.try_emplace(streamId, ResponseData{ this->_sharedData, std::move(response) }).first->second; + nghttp3_conn_set_stream_user_data(_httpconn, streamId, &msg); + + constexpr auto noCopy = NGHTTP3_NV_FLAG_NO_COPY_NAME | NGHTTP3_NV_FLAG_NO_COPY_VALUE; + const auto statusStr = std::to_string(msg.restResponse.code); + + std::vector headers; + headers.reserve(msg.restResponse.headers.size() + 2); + // :status must go first, otherwise browsers and curl will not accept the response + headers.push_back(nv3(u8span(":status"), u8span(statusStr), NGHTTP3_NV_FLAG_NO_COPY_NAME)); + headers.push_back(nv3(u8span("access-control-allow-origin"), u8span("*"), noCopy)); + + for (const auto &[name, value] : msg.restResponse.headers) { + headers.push_back(nv3(u8span(name), u8span(value), noCopy)); + } + + nghttp3_data_reader data_prd; + data_prd.read_data = nullptr; + + if (msg.restResponse.bodyReader) { + data_prd.read_data = [](nghttp3_conn *, int64_t stream_id, nghttp3_vec *vec, std::size_t /*veccnt*/, std::uint32_t *pflags, void * /*conn_user_data*/, void *stream_user_data) -> nghttp3_ssize { + std::ignore = stream_id; + auto res = static_cast(stream_user_data); + + // We need to cache the data from the reader until nghttp3 stream_data_acked confirms that the data was sent. + // Use a chunk pool to avoid constant allocation/deallocation. + if (res->bodyReaderChunks.empty() || res->bodyReaderChunks.back()->full()) { + res->bodyReaderChunks.push_back(res->sharedData->acquireChunk()); + } + auto &chunk = res->bodyReaderChunks.back(); + + const auto r = res->restResponse.bodyReader(chunk->data()); + if (!r) { + HTTP_DBG("Server::H3::bodyReaderCallback: stream_id={} Error reading body: {}", stream_id, r.error()); + return NGHTTP3_ERR_CALLBACK_FAILURE; + } + const auto &[bytesRead, hasMore] = *r; + HTTP_DBG("Server::H3::bodyReaderCallback: stream_id={} bytesRead={} hasMore={}", stream_id, bytesRead, hasMore); + + vec[0].base = &chunk->buffer[chunk->offset]; + vec[0].len = bytesRead; + chunk->bytesAdded(bytesRead); + + HTTP_DBG("Server::H3::bodyReaderCallback: stream_id={} chunks={} vec[0].len={}", stream_id, res->bodyReaderChunks.size(), vec[0].len); + + if (!hasMore) { + *pflags = NGHTTP3_DATA_FLAG_EOF; + } + return 1; + }; + } else if (!msg.restResponse.bodyView.empty()) { + data_prd.read_data = viewCallback(); + } else { + msg.bodyBuffer = &msg.restResponse.body; + data_prd.read_data = ioBufferCallback(); + } + +#ifdef OPENCMW_DEBUG_HTTP + auto formattedHeaders = headers | std::views::transform([](const auto &header) { + return std::format("'{}'='{}'", std::string_view(reinterpret_cast(header.name), header.namelen), std::string_view(reinterpret_cast(header.value), header.valuelen)); + }); + HTTP_DBG("Server::H3: Sending response {} to streamId {}. Headers:\n{}\n Body: {}", msg.restResponse.code, streamId, opencmw::join(formattedHeaders, "\n"), msg.restResponse.bodyReader ? "reader" : std::format("{} bytes", msg.restResponse.body.size())); +#endif + auto prd = data_prd.read_data ? &data_prd : nullptr; + if (auto rc = nghttp3_conn_submit_response(_httpconn, streamId, headers.data(), headers.size(), prd); rc != 0) { + HTTP_DBG("Server::H3: nghttp3_conn_submit_response for stream ID {} failed: {}", streamId, nghttp3_strerror(rc)); + this->_responsesByStreamId.erase(streamId); + } + } + + void sendResponse(std::int64_t streamId, int responseCode, Message &&responseMessage, std::vector extraHeaders = {}) { + // store message while sending so we don't need to copy the data + auto &msg = this->_responsesByStreamId.try_emplace(streamId, ResponseData{ this->_sharedData, std::move(responseMessage) }).first->second; + msg.bodyBuffer = msg.errorBuffer.empty() ? &msg.message.data : &msg.errorBuffer; + + auto codeStr = std::to_string(responseCode); + auto contentLength = std::to_string(msg.bodyBuffer->size()); + constexpr std::uint8_t noCopy = NGHTTP3_NV_FLAG_NO_COPY_NAME | NGHTTP3_NV_FLAG_NO_COPY_VALUE; + // :status must go first + auto headers = std::vector{ nv3(u8span(":status"), u8span(codeStr)), nv3(u8span("x-opencmw-topic"), u8span(msg.message.topic.str()), noCopy), + nv3(u8span("x-opencmw-service-name"), u8span(msg.message.serviceName), noCopy), nv3(u8span("access-control-allow-origin"), u8span("*"), noCopy), nv3(u8span("content-length"), u8span(contentLength)) }; + + auto nv2ToNv3 = [](const auto &nv) { + auto mapFlags = [](uint8_t flags) { + uint8_t r = NGHTTP2_NV_FLAG_NONE; + if (flags & NGHTTP2_NV_FLAG_NO_COPY_NAME) { + r = NGHTTP3_NV_FLAG_NO_COPY_NAME; + } + if (flags & NGHTTP2_NV_FLAG_NO_COPY_VALUE) { + r |= NGHTTP3_NV_FLAG_NO_COPY_VALUE; + } + return r; + }; + return nghttp3_nv{ nv.name, nv.value, nv.namelen, nv.valuelen, mapFlags(nv.flags) }; + }; + std::transform(extraHeaders.begin(), extraHeaders.end(), std::back_inserter(headers), nv2ToNv3); + + nghttp3_data_reader data_prd; + nghttp3_conn_set_stream_user_data(_httpconn, streamId, &msg); + data_prd.read_data = ioBufferCallback(); + +#ifdef OPENCMW_DEBUG_HTTP + auto formattedHeaders = headers | std::views::transform([](const auto &header) { + return std::format("'{}'='{}'", std::string_view(reinterpret_cast(header.name), header.namelen), std::string_view(reinterpret_cast(header.value), header.valuelen)); + }); + HTTP_DBG("Server::H3: Sending response {} to streamId {}. Headers:\n{}", responseCode, streamId, opencmw::join(formattedHeaders, "\n")); +#endif + if (auto rc = nghttp3_conn_submit_response(_httpconn, streamId, headers.data(), headers.size(), &data_prd); rc != 0) { + HTTP_DBG("Server::H3: nghttp3_conn_submit_response for stream ID {} failed: {}", streamId, nghttp2_strerror(rc)); + this->_responsesByStreamId.erase(streamId); + } + } + + void respondWithRedirect(std::int64_t streamId, std::string_view location) { + HTTP_DBG("Server::H3::respondWithRedirect: streamId={} location={}", streamId, location); + // :status must go first + constexpr auto noCopy = NGHTTP3_NV_FLAG_NO_COPY_NAME | NGHTTP3_NV_FLAG_NO_COPY_VALUE; + const auto headers = std::array{ nv3(u8span(":status"), u8span("302"), noCopy), nv3(u8span("location"), u8span(location)) }; + nghttp3_conn_submit_response(_httpconn, streamId, headers.data(), headers.size(), nullptr); + } + + int init(const Endpoint &ep, const Address &local_addr, const sockaddr *sa, socklen_t salen, const ngtcp2_cid *dcid, const ngtcp2_cid *scid, const ngtcp2_cid *ocid, std::span token, ngtcp2_token_type token_type, std::uint32_t version, TLSServerContext &tls_ctx) { + auto handshakeCompleted = [](ngtcp2_conn *, void *user_data) { + auto session = static_cast(user_data); + if (session->handshake_completed() != 0) { + HTTP_DBG("Server::H3: handshake_completed failed"); + return NGTCP2_ERR_CALLBACK_FAILURE; + } + return 0; + }; + + auto recvStreamData = [](ngtcp2_conn *, std::uint32_t flags, std::int64_t stream_id, std::uint64_t /*offset*/, const std::uint8_t *data, std::size_t datalen, void *user_data, void * /*stream_user_data*/) { + auto session = static_cast(user_data); + if (session->recv_stream_data(flags, stream_id, { data, datalen }) != 0) { + return NGTCP2_ERR_CALLBACK_FAILURE; + } + return 0; + }; + + auto ackedStreamDataOffset = [](ngtcp2_conn *, int64_t stream_id, uint64_t /*offset*/, uint64_t datalen, void *user_data, void * /*stream_user_data*/) { + auto session = static_cast(user_data); + if (session->acked_stream_data_offset(stream_id, datalen) != 0) { + return NGTCP2_ERR_CALLBACK_FAILURE; + } + return 0; + }; + + auto streamOpen = [](ngtcp2_conn *, int64_t /*stream_id*/, void */*user_data*/) { return 0; }; + + auto streamClose = [](ngtcp2_conn *, uint32_t flags, int64_t stream_id, uint64_t app_error_code, void *user_data, void * /*stream_user_data*/) { + auto session = static_cast(user_data); + + if (!(flags & NGTCP2_STREAM_CLOSE_FLAG_APP_ERROR_CODE_SET)) { + app_error_code = NGHTTP3_H3_NO_ERROR; + } + + if (session->on_stream_close(stream_id, app_error_code) != 0) { + return NGTCP2_ERR_CALLBACK_FAILURE; + } + return 0; + }; + + auto randCb = [](uint8_t *dest, size_t destlen, const ngtcp2_rand_ctx *) { + auto rv = generate_secure_random({ dest, destlen }); + if (rv != 0) { + assert(0); + abort(); + } + }; + + auto getNewConnectionId = [](ngtcp2_conn *, ngtcp2_cid *cid, uint8_t *token, size_t cidlen, void *user_data) { + if (generate_secure_random({ cid->data, cidlen }) != 0) { + return NGTCP2_ERR_CALLBACK_FAILURE; + } + + auto session = static_cast(user_data); + auto &static_secret = session->_sharedData->_static_secret; + + cid->datalen = cidlen; + if (ngtcp2_crypto_generate_stateless_reset_token(token, static_secret.data(), static_secret.size(), cid) != 0) { + return NGTCP2_ERR_CALLBACK_FAILURE; + } + session->server_->associate_cid(*cid, session->shared_from_this()); + + return 0; + }; + + auto removeConnectionId = [](ngtcp2_conn *, const ngtcp2_cid *cid, void *user_data) { + auto session = static_cast(user_data); + session->server_->dissociate_cid(*cid); + return 0; + }; + + auto updateKey = [](ngtcp2_conn *, uint8_t *rx_secret, uint8_t *tx_secret, ngtcp2_crypto_aead_ctx *rx_aead_ctx, uint8_t *rx_iv, ngtcp2_crypto_aead_ctx *tx_aead_ctx, uint8_t *tx_iv, const uint8_t *current_rx_secret, const uint8_t *current_tx_secret, size_t secretlen, void *user_data) { + auto session = static_cast(user_data); + if (session->update_key(rx_secret, tx_secret, rx_aead_ctx, rx_iv, tx_aead_ctx, tx_iv, current_rx_secret, current_tx_secret, secretlen) != 0) { + return NGTCP2_ERR_CALLBACK_FAILURE; + } + return 0; + }; + + auto pathValidation = [](ngtcp2_conn *conn, uint32_t flags, const ngtcp2_path *path, const ngtcp2_path * /*old_path*/, ngtcp2_path_validation_result res, void *user_data) { + if (res != NGTCP2_PATH_VALIDATION_RESULT_SUCCESS || !(flags & NGTCP2_PATH_VALIDATION_FLAG_NEW_TOKEN)) { + return 0; + } + + std::array token; + auto t = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); + + auto session = static_cast(user_data); + auto &static_secret = session->_sharedData->_static_secret; + auto tokenlen = ngtcp2_crypto_generate_regular_token(token.data(), static_secret.data(), static_secret.size(), path->remote.addr, path->remote.addrlen, static_cast(t)); + if (tokenlen < 0) { + HTTP_DBG("ngtcp2_crypto_generate_regular_token: {}", ngtcp2_strerror(static_cast(tokenlen))); + return 0; + } + + if (auto rv = ngtcp2_conn_submit_new_token(conn, token.data(), static_cast(tokenlen)); rv != 0) { + HTTP_DBG("ngtcp2_conn_submit_new_token: {}", ngtcp2_strerror(rv)); + return NGTCP2_ERR_CALLBACK_FAILURE; + } + + return 0; + }; + + auto streamReset = [](ngtcp2_conn *, int64_t stream_id, uint64_t /*final_size*/, uint64_t /*app_error_code*/, void *user_data, void * /*stream_user_data*/) { + auto session = static_cast(user_data); + if (session->on_stream_reset(stream_id) != 0) { + return NGTCP2_ERR_CALLBACK_FAILURE; + } + return 0; + }; + + auto extendMaxRemoteStreamsBidi = [](ngtcp2_conn *, uint64_t max_streams, void *user_data) { + auto session = static_cast(user_data); + session->extend_max_remote_streams_bidi(max_streams); + return 0; + }; + + auto extendMaxStreamData = [](ngtcp2_conn *, int64_t stream_id, uint64_t max_data, void *user_data, void * /*stream_user_data*/) { + auto session = static_cast(user_data); + if (session->extend_max_stream_data(stream_id, max_data) != 0) { + return NGTCP2_ERR_CALLBACK_FAILURE; + } + return 0; + }; + + auto streamStopSending = [](ngtcp2_conn *, int64_t stream_id, uint64_t /*app_error_code*/, void *user_data, void * /*stream_user_data*/) { + auto session = static_cast(user_data); + if (session->on_stream_stop_sending(stream_id) != 0) { + return NGTCP2_ERR_CALLBACK_FAILURE; + } + return 0; + }; + + auto recvTxKey = [](ngtcp2_conn *, ngtcp2_encryption_level level, void *user_data) { + if (level != NGTCP2_ENCRYPTION_LEVEL_1RTT) { + return 0; + } + auto session = static_cast(user_data); + if (session->setup_httpconn() != 0) { + return NGTCP2_ERR_CALLBACK_FAILURE; + } + return 0; + }; + + ngtcp2_callbacks callbacks; + memset(&callbacks, 0, sizeof(callbacks)); + callbacks.client_initial = ngtcp2_crypto_client_initial_cb; + callbacks.recv_client_initial = ngtcp2_crypto_recv_client_initial_cb; + callbacks.recv_crypto_data = ngtcp2_crypto_recv_crypto_data_cb; + callbacks.handshake_completed = handshakeCompleted; + callbacks.encrypt = ngtcp2_crypto_encrypt_cb; + callbacks.decrypt = ngtcp2_crypto_decrypt_cb; + callbacks.hp_mask = ngtcp2_crypto_hp_mask_cb; + callbacks.recv_stream_data = recvStreamData; + callbacks.acked_stream_data_offset = ackedStreamDataOffset; + callbacks.stream_open = streamOpen; + callbacks.stream_close = streamClose; + callbacks.rand = randCb; + callbacks.get_new_connection_id = getNewConnectionId; + callbacks.remove_connection_id = removeConnectionId; + callbacks.update_key = updateKey; + callbacks.path_validation = pathValidation; + callbacks.stream_reset = streamReset; + callbacks.extend_max_remote_streams_bidi = extendMaxRemoteStreamsBidi; + callbacks.extend_max_stream_data = extendMaxStreamData; + callbacks.delete_crypto_aead_ctx = ngtcp2_crypto_delete_crypto_aead_ctx_cb; + callbacks.delete_crypto_cipher_ctx = ngtcp2_crypto_delete_crypto_cipher_ctx_cb; + callbacks.get_path_challenge_data = ngtcp2_crypto_get_path_challenge_data_cb; + callbacks.stream_stop_sending = streamStopSending; + callbacks.version_negotiation = ngtcp2_crypto_version_negotiation_cb; + callbacks.recv_tx_key = recvTxKey; + + _scid.datalen = NGTCP2_SV_SCIDLEN; + if (generate_secure_random({ _scid.data, _scid.datalen }) != 0) { + HTTP_DBG("Could not generate connection ID"); + return -1; + } + + ngtcp2_settings settings; + ngtcp2_settings_default(&settings); + settings.token = token.data(); + settings.tokenlen = token.size(); + settings.token_type = token_type; +#ifdef OPENCMW_DEBUG_HTTP + settings.log_printf = print_debug; +#endif + ngtcp2_transport_params params; + ngtcp2_transport_params_default(¶ms); + params.initial_max_stream_data_bidi_local = 65535; + params.initial_max_stream_data_bidi_remote = 65535; + params.initial_max_stream_data_uni = 65535; + params.initial_max_data = 128 * 1024; + params.initial_max_streams_bidi = 100; + params.initial_max_streams_uni = 3; + params.max_idle_timeout = 3600 * NGTCP2_SECONDS; + params.stateless_reset_token_present = 0; + params.active_connection_id_limit = 8; + + if (ocid) { + params.original_dcid = *ocid; + params.retry_scid = *scid; + params.retry_scid_present = 1; + } else { + params.original_dcid = *scid; + } + + params.original_dcid_present = 1; + + if (ngtcp2_crypto_generate_stateless_reset_token(params.stateless_reset_token, this->_sharedData->_static_secret.data(), this->_sharedData->_static_secret.size(), &_scid) != 0) { + return -1; + } + + auto path = ngtcp2_path{ + .local = { + .addr = const_cast(&local_addr.su.sa), + .addrlen = local_addr.len, + }, + .remote = { + .addr = const_cast(sa), + .addrlen = salen, + }, + .user_data = const_cast(&ep), + }; + if (auto rv = ngtcp2_conn_server_new(&_conn, dcid, &_scid, &path, version, &callbacks, &settings, ¶ms, nullptr, this); rv != 0) { + HTTP_DBG("ngtcp2_conn_server_new: {}", ngtcp2_strerror(rv)); + return -1; + } + + if (ngtcp2_crypto_ossl_init() != 0) { + HTTP_DBG("ngtcp2_crypto_ossl_init: failed"); + return -1; + } + + if (_tlsSession.init(tls_ctx, &_conn_ref) != 0) { + return -1; + } + + _tlsSession.enable_keylog(); + + ngtcp2_conn_set_tls_native_handle(_conn, _tlsSession.get_native_handle()); + + return 0; + } + + int on_stream_reset(int64_t stream_id) { + if (_httpconn) { + if (auto rv = nghttp3_conn_shutdown_stream_read(_httpconn, stream_id); + rv != 0) { + HTTP_DBG("Server::H3::on_stream_reset: nghttp3_conn_shutdown_stream_read: {}", nghttp3_strerror(rv)); + return -1; + } + } + return 0; + } + + int on_stream_close(int64_t stream_id, uint64_t app_error_code) { + HTTP_DBG("Server::H3::on_stream_close: stream_id={} app_error_code={}", stream_id, app_error_code); + + if (_httpconn) { + if (app_error_code == 0) { + app_error_code = NGHTTP3_H3_NO_ERROR; + } + auto rv = nghttp3_conn_close_stream(_httpconn, stream_id, app_error_code); + switch (rv) { + case 0: + break; + case NGHTTP3_ERR_STREAM_NOT_FOUND: + if (ngtcp2_is_bidi_stream(stream_id)) { + assert(!ngtcp2_conn_is_local_stream(_conn, stream_id)); + ngtcp2_conn_extend_max_streams_bidi(_conn, 1); + } + break; + default: + HTTP_DBG("Server::H3::on_stream_close: nghttp3_conn_close_stream: {}", nghttp3_strerror(rv)); + ngtcp2_ccerr_set_application_error(&_last_error, nghttp3_err_infer_quic_app_error_code(rv), nullptr, 0); + return -1; + } + } + + return 0; + } + + void extend_max_remote_streams_bidi(uint64_t max_streams) { + if (!_httpconn) { + return; + } + + nghttp3_conn_set_max_client_streams_bidi(_httpconn, max_streams); + } + + int extend_max_stream_data(int64_t stream_id, uint64_t max_data) { + std::ignore = max_data; + if (auto rv = nghttp3_conn_unblock_stream(_httpconn, stream_id); rv != 0) { + HTTP_DBG("Server::H3::extend_max_stream_data: nghttp3_conn_unblock_stream: {}", nghttp3_strerror(rv)); + return -1; + } + return 0; + } + + int setup_httpconn() { + if (_httpconn) { + return 0; + } + + if (const auto n = ngtcp2_conn_get_streams_uni_left(_conn); n < 3) { + HTTP_DBG("peer does not allow at least 3 unidirectional streams. (allows {})", n); + return -1; + } + + nghttp3_callbacks callbacks; + memset(&callbacks, 0, sizeof(callbacks)); +#ifdef OPENCMW_DEBUG_HTTP + nghttp3_set_debug_vprintf_callback(print_debug2); +#endif + callbacks.acked_stream_data = [](nghttp3_conn *, std::int64_t stream_id, std::uint64_t datalen, void *conn_user_data, void *) { + auto session = static_cast(conn_user_data); + return session->acked_stream_data(stream_id, datalen); + }; + callbacks.stream_close = [](nghttp3_conn *, std::int64_t stream_id, std::uint64_t error_code, void *conn_user_data, void *) { + std::ignore = error_code; + HTTP_DBG("Server::H3::stream_close: stream_id={} error_code={}", stream_id, error_code); + auto session = static_cast(conn_user_data); + + if (ngtcp2_is_bidi_stream(stream_id)) { + assert(!ngtcp2_conn_is_local_stream(session->_conn, stream_id)); + ngtcp2_conn_extend_max_streams_bidi(session->_conn, 1); + } + + return 0; + }; + callbacks.recv_data = [](nghttp3_conn *, std::int64_t stream_id, const uint8_t *data, std::size_t datalen, void *conn_user_data, void * /*stream_user_data*/) { + HTTP_DBG("Server::H3::recv_data: stream_id={} datalen={}", stream_id, datalen); + auto dataView = std::string_view(reinterpret_cast(data), datalen); + auto session = static_cast(conn_user_data); + session->addData(stream_id, dataView); + return 0; + }; + callbacks.deferred_consume = [](nghttp3_conn *, std::int64_t stream_id, std::size_t datalen, void * /*conn_user_data*/, void *) { + std::ignore = stream_id; + std::ignore = datalen; + HTTP_DBG("Server::H3::deferred_consume: stream_id={} datalen={}", stream_id, datalen); + return 0; + }; + callbacks.begin_headers = [](nghttp3_conn *, std::int64_t stream_id, void * /*conn_user_data*/, void *) { + std::ignore = stream_id; + HTTP_DBG("Server::H3::begin_headers: stream_id={}", stream_id); + return 0; + }; + callbacks.recv_header = [](nghttp3_conn *, std::int64_t stream_id, std::int32_t token, nghttp3_rcbuf *name, nghttp3_rcbuf *value, uint8_t /*flags*/, void *conn_user_data, void *) { + std::ignore = token; + auto nameView = as_view(name); + auto valueView = as_view(value); + HTTP_DBG("Server::H3::recv_header: stream_id={} token={} name={} value={}", stream_id, token, nameView, valueView); + auto session = static_cast(conn_user_data); + session->addHeader(stream_id, nameView, valueView); + return 0; + }; + callbacks.end_headers = [](nghttp3_conn *, std::int64_t stream_id, int fin, void * /*conn_user_data*/, void *) { + std::ignore = stream_id; + std::ignore = fin; + HTTP_DBG("Server::H3::end_headers: stream_id={} fin={}", stream_id, fin); + return 0; + }; + callbacks.stop_sending = [](nghttp3_conn *, std::int64_t stream_id, std::uint64_t app_error_code, void * /*conn_user_data*/, void *) { + std::ignore = stream_id; + std::ignore = app_error_code; + HTTP_DBG("Server::H3::stop_sending: stream_id={} error_code={}", stream_id, app_error_code); + return 0; + }; + callbacks.end_stream = [](nghttp3_conn *, std::int64_t stream_id, void *conn_user_data, void *) { + HTTP_DBG("Server::H3::end_stream: stream_id={}", stream_id); + auto session = static_cast(conn_user_data); + session->processCompletedRequest(stream_id); + return 0; + }; + callbacks.reset_stream = [](nghttp3_conn *, std::int64_t stream_id, std::uint64_t error_code, void * /*conn_user_data*/, void *) { + std::ignore = stream_id; + std::ignore = error_code; + HTTP_DBG("Server::H3::reset_stream: stream_id={} error_code={}", stream_id, error_code); + return 0; + }; + callbacks.shutdown = [](nghttp3_conn *, std::int64_t id, void * /*conn_user_data*/) { + std::ignore = id; + HTTP_DBG("Server::H3::shutdown: id={}", id); + return 0; + }; + callbacks.recv_settings = [](nghttp3_conn *, const nghttp3_settings *, void * /*conn_user_data*/) { + HTTP_DBG("Server::H3::recv_settings"); + return 0; + }; + + nghttp3_settings settings; + nghttp3_settings_default(&settings); + settings.qpack_max_dtable_capacity = 4096; + settings.qpack_blocked_streams = 100; + + auto mem = nghttp3_mem_default(); + + if (auto rv = nghttp3_conn_server_new(&_httpconn, &callbacks, &settings, mem, this); rv != 0) { + HTTP_DBG("nghttp3_conn_server_new: {}", nghttp3_strerror(rv)); + return -1; + } + + auto params = ngtcp2_conn_get_local_transport_params(_conn); + + nghttp3_conn_set_max_client_streams_bidi(_httpconn, params->initial_max_streams_bidi); + + int64_t ctrl_stream_id; + + if (auto rv = ngtcp2_conn_open_uni_stream(_conn, &ctrl_stream_id, nullptr); rv != 0) { + HTTP_DBG("ngtcp2_conn_open_uni_stream: {}", ngtcp2_strerror(rv)); + return -1; + } + if (auto rv = nghttp3_conn_bind_control_stream(_httpconn, ctrl_stream_id); rv != 0) { + HTTP_DBG("nghttp3_conn_bind_control_stream: {}", nghttp3_strerror(rv)); + return -1; + } + + HTTP_DBG("Server::H3::setup_httpconn: stream_id={}", ctrl_stream_id); + + int64_t qpack_enc_stream_id, qpack_dec_stream_id; + + if (auto rv = ngtcp2_conn_open_uni_stream(_conn, &qpack_enc_stream_id, nullptr); rv != 0) { + HTTP_DBG("ngtcp2_conn_open_uni_stream: {}", ngtcp2_strerror(rv)); + return -1; + } + + if (auto rv = ngtcp2_conn_open_uni_stream(_conn, &qpack_dec_stream_id, nullptr); rv != 0) { + HTTP_DBG("ngtcp2_conn_open_uni_stream: {}", ngtcp2_strerror(rv)); + return -1; + } + + if (auto rv = nghttp3_conn_bind_qpack_streams(_httpconn, qpack_enc_stream_id, qpack_dec_stream_id); rv != 0) { + HTTP_DBG("nghttp3_conn_bind_qpack_streams: {}", nghttp3_strerror(rv)); + return -1; + } + + HTTP_DBG("Server::H3::setup_httpconn: qpack streams encoder={} decoder={}", qpack_enc_stream_id, qpack_dec_stream_id); + return 0; + } + + int acked_stream_data(std::int64_t stream_id, std::size_t datalen) { + HTTP_DBG("Server::H3::acked_stream_data: stream_id={} datalen={}", stream_id, datalen); + auto it = this->_responsesByStreamId.find(stream_id); + assert(it != this->_responsesByStreamId.end()); + auto &response = it->second; + if (response.restResponse.bodyReader) { + while (datalen > 0) { + assert(!response.bodyReaderChunks.empty()); + auto &front = response.bodyReaderChunks.front(); + const auto toAck = std::min(datalen, front->unacked); + front->unacked -= toAck; + datalen -= toAck; + HTTP_DBG("Server::H3::acked_stream_data: stream_id={} acked {} bytes; left={}", stream_id, toAck, front->unacked); + if (front->unacked == 0) { + this->_sharedData->releaseChunk(std::move(response.bodyReaderChunks.front())); + response.bodyReaderChunks.pop_front(); + } + HTTP_DBG("Server::H3::acked_stream_data: stream_id={} chunks={}", stream_id, response.bodyReaderChunks.size()); + } + } else { + // BodyBuffer/BodyView, nothing to do + HTTP_DBG("Server::H3::acked_stream_data: stream_id={} acked {}", stream_id, datalen); + } + + return 0; + } + + int acked_stream_data_offset(int64_t stream_id, uint64_t datalen) { + HTTP_DBG("Server::H3::acked_stream_data_offset: stream_id={} datalen={}", stream_id, datalen); + if (!_httpconn) { + return 0; + } + + if (auto rv = nghttp3_conn_add_ack_offset(_httpconn, stream_id, datalen); rv != 0) { + HTTP_DBG("nghttp3_conn_add_ack_offset: {}", nghttp3_strerror(rv)); + return -1; + } + + return 0; + } + + int handshake_completed() { + if (_tlsSession.send_session_ticket() != 0) { + HTTP_DBG("Unable to send session ticket"); + } + + std::array token; + + auto path = ngtcp2_conn_get_path(_conn); + auto t = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + + auto tokenlen = ngtcp2_crypto_generate_regular_token(token.data(), this->_sharedData->_static_secret.data(), this->_sharedData->_static_secret.size(), path->remote.addr, path->remote.addrlen, static_cast(t)); + if (tokenlen < 0) { + HTTP_DBG("Server::H3::handshake_completed: ngtcp2_crypto_generate_regular_token failed"); + return 0; + } + + if (auto rv = ngtcp2_conn_submit_new_token(_conn, token.data(), static_cast(tokenlen)); rv != 0) { + HTTP_DBG("Server::H3::handshake_completed: ngtcp2_conn_submit_new_token failed: {}", ngtcp2_strerror(rv)); + return -1; + } + + return 0; + } + + int on_read(const Endpoint &ep, const Address &local_addr, const sockaddr *sa, socklen_t salen, const ngtcp2_pkt_info *pi, std::span data) { + auto path = ngtcp2_path{ + .local = { + .addr = const_cast(&local_addr.su.sa), + .addrlen = local_addr.len, + }, + .remote = { + .addr = const_cast(sa), + .addrlen = salen, + }, + .user_data = const_cast(&ep), + }; + if (auto rv = ngtcp2_conn_read_pkt(_conn, &path, pi, data.data(), data.size(), timestamp()); rv != 0) { + switch (rv) { + case NGTCP2_ERR_DRAINING: + start_draining_period(); + return NETWORK_ERR_CLOSE_WAIT; + case NGTCP2_ERR_RETRY: + return NETWORK_ERR_RETRY; + case NGTCP2_ERR_DROP_CONN: + return NETWORK_ERR_DROP_CONN; + case NGTCP2_ERR_CRYPTO: + if (!_last_error.error_code) { + ngtcp2_ccerr_set_tls_alert(&_last_error, ngtcp2_conn_get_tls_alert(_conn), nullptr, 0); + } + break; + default: + if (!_last_error.error_code) { + ngtcp2_ccerr_set_liberr(&_last_error, rv, nullptr, 0); + } + } + return handle_error(); + } + + return 0; + } + + void start_draining_period() { + HTTP_DBG("Server::H3::start_draining_period"); + _state = State::Draining; + } + + int on_stream_stop_sending(std::int64_t stream_id) { + if (!_httpconn) { + return 0; + } + + if (auto rv = nghttp3_conn_shutdown_stream_read(_httpconn, stream_id); rv != 0) { + HTTP_DBG("nghttp3_conn_shutdown_stream_read: {}", nghttp3_strerror(rv)); + return -1; + } + + return 0; + } + + int send_blocked_packet() { + assert(_tx.send_blocked); + + for (; _tx.num_blocked_sent < _tx.num_blocked; ++_tx.num_blocked_sent) { + auto &p = _tx.blocked[_tx.num_blocked_sent]; + + ngtcp2_addr local_addr{ + .addr = &p.local_addr.su.sa, + .addrlen = p.local_addr.len, + }; + ngtcp2_addr remote_addr{ + .addr = &p.remote_addr.su.sa, + .addrlen = p.remote_addr.len, + }; + + auto [rest, rv] = server_->send_packet(*p.endpoint, _no_gso, local_addr, remote_addr, p.ecn, p.data, p.gso_size); + if (rv != 0) { + assert(NETWORK_ERR_SEND_BLOCKED == rv); + p.data = rest; + return 0; + } + } + + _tx.send_blocked = false; + _tx.num_blocked = 0; + _tx.num_blocked_sent = 0; + + return 0; + } + + int on_write() { + if (ngtcp2_conn_in_closing_period(_conn) || ngtcp2_conn_in_draining_period(_conn)) { + return 0; + } + + if (_tx.send_blocked) { + if (auto rv = send_blocked_packet(); rv != 0) { + return rv; + } + + if (_tx.send_blocked) { + return 0; + } + } + + if (auto rv = write_streams(); rv != 0) { + return rv; + } + + return 0; + } + + void write_handler() { + // this tries to mimick the state handling from the example, where different ev callbacks are set at different times + + switch (_state) { + case State::Initial: // writecb + switch (on_write()) { + case 0: + case NETWORK_ERR_CLOSE_WAIT: + return; + default: + server_->remove(_conn); + } + break; + case State::Closing: + case State::Draining: { // close_waitcb + if (ngtcp2_conn_in_closing_period(_conn)) { + HTTP_DBG("Server::H3::write_handler: closing period is over"); + server_->remove(_conn); + return; + } + if (ngtcp2_conn_in_draining_period(_conn)) { + HTTP_DBG("Server::H3::write_handler: draining period is over"); + server_->remove(_conn); + return; + } + + assert(0); + } break; + } + } + + void on_send_blocked(Endpoint &ep, const ngtcp2_addr &local_addr, const ngtcp2_addr &remote_addr, unsigned int ecn, std::span data, size_t gso_size) { + assert(_tx.num_blocked || !_tx.send_blocked); + assert(_tx.num_blocked < 2); + assert(gso_size); + + _tx.send_blocked = true; + + auto &p = _tx.blocked[_tx.num_blocked++]; + + memcpy(&p.local_addr.su, local_addr.addr, local_addr.addrlen); + memcpy(&p.remote_addr.su, remote_addr.addr, remote_addr.addrlen); + + p.local_addr.len = local_addr.addrlen; + p.remote_addr.len = remote_addr.addrlen; + p.endpoint = &ep; + p.ecn = ecn; + p.data = data; + p.gso_size = gso_size; + } + + int write_streams() { + std::array vec; + ngtcp2_path_storage ps, prev_ps; + uint32_t prev_ecn = 0; + auto max_udp_payload_size = ngtcp2_conn_get_max_tx_udp_payload_size(_conn); + auto path_max_udp_payload_size = ngtcp2_conn_get_path_max_tx_udp_payload_size(_conn); + ngtcp2_pkt_info pi; + size_t gso_size = 0; + auto ts = timestamp(); + auto txbuf = std::span{ _tx.data.get(), std::max(ngtcp2_conn_get_send_quantum(_conn), path_max_udp_payload_size) }; + auto buf = txbuf; + + ngtcp2_path_storage_zero(&ps); + ngtcp2_path_storage_zero(&prev_ps); + + for (;;) { + int64_t stream_id = -1; + int fin = 0; + nghttp3_ssize sveccnt = 0; + + if (_httpconn && ngtcp2_conn_get_max_data_left(_conn)) { + sveccnt = nghttp3_conn_writev_stream(_httpconn, &stream_id, &fin, vec.data(), vec.size()); + if (sveccnt < 0) { + HTTP_DBG("nghttp3_conn_writev_stream: {}", nghttp3_strerror(static_cast(sveccnt))); + ngtcp2_ccerr_set_application_error(&_last_error, nghttp3_err_infer_quic_app_error_code(static_cast(sveccnt)), nullptr, 0); + return handle_error(); + } + } + + ngtcp2_ssize ndatalen; + auto v = vec.data(); + auto vcnt = static_cast(sveccnt); + + uint32_t flags = NGTCP2_WRITE_STREAM_FLAG_MORE; + if (fin) { + flags |= NGTCP2_WRITE_STREAM_FLAG_FIN; + } + + auto buflen = buf.size() >= max_udp_payload_size + ? max_udp_payload_size + : path_max_udp_payload_size; + + auto nwrite = ngtcp2_conn_writev_stream(_conn, &ps.path, &pi, buf.data(), buflen, &ndatalen, flags, stream_id, reinterpret_cast(v), vcnt, ts); + if (nwrite < 0) { + switch (nwrite) { + case NGTCP2_ERR_STREAM_DATA_BLOCKED: + assert(ndatalen == -1); + nghttp3_conn_block_stream(_httpconn, stream_id); + continue; + case NGTCP2_ERR_STREAM_SHUT_WR: + assert(ndatalen == -1); + nghttp3_conn_shutdown_stream_write(_httpconn, stream_id); + continue; + case NGTCP2_ERR_WRITE_MORE: + assert(ndatalen >= 0); + if (auto rv = nghttp3_conn_add_write_offset(_httpconn, stream_id, static_cast(ndatalen)); rv != 0) { + HTTP_DBG("nghttp3_conn_add_write_offset: {}", nghttp3_strerror(rv)); + ngtcp2_ccerr_set_application_error(&_last_error, nghttp3_err_infer_quic_app_error_code(rv), nullptr, 0); + return handle_error(); + } + continue; + } + + assert(ndatalen == -1); + + HTTP_DBG("ngtcp2_conn_writev_stream: {}", ngtcp2_strerror(static_cast(nwrite))); + ngtcp2_ccerr_set_liberr(&_last_error, static_cast(nwrite), nullptr, 0); + return handle_error(); + } else if (ndatalen >= 0) { + if (auto rv = nghttp3_conn_add_write_offset(_httpconn, stream_id, static_cast(ndatalen)); rv != 0) { + HTTP_DBG("nghttp3_conn_add_write_offset: {}", nghttp3_strerror(rv)); + ngtcp2_ccerr_set_application_error(&_last_error, nghttp3_err_infer_quic_app_error_code(rv), nullptr, 0); + return handle_error(); + } + } + + if (nwrite == 0) { + auto data = std::span{ std::begin(txbuf), std::begin(buf) }; + if (!data.empty()) { + auto &ep = *static_cast(prev_ps.path.user_data); + + if (auto [rest, rv] = server_->send_packet(ep, _no_gso, prev_ps.path.local, prev_ps.path.remote, prev_ecn, data, gso_size); rv != NETWORK_ERR_OK) { + assert(NETWORK_ERR_SEND_BLOCKED == rv); + + on_send_blocked(ep, prev_ps.path.local, prev_ps.path.remote, prev_ecn, rest, gso_size); + } + } + + // We are congestion limited. + ngtcp2_conn_update_pkt_tx_time(_conn, ts); + return 0; + } + + auto last_pkt = std::begin(buf); + + buf = buf.subspan(static_cast(nwrite)); + + if (last_pkt == std::begin(txbuf)) { + ngtcp2_path_copy(&prev_ps.path, &ps.path); + prev_ecn = pi.ecn; + gso_size = static_cast(nwrite); + } else if (!ngtcp2_path_eq(&prev_ps.path, &ps.path) || prev_ecn != pi.ecn || static_cast(nwrite) > gso_size || (gso_size > path_max_udp_payload_size && static_cast(nwrite) != gso_size)) { + auto &ep = *static_cast(prev_ps.path.user_data); + auto data = std::span{ std::begin(txbuf), last_pkt }; + + if (auto [rest, rv] = server_->send_packet(ep, _no_gso, prev_ps.path.local, prev_ps.path.remote, prev_ecn, data, gso_size); rv != 0) { + assert(NETWORK_ERR_SEND_BLOCKED == rv); + + on_send_blocked(ep, prev_ps.path.local, prev_ps.path.remote, prev_ecn, rest, gso_size); + + data = std::span{ last_pkt, std::begin(buf) }; + on_send_blocked(*static_cast(ps.path.user_data), ps.path.local, ps.path.remote, pi.ecn, data, data.size()); + } + + ngtcp2_conn_update_pkt_tx_time(_conn, ts); + return 0; + } + + if (buf.size() < path_max_udp_payload_size || static_cast(nwrite) < gso_size) { + auto &ep = *static_cast(ps.path.user_data); + auto data = std::span{ std::begin(txbuf), std::begin(buf) }; + + if (auto [rest, rv] = server_->send_packet(ep, _no_gso, ps.path.local, ps.path.remote, pi.ecn, data, gso_size); rv != 0) { + assert(NETWORK_ERR_SEND_BLOCKED == rv); + + on_send_blocked(ep, ps.path.local, ps.path.remote, pi.ecn, rest, gso_size); + } + + ngtcp2_conn_update_pkt_tx_time(_conn, ts); + return 0; + } + } + } + + int recv_stream_data(uint32_t flags, int64_t stream_id, std::span data) { + HTTP_DBG("Server::QUIC::recv_stream_data: stream_id={} datalen={}", stream_id, data.size()); + + auto nconsumed = nghttp3_conn_read_stream(_httpconn, stream_id, data.data(), data.size(), flags & NGTCP2_STREAM_DATA_FLAG_FIN); + if (nconsumed < 0) { + HTTP_DBG("nghttp3_conn_read_stream: {}", nghttp3_strerror(static_cast(nconsumed))); + ngtcp2_ccerr_set_application_error(&_last_error, nghttp3_err_infer_quic_app_error_code(static_cast(nconsumed)), nullptr, 0); + return -1; + } + + ngtcp2_conn_extend_max_stream_offset(_conn, stream_id, static_cast(nconsumed)); + ngtcp2_conn_extend_max_offset(_conn, static_cast(nconsumed)); + + return 0; + } + + int send_conn_close() { + HTTP_DBG("Server::QUIC::send_conn_close"); + assert(_conn_closebuf && _conn_closebuf->size()); + assert(_conn); + assert(!ngtcp2_conn_in_draining_period(_conn)); + + auto path = ngtcp2_conn_get_path(_conn); + + return server_->send_packet(*static_cast(path->user_data), path->local, path->remote, /* ecn = */ 0, _conn_closebuf->data()); + } + + int update_key(uint8_t *rx_secret, uint8_t *tx_secret, ngtcp2_crypto_aead_ctx *rx_aead_ctx, uint8_t *rx_iv, ngtcp2_crypto_aead_ctx *tx_aead_ctx, uint8_t *tx_iv, const uint8_t *current_rx_secret, const uint8_t *current_tx_secret, size_t secretlen) { + std::array rx_key, tx_key; + if (ngtcp2_crypto_update_key(_conn, rx_secret, tx_secret, rx_aead_ctx, rx_key.data(), rx_iv, tx_aead_ctx, tx_key.data(), tx_iv, current_rx_secret, current_tx_secret, secretlen) + != 0) { + return -1; + } + return 0; + } + + int send_conn_close(const Endpoint &ep, const Address &local_addr, const sockaddr *sa, socklen_t salen, const ngtcp2_pkt_info * /*pi*/, std::span data) { + assert(_conn_closebuf && _conn_closebuf->size()); + + _close_wait.bytes_recv += data.size(); + ++_close_wait.num_pkts_recv; + + if (_close_wait.num_pkts_recv < _close_wait.next_pkts_recv || _close_wait.bytes_recv * 3 < _close_wait.bytes_sent + _conn_closebuf->size()) { + return 0; + } + + auto path = ngtcp2_path{ + .local = { + .addr = const_cast(&local_addr.su.sa), + .addrlen = local_addr.len, + }, + .remote = { + .addr = const_cast(sa), + .addrlen = salen, + }, + .user_data = const_cast(&ep), + }; + + auto rv = server_->send_packet(ep, path.local, path.remote, + /* ecn = */ 0, _conn_closebuf->data()); + if (rv != 0) { + return rv; + } + + _close_wait.bytes_sent += _conn_closebuf->size(); + _close_wait.next_pkts_recv *= 2; + + return 0; + } + + int start_closing_period() { + if (!_conn || ngtcp2_conn_in_closing_period(_conn) || ngtcp2_conn_in_draining_period(_conn)) { + return 0; + } + + _state = State::Closing; + + _conn_closebuf = std::make_unique(NGTCP2_MAX_UDP_PAYLOAD_SIZE); + + ngtcp2_path_storage ps; + + ngtcp2_path_storage_zero(&ps); + + ngtcp2_pkt_info pi; + auto n = ngtcp2_conn_write_connection_close( + _conn, &ps.path, &pi, _conn_closebuf->wpos(), _conn_closebuf->left(), + &_last_error, timestamp()); + if (n < 0) { + HTTP_DBG("ngtcp2_conn_write_connection_close: {}", ngtcp2_strerror(static_cast(n))); + return -1; + } + + if (n == 0) { + return 0; + } + + _conn_closebuf->push(static_cast(n)); + return 0; + } + + int handle_error() { + if (_last_error.type == NGTCP2_CCERR_TYPE_IDLE_CLOSE) { + return -1; + } + + if (start_closing_period() != 0) { + return -1; + } + + if (ngtcp2_conn_in_draining_period(_conn)) { + return NETWORK_ERR_CLOSE_WAIT; + } + + if (auto rv = send_conn_close(); rv != NETWORK_ERR_OK) { + return rv; + } + + return NETWORK_ERR_CLOSE_WAIT; + } +}; + +struct Http3ServerSocket { + int fd = -1; + Address addr; + + Http3ServerSocket() = default; + + Http3ServerSocket(int fd_, Address addr_) + : fd(fd_), addr(std::move(addr_)) { + } + + Http3ServerSocket(const Http3ServerSocket &) = delete; + Http3ServerSocket &operator=(const Http3ServerSocket &) = delete; + + Http3ServerSocket(Http3ServerSocket &&other) noexcept + : fd(std::exchange(other.fd, -1)), addr(other.addr) {} + + Http3ServerSocket &operator=(Http3ServerSocket &&other) noexcept { + if (this != &other) { + close(); + std::swap(fd, other.fd); + addr = other.addr; + } + return *this; + } + + void close() { + if (fd != -1) { + ::close(fd); + fd = -1; + } + } + + static std::expected create(uint16_t port) { + Address address; + auto maybeFd = create_sock(address, "*", std::to_string(port).c_str(), AF_INET); + if (!maybeFd) { + return std::unexpected(std::format("Failed to create HTTP/3 server socket: {}", maybeFd.error())); + } + return Http3ServerSocket{ maybeFd.value(), address }; + } + + ~Http3ServerSocket() { + close(); + } +}; + +inline std::expected createTcpServerSocket(SSL_CTX *ssl_ctx, uint16_t port) { + auto ssl = SSL_Ptr(nullptr, SSL_free); + if (ssl_ctx) { + auto maybeSsl = create_ssl(ssl_ctx); + if (!maybeSsl) { + return std::unexpected(std::format("Failed to set up TCP server socket: {}", maybeSsl.error())); + } + ssl = std::move(maybeSsl.value()); + } + + auto serverSocket = TcpSocket::create(std::move(ssl), socket(AF_INET, SOCK_STREAM, 0)); + if (!serverSocket) { + return std::unexpected(serverSocket.error()); + } + + int reuseFlag = 1; + if (setsockopt(serverSocket->fd, SOL_SOCKET, SO_REUSEADDR, &reuseFlag, sizeof(reuseFlag)) < 0) { + return std::unexpected(std::format("setsockopt(SO_REUSEADDR) failed: {}", strerror(errno))); + } + + sockaddr_in address; + memset(&address, 0, sizeof(address)); + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; + address.sin_port = htons(port); + memset(address.sin_zero, 0, sizeof(address.sin_zero)); + if (::bind(serverSocket->fd, reinterpret_cast(&address), sizeof(address)) < 0) { + return std::unexpected(std::format("Bind failed: {}", strerror(errno))); + } + + if (listen(serverSocket->fd, 32) < 0) { + return std::unexpected(std::format("Listen failed: {}", strerror(errno))); + } + + return serverSocket; +} + +struct RestServer { + TcpSocket _tcpServerSocket; + Http3ServerSocket _quicServerSocket; + SSL_CTX_Ptr _sslCtxTcp = SSL_CTX_Ptr(nullptr, SSL_CTX_free); + EVP_PKEY_Ptr _key = EVP_PKEY_Ptr(nullptr, EVP_PKEY_free); + X509_Ptr _cert = X509_Ptr(nullptr, X509_free); + std::shared_ptr _sharedData = std::make_shared(); + std::map> _h2Sessions; + std::unordered_map>> _h3Sessions; + std::vector _touchedSessions; + std::vector _sessionsToRemove; + IdGenerator _requestIdGenerator; + + Endpoint _endpoint; + TLSServerContext _tls_ctx; + std::mt19937 _randgen; + + std::size_t _stateless_reset_bucket = NGTCP2_STATELESS_RESET_BURST; + + RestServer() = default; + RestServer(const RestServer &) = delete; + RestServer &operator=(const RestServer &) = delete; + RestServer(RestServer &&) = default; + RestServer &operator=(RestServer &&) = default; + + RestServer(SSL_CTX_Ptr sslCtxTcp, EVP_PKEY_Ptr key, X509_Ptr cert) + : _sslCtxTcp(std::move(sslCtxTcp)), _key(std::move(key)), _cert(std::move(cert)) { + if (_sslCtxTcp) { + SSL_library_init(); + SSL_load_error_strings(); + OpenSSL_add_all_algorithms(); + } + } + + static std::expected unencrypted() { + return RestServer(SSL_CTX_Ptr(nullptr, SSL_CTX_free), EVP_PKEY_Ptr(nullptr, EVP_PKEY_free), X509_Ptr(nullptr, X509_free)); + } + + static std::expected sslWithBuffers(std::string_view certBuffer, std::string_view keyBuffer) { + auto maybeCert = readServerCertificateFromBuffer(certBuffer); + if (!maybeCert) { + return std::unexpected(maybeCert.error()); + } + auto maybeKey = readServerPrivateKeyFromBuffer(keyBuffer); + if (!maybeKey) { + return std::unexpected(maybeKey.error()); + } + auto maybeSslCtxTcp = create_ssl_ctx(maybeKey->get(), maybeCert->get()); + if (!maybeSslCtxTcp) { + return std::unexpected(maybeSslCtxTcp.error()); + } + return RestServer(std::move(maybeSslCtxTcp.value()), std::move(maybeKey.value()), std::move(maybeCert.value())); + } + + static std::expected sslWithPaths(std::filesystem::path certPath, std::filesystem::path keyPath) { + auto maybeCert = readServerCertificateFromFile(certPath); + if (!maybeCert) { + return std::unexpected(maybeCert.error()); + } + auto maybeKey = readServerPrivateKeyFromFile(keyPath); + if (!maybeKey) { + return std::unexpected(maybeKey.error()); + } + auto maybeSslCtxTcp = create_ssl_ctx(maybeKey->get(), maybeCert->get()); + if (!maybeSslCtxTcp) { + return std::unexpected(maybeSslCtxTcp.error()); + } + + return RestServer(std::move(maybeSslCtxTcp.value()), std::move(maybeKey.value()), std::move(maybeCert.value())); + } + + void setHandlers(std::vector handlers) { + _sharedData->_handlers = std::move(handlers); + } + + void handleResponse(Message &&message) { + auto view = message.clientRequestID.asString(); + std::uint64_t id; + const auto ec = std::from_chars(view.begin(), view.end(), id); + if (ec.ec != std::errc{}) { + HTTP_DBG("Failed to parse request ID: '{}'", view); + return; + } + + auto handle = [&message, id](auto &sessions) { + auto matchesId = [id](const auto &pendingRequest) { return std::get<0>(pendingRequest) == id; }; + + auto it = std::ranges::find_if(sessions, [matchesId](const auto &session) { + return std::ranges::find_if(session.second->_pendingRequests, matchesId) != session.second->_pendingRequests.end(); + }); + + if (it != sessions.end()) { + auto &session = it->second; + auto pendingIt = std::ranges::find_if(session->_pendingRequests, matchesId); + const auto &[reqId, streamId] = *pendingIt; + const auto code = message.error.empty() ? kHttpOk : kHttpError; + session->sendResponse(streamId, code, std::move(message)); + session->_pendingRequests.erase(pendingIt); + return true; + }; + return false; + }; + + if (!handle(_h2Sessions)) { + handle(_h3Sessions); + } + } + + void handleNotification(const mdp::Topic &topic, Message &&message) { + const auto zmqTopic = topic.toZmqTopic(); + auto entryIt = _sharedData->_subscriptionCache.find(zmqTopic); + if (entryIt == _sharedData->_subscriptionCache.end()) { + HTTP_DBG("Server::handleNotification: No subscription for topic '{}'", zmqTopic); + return; + } + auto &entry = entryIt->second; + entry.add(std::move(message)); + const auto index = entry.lastIndex(); + + for (auto &session : _h2Sessions | std::views::values) { + session->handleNotification(zmqTopic, index, entry.messages.back()); + } + for (auto &session : _h3Sessions | std::views::values) { + session->handleNotification(zmqTopic, index, entry.messages.back()); + } + } + + void populatePollerItems(std::vector &items) { + if (_tcpServerSocket.fd != -1) { + items.push_back(zmq_pollitem_t{ nullptr, _tcpServerSocket.fd, ZMQ_POLLIN, 0 }); + } + if (_quicServerSocket.fd != -1) { + items.push_back(zmq_pollitem_t{ nullptr, _quicServerSocket.fd, ZMQ_POLLIN | ZMQ_POLLOUT, 0 }); + } + for (const auto &[_, session] : _h2Sessions) { + const auto wantsRead = session->wantsToRead(); + const auto wantsWrite = session->wantsToWrite(); + if (wantsRead || wantsWrite) { + items.emplace_back(nullptr, session->_socket.fd, static_cast((wantsRead ? ZMQ_POLLIN : 0) | (wantsWrite ? ZMQ_POLLOUT : 0)), 0); + } + } + } + + std::vector processReadWrite(int fd, bool read, bool write) { + if (fd == _tcpServerSocket.fd) { + auto maybeSocket = _tcpServerSocket.accept(_sslCtxTcp.get(), TcpSocket::None); + if (!maybeSocket) { + HTTP_DBG("Failed to accept client: {}", maybeSocket.error()); + return {}; + } + + auto clientSocket = std::move(maybeSocket.value()); + if (!clientSocket) { + return {}; + } + + auto newFd = clientSocket->fd; + + auto [newSessionIt, inserted] = _h2Sessions.try_emplace(newFd, std::make_unique(std::move(clientSocket.value()), _sharedData)); + assert(inserted); + auto &newSession = newSessionIt->second; + nghttp2_settings_entry iv[1] = { { NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 3000 } }; + if (nghttp2_submit_settings(newSession->_session, NGHTTP2_FLAG_NONE, iv, 1) != 0) { + HTTP_DBG("nghttp2_submit_settings failed"); + } + return {}; + } + + if (fd == _quicServerSocket.fd) { + if (read) { + on_read(); + } + if (write) { + for (auto &session : _h3Sessions | std::views::values) { + session->write_handler(); + } + std::erase_if(_h3Sessions, [this](const auto &pair) { + return std::ranges::contains(_sessionsToRemove, pair.first); + }); + _sessionsToRemove.clear(); + } + std::vector messages; + for (const auto &dcid : _touchedSessions) { + auto sessionIt = _h3Sessions.find(dcid); + if (sessionIt == _h3Sessions.end()) { + continue; + } + auto ms = sessionIt->second->getMessages(_requestIdGenerator); + messages.insert(messages.end(), std::make_move_iterator(ms.begin()), std::make_move_iterator(ms.end())); + } + _touchedSessions.clear(); + return messages; + } + + auto sessionIt = _h2Sessions.find(fd); + assert(sessionIt != _h2Sessions.end()); + assert(sessionIt->second->_socket.fd == fd); + auto &session = sessionIt->second; + + if (session->_socket._state != TcpSocket::Connected) { + if (auto r = session->_socket.continueHandshake(); !r) { + HTTP_DBG("Handshake failed: {}", r.error()); + _h2Sessions.erase(sessionIt); + return {}; + } + return {}; + } + + if (write) { + if (!session->_writeBuffer.write(session->_session, session->_socket)) { + HTTP_DBG("Failed to write to peer"); + _h2Sessions.erase(sessionIt); + return {}; + } + } + + if (!read) { + return {}; + } + + bool mightHaveMore = true; + + while (mightHaveMore) { + std::array buffer; + ssize_t bytes_read = session->_socket.read(buffer.data(), sizeof(buffer)); + if (bytes_read <= 0 && errno != EAGAIN) { + if (bytes_read < 0) { + HTTP_DBG("Server::read failed: {} {}", bytes_read, session->_socket.lastError()); + } + _h2Sessions.erase(sessionIt); + return {}; + } + if (nghttp2_session_mem_recv2(session->_session, buffer.data(), static_cast(bytes_read)) < 0) { + HTTP_DBG("Server: nghttp2_session_mem_recv2 failed"); + _h2Sessions.erase(sessionIt); + return {}; + } + mightHaveMore = bytes_read == static_cast(buffer.size()); + } + + return session->getMessages(_requestIdGenerator); + } + + std::expected + bind(std::uint16_t port, int protocols) { + using enum majordomo::rest::Protocol; + if ((protocols & Http2) == 0 && (protocols & Http3) == 0) { + return std::unexpected("At least one protocol must be enabled (HTTP/2 or HTTP/3)"); + } + + if (_tcpServerSocket.fd != -1) { + return std::unexpected("Server already bound"); + } + + if ((protocols & Http2) != 0) { + auto tcpSocket = createTcpServerSocket(_sslCtxTcp.get(), port); + if (!tcpSocket) { + return std::unexpected(tcpSocket.error()); + } + _tcpServerSocket = std::move(tcpSocket.value()); + } + + if ((protocols & Http3) == 0) { + return {}; + } + + if (!_sslCtxTcp) { + return std::unexpected("HTTP/3 requires TLS"); + } + + auto quicSocket = Http3ServerSocket::create(port); + if (!quicSocket) { + return std::unexpected(quicSocket.error()); + } + if (_tls_ctx.init(_key, _cert, AppProtocol::H3) != 0) { + return std::unexpected("Failed to initialize TLS context for HTTP/3"); + } + _sharedData->_altSvcHeaderValue = std::format("h3=\":{}\"; ma=86400", port); + _sharedData->_altSvcHeader = nv(u8span("alt-svc"), u8span(_sharedData->_altSvcHeaderValue), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE); + _quicServerSocket = std::move(quicSocket.value()); + _endpoint.fd = _quicServerSocket.fd; + return {}; + } + + inline int send_version_negotiation(uint32_t version, std::span dcid, std::span scid, Endpoint &ep, const Address &local_addr, const sockaddr *sa, socklen_t salen) { + Buffer buf{ NGTCP2_MAX_UDP_PAYLOAD_SIZE }; + std::array sv; + + auto p = std::begin(sv); + + *p++ = generate_reserved_version(sa, salen, version); + + *p++ = NGTCP2_PROTO_VER_V1; + *p++ = NGTCP2_PROTO_VER_V2; + + auto nwrite = ngtcp2_pkt_write_version_negotiation(buf.wpos(), buf.left(), std::uniform_int_distribution()(_randgen), dcid.data(), dcid.size(), scid.data(), scid.size(), sv.data(), static_cast(p - std::begin(sv))); + if (nwrite < 0) { + HTTP_DBG("ngtcp2_pkt_write_version_negotiation: {}", ngtcp2_strerror(static_cast(nwrite))); + return -1; + } + + buf.push(static_cast(nwrite)); + + ngtcp2_addr laddr{ + .addr = const_cast(&local_addr.su.sa), + .addrlen = local_addr.len, + }; + ngtcp2_addr raddr{ + .addr = const_cast(sa), + .addrlen = salen, + }; + + if (send_packet(ep, laddr, raddr, /* ecn = */ 0, buf.data()) != NETWORK_ERR_OK) { + return -1; + } + + return 0; + } + + int on_read() { + sockaddr_union su; + std::array buf; + size_t pktcnt = 0; + ngtcp2_pkt_info pi; + + iovec msg_iov{ + .iov_base = buf.data(), + .iov_len = buf.size(), + }; + + uint8_t msg_ctrl[CMSG_SPACE(sizeof(int)) + CMSG_SPACE(sizeof(in6_pktinfo)) + CMSG_SPACE(sizeof(int))]; + + msghdr msg{ + .msg_name = &su, + .msg_namelen = sizeof(su), + .msg_iov = &msg_iov, + .msg_iovlen = 1, + .msg_control = msg_ctrl, + .msg_controllen = sizeof(msg_ctrl), + .msg_flags = 0, + }; + + for (; pktcnt < 10;) { + msg.msg_namelen = sizeof(su); + msg.msg_controllen = sizeof(msg_ctrl); + + auto nread = ::recvmsg(_quicServerSocket.fd, &msg, 0); + if (nread == -1) { + if (!(errno == EAGAIN || errno == ENOTCONN)) { + HTTP_DBG("recvmsg: {}", strerror(errno)); + } + return 0; + } + + // Packets less than 21 bytes never be a valid QUIC packet. + if (nread < 21) { + ++pktcnt; + + continue; + } + + pi.ecn = static_cast(msghdr_get_ecn(&msg, su.storage.ss_family)); + auto local_addr = msghdr_get_local_addr(&msg, su.storage.ss_family); + if (!local_addr) { + ++pktcnt; + HTTP_DBG("Unable to obtain local address"); + continue; + } + + auto gso_size = msghdr_get_udp_gro(&msg); + if (gso_size == 0) { + gso_size = static_cast(nread); + } + + set_port(*local_addr, _quicServerSocket.addr); + + auto data = std::span{ buf.data(), static_cast(nread) }; + + for (; !data.empty();) { + auto datalen = std::min(data.size(), gso_size); + + ++pktcnt; + + // Packets less than 21 bytes never be a valid QUIC packet. + if (datalen < 21) { + break; + } + + // Endpoint, Address kept for upstream with example code + + _endpoint.addr = *local_addr; + _endpoint.fd = _quicServerSocket.fd; + + auto dcid = read_pkt(_endpoint, _endpoint.addr, &su.sa, msg.msg_namelen, &pi, { data.data(), datalen }); + if (dcid) { + _touchedSessions.push_back(*dcid); + } + + data = data.subspan(datalen); + } + } + + return 0; + } + + int send_stateless_connection_close(const ngtcp2_pkt_hd *chd, Endpoint &ep, const Address &local_addr, const sockaddr *sa, socklen_t salen) { + HTTP_DBG("Server::QUIC::send_stateless_connection_close"); + Buffer buf{ NGTCP2_MAX_UDP_PAYLOAD_SIZE }; + + auto nwrite = ngtcp2_crypto_write_connection_close(buf.wpos(), buf.left(), chd->version, &chd->scid, &chd->dcid, NGTCP2_INVALID_TOKEN, nullptr, 0); + if (nwrite < 0) { + HTTP_DBG("ngtcp2_crypto_write_connection_close: {}", ngtcp2_strerror(static_cast(nwrite))); + return -1; + } + + buf.push(static_cast(nwrite)); + + ngtcp2_addr laddr{ + .addr = const_cast(&local_addr.su.sa), + .addrlen = local_addr.len, + }; + ngtcp2_addr raddr{ + .addr = const_cast(sa), + .addrlen = salen, + }; + + if (send_packet(ep, laddr, raddr, /* ecn = */ 0, buf.data()) != NETWORK_ERR_OK) { + return -1; + } + + return 0; + } + + int send_stateless_reset(size_t pktlen, std::span dcid, Endpoint &ep, const Address &local_addr, const sockaddr *sa, socklen_t salen) { + if (_stateless_reset_bucket == 0) { + return 0; + } + + --_stateless_reset_bucket; + + ngtcp2_cid cid; + + ngtcp2_cid_init(&cid, dcid.data(), dcid.size()); + + std::array token; + + if (ngtcp2_crypto_generate_stateless_reset_token(token.data(), _sharedData->_static_secret.data(), _sharedData->_static_secret.size(), &cid) != 0) { + return -1; + } + + // SCID + minimum expansion - NGTCP2_STATELESS_RESET_TOKENLEN + constexpr size_t max_rand_byteslen = NGTCP2_MAX_CIDLEN + 22 - NGTCP2_STATELESS_RESET_TOKENLEN; + + size_t rand_byteslen; + + if (pktlen <= 43) { + // As per https://datatracker.ietf.org/doc/html/rfc9000#section-10.3 + rand_byteslen = pktlen - NGTCP2_STATELESS_RESET_TOKENLEN - 1; + } else { + rand_byteslen = max_rand_byteslen; + } + + std::array rand_bytes; + + if (generate_secure_random({ rand_bytes.data(), rand_byteslen }) != 0) { + return -1; + } + + Buffer buf{ NGTCP2_MAX_UDP_PAYLOAD_SIZE }; + + auto nwrite = ngtcp2_pkt_write_stateless_reset(buf.wpos(), buf.left(), token.data(), rand_bytes.data(), rand_byteslen); + if (nwrite < 0) { + HTTP_DBG("ngtcp2_pkt_write_stateless_reset: {}", ngtcp2_strerror(static_cast(nwrite))); + return -1; + } + + buf.push(static_cast(nwrite)); + + ngtcp2_addr laddr{ + .addr = const_cast(&local_addr.su.sa), + .addrlen = local_addr.len, + }; + ngtcp2_addr raddr{ + .addr = const_cast(sa), + .addrlen = salen, + }; + + if (send_packet(ep, laddr, raddr, /* ecn = */ 0, buf.data()) != NETWORK_ERR_OK) { + return -1; + } + + return 0; + } + + int verify_retry_token(ngtcp2_cid *ocid, const ngtcp2_pkt_hd *hd, const sockaddr *sa, socklen_t salen) { + int rv; + + HTTP_DBG("Received Retry token from [{}]:{}", straddr(sa, salen), straddr(sa, salen)); + + auto t = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); + + rv = ngtcp2_crypto_verify_retry_token2(ocid, hd->token, hd->tokenlen, _sharedData->_static_secret.data(), _sharedData->_static_secret.size(), hd->version, sa, salen, &hd->dcid, 10 * NGTCP2_SECONDS, static_cast(t)); + switch (rv) { + case 0: + break; + case NGTCP2_CRYPTO_ERR_VERIFY_TOKEN: + HTTP_DBG("Could not verify Retry token"); + return -1; + default: + HTTP_DBG("Could not verify Retry token: {}. Continue without the token", ngtcp2_strerror(rv)); + return 1; + } + + HTTP_DBG("Token was successfully validated"); + return 0; + } + + int verify_token(const ngtcp2_pkt_hd *hd, const sockaddr *sa, socklen_t salen) { + std::array host; + std::array port; + + if (auto rv = getnameinfo(sa, salen, host.data(), host.size(), port.data(), port.size(), NI_NUMERICHOST | NI_NUMERICSERV); rv != 0) { + HTTP_DBG("getnameinfo: {}", gai_strerror(rv)); + return -1; + } + + HTTP_DBG("Received token from [{}]:{}", host.data(), port.data()); + + auto t = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + + if (ngtcp2_crypto_verify_regular_token(hd->token, hd->tokenlen, _sharedData->_static_secret.data(), _sharedData->_static_secret.size(), sa, salen, 3600 * NGTCP2_SECONDS, static_cast(t)) != 0) { + HTTP_DBG("Could not verify token"); + return -1; + } + + HTTP_DBG("Token was successfully validated"); + return 0; + } + + int send_packet(const Endpoint &ep, const ngtcp2_addr &local_addr, const ngtcp2_addr &remote_addr, unsigned int ecn, std::span data) { + auto no_gso = false; + auto [_, rv] = send_packet(ep, no_gso, local_addr, remote_addr, ecn, data, data.size()); + return rv; + } + + std::pair, int> send_packet(const Endpoint &ep, bool &no_gso, const ngtcp2_addr &local_addr, const ngtcp2_addr &remote_addr, unsigned int ecn, std::span data, size_t gso_size) { + assert(gso_size); + + if (no_gso && data.size() > gso_size) { + for (; !data.empty();) { + auto len = std::min(gso_size, data.size()); + + auto [_, rv] = send_packet(ep, no_gso, local_addr, remote_addr, ecn, { std::begin(data), len }, len); + if (rv != 0) { + return { data, rv }; + } + + data = data.subspan(len); + } + + return { {}, 0 }; + } + + iovec msg_iov{ + .iov_base = const_cast(data.data()), + .iov_len = data.size(), + }; + + uint8_t msg_ctrl[CMSG_SPACE(sizeof(int)) + CMSG_SPACE(sizeof(uint16_t)) + CMSG_SPACE(sizeof(in6_pktinfo))]{}; + + msghdr msg{ + .msg_name = const_cast(remote_addr.addr), + .msg_namelen = remote_addr.addrlen, + .msg_iov = &msg_iov, + .msg_iovlen = 1, + .msg_control = msg_ctrl, + .msg_controllen = sizeof(msg_ctrl), + .msg_flags = 0, + }; + + size_t controllen = 0; + + auto cm = CMSG_FIRSTHDR(&msg); + + switch (local_addr.addr->sa_family) { + case AF_INET: { + controllen += CMSG_SPACE(sizeof(in_pktinfo)); + cm->cmsg_level = IPPROTO_IP; + cm->cmsg_type = IP_PKTINFO; + cm->cmsg_len = CMSG_LEN(sizeof(in_pktinfo)); + auto addrin = reinterpret_cast(local_addr.addr); + in_pktinfo pktinfo; + pktinfo.ipi_spec_dst = addrin->sin_addr, + memcpy(CMSG_DATA(cm), &pktinfo, sizeof(pktinfo)); + + break; + } + case AF_INET6: { + controllen += CMSG_SPACE(sizeof(in6_pktinfo)); + cm->cmsg_level = IPPROTO_IPV6; + cm->cmsg_type = IPV6_PKTINFO; + cm->cmsg_len = CMSG_LEN(sizeof(in6_pktinfo)); + auto addrin = reinterpret_cast(local_addr.addr); + in6_pktinfo pktinfo; + pktinfo.ipi6_addr = addrin->sin6_addr; + memcpy(CMSG_DATA(cm), &pktinfo, sizeof(pktinfo)); + + break; + } + default: + assert(0); + } + + if (data.size() > gso_size) { + controllen += CMSG_SPACE(sizeof(uint16_t)); + cm = CMSG_NXTHDR(&msg, cm); + cm->cmsg_level = SOL_UDP; + cm->cmsg_type = UDP_SEGMENT; + cm->cmsg_len = CMSG_LEN(sizeof(uint16_t)); + uint16_t n = static_cast(gso_size); + memcpy(CMSG_DATA(cm), &n, sizeof(n)); + } + + controllen += CMSG_SPACE(sizeof(int)); + cm = CMSG_NXTHDR(&msg, cm); + cm->cmsg_len = CMSG_LEN(sizeof(int)); + memcpy(CMSG_DATA(cm), &ecn, sizeof(ecn)); + + switch (local_addr.addr->sa_family) { + case AF_INET: + cm->cmsg_level = IPPROTO_IP; + cm->cmsg_type = IP_TOS; + + break; + case AF_INET6: + cm->cmsg_level = IPPROTO_IPV6; + cm->cmsg_type = IPV6_TCLASS; + + break; + default: + assert(0); + } + + msg.msg_controllen = controllen; + + ssize_t nwrite = 0; + + do { + nwrite = sendmsg(ep.fd, &msg, 0); + } while (nwrite == -1 && errno == EINTR); + + if (nwrite == -1) { + switch (errno) { + case EAGAIN: +#if EAGAIN != EWOULDBLOCK + case EWOULDBLOCK: +#endif // EAGAIN != EWOULDBLOCK + return { data, NETWORK_ERR_SEND_BLOCKED }; +#ifdef UDP_SEGMENT + case EIO: + if (data.size() > gso_size) { + // GSO failure; send each packet in a separate sendmsg call. + HTTP_DBG("sendmsg: disabling GSO due to {}", strerror(errno)); + no_gso = true; + return send_packet(ep, no_gso, local_addr, remote_addr, ecn, data, gso_size); + } + break; +#endif // defined(UDP_SEGMENT) + } + + HTTP_DBG("sendmsg on fd {}: {} ({})", ep.fd, strerror(errno), errno); + // TODO We have packet which is expected to fail to send (e.g., + // path validation to old path). + return { {}, NETWORK_ERR_OK }; + } + + return { {}, NETWORK_ERR_OK }; + } + + int send_retry(const ngtcp2_pkt_hd *chd, Endpoint &ep, const Address &local_addr, const sockaddr *sa, socklen_t salen, size_t max_pktlen) { + std::array host; + std::array port; + + if (auto rv = getnameinfo(sa, salen, host.data(), host.size(), port.data(), port.size(), NI_NUMERICHOST | NI_NUMERICSERV); rv != 0) { + HTTP_DBG("getnameinfo: {}", gai_strerror(rv)); + return -1; + } + + HTTP_DBG("Server::send_retry: host={} port={}", host.data(), port.data()); + + ngtcp2_cid scid; + + scid.datalen = NGTCP2_SV_SCIDLEN; + if (generate_secure_random({ scid.data, scid.datalen }) != 0) { + return -1; + } + + std::array token; + + auto t = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); + + auto tokenlen = ngtcp2_crypto_generate_retry_token2(token.data(), _sharedData->_static_secret.data(), _sharedData->_static_secret.size(), chd->version, sa, salen, &scid, &chd->dcid, static_cast(t)); + if (tokenlen < 0) { + return -1; + } + + HTTP_DBG("Generated address validation token"); + + Buffer buf{ + std::min(static_cast(NGTCP2_MAX_UDP_PAYLOAD_SIZE), max_pktlen) + }; + + auto nwrite = ngtcp2_crypto_write_retry(buf.wpos(), buf.left(), chd->version, &chd->scid, &scid, &chd->dcid, token.data(), static_cast(tokenlen)); + if (nwrite < 0) { + HTTP_DBG("ngtcp2_crypto_write_retry: {}", ngtcp2_strerror(static_cast(nwrite))); + return -1; + } + + buf.push(static_cast(nwrite)); + + ngtcp2_addr laddr{ + .addr = const_cast(&local_addr.su.sa), + .addrlen = local_addr.len, + }; + ngtcp2_addr raddr{ + .addr = const_cast(sa), + .addrlen = salen, + }; + + if (send_packet(ep, laddr, raddr, /* ecn = */ 0, buf.data()) != NETWORK_ERR_OK) { + return -1; + } + + return 0; + } + + std::optional read_pkt(Endpoint &ep, const Address &local_addr, const sockaddr *sa, socklen_t salen, const ngtcp2_pkt_info *pi, std::span data) { + ngtcp2_version_cid vc; + + switch (auto rv = ngtcp2_pkt_decode_version_cid(&vc, data.data(), data.size(), NGTCP2_SV_SCIDLEN); rv) { + case 0: + break; + case NGTCP2_ERR_VERSION_NEGOTIATION: + send_version_negotiation(vc.version, { vc.scid, vc.scidlen }, { vc.dcid, vc.dcidlen }, ep, local_addr, sa, salen); + return std::nullopt; + default: + HTTP_DBG("Could not decode version and CID from QUIC packet header: {}", ngtcp2_strerror(rv)); + return std::nullopt; + } + + auto dcid_key = make_cid_key({ vc.dcid, vc.dcidlen }); + + auto handler_it = _h3Sessions.find(dcid_key); + if (handler_it == std::end(_h3Sessions)) { + ngtcp2_pkt_hd hd; + + if (auto rv = ngtcp2_accept(&hd, data.data(), data.size()); rv != 0) { + HTTP_DBG("Unexpected packet received: length={}", data.size()); + + if (!(data[0] & 0x80) && data.size() >= NGTCP2_SV_SCIDLEN + 21) { + send_stateless_reset(data.size(), { vc.dcid, vc.dcidlen }, ep, local_addr, sa, salen); + } + + return std::nullopt; + } + + ngtcp2_cid ocid; + ngtcp2_cid *pocid = nullptr; + ngtcp2_token_type token_type = NGTCP2_TOKEN_TYPE_UNKNOWN; + + assert(hd.type == NGTCP2_PKT_INITIAL); + + if (hd.tokenlen) { + HTTP_DBG("Perform stateless address validation"); + if (hd.tokenlen == 0) { + send_retry(&hd, ep, local_addr, sa, salen, data.size() * 3); + return std::nullopt; + } + + if (hd.token[0] != NGTCP2_CRYPTO_TOKEN_MAGIC_RETRY2 && hd.dcid.datalen < NGTCP2_MIN_INITIAL_DCIDLEN) { + send_stateless_connection_close(&hd, ep, local_addr, sa, salen); + return std::nullopt; + } + + switch (hd.token[0]) { + case NGTCP2_CRYPTO_TOKEN_MAGIC_RETRY2: + switch (verify_retry_token(&ocid, &hd, sa, salen)) { + case 0: + pocid = &ocid; + token_type = NGTCP2_TOKEN_TYPE_RETRY; + break; + case -1: + send_stateless_connection_close(&hd, ep, local_addr, sa, salen); + return std::nullopt; + case 1: + hd.token = nullptr; + hd.tokenlen = 0; + break; + } + + break; + case NGTCP2_CRYPTO_TOKEN_MAGIC_REGULAR: + if (verify_token(&hd, sa, salen) != 0) { + hd.token = nullptr; + hd.tokenlen = 0; + } else { + token_type = NGTCP2_TOKEN_TYPE_NEW_TOKEN; + } + break; + default: + HTTP_DBG("Ignore unrecognized token"); + hd.token = nullptr; + hd.tokenlen = 0; + break; + } + } + + auto h = std::make_shared>(this, _sharedData); + if (h->init(ep, local_addr, sa, salen, &hd.scid, &hd.dcid, pocid, { hd.token, hd.tokenlen }, token_type, hd.version, _tls_ctx) != 0) { + return std::nullopt; + } + + switch (h->on_read(ep, local_addr, sa, salen, pi, data)) { + case 0: + break; + case NETWORK_ERR_RETRY: + send_retry(&hd, ep, local_addr, sa, salen, data.size() * 3); + return std::nullopt; + default: + return std::nullopt; + } + + if (h->on_write() != 0) { + return std::nullopt; + } + + std::array scids; + + auto num_scid = ngtcp2_conn_get_scid(h->_conn, nullptr); + + assert(num_scid <= scids.size()); + + ngtcp2_conn_get_scid(h->_conn, scids.data()); + + for (size_t i = 0; i < num_scid; ++i) { + associate_cid(scids[i], h); + } + + _h3Sessions.emplace(dcid_key, std::move(h)); + + return dcid_key; + } + + auto h = (*handler_it).second; + auto conn = h->_conn; + if (ngtcp2_conn_in_closing_period(conn)) { + if (h->send_conn_close(ep, local_addr, sa, salen, pi, data) != 0) { + remove(conn); + } + return std::nullopt; + } + if (ngtcp2_conn_in_draining_period(conn)) { + return std::nullopt; + } + + if (auto rv = h->on_read(ep, local_addr, sa, salen, pi, data); rv != 0) { + if (rv != NETWORK_ERR_CLOSE_WAIT) { + remove(conn); + } + return std::nullopt; + } + + return dcid_key; + } + + void associate_cid(const ngtcp2_cid &cid, const std::shared_ptr> &session) { + _h3Sessions.emplace(cid, session); + } + + void dissociate_cid(const ngtcp2_cid &cid) { + _sessionsToRemove.push_back(cid); + } + + void remove(ngtcp2_conn *conn) { + dissociate_cid(*ngtcp2_conn_get_client_initial_dcid(conn)); + + std::vector cids(ngtcp2_conn_get_scid(conn, nullptr)); + ngtcp2_conn_get_scid(conn, cids.data()); + + for (const auto &cid : cids) { + dissociate_cid(cid); + } + } +}; + +} // namespace opencmw::majordomo::detail::rest + +#endif // OPENCMW_MAJORDOMO_RESTSERVER_HPP diff --git a/src/majordomo/include/majordomo/TlsServerSession_Ossl.hpp b/src/majordomo/include/majordomo/TlsServerSession_Ossl.hpp new file mode 100644 index 00000000..c75c3629 --- /dev/null +++ b/src/majordomo/include/majordomo/TlsServerSession_Ossl.hpp @@ -0,0 +1,286 @@ +/* + * ngtcp2 + * + * Copyright (c) 2025 ngtcp2 contributors + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#ifndef TLS_SERVER_SESSION_OSSL_H +#define TLS_SERVER_SESSION_OSSL_H + +#include "NgTcp2Util.hpp" +#include "TlsSessionBase_Ossl.hpp" + +#include +#include + +#include + +#include + +namespace opencmw::majordomo::detail::rest { + +inline int alpn_select_proto_h3_cb(SSL *ssl, const unsigned char **out, unsigned char *outlen, const unsigned char *in, unsigned int inlen, void * /*arg*/) { + auto conn_ref = static_cast(SSL_get_app_data(ssl)); + auto conn = conn_ref->get_conn(conn_ref); + const uint8_t *alpn; + size_t alpnlen; + // This should be the negotiated version, but we have not set the negotiated version when this callback is called. + auto version = ngtcp2_conn_get_client_chosen_version(conn); + + switch (version) { + case NGTCP2_PROTO_VER_V1: + case NGTCP2_PROTO_VER_V2: + alpn = H3_ALPN_V1; + alpnlen = str_size(H3_ALPN_V1); + break; + default: + std::println(std::cerr, "Unexpected quic protocol version: 0x{:x}", version); + return SSL_TLSEXT_ERR_ALERT_FATAL; + } + + for (auto p = in, end = in + inlen; p + alpnlen <= end; p += *p + 1) { + if (std::equal(alpn, alpn + alpnlen, p)) { + *out = p + 1; + *outlen = *p; + return SSL_TLSEXT_ERR_OK; + } + } + + return SSL_TLSEXT_ERR_ALERT_FATAL; +} + +inline int alpn_select_proto_hq_cb(SSL *ssl, const unsigned char **out, + unsigned char *outlen, const unsigned char *in, + unsigned int inlen, void * /*arg*/) { + auto conn_ref = static_cast(SSL_get_app_data(ssl)); + auto conn = conn_ref->get_conn(conn_ref); + + // This should be the negotiated version, but we have not set the negotiated version when this callback is called. + auto version = ngtcp2_conn_get_client_chosen_version(conn); + + const uint8_t *alpn; + size_t alpnlen; + + switch (version) { + case NGTCP2_PROTO_VER_V1: + case NGTCP2_PROTO_VER_V2: + alpn = HQ_ALPN_V1; + alpnlen = str_size(HQ_ALPN_V1); + break; + default: + + std::println(std::cerr, "Unexpected quic protocol version: 0x{:x}", version); + return SSL_TLSEXT_ERR_ALERT_FATAL; + } + + for (auto p = in, end = in + inlen; p + alpnlen <= end; p += *p + 1) { + if (std::equal(alpn, alpn + alpnlen, p)) { + *out = p + 1; + *outlen = *p; + return SSL_TLSEXT_ERR_OK; + } + } + + HTTP_DBG("Client did not present ALPN"); + + return SSL_TLSEXT_ERR_ALERT_FATAL; +} + +inline int verify_cb(int /*preverify_ok*/, X509_STORE_CTX *) { + // We don't verify the client certificate. Just request it for the testing purpose. + return 1; +} + +inline int gen_ticket_cb(SSL *ssl, void * /*arg*/) { + auto conn = static_cast(SSL_get_app_data(ssl)); + auto ver = htonl(ngtcp2_conn_get_negotiated_version(conn)); + if (!SSL_SESSION_set1_ticket_appdata(SSL_get0_session(ssl), &ver, + sizeof(ver))) { + return 0; + } + + return 1; +} + +inline SSL_TICKET_RETURN decrypt_ticket_cb(SSL *ssl, SSL_SESSION *session, const unsigned char * /*keyname*/, size_t /*keynamelen*/, SSL_TICKET_STATUS status, void * /*arg*/) { + switch (status) { + case SSL_TICKET_EMPTY: + case SSL_TICKET_NO_DECRYPT: + return SSL_TICKET_RETURN_IGNORE_RENEW; + } + + uint8_t *pver; + uint32_t ver; + size_t verlen; + + if (!SSL_SESSION_get0_ticket_appdata( + session, reinterpret_cast(&pver), &verlen) + || verlen != sizeof(ver)) { + switch (status) { + case SSL_TICKET_SUCCESS: + return SSL_TICKET_RETURN_IGNORE; + case SSL_TICKET_SUCCESS_RENEW: + default: + return SSL_TICKET_RETURN_IGNORE_RENEW; + } + } + + memcpy(&ver, pver, sizeof(ver)); + auto conn_ref = static_cast(SSL_get_app_data(ssl)); + auto conn = conn_ref->get_conn(conn_ref); + + if (ngtcp2_conn_get_client_chosen_version(conn) != ntohl(ver)) { + switch (status) { + case SSL_TICKET_SUCCESS: + return SSL_TICKET_RETURN_IGNORE; + case SSL_TICKET_SUCCESS_RENEW: + default: + return SSL_TICKET_RETURN_IGNORE_RENEW; + } + } + + switch (status) { + case SSL_TICKET_SUCCESS: + return SSL_TICKET_RETURN_USE; + case SSL_TICKET_SUCCESS_RENEW: + default: + return SSL_TICKET_RETURN_USE_RENEW; + } +} + +class TLSServerContext { +public: + TLSServerContext() + : ssl_ctx_{ nullptr } {} + + ~TLSServerContext() { + if (ssl_ctx_) { + SSL_CTX_free(ssl_ctx_); + } + } + + TLSServerContext(const TLSServerContext &) = delete; + TLSServerContext &operator=(const TLSServerContext &) = delete; + TLSServerContext(TLSServerContext &&other) noexcept + : ssl_ctx_{ std::exchange(other.ssl_ctx_, nullptr) } {} + TLSServerContext &operator=(TLSServerContext &&other) noexcept { + if (this != &other) { + if (ssl_ctx_) { + SSL_CTX_free(ssl_ctx_); + } + ssl_ctx_ = std::exchange(other.ssl_ctx_, nullptr); + } + return *this; + } + + int init(const opencmw::rest::detail::EVP_PKEY_Ptr &key, const opencmw::rest::detail::X509_Ptr &cert, AppProtocol app_proto) { + constexpr static unsigned char sid_ctx[] = "ngtcp2 server"; + + ssl_ctx_ = SSL_CTX_new(TLS_server_method()); + if (!ssl_ctx_) { + std::cerr << "SSL_CTX_new: " << ERR_error_string(ERR_get_error(), nullptr) + << std::endl; + return -1; + } + + SSL_CTX_set_max_early_data(ssl_ctx_, UINT32_MAX); + + constexpr auto ssl_opts = (SSL_OP_ALL & ~SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS) | SSL_OP_SINGLE_ECDH_USE | SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_NO_ANTI_REPLAY; + + SSL_CTX_set_options(ssl_ctx_, ssl_opts); + + SSL_CTX_set_mode(ssl_ctx_, SSL_MODE_RELEASE_BUFFERS); + + switch (app_proto) { + case AppProtocol::H3: + SSL_CTX_set_alpn_select_cb(ssl_ctx_, alpn_select_proto_h3_cb, nullptr); + break; + case AppProtocol::HQ: + SSL_CTX_set_alpn_select_cb(ssl_ctx_, alpn_select_proto_hq_cb, nullptr); + break; + } + + SSL_CTX_set_default_verify_paths(ssl_ctx_); + + if (SSL_CTX_use_PrivateKey(ssl_ctx_, key.get()) + != 1) { + std::cerr << "SSL_CTX_use_PrivateKey_file: " << ERR_error_string(ERR_get_error(), nullptr) << std::endl; + return -1; + } + + if (SSL_CTX_use_certificate(ssl_ctx_, cert.get()) != 1) { + std::cerr << "SSL_CTX_use_certificate_chain_file: " << ERR_error_string(ERR_get_error(), nullptr) << std::endl; + return -1; + } + + if (SSL_CTX_check_private_key(ssl_ctx_) != 1) { + std::cerr << "SSL_CTX_check_private_key: " << ERR_error_string(ERR_get_error(), nullptr) << std::endl; + return -1; + } + + SSL_CTX_set_session_id_context(ssl_ctx_, sid_ctx, sizeof(sid_ctx) - 1); + SSL_CTX_set_session_ticket_cb(ssl_ctx_, gen_ticket_cb, decrypt_ticket_cb, nullptr); + + return 0; + } + + SSL_CTX *get_native_handle() const { + return ssl_ctx_; + } + + void enable_keylog() { + } + +private: + SSL_CTX *ssl_ctx_; +}; + +class TLSServerSession : public TLSSessionBase { +public: + int init(const TLSServerContext &tls_ctx, ngtcp2_crypto_conn_ref *conn_ref) { + auto ssl_ctx = tls_ctx.get_native_handle(); + + auto ssl = SSL_new(ssl_ctx); + if (!ssl) { + std::cerr << "SSL_new: " << ERR_error_string(ERR_get_error(), nullptr) << std::endl; + return -1; + } + + ngtcp2_crypto_ossl_ctx_set_ssl(ossl_ctx_, ssl); + + if (ngtcp2_crypto_ossl_configure_server_session(ssl) != 0) { + std::cerr << "ngtcp2_crypto_ossl_configure_server_session failed" << std::endl; + return -1; + } + SSL_set_app_data(ssl, conn_ref); + SSL_set_accept_state(ssl); + SSL_set_quic_tls_early_data_enabled(ssl, 1); + + return 0; + } + + // ticket is sent automatically. + int send_session_ticket() { return 0; } +}; + +} // namespace opencmw::majordomo::detail::rest + +#endif // TLS_SERVER_SESSION_OSSL_H diff --git a/src/majordomo/include/majordomo/TlsSessionBase_Ossl.hpp b/src/majordomo/include/majordomo/TlsSessionBase_Ossl.hpp new file mode 100644 index 00000000..7f96963b --- /dev/null +++ b/src/majordomo/include/majordomo/TlsSessionBase_Ossl.hpp @@ -0,0 +1,91 @@ +/* + * ngtcp2 + * + * Copyright (c) 2025 ngtcp2 contributors + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#ifndef TLS_SESSION_BASE_OSSL_H +#define TLS_SESSION_BASE_OSSL_H + +#include +#include + +#include + +#include + +namespace opencmw::majordomo::detail::rest { + +class TLSSessionBase { +public: + TLSSessionBase() { + ngtcp2_crypto_ossl_ctx_new(&ossl_ctx_, nullptr); + } + ~TLSSessionBase() { + auto ssl = ngtcp2_crypto_ossl_ctx_get_ssl(ossl_ctx_); + + if (ssl) { + SSL_set_app_data(ssl, NULL); + SSL_free(ssl); + } + + ngtcp2_crypto_ossl_ctx_del(ossl_ctx_); + } + + ngtcp2_crypto_ossl_ctx *get_native_handle() const { + return ossl_ctx_; + } + + std::string get_cipher_name() const { + return SSL_get_cipher_name(ngtcp2_crypto_ossl_ctx_get_ssl(ossl_ctx_)); + } + + std::string_view get_negotiated_group() const { + auto ssl = ngtcp2_crypto_ossl_ctx_get_ssl(ossl_ctx_); + auto name = SSL_get0_group_name(ssl); + + if (!name) { + return std::string_view{ "" }; + } + + return name; + } + + std::string get_selected_alpn() const { + auto ssl = ngtcp2_crypto_ossl_ctx_get_ssl(ossl_ctx_); + const unsigned char *alpn = nullptr; + unsigned int alpnlen; + + SSL_get0_alpn_selected(ssl, &alpn, &alpnlen); + + return std::string{ alpn, alpn + alpnlen }; + } + + // Keylog is enabled per SSL_CTX. + void enable_keylog() {} + +protected: + ngtcp2_crypto_ossl_ctx *ossl_ctx_; +}; + +} // namespace opencmw::majordomo::detail::rest + +#endif // TLS_SESSION_BASE_OSSL_H diff --git a/src/majordomo/test/CMakeLists.txt b/src/majordomo/test/CMakeLists.txt index a031859a..d8cfd566 100644 --- a/src/majordomo/test/CMakeLists.txt +++ b/src/majordomo/test/CMakeLists.txt @@ -29,7 +29,6 @@ function(opencmw_add_test_catch2 name sources) catch_discover_tests(${name}) endfunction() - opencmw_add_test_app(majordomo_testapp testapp.cpp) opencmw_add_test_app(majordomo_benchmark majordomo_benchmark.cpp) @@ -40,3 +39,7 @@ opencmw_add_test_catch2(SubscriptionMatch_tests subscriptionmatcher_tests.cpp) opencmw_add_test_catch2(majordomo_tests majordomo_tests.cpp;subscriptionmatcher_tests.cpp) opencmw_add_test_catch2(majordomo_worker_tests majordomoworker_tests.cpp;subscriptionmatcher_tests.cpp) opencmw_add_test_catch2(majordomo_worker_rest_tests majordomoworker_rest_tests.cpp;subscriptionmatcher_tests.cpp) + +if(NOT ENSCRIPTEN) + opencmw_add_test_catch2(majordomo_worker_load_tests majordomo_load_tests.cpp) +endif() diff --git a/src/majordomo/test/cryptography_tests.cpp b/src/majordomo/test/cryptography_tests.cpp index 536d0745..d40c703f 100644 --- a/src/majordomo/test/cryptography_tests.cpp +++ b/src/majordomo/test/cryptography_tests.cpp @@ -1,4 +1,4 @@ -#include .hpp> +#include #include diff --git a/src/majordomo/test/majordomo_load_tests.cpp b/src/majordomo/test/majordomo_load_tests.cpp new file mode 100644 index 00000000..6bbcdfe8 --- /dev/null +++ b/src/majordomo/test/majordomo_load_tests.cpp @@ -0,0 +1,126 @@ +#include "ClientCommon.hpp" +#include "IoBuffer.hpp" +#include "IoSerialiser.hpp" +#include "IoSerialiserYaS.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +using namespace opencmw; + +constexpr std::uint16_t kServerPort = 12355; + +namespace { +template +void waitFor(std::atomic &responseCount, T expected, std::chrono::milliseconds timeout = std::chrono::seconds(5)) { + const auto start = std::chrono::system_clock::now(); + while (responseCount.load() < expected && std::chrono::system_clock::now() - start < timeout) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + const auto received = responseCount.load(); + if (received != expected) { + FAIL(std::format("Expected {} responses, but got {}\n", expected, received)); + } +} +} // namespace + +TEST_CASE("Load test", "[majordomo][majordomoworker][load_test][http2]") { + majordomo::Broker broker("/TestBroker", testSettings()); + majordomo::rest::Settings rest; + rest.port = kServerPort; + rest.protocols = majordomo::rest::Protocol::Http2; + auto bound = broker.bindRest(rest); + if (!bound) { + FAIL(std::format("Failed to bind REST server: {}", bound.error())); + return; + } + + query::registerTypes(opencmw::load_test::Context(), broker); + + majordomo::load_test::Worker worker(broker); + + RunInThread brokerRun(broker); + RunInThread workerRun(worker); + REQUIRE(waitUntilWorkerServiceAvailable(broker.context, worker)); + + constexpr auto kNClients = 10UZ; + constexpr auto kSubscriptions = 10UZ; + constexpr bool kSeparateSubscriptions = true; + constexpr auto kNUpdates = 100UZ; + constexpr auto kIntervalMs = 40UZ; + constexpr auto kInitialDelayMs = 50UZ; + constexpr auto kPayloadSize = 4096UZ; + + std::atomic responseCount = 0; + + std::array, kNClients> clients; + for (std::size_t i = 0; i < clients.size(); i++) { + clients[i] = std::make_unique(client::DefaultContentTypeHeader(MIME::BINARY)); + } + + const auto start = std::chrono::system_clock::now(); + std::array latencies; + + std::atomic responseCountOfi0j0 = 0; + std::array latenciesOfi0j0; + + for (std::size_t i = 0; i < clients.size(); i++) { + for (std::size_t j = 0; j < kSubscriptions; j++) { + client::Command cmd; + cmd.command = mdp::Command::Subscribe; + cmd.serviceName = "/loadTest"; + const auto topic = std::format("{}:{}", kSeparateSubscriptions ? i : 0, j); + cmd.topic = URI<>(std::format("http://localhost:{}/loadTest?topic={}&intervalMs={}&payloadSize={}&nUpdates={}&initialDelayMs={}", kServerPort, topic, kIntervalMs, kPayloadSize, kNUpdates, kInitialDelayMs)); + cmd.callback = [&responseCount, &responseCountOfi0j0, &latencies, &latenciesOfi0j0, kPayloadSize, i, j](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Notify); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.size() > 0); + const auto index = responseCount.fetch_add(1); + + load_test::Payload payload; + try { + IoBuffer buffer{ msg.data }; + opencmw::deserialise(buffer, payload); + REQUIRE(payload.data.size() == kPayloadSize); + } catch (const opencmw::ProtocolException &e) { + FAIL(std::format("Failed to deserialise payload: {}", e.what())); + return; + } + const auto now = opencmw::load_test::timestamp().count(); + const auto latency = now - payload.timestampNs; + assert(latency >= 0); + if (latency < 0) { + std::print("Negative latency: {} ({} - {})\n", latency, now, payload.timestampNs); + } + latencies[index] = static_cast(latency); + if (i == 0 && j == 0) { + const auto idx = responseCountOfi0j0.fetch_add(1); + latenciesOfi0j0[idx] = static_cast(latency / 1000); + } + }; + clients[i]->request(std::move(cmd)); + } + } + + waitFor(responseCount, kNUpdates * kSubscriptions * kNClients, 30s); + + std::println("Received {} responses in {}ms (Net production time: {}ms)", responseCount.load(), std::chrono::duration_cast(std::chrono::system_clock::now() - start).count(), kInitialDelayMs + (kNUpdates * kIntervalMs)); + // TODO maybe print more detailed distribution + std::uint64_t sumLatency = 0; + for (const auto &latency : latencies) { + sumLatency += latency; + } + + const auto averageLatency = static_cast(sumLatency) / static_cast(responseCount.load()); + std::println("Average latency: {}µs", averageLatency / 1000.0); + // TODO compute drift over time +} diff --git a/src/majordomo/test/majordomoworker_rest_tests.cpp b/src/majordomo/test/majordomoworker_rest_tests.cpp index 00d25a04..6b845aff 100644 --- a/src/majordomo/test/majordomoworker_rest_tests.cpp +++ b/src/majordomo/test/majordomoworker_rest_tests.cpp @@ -1,5 +1,5 @@ +#include "majordomo/Rest.hpp" #include -#include #include #include @@ -14,55 +14,13 @@ #include #include -#include #include #include -#include // Concepts and tests use common types #include -std::jthread makeGetRequestResponseCheckerThread(const std::string &address, const std::vector &requiredResponses, const std::vector &requiredStatusCodes = {}, [[maybe_unused]] std::source_location location = std::source_location::current()) { - return std::jthread([=] { - httplib::Client http("localhost", majordomo::DEFAULT_REST_PORT); - http.set_follow_location(true); - http.set_keep_alive(true); -#define requireWithSource(arg) \ - if (!(arg)) opencmw::zmq::debug::withLocation(location) << "<- call got a failed requirement:"; \ - REQUIRE(arg) - for (std::size_t i = 0; i < requiredResponses.size(); ++i) { - const auto response = http.Get(address); - requireWithSource(response); - const auto requiredStatusCode = i < requiredStatusCodes.size() ? requiredStatusCodes[i] : 200; - requireWithSource(response->status == requiredStatusCode); - requireWithSource(response->body.find(requiredResponses[i]) != std::string::npos); - } -#undef requireWithSource - }); -} - -std::jthread makeLongPollingRequestResponseCheckerThread(const std::string &address, const std::vector &requiredResponses, const std::vector &requiredStatusCodes = {}, [[maybe_unused]] std::source_location location = std::source_location::current()) { - return std::jthread([=] { - httplib::Client http("localhost", majordomo::DEFAULT_REST_PORT); - http.set_follow_location(true); - http.set_keep_alive(true); -#define requireWithSource(arg) \ - if (!(arg)) opencmw::zmq::debug::withLocation(location) << "<- call got a failed requirement:"; \ - REQUIRE(arg) - for (std::size_t i = 0; i < requiredResponses.size(); ++i) { - const std::string url = std::format("{}{}LongPollingIdx={}", address, address.contains('?') ? "&" : "?", i == 0 ? "Next" : std::format("{}", i)); - const auto response = http.Get(url); - if (i == 0) { // check forwarding to the explicit index - REQUIRE(response->location.find("LongPollingIdx=0") != std::string::npos); - } - requireWithSource(response); - const auto requiredStatusCode = i < requiredStatusCodes.size() ? requiredStatusCodes[i] : 200; - requireWithSource(response->status == requiredStatusCode); - requireWithSource(response->body.find(requiredResponses[i]) != std::string::npos); - } -#undef requireWithSource - }); -} +constexpr std::uint16_t kServerPort = 12348; struct ColorContext { bool red = false; @@ -88,6 +46,11 @@ class ColorWorker : public majordomo::Worker explicit ColorWorker(const BrokerType &broker, std::vector notificationContexts) : super_t(broker, {}) { + super_t::setCallback([](majordomo::RequestContext & /*rawCtx*/, const ColorContext &inCtx, const majordomo::Empty &, ColorContext &outCtx, SingleString &out) { + outCtx = inCtx; + out.value = std::format("red={}, green={}, blue={}\n", inCtx.red, inCtx.green, inCtx.blue); + FAIL(std::format("Unexpected GET/SET request: {}", out.value)); + }); notifyThread = std::jthread([this, contexts = std::move(notificationContexts)]() { int counter = 0; for (const auto &context : contexts) { @@ -134,10 +97,10 @@ struct WaitingContext { ENABLE_REFLECTION_FOR(WaitingContext, timeoutMs, contentType) struct UpdateTime { - long updateTime; + long updateTimeµs; std::vector payload; }; -ENABLE_REFLECTION_FOR(UpdateTime, updateTime, payload) +ENABLE_REFLECTION_FOR(UpdateTime, updateTimeµs, payload) template class WaitingWorker : public majordomo::Worker { @@ -169,11 +132,11 @@ class ClockWorker : public majordomo::Worker 0 && !_shutdownRequested) { std::this_thread::sleep_until(updateTime); std::print("publishing update\n"); - UpdateTime update{ std::chrono::duration_cast(updateTime.time_since_epoch()).count(), std::views::iota(0, payloadSize) | std::ranges::to() }; + UpdateTime update{ std::chrono::duration_cast(updateTime.time_since_epoch()).count(), std::views::iota(0, payloadSize) | std::ranges::to() }; this->notify(SimpleContext(), update); updateTime += _period; _nUpdates--; @@ -187,12 +150,32 @@ class ClockWorker : public majordomo::Worker +void waitFor(std::atomic &responseCount, T expected, std::chrono::milliseconds timeout = std::chrono::seconds(5)) { + const auto start = std::chrono::system_clock::now(); + while (responseCount.load() < expected && std::chrono::system_clock::now() - start < timeout) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + const auto result = responseCount.load() == expected; + if (!result) { + FAIL(std::format("Expected {} responses, but got {}\n", expected, responseCount.load())); + } +} +} // namespace + +TEST_CASE("Simple REST example", "[majordomo][majordomoworker][simple_example][http2]") { // We run both broker and worker inproc - majordomo::Broker broker("/TestBroker", testSettings()); - auto fs = cmrc::assets::get_filesystem(); - FileServerRestBackend rest(broker, fs); - RunInThread restServerRun(rest); + majordomo::Broker broker("/TestBroker", testSettings()); + majordomo::rest::Settings rest; + rest.port = kServerPort; + rest.protocols = majordomo::rest::Protocol::Http2; + rest.handlers = { majordomo::rest::cmrcHandler("/assets/*", "", std::make_shared(cmrc::assets::get_filesystem()), "") }; + + if (auto bound = broker.bindRest(rest); !bound) { + FAIL(std::format("Failed to bind REST server: {}", bound.error())); + return; + } // For subscription matching, it is necessary that broker knows how to handle the query params "ctx" and "contentType". // ("ctx" needs to use the TimingCtxFilter, and "contentType" compare the mime types (currently simply a string comparison)) @@ -205,59 +188,144 @@ TEST_CASE("Simple MajordomoWorker example showing its usage", "[majordomo][major // Create MajordomoWorker with our domain objects, and our TestHandler. majordomo::Worker<"/addressbook", SimpleContext, AddressRequest, AddressEntry> worker(broker, TestAddressHandler()); - // Run worker and broker in separate threads - RunInThread brokerRun(broker); - RunInThread workerRun(worker); - + RunInThread brokerRun(broker); + RunInThread workerRun(worker); REQUIRE(waitUntilWorkerServiceAvailable(broker.context, worker)); - SECTION("request Address information as JSON and as HTML") { - auto httpThreadJSON = makeGetRequestResponseCheckerThread("/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=application%2Fjavascript", { "Santa Claus" }); - - auto httpThreadHTML = makeGetRequestResponseCheckerThread("/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=text%2Fhtml", { "Elf Road" }); - } - - SECTION("post data") { - httplib::Client postData{ "http://localhost:8080" }; - postData.Post("/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=application%json", "{\"streetNumber\": 1882}", "application/json"); - - auto httpThreadJSON = makeGetRequestResponseCheckerThread("/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=application%json", { "1882" }); - } - - SECTION("post data as multipart") { - std::jthread putRequestThread{ - [] { - // set a value on the server - httplib::Client postRequest("http://localhost:8080"); - postRequest.set_keep_alive(true); - - httplib::MultipartFormDataItems items{ - { "name", "Kalle", "name_file", "text" }, - { "street", "calle", "street_file", "text" }, - // { "streetNumber", "8", "number_file", "number" }, // `error(22) parsing number at buffer position: 41"` , deserialiser finds "8" instead of 8 - { "postalCode", "14005", "postal_code_file", "text" }, - { "city", "ciudad", "city_file", "text" } - // "isCurrent", "true", "is_current_file", "text" }, // does not work because true will be quoted. which is not a valid boolean - //{ "isCurrent", "false", "is_current_file", "boolean" }, // content type boolean might not exist, anyway, content_type is not taken into account anyway - }; - - auto r = postRequest.Put("/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=application%2Fjavascript", items); - - REQUIRE(r); - CAPTURE(r->reason); - CAPTURE(r->body); - REQUIRE(r->status == 200); - - auto httpThreadJSON = makeGetRequestResponseCheckerThread("/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=application%2Fjavascript", { "Kalle" }); - } - }; - } + std::atomic responseCount = 0; + + opencmw::client::RestClient client; + + // Invalid port, assuming there's nobody listening on port 44444 + opencmw::client::Command cannotReach; + cannotReach.command = mdp::Command::Get; + cannotReach.topic = opencmw::URI<>("http://localhost:44444/"); + cannotReach.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == "Connection refused"); + REQUIRE(msg.data.asString() == ""); + responseCount++; + }; + client.request(std::move(cannotReach)); + + auto serviceListCallback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "/addressbook,/mmi.dns,/mmi.echo,/mmi.openapi,/mmi.service"); + responseCount++; + }; + + // path "" returns service list + opencmw::client::Command serviceList; + serviceList.command = mdp::Command::Get; + serviceList.topic = opencmw::URI<>(std::format("http://localhost:{}", kServerPort)); + serviceList.callback = serviceListCallback; + client.request(std::move(serviceList)); + + // path "/" returns service list + opencmw::client::Command serviceList2; + serviceList2.command = mdp::Command::Get; + serviceList2.topic = opencmw::URI<>(std::format("http://localhost:{}/", kServerPort)); + serviceList2.callback = serviceListCallback; + client.request(std::move(serviceList2)); + + // Asset file + opencmw::client::Command getAsset; + getAsset.command = mdp::Command::Get; + getAsset.topic = opencmw::URI<>(std::format("http://localhost:{}/assets/main.css", kServerPort)); + getAsset.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains("body {")); + responseCount++; + }; + client.request(std::move(getAsset)); + + // Asset file that does not exist + opencmw::client::Command getAsset404; + getAsset404.command = mdp::Command::Get; + getAsset404.topic = opencmw::URI<>(std::format("http://localhost:{}/assets/does-not-exist", kServerPort)); + getAsset404.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == "Not found"); + REQUIRE(msg.data.asString() == ""); + responseCount++; + }; + client.request(std::move(getAsset404)); + + waitFor(responseCount, 5); + responseCount = 0; + + constexpr int kExpectedResponses = 4; + + // The following requests are all to the same service and thus must be processed in order + + // Get worker data as JSON + opencmw::client::Command getJson; + getJson.command = mdp::Command::Get; + getJson.topic = opencmw::URI<>(std::format("http://localhost:{}/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=application%2Fjson", kServerPort)); + getJson.callback = [&responseCount](const auto &msg) { + REQUIRE(responseCount == 0); + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains("\"name\": \"Santa Claus\"")); + REQUIRE(msg.topic == opencmw::URI<>("/addressbook?contentType=application%2Fjson&testFilter=&ctx=FAIR.SELECTOR.ALL")); + responseCount++; + }; + client.request(std::move(getJson)); + + // Get worker data as HTML + opencmw::client::Command getHtml; + getHtml.command = mdp::Command::Get; + getHtml.topic = opencmw::URI<>(std::format("http://localhost:{}/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=text%2Fhtml", kServerPort)); + getHtml.callback = [&responseCount](const auto &msg) { + REQUIRE(responseCount == 1); + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains("Elf Road")); + REQUIRE(msg.topic == opencmw::URI<>("/addressbook?contentType=text%2Fhtml&testFilter=&ctx=FAIR.SELECTOR.ALL")); + responseCount++; + }; + client.request(std::move(getHtml)); + + // Set worker data + opencmw::client::Command postJson; + postJson.command = mdp::Command::Set; + postJson.topic = opencmw::URI<>(std::format("http://localhost:{}/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=application%2Fjson", kServerPort)); + postJson.data = opencmw::IoBuffer("{\"streetNumber\": 1882}"); + postJson.callback = [&responseCount](const auto &msg) { + REQUIRE(responseCount == 2); + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.topic == opencmw::URI<>("/addressbook?contentType=application%2Fjson&testFilter=&ctx=FAIR.SELECTOR.ALL")); + responseCount++; + }; + client.request(std::move(postJson)); + + // Test that the Set call is correctly applied + opencmw::client::Command getJsonAfterSet; + getJsonAfterSet.command = mdp::Command::Get; + getJsonAfterSet.topic = opencmw::URI<>(std::format("http://localhost:{}/addressbook?ctx=FAIR.SELECTOR.ALL&contentType=application%2Fjson", kServerPort)); + getJsonAfterSet.callback = [&responseCount](const auto &msg) { + REQUIRE(responseCount == 3); + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains("\"streetNumber\": 1882")); + REQUIRE(msg.topic == opencmw::URI<>("/addressbook?contentType=application%2Fjson&testFilter=&ctx=FAIR.SELECTOR.ALL")); + responseCount++; + }; + client.request(std::move(getJsonAfterSet)); + + waitFor(responseCount, kExpectedResponses); } + TEST_CASE("Invalid paths", "[majordomo][majordomoworker][rest]") { majordomo::Broker broker("/TestBroker", testSettings()); - auto fs = cmrc::assets::get_filesystem(); - FileServerRestBackend rest(broker, fs); - RunInThread restServerRun(rest); + majordomo::rest::Settings rest; + rest.port = kServerPort; + rest.protocols = majordomo::rest::Protocol::Http2; + auto bound = broker.bindRest(rest); + REQUIRE(bound); opencmw::query::registerTypes(PathContext(), broker); @@ -268,16 +336,48 @@ TEST_CASE("Invalid paths", "[majordomo][majordomoworker][rest]") { REQUIRE(waitUntilWorkerServiceAvailable(broker.context, worker)); - auto space = makeGetRequestResponseCheckerThread("/paths/with%20space", { "Invalid service name" }, { 500 }); - auto invalidSubscription = makeGetRequestResponseCheckerThread("/p-a-t-h-s/?LongPollIdx=Next", { "Invalid service name" }, { 500 }); + constexpr int kExpectedResponses = 1; + std::atomic responseCount = 0; + + opencmw::client::RestClient client; + +#if 0 // TODO Currently URI<>() throws on %20 in the path (which is a valid URI but not a valid service name) + opencmw::client::Command getSpace; + getSpace.command = mdp::Command::Get; + getSpace.topic = opencmw::URI<>(std::format("http://localhost:{}/paths/with%20space", kServerPort)); + getSpace.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == "Invalid service name"); + REQUIRE(msg.data.empty()); + responseCount++; + }; + client.request(std::move(getSpace)); +#endif + opencmw::client::Command invalidSubscription; + invalidSubscription.command = mdp::Command::Get; + invalidSubscription.topic = opencmw::URI<>(std::format("http://localhost:{}//p-a-t-h-s/?LongPollIdx=Next", kServerPort)); + invalidSubscription.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == "Invalid service name '//p-a-t-h-s'"); + REQUIRE(msg.data.empty()); + responseCount++; + }; + + client.request(std::move(invalidSubscription)); + + waitFor(responseCount, kExpectedResponses); } TEST_CASE("Get/Set with subpaths", "[majordomo][majordomoworker][rest]") { - majordomo::Broker broker("/TestBroker", testSettings()); - auto fs = cmrc::assets::get_filesystem(); - FileServerRestBackend rest(broker, fs); - RunInThread restServerRun(rest); - + majordomo::Broker broker("/TestBroker", testSettings()); + majordomo::rest::Settings rest; + rest.port = kServerPort; + rest.protocols = majordomo::rest::Protocol::Http2; + auto bound = broker.bindRest(rest); + if (!bound) { + FAIL(std::format("Failed to bind REST server: {}", bound.error())); + return; + } opencmw::query::registerTypes(PathContext(), broker); PathWorker<"/paths"> worker(broker); @@ -287,17 +387,60 @@ TEST_CASE("Get/Set with subpaths", "[majordomo][majordomoworker][rest]") { REQUIRE(waitUntilWorkerServiceAvailable(broker.context, worker)); - auto empty = makeGetRequestResponseCheckerThread("/paths", { "path=''" }); - auto one = makeGetRequestResponseCheckerThread("/paths/a", { "path='\\/a'" }); - auto two = makeGetRequestResponseCheckerThread("/paths/a/b", { "path='\\/a\\/b'" }); + constexpr int kExpectedResponses = 3; + std::atomic responseCount = 0; + + opencmw::client::RestClient client; + + opencmw::client::Command getEmpty; + getEmpty.command = mdp::Command::Get; + getEmpty.topic = opencmw::URI<>(std::format("http://localhost:{}/paths", kServerPort)); + getEmpty.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains("path=''")); + REQUIRE(msg.topic == opencmw::URI<>("/paths?contentType=application%2Fjson")); + responseCount++; + }; + client.request(std::move(getEmpty)); + + opencmw::client::Command getOne; + getOne.command = mdp::Command::Get; + getOne.topic = opencmw::URI<>(std::format("http://localhost:{}/paths/a", kServerPort)); + getOne.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains("path='\\/a'")); + REQUIRE(msg.topic == opencmw::URI<>("/paths/a?contentType=application%2Fjson")); + responseCount++; + }; + client.request(std::move(getOne)); + + opencmw::client::Command getTwo; + getTwo.command = mdp::Command::Get; + getTwo.topic = opencmw::URI<>(std::format("http://localhost:{}/paths/a/b", kServerPort)); + getTwo.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains("path='\\/a\\/b'")); + REQUIRE(msg.topic == opencmw::URI<>("/paths/a/b?contentType=application%2Fjson")); + responseCount++; + }; + client.request(std::move(getTwo)); + + waitFor(responseCount, kExpectedResponses); } TEST_CASE("Subscriptions", "[majordomo][majordomoworker][subscription]") { - majordomo::Broker broker("/TestBroker", testSettings()); - auto fs = cmrc::assets::get_filesystem(); - FileServerRestBackend rest(broker, fs); - RunInThread restServerRun(rest); - + majordomo::Broker broker("/TestBroker", testSettings()); + majordomo::rest::Settings rest; + rest.port = kServerPort; + rest.protocols = majordomo::rest::Protocol::Http2; + const auto bound = broker.bindRest(rest); + if (!bound) { + FAIL(std::format("Failed to bind REST server: {}", bound.error())); + return; + } opencmw::query::registerTypes(ColorContext(), broker); constexpr auto red = ColorContext{ .red = true }; @@ -315,14 +458,93 @@ TEST_CASE("Subscriptions", "[majordomo][majordomoworker][subscription]") { REQUIRE(waitUntilWorkerServiceAvailable(broker.context, worker)); - auto allListener = makeLongPollingRequestResponseCheckerThread("/colors", { "0", "1", "2", "3", "4", "5", "6" }); - auto redListener = makeLongPollingRequestResponseCheckerThread("/colors?red", { "0", "3", "4", "6" }); - auto yellowListener = makeLongPollingRequestResponseCheckerThread("/colors?red&green", { "4", "6" }); - auto whiteListener1 = makeLongPollingRequestResponseCheckerThread("/colors?red&green&blue", { "6" }); - auto whiteListener2 = makeLongPollingRequestResponseCheckerThread("/colors?green&red&blue", { "6" }); - auto whiteListener3 = makeLongPollingRequestResponseCheckerThread("/colors?blue&green&red", { "6" }); - - std::this_thread::sleep_for(50ms); // give time for subscriptions to happen + opencmw::client::RestClient client; + + constexpr auto allExpected = std::array{ "0", "1", "2", "3", "4", "5", "6" }; + constexpr auto redExpected = std::array{ "0", "3", "4", "6" }; + constexpr auto yellowExpected = std::array{ "4", "6" }; + constexpr auto whiteExpected = std::array{ "6" }; + + std::atomic allReceived = 0; + std::atomic redReceived = 0; + std::atomic yellowReceived = 0; + std::atomic whiteReceived1 = 0; + std::atomic whiteReceived2 = 0; + std::atomic whiteReceived3 = 0; + + opencmw::client::Command allSub; + allSub.command = mdp::Command::Subscribe; + allSub.topic = opencmw::URI<>(std::format("http://localhost:{}/colors", kServerPort)); + allSub.callback = [&allReceived, &allExpected](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Notify); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains(allExpected[allReceived])); + allReceived++; + }; + client.request(std::move(allSub)); + + opencmw::client::Command redSub; + redSub.command = mdp::Command::Subscribe; + redSub.topic = opencmw::URI<>(std::format("http://localhost:{}/colors?red", kServerPort)); + redSub.callback = [&redReceived, &redExpected](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Notify); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains(redExpected[redReceived])); + redReceived++; + }; + client.request(std::move(redSub)); + + opencmw::client::Command yellowSub; + yellowSub.command = mdp::Command::Subscribe; + yellowSub.topic = opencmw::URI<>(std::format("http://localhost:{}/colors?red&green", kServerPort)); + yellowSub.callback = [&yellowReceived, &yellowExpected](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Notify); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains(yellowExpected[yellowReceived])); + yellowReceived++; + }; + client.request(std::move(yellowSub)); + + opencmw::client::Command whiteSub1; + whiteSub1.command = mdp::Command::Subscribe; + whiteSub1.topic = opencmw::URI<>(std::format("http://localhost:{}/colors?red&green&blue", kServerPort)); + whiteSub1.callback = [&whiteReceived1, &whiteExpected](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Notify); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains(whiteExpected[whiteReceived1])); + whiteReceived1++; + }; + client.request(std::move(whiteSub1)); + + opencmw::client::Command whiteSub2; + whiteSub2.command = mdp::Command::Subscribe; + whiteSub2.topic = opencmw::URI<>(std::format("http://localhost:{}/colors?green&red&blue", kServerPort)); + whiteSub2.callback = [&whiteReceived2, &whiteExpected](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Notify); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains(whiteExpected[whiteReceived2])); + whiteReceived2++; + }; + client.request(std::move(whiteSub2)); + + opencmw::client::Command whiteSub3; + whiteSub3.command = mdp::Command::Subscribe; + whiteSub3.topic = opencmw::URI<>(std::format("http://localhost:{}/colors?blue&green&red", kServerPort)); + whiteSub3.callback = [&whiteReceived3, &whiteExpected](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Notify); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString().contains(whiteExpected[whiteReceived3])); + whiteReceived3++; + }; + + client.request(std::move(whiteSub3)); + + waitFor(allReceived, allExpected.size()); + waitFor(redReceived, redExpected.size()); + waitFor(yellowReceived, yellowExpected.size()); + waitFor(whiteReceived1, whiteExpected.size()); + waitFor(whiteReceived2, whiteExpected.size()); + waitFor(whiteReceived3, whiteExpected.size()); std::vector subscriptions; for (const auto &subscription : worker.activeSubscriptions()) { @@ -332,40 +554,212 @@ TEST_CASE("Subscriptions", "[majordomo][majordomoworker][subscription]") { REQUIRE(subscriptions == std::vector{ "/colors#", "/colors?blue&green&red#", "/colors?green&red#", "/colors?red#" }); } +TEST_CASE("Handler matching", "[majordomo][rest]") { + majordomo::Broker broker("/TestBroker", testSettings()); + majordomo::rest::Settings rest; + rest.port = kServerPort; + rest.protocols = majordomo::rest::Protocol::Http2; + + auto echoHandler + = [](std::string method, std::string path) { + return majordomo::rest::Handler{ + .method = method, + .path = path, + .handler = [path, method](const auto &) { + majordomo::rest::Response response; + response.code = 200; + response.body.put(std::format("{}|{}", method, path)); + return response; + } + }; + }; + + // Cannot test POST because Command::Set is also using "GET" + rest.handlers = { + echoHandler("GET", "/"), + echoHandler("GET", "/assets/*"), + echoHandler("GET", "/assets/subresource") + }; + + auto bound = broker.bindRest(rest); + if (!bound) { + FAIL(std::format("Failed to bind REST server: {}", bound.error())); + return; + } + + RunInThread brokerRun(broker); + + opencmw::client::RestClient client; + + std::atomic responseCount = 0; + + opencmw::client::Command getRoot; + getRoot.command = mdp::Command::Get; + getRoot.topic = opencmw::URI<>(std::format("http://localhost:{}/", kServerPort)); + getRoot.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "GET|/"); + responseCount++; + }; + client.request(std::move(getRoot)); + + opencmw::client::Command getAssets; + getAssets.command = mdp::Command::Get; + getAssets.topic = opencmw::URI<>(std::format("http://localhost:{}/assets/test.txt", kServerPort)); + getAssets.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "GET|/assets/*"); + responseCount++; + }; + client.request(std::move(getAssets)); + + opencmw::client::Command getSubresource; + getSubresource.command = mdp::Command::Get; + getSubresource.topic = opencmw::URI<>(std::format("http://localhost:{}/assets/subresource", kServerPort)); + getSubresource.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "GET|/assets/subresource"); + responseCount++; + }; + client.request(std::move(getSubresource)); + + opencmw::client::Command getSubresource2; + getSubresource2.command = mdp::Command::Get; + getSubresource2.topic = opencmw::URI<>(std::format("http://localhost:{}/assets/subresource/extra", kServerPort)); + getSubresource2.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == "GET|/assets/*"); + responseCount++; + }; + client.request(std::move(getSubresource2)); + waitFor(responseCount, 4); +} + +constexpr auto kExpectedFileContent = R"(-----BEGIN CERTIFICATE----- +MIIFiTCCA3GgAwIBAgIUDBxaxLvthSz4Knvh6R0/zDrxe3QwDQYJKoZIhvcNAQEL +BQAwVDELMAkGA1UEBhMCREUxEDAOBgNVBAgMB1Vua25vd24xEDAOBgNVBAcMB1Vu +a25vd24xITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMTEy +MDUxMDA0MjFaFw0zMTEyMDMxMDA0MjFaMFQxCzAJBgNVBAYTAkRFMRAwDgYDVQQI +DAdVbmtub3duMRAwDgYDVQQHDAdVbmtub3duMSEwHwYDVQQKDBhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDx +1Es3W/5OyMdOPUmCqQuAYV/4xjhP8Fhoi1gjnnlaCrBbBXJl1nW8DdMwVhXF9Yy8 +wPP0SylqkbatiDnUwjviizL6v6DMNQbS+OES5OleCuwbCWAFH3vsDllRZl3LYAdB +6Ec4wNjX5EjE1RgIgT+GEkR13XqyHQi4ELOMEUxxpVcWeBjAFhgiTXvbpnBfBfJo +XPsMoCaWTWhRQksodKM4Mjfn/wxKAfbspNaX5zfPcr/5vNGY2CYiKbsZqwiM5VNq +ml9XoUc5BIWuj4liHerLOqdEj3zpBhn3i+RimGm+N2xDAXOKMlP4w5jyewd5FtLl +vmyLEakiqwSCODkjrP7rbmQ9hRohsF8V5Y4KwaYYEp+pZ4BFCBYdDv+drnVD8MOJ +/7f8LlCE3xN/MvEX2Um1xt5oT4gb3SdSTRZfUEFzkrQK5wodQBUZVhVTg0NRwiWx +hfeILR6/Qd8pciwSXbkT7JBEf1gyghKFDd+GfyDmm/+aIxNAItQ10KUDeLyavP3x +SN9ZR4zITDT9dwzd/yqsKudFB2mqioUv8zXah7lRpQDQmsrwkj4sKOrfS9BxtM8T +tUn56aOeyl7zdWvhI+wJOT07f7l7+c3WW06v9jPXk0IDjhvUdn11uL6aFQiO6cBE +4Jj2z2l4VHsLekcPdXzt6V1IhJuyGWSasOZBwb13tQIDAQABo1MwUTAdBgNVHQ4E +FgQUwA8sghflexohahmoZAFHcpgnJl0wHwYDVR0jBBgwFoAUwA8sghflexohahmo +ZAFHcpgnJl0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAgEAXlVh +oCF+7ywVU1ek2wcPIDayqAUA2JnPGvpjClSqqOfAvst07RspDftxTQ0/BVy7N1Ep +DX9uQP1YiuvZDxc4ySCYUnYiXmf7aeyjSGIwVHhWRx/uQQV880tJ+TIK+JJRLiJg +DLnixdyW/uPY0RkjUzHCADI1zZmrJyZpMAFbqwuoVOtEoMmCJkjJRrZzytRQkTbe +9cUAvHcdjbFADf1yZdTELgNTs9Xolb+aZhXb+DolrQiTpDQj9RSt/raaRyrlssFD +9V0ugW2e9nEEe7PlofDGOYqhaadka9s680xT47s6K8WjPVFwgDVKKojW186JCeAR +8W8AEd1tOmp8BQOY7OLY8hZ9kTnnd8XoLy+l8UH03kfIIPAulGARNCYYo7SJfQU7 +1rQEf28BGi9mL0QhkY0xSSTvuLVbG5DAceUVix6Y+NsTgF+YCWsqXMZX6M/qHjiy +7qf5LeKrLqG1RQsYbi4UzakKfwG2uVH/lETc/j1PMZz4WPPJGeeg+urhHvzvZJ5C +xbuJH3fc9TsTET4FiwQwINSwdY4i/iNCt3kE+6EYNf1G13f3ffCMy3JYxakXjDZB +ido85zpB/BELE69ap2g/pHCNEd9y2ZHCYDvIZaxnJtdBBhWmHqAK00HYEwG5LZPU +I0uPpRG4pT+BvLUZo8rKVpFLHj2nnflS5dXqKPo= +-----END CERTIFICATE----- +)"; + +template +void runFilesystemTest() { + majordomo::Broker broker("/TestBroker", testSettings()); + majordomo::rest::Settings rest; + rest.port = kServerPort; + rest.protocols = majordomo::rest::Protocol::Http2; + rest.handlers = { majordomo::rest::fileSystemHandler("/files/*", "/files/", std::filesystem::current_path(), {}, MMapThreshold) }; + auto bound = broker.bindRest(rest); + if (!bound) { + FAIL(std::format("Failed to bind REST server: {}", bound.error())); + return; + } + + RunInThread brokerRun(broker); + + std::atomic responseCount = 0; + opencmw::client::RestClient client; + + opencmw::client::Command fileExists; + fileExists.command = mdp::Command::Get; + fileExists.topic = opencmw::URI<>(std::format("http://localhost:{}/files/demo_public.crt", kServerPort)); + fileExists.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == ""); + REQUIRE(msg.data.asString() == kExpectedFileContent); + responseCount++; + }; + client.request(std::move(fileExists)); + + opencmw::client::Command fileDoesNotExist; + fileDoesNotExist.command = mdp::Command::Get; + fileDoesNotExist.topic = opencmw::URI<>(std::format("http://localhost:{}/files/does_not_exist", kServerPort)); + fileDoesNotExist.callback = [&responseCount](const auto &msg) { + REQUIRE(msg.command == mdp::Command::Final); + REQUIRE(msg.error == "Not found"); + REQUIRE(msg.data.asString() == ""); + responseCount++; + }; + client.request(std::move(fileDoesNotExist)); + + waitFor(responseCount, 2); +}; + +TEST_CASE("File system handler (stream)", "[majordomo][majordomoworker][rest]") { + runFilesystemTest<1024 * 1024 * 100>(); +} + +TEST_CASE("File system handler (mmap)", "[majordomo][majordomoworker][rest]") { + runFilesystemTest<1>(); +} + TEST_CASE("Subscription latencies", "[majordomo][majordomoworker][rest]") { std::atomic nReceived = 0; std::atomic msLatency = 0; { majordomo::Broker broker("/TestBroker", testSettings()); - auto fs = cmrc::assets::get_filesystem(); - majordomo::RestBackend rest(broker, fs, opencmw::URI<>("http://localhost:12346")); - - ClockWorker<"/clock", 2550> worker(broker, 10ms, 70); - RunInThread restServerRun(rest); - RunInThread brokerRun(broker); - RunInThread workerRun(worker); + majordomo::rest::Settings rest; + rest.port = kServerPort; + rest.protocols = majordomo::rest::Protocol::Http2; + auto bound = broker.bindRest(rest); + if (!bound) { + FAIL(std::format("Failed to bind REST server: {}", bound.error())); + return; + } + ClockWorker<"/clock", 2550> worker(broker, 10ms, 70); + RunInThread brokerRun(broker); + RunInThread workerRun(worker); REQUIRE(waitUntilWorkerServiceAvailable(broker.context, worker)); - rest.setMajordomoTimeout(800ms); // set timeout to unit-test friendly interval - - opencmw::client::RestClient client{ std::string("RestSubLatencyClient") }; + opencmw::client::RestClient client; - opencmw::client::Command _command; - _command.command = opencmw::mdp::Command::Subscribe; - _command.topic = opencmw::URI<>("http://localhost:12346/clock"); - _command.callback = [&nReceived, &msLatency](const opencmw::mdp::Message &reply) { + opencmw::client::Command command; + command.command = opencmw::mdp::Command::Subscribe; + command.topic = opencmw::URI<>(std::format("http://localhost:{}/clock", kServerPort)); + command.callback = [&nReceived, &msLatency](const opencmw::mdp::Message &reply) { UpdateTime replyData; opencmw::IoBuffer buffer = reply.data; opencmw::deserialise(buffer, replyData); - auto now = std::chrono::system_clock::now(); - auto latency = now.time_since_epoch() - std::chrono::milliseconds(replyData.updateTime); + auto now = std::chrono::high_resolution_clock::now(); + auto latency = now.time_since_epoch() - std::chrono::microseconds(replyData.updateTimeµs); nReceived++; nReceived.notify_all(); - msLatency.fetch_add(static_cast(std::chrono::duration_cast(latency).count())); - std::print("Received {}th update with a latency of {} ms.\n", nReceived.load(), std::chrono::duration_cast(latency).count()); + msLatency.fetch_add(static_cast(std::chrono::duration_cast(latency).count())); + std::print("Received {}th update with a latency of {} µs.\n", nReceived.load(), std::chrono::duration_cast(latency).count()); }; - client.request(_command); + client.request(std::move(command)); std::print("waiting for 40 samples to be received\n"); int n = nReceived; @@ -375,55 +769,7 @@ TEST_CASE("Subscription latencies", "[majordomo][majordomoworker][rest]") { } } - std::print("Received {} updates with an average latency of {} ms.\n", nReceived.load(), nReceived > 0 ? static_cast(msLatency) / nReceived : 0.0); + std::print("Received {} updates with an average latency of {} µs.\n", nReceived.load(), nReceived > 0 ? static_cast(msLatency) / nReceived : 0.0); REQUIRE(nReceived > 10); - REQUIRE(static_cast(msLatency) / nReceived < 20); -} - -TEST_CASE("Majordomo timeouts", "[majordomo][majordomoworker][rest]") { - majordomo::Broker broker("/TestBroker", testSettings()); - auto fs = cmrc::assets::get_filesystem(); - FileServerRestBackend rest(broker, fs); - RunInThread restServerRun(rest); - - opencmw::query::registerTypes(WaitingContext(), broker); - - WaitingWorker<"/waiter"> worker(broker); - - RunInThread brokerRun(broker); - RunInThread workerRun(worker); - - REQUIRE(waitUntilWorkerServiceAvailable(broker.context, worker)); - - // set timeout to unit-test friendly interval - rest.setMajordomoTimeout(800ms); - - SECTION("Waiting for notification that doesn't happen in time returns 504 message") { - std::vector clientThreads; - for (int i = 0; i < 16; ++i) { - clientThreads.push_back(makeGetRequestResponseCheckerThread("/waiter?LongPollingIdx=Next", { "Timeout" }, { 504 })); - } - } - - SECTION("Waiting for notification that happens in time gives expected response") { - auto client = makeGetRequestResponseCheckerThread("/waiter?LongPollingIdx=Next", { "This is a notification" }); - std::this_thread::sleep_for(400ms); - worker.notify({}, { "This is a notification" }); - } - - SECTION("Response to request takes too long, timeout status is returned") { - httplib::Client postData{ "http://localhost:8080" }; - auto reply = postData.Post("/waiter?contentType=application%2Fjson&timeoutMs=1200", "{\"value\": \"Hello!\"}", "application/json"); - REQUIRE(reply); - REQUIRE(reply->status == 504); - REQUIRE(reply->body.find("No response") != std::string::npos); - } - - SECTION("Response to request arrives in time") { - httplib::Client postData{ "http://localhost:8080" }; - auto reply = postData.Post("/waiter?contentType=application%2Fjson&timeoutMs=0", "{\"value\": \"Hello!\"}", "application/json"); - REQUIRE(reply); - REQUIRE(reply->status == 200); - REQUIRE(reply->body.find("You said: Hello!") != std::string::npos); - } + REQUIRE(static_cast(msLatency) / nReceived < 20000); // unit is µs } diff --git a/src/rest/CMakeLists.txt b/src/rest/CMakeLists.txt new file mode 100644 index 00000000..755d0f5b --- /dev/null +++ b/src/rest/CMakeLists.txt @@ -0,0 +1,20 @@ +add_library(rest INTERFACE + include/rest/RestUtils.hpp +) + +target_include_directories(rest INTERFACE $ $) +target_link_libraries(rest + INTERFACE + nghttp2-static + nghttp3-static + ngtcp2-static + ngtcp2-crypto-ossl-static + openssl-ssl-static + openssl-crypto-static +) + +install( + TARGETS rest + EXPORT opencmwTargets + PUBLIC_HEADER DESTINATION include/opencmw +) diff --git a/src/rest/include/rest/RestUtils.hpp b/src/rest/include/rest/RestUtils.hpp new file mode 100644 index 00000000..9e4f1425 --- /dev/null +++ b/src/rest/include/rest/RestUtils.hpp @@ -0,0 +1,473 @@ +#ifndef OPENCMW_NGHTTP2HELPERS_H +#define OPENCMW_NGHTTP2HELPERS_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +#ifdef OPENCMW_DEBUG_HTTP +#include +#define HTTP_DBG(...) std::println(std::cerr, __VA_ARGS__); +#else +#define HTTP_DBG(...) +#endif + +namespace opencmw::rest::detail { +using SSL_CTX_Ptr = std::unique_ptr; +using SSL_Ptr = std::unique_ptr; +using X509_STORE_Ptr = std::unique_ptr; +using X509_Ptr = std::unique_ptr; +using EVP_PKEY_Ptr = std::unique_ptr; + +inline int readCertificateBundleFromBuffer(X509_STORE &cert_store, const std::string_view &X509_ca_bundle) { + BIO *cbio = BIO_new_mem_buf(X509_ca_bundle.data(), static_cast(X509_ca_bundle.size())); + if (!cbio) { + return -1; + } + STACK_OF(X509_INFO) *inf = PEM_X509_INFO_read_bio(cbio, nullptr, nullptr, nullptr); + + if (!inf) { + BIO_free(cbio); // cleanup + return -1; + } + // iterate over all entries from the pem file, add them to the x509_store one by one + int count = 0; + for (int i = 0; i < sk_X509_INFO_num(inf); i++) { + X509_INFO *itmp = sk_X509_INFO_value(inf, i); + if (itmp->x509) { + X509_STORE_add_cert(&cert_store, itmp->x509); + count++; + } + if (itmp->crl) { + X509_STORE_add_crl(&cert_store, itmp->crl); + count++; + } + } + + sk_X509_INFO_pop_free(inf, X509_INFO_free); + BIO_free(cbio); + return count; +} + +inline std::expected createCertificateStore(std::string_view x509_ca_bundle) { + X509_STORE_Ptr cert_store = X509_STORE_Ptr(X509_STORE_new(), X509_STORE_free); + if (readCertificateBundleFromBuffer(*cert_store.get(), x509_ca_bundle) <= 0) { + return std::unexpected(std::format("failed to read certificate bundle from buffer:\n#---start---\n{}\n#---end---\n", x509_ca_bundle)); + } + return cert_store; +} + +inline std::expected readServerCertificateFromBuffer(std::string_view X509_ca_bundle) { + BIO *certBio = BIO_new(BIO_s_mem()); + BIO_write(certBio, X509_ca_bundle.data(), static_cast(X509_ca_bundle.size())); + auto certX509 = X509_Ptr(PEM_read_bio_X509(certBio, nullptr, nullptr, nullptr), X509_free); + BIO_free(certBio); + if (!certX509) { + return std::unexpected(std::format("failed to read certificate from buffer:\n#---start---\n{}\n#---end---\n", X509_ca_bundle)); + } + return certX509; +} + +inline std::expected readServerCertificateFromFile(std::filesystem::path fpath) { + auto path = fpath.string(); + BIO *certBio + = BIO_new_file(path.data(), "r"); + if (!certBio) { + return std::unexpected(std::format("failed to read certificate from file {}: {}", path, ERR_error_string(ERR_get_error(), nullptr))); + } + auto certX509 = X509_Ptr(PEM_read_bio_X509(certBio, nullptr, nullptr, nullptr), X509_free); + BIO_free(certBio); + if (!certX509) { + return std::unexpected(std::format("failed to read certificate key from file: {}", path)); + } + return certX509; +} + +inline std::expected readServerPrivateKeyFromBuffer(std::string_view x509_private_key) { + BIO *certBio = BIO_new(BIO_s_mem()); + BIO_write(certBio, x509_private_key.data(), static_cast(x509_private_key.size())); + EVP_PKEY_Ptr privateKeyX509 = EVP_PKEY_Ptr(PEM_read_bio_PrivateKey(certBio, nullptr, nullptr, nullptr), EVP_PKEY_free); + BIO_free(certBio); + if (!privateKeyX509) { + return std::unexpected(std::format("failed to read private key from buffer")); + } + return privateKeyX509; +} + +inline std::expected readServerPrivateKeyFromFile(std::filesystem::path fpath) { + auto path = fpath.string(); + BIO *certBio = BIO_new_file(path.data(), "r"); + if (!certBio) { + return std::unexpected(std::format("failed to read private key from file {}: {}", path, ERR_error_string(ERR_get_error(), nullptr))); + } + EVP_PKEY_Ptr privateKeyX509 = EVP_PKEY_Ptr(PEM_read_bio_PrivateKey(certBio, nullptr, nullptr, nullptr), EVP_PKEY_free); + BIO_free(certBio); + if (!privateKeyX509) { + return std::unexpected(std::format("failed to read private key from file: {}", path)); + } + return privateKeyX509; +} + +#ifdef OPENCMW_PROFILE_HTTP +inline std::chrono::nanoseconds latency(const std::string_view &value) { + const auto now = std::chrono::high_resolution_clock::now(); + const uint64_t ts = std::stoull(std::string(value)); + const auto parsed = std::chrono::time_point(std::chrono::nanoseconds(ts)); + return std::chrono::duration_cast(now - parsed); +} +#endif + +inline std::expected create_ssl(SSL_CTX *ssl_ctx) { + auto ssl = SSL_Ptr(SSL_new(ssl_ctx), SSL_free); + if (!ssl) { + return std::unexpected(std::format("Could not create TLS session object: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + return ssl; +} + +inline std::span u8span(std::string_view view) { + return { const_cast(reinterpret_cast(view.data())), view.size() }; +} + +inline nghttp2_nv nv(const std::span &name, const std::span &value, uint8_t flags = NGHTTP2_NV_FLAG_NO_COPY_NAME) { + return { name.data(), value.data(), name.size(), value.size(), flags }; +} + +inline nghttp3_nv nv3(const std::span &name, const std::span &value, uint8_t flags = NGHTTP3_NV_FLAG_NO_COPY_NAME) { + return { name.data(), value.data(), name.size(), value.size(), flags }; +} + +inline std::string_view as_view(nghttp2_rcbuf *rcbuf) { + auto vec = nghttp2_rcbuf_get_buf(rcbuf); + return { reinterpret_cast(vec.base), vec.len }; +} + +inline std::string_view as_view(const nghttp3_rcbuf *rcbuf) { + auto vec = nghttp3_rcbuf_get_buf(rcbuf); + return { reinterpret_cast(vec.base), vec.len }; +} + +// Convenience/RAII wrapper for a non-blocking/no-delay TCP socket that can be used with SSL or without. +struct TcpSocket { + using AddrinfoPtr = std::unique_ptr; + + enum Flags { + None = 0x0, + VerifyPeer = 0x1, + }; + + enum State { + Uninitialized, + Connecting, + SSLConnectWantsRead, + SSLConnectWantsWrite, + SSLAcceptWantsRead, + SSLAcceptWantsWrite, + Connected + }; + + int fd = -1; + int flags = None; + SSL_Ptr _ssl = SSL_Ptr(nullptr, SSL_free); + State _state = Uninitialized; + AddrinfoPtr address = AddrinfoPtr(nullptr, freeaddrinfo); + + TcpSocket() = default; + + TcpSocket(const TcpSocket &) = delete; + TcpSocket &operator=(const TcpSocket &) = delete; + TcpSocket(TcpSocket &&other) noexcept { + std::swap(fd, other.fd); + std::swap(flags, other.flags); + std::swap(_state, other._state); + _ssl = std::move(other._ssl); + } + TcpSocket &operator=(TcpSocket &&other) noexcept { + if (this != &other) { + close(); + std::swap(fd, other.fd); + std::swap(flags, other.flags); + std::swap(_state, other._state); + _ssl = std::move(other._ssl); + } + return *this; + } + + ~TcpSocket() { + close(); + } + + static std::expected create(SSL_Ptr ssl_, int fd_, int flags_ = VerifyPeer) { + if (fd_ == -1) { + return std::unexpected(std::format("Invalid socket file descriptor: {}", fd_)); + } + TcpSocket socket; + socket.fd = fd_; + socket.flags = flags_; + socket._ssl = std::move(ssl_); + int f = fcntl(socket.fd, F_GETFL, 0); + assert(f != -1); + fcntl(socket.fd, F_SETFL, f | O_NONBLOCK); + + int flag = 1; + setsockopt(socket.fd, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag)); + + if (socket._ssl) { + auto bio = BIO_new_socket(socket.fd, BIO_NOCLOSE); + if (!bio) { + return std::unexpected(std::format("Failed to create BIO object: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + // SSL will take ownership of the BIO object + SSL_set_bio(socket._ssl.get(), bio, bio); + } + return socket; + } + + std::string lastError() { + if (_ssl) { + return ERR_error_string(ERR_get_error(), nullptr); + } else { + return strerror(errno); + } + } + + std::expected continueHandshake() { + switch (_state) { + case Uninitialized: + return std::unexpected("Socket not initialized"); + case SSLConnectWantsRead: + case SSLConnectWantsWrite: { + int ret = SSL_connect(_ssl.get()); + if (ret == 1) { + _state = Connected; + return _state; + } + if (ret == 0) { + return std::unexpected(std::format("SSL handshake failed: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + int err = SSL_get_error(_ssl.get(), ret); + if (err == SSL_ERROR_WANT_READ) { + _state = SSLConnectWantsRead; + return _state; + } + if (err == SSL_ERROR_WANT_WRITE) { + _state = SSLConnectWantsWrite; + return _state; + } + return std::unexpected(std::format("SSL handshake failed: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + case SSLAcceptWantsRead: + case SSLAcceptWantsWrite: { + int ret = SSL_accept(_ssl.get()); + if (ret == 1) { + _state = Connected; + return _state; + } + if (ret == 0) { + return std::unexpected(std::format("SSL handshake failed: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + int err = SSL_get_error(_ssl.get(), ret); + if (err == SSL_ERROR_WANT_READ) { + _state = SSLAcceptWantsRead; + return _state; + } + if (err == SSL_ERROR_WANT_WRITE) { + _state = SSLAcceptWantsWrite; + return _state; + } + return std::unexpected(std::format("SSL handshake failed: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + case Connecting: + case Connected: + return _state; + } + + return _state; + } + + std::expected prepareConnect(std::string_view host, uint16_t port) { + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; // IPv4 + hints.ai_socktype = SOCK_STREAM; // TCP + + struct addrinfo *res; + int status = getaddrinfo(host.data(), nullptr, &hints, &res); + if (status != 0) { + return std::unexpected(std::format("Could not resolve address: {}", strerror(status))); + } + address = AddrinfoPtr(res, freeaddrinfo); + reinterpret_cast(address->ai_addr)->sin_port = htons(port); + + if (_ssl && (flags & VerifyPeer) != 0) { + // Use instead of SSL_set_tlsext_host_name() to avoid warning about (void*) cast + if (auto r = SSL_ctrl(_ssl.get(), SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, const_cast(host.data())); r != SSL_TLSEXT_ERR_OK && r != SSL_TLSEXT_ERR_ALERT_WARNING) { + return std::unexpected(std::format("Failed to set the TLS SNI hostname: {}", ERR_error_string(ERR_get_error(), nullptr))); + } + SSL_set_verify(_ssl.get(), SSL_VERIFY_PEER, nullptr); + } + + _state = Connecting; + return {}; + } + + std::expected connect() { + assert(address); + if (::connect(fd, address.get()->ai_addr, sizeof(sockaddr)) < 0) { + if (errno == EINPROGRESS) { + return {}; + } + + return std::unexpected(std::format("Connect failed: {}", strerror(errno))); + } + + if (!_ssl) { + _state = Connected; + return {}; + } + + _state = SSLConnectWantsWrite; + if (auto r = continueHandshake(); !r) { + return std::unexpected(r.error()); + } + return {}; + } + + std::expected, std::string> accept(SSL_CTX *ssl_ctx, int flags_) { + struct sockaddr_in client_addr; + socklen_t client_len = sizeof(client_addr); + int client_fd = ::accept(fd, static_cast(static_cast(&client_addr)), &client_len); + if (client_fd < 0) { + if (errno == EAGAIN) { + return {}; + } + return std::unexpected(std::format("Accept failed: {}", strerror(errno))); + } + + if (!ssl_ctx) { + auto maybeSocket = TcpSocket::create({ nullptr, SSL_free }, client_fd); + if (maybeSocket) { + maybeSocket->_state = Connected; + } + return maybeSocket; + } + + auto ssl = create_ssl(ssl_ctx); + if (!ssl) { + ::close(client_fd); + return std::unexpected(std::format("Failed to create SSL object: {}", ssl.error())); + } + + auto clientSocket = TcpSocket::create(std::move(ssl.value()), client_fd, flags_); + if (clientSocket) { + clientSocket->_state = SSLAcceptWantsRead; + } + return clientSocket; + } + + ssize_t read(uint8_t *data, std::size_t length) { + if (_ssl) { + return SSL_read(_ssl.get(), data, static_cast(length)); + } else { + return ::recv(fd, data, length, 0); + } + } + + ssize_t write(const uint8_t *data, std::size_t length, int wflags = 0) { + if (_ssl) { + return SSL_write(_ssl.get(), data, static_cast(length)); + } else { + return ::send(fd, data, length, wflags); + } + } + + void close() { + if (_ssl) { + SSL_shutdown(_ssl.get()); + _ssl.reset(); + } + if (fd != -1) { + ::close(fd); + fd = -1; + } + } +}; + +template +struct WriteBuffer { + std::vector buffer = std::vector(InitialCapacity); + std::size_t size = 0; + + bool hasData() const { + return size > 0; + } + + bool wantsToWrite(nghttp2_session *session) const { + return size > 0 || nghttp2_session_want_write(session); + } + + bool write(nghttp2_session *session, TcpSocket &socket) { + const uint8_t *chunk; + + while (size < Limit && nghttp2_session_want_write(session)) { + nghttp2_ssize len = nghttp2_session_mem_send2(session, &chunk); + if (len < 0) { + break; // out of memory, try again later + } + if (len == 0) { + continue; + } + const std::size_t newSize = size + static_cast(len); + if (newSize > buffer.size()) { + HTTP_DBG("Resizing buffer from {} to {}", buffer.size(), newSize); + buffer.resize(newSize); + } + + std::copy(chunk, chunk + len, buffer.data() + size); + size += static_cast(len); + } + + std::size_t written = 0; + while (size - written > 0) { + const ssize_t n = socket.write(buffer.data() + written, size - written); + if (n == 0 && errno == EAGAIN) { + break; + } + if (n <= 0) { + return false; + } + written += static_cast(n); + } + size -= written; + std::move(buffer.data() + written, buffer.data() + written + size, buffer.data()); + return true; + } +}; + +} // namespace opencmw::rest::detail + +#endif diff --git a/src/services/CMakeLists.txt b/src/services/CMakeLists.txt index c17721d0..61a4686c 100644 --- a/src/services/CMakeLists.txt +++ b/src/services/CMakeLists.txt @@ -13,7 +13,7 @@ target_include_directories(services INTERFACE $ #include - -#include - namespace opencmw::service::dns { using DnsWorkerType = majordomo::Worker<"/dns", Context, FlatEntryList, FlatEntryList, majordomo::description<"Register and Query Signals">>; @@ -36,57 +35,45 @@ class DnsHandler { } }; -Entry registerSignals(const std::vector &entries, std::string scheme_host_port = "http://localhost:8080") { - IoBuffer outBuffer; - FlatEntryList entrylist{ entries }; - opencmw::serialise(outBuffer, entrylist); - std::string contentType{ MIME::BINARY.typeName() }; - std::string body{ outBuffer.asString() }; - - // send request to register Signal - httplib::Client client{ - scheme_host_port - }; - - auto response = client.Post("dns", body, contentType); +inline Entry registerSignals(const std::vector &entries, std::string scheme_host_port = "http://localhost:8080") { + client::Command cmd; + cmd.command = mdp::Command::Set; + cmd.serviceName = "/dns"; + cmd.topic = URI<>::UriFactory{ URI<>(std::move(scheme_host_port)) }.path("/dns").build(); + opencmw::serialise(cmd.data, FlatEntryList{ entries }); - if (response.error() == httplib::Error::Read) throw std::runtime_error{ "Server did not send an answer" }; - if (response.error() != httplib::Error::Success || response->status == 500) throw std::runtime_error{ response->reason }; + client::RestClient client{ client::DefaultContentTypeHeader(MIME::BINARY) }; + auto reply = client.blockingRequest(std::move(cmd)); - // deserialise response - IoBuffer inBuffer; - inBuffer.put(response->body); + if (!reply.error.empty()) { + throw std::runtime_error{ reply.error }; + } FlatEntryList res; try { - opencmw::deserialise(inBuffer, res); + opencmw::deserialise(reply.data, res); } catch (const ProtocolException &exc) { throw std::runtime_error{ exc.what() }; // rethrowing, because ProtocolException behaves weird } - return res.toEntries().front(); } -std::vector querySignals(const Entry &a = {}, std::string scheme_host_port = "http://localhost:8080") { - // send request to register Service - httplib::Client client{ - scheme_host_port - }; +inline std::vector querySignals(const Entry &a = {}, std::string scheme_host_port = "http://localhost:8080") { + client::Command cmd; + cmd.command = mdp::Command::Get; + cmd.serviceName = "/dns"; + cmd.topic = URI<>::UriFactory{ URI<>(std::move(scheme_host_port)) }.path("/dns").addQueryParameter("service_name", a.service_name).addQueryParameter("signal_name", a.signal_name).addQueryParameter("signal_unit", a.signal_unit).addQueryParameter("signal_rate", std::to_string(a.signal_rate)).addQueryParameter("signal_type", a.signal_type).build(); - auto response = client.Get("dns", httplib::Params{ { "protocol", a.protocol }, { "hostname", a.hostname }, { "port", std::to_string(a.port) }, { "service_name", a.service_name }, { "service_type", a.service_type }, { "signal_name", a.signal_name }, { "signal_unit", a.signal_unit }, { "signal_rate", std::to_string(a.signal_rate) }, { "signal_type", a.signal_type } }, - httplib::Headers{ - { std::string{ "Content-Type" }, std::string{ MIME::BINARY.typeName() } } }); + client::RestClient client{ client::DefaultContentTypeHeader(MIME::BINARY) }; + auto reply = client.blockingRequest(std::move(cmd)); - if (response.error() == httplib::Error::Read) throw std::runtime_error{ "Server did not send an answer" }; - if (response.error() != httplib::Error::Success || response->status == 500) throw std::runtime_error{ response->reason }; - - // deserialise response - IoBuffer inBuffer; - inBuffer.put(response->body); + if (!reply.error.empty()) { + throw std::runtime_error{ reply.error }; + } FlatEntryList res; try { - opencmw::deserialise(inBuffer, res); + opencmw::deserialise(reply.data, res); } catch (const ProtocolException &exc) { throw std::runtime_error{ exc.what() }; // rethrowing, because ProtocolException behaves weird } diff --git a/src/services/include/services/dns_client.hpp b/src/services/include/services/dns_client.hpp index 196e6a38..a9508d2f 100644 --- a/src/services/include/services/dns_client.hpp +++ b/src/services/include/services/dns_client.hpp @@ -1,14 +1,13 @@ #ifndef DNS_CLIENT_HPP #define DNS_CLIENT_HPP -#include "Debug.hpp" +#include "ClientContext.hpp" #include "dns_types.hpp" #include "MdpMessage.hpp" -#include "RestClient.hpp" #include #include #include -#include +#include #include #include #include diff --git a/src/services/test/dns_tests.cpp b/src/services/test/dns_tests.cpp index 60f134c8..b56b94fd 100644 --- a/src/services/test/dns_tests.cpp +++ b/src/services/test/dns_tests.cpp @@ -3,15 +3,18 @@ #include #include -#include #include +#include + #ifndef __EMSCRIPTEN__ #include +#include +#include +#include // Concepts and tests use common types #include #include -#include #endif namespace fs = std::filesystem; @@ -211,12 +214,15 @@ TEST_CASE("data storage - Deleting Entries") { TEST_CASE("run services", "[DNS]") { FileDeleter fd; majordomo::Broker<> broker{ "/Broker", {} }; - std::string rootPath{ "./" }; - auto fs = cmrc::assets::get_filesystem(); - majordomo::RestBackend rest_backend{ broker, fs }; + majordomo::rest::Settings rest; + rest.protocols = majordomo::rest::Protocol::Http2; + const auto bound = broker.bindRest(rest); + if (!bound) { + FAIL(std::format("Failed to bind REST: {}", bound.error())); + } + REQUIRE(bound); DnsWorkerType dnsWorker{ broker, {} }; - RunInThread restThread(rest_backend); RunInThread brokerThread(broker); RunInThread dnsThread(dnsWorker); } @@ -224,13 +230,13 @@ TEST_CASE("run services", "[DNS]") { TEST_CASE("client", "[DNS]") { FileDeleter fd; majordomo::Broker<> broker{ "/Broker", {} }; - std::string rootPath{ "./" }; - auto fs = cmrc::assets::get_filesystem(); - majordomo::RestBackend rest_backend{ broker, fs }; + majordomo::rest::Settings rest; + rest.protocols = majordomo::rest::Protocol::Http2; + const auto bound = broker.bindRest(rest); + REQUIRE(bound); DnsWorkerType dnsWorker{ broker, DnsHandler{} }; broker.bind(URI<>{ "inproc://dns_server" }, majordomo::BindOption::Router); - RunInThread restThread(rest_backend); RunInThread dnsThread(dnsWorker); RunInThread brokerThread(broker); @@ -268,11 +274,12 @@ TEST_CASE("client", "[DNS]") { TEST_CASE("query", "[DNS]") { FileDeleter fd; majordomo::Broker<> broker{ "/Broker", {} }; - auto fs = cmrc::assets::get_filesystem(); - majordomo::RestBackend rest_backend{ broker, fs }; - DnsWorkerType dnsWorker{ broker, DnsHandler{} }; + majordomo::rest::Settings rest; + rest.protocols = majordomo::rest::Protocol::Http2; + const auto bound = broker.bindRest(rest); + REQUIRE(bound); - RunInThread restThread(rest_backend); + DnsWorkerType dnsWorker{ broker, DnsHandler{} }; RunInThread brokerThread(broker); RunInThread dnsThread(dnsWorker); @@ -308,11 +315,11 @@ TEST_CASE("query", "[DNS]") { TEST_CASE("client unregister entries", "[DNS]") { FileDeleter fd; majordomo::Broker<> broker{ "/Broker", {} }; - auto fs = cmrc::assets::get_filesystem(); - majordomo::RestBackend rest_backend{ broker, fs }; + majordomo::rest::Settings rest; + rest.protocols = majordomo::rest::Protocol::Http2; + const auto bound = broker.bindRest(rest); + REQUIRE(bound); DnsWorkerType dnsWorker{ broker, DnsHandler{} }; - - RunInThread restThread(rest_backend); RunInThread brokerThread(broker); RunInThread dnsThread(dnsWorker); @@ -320,7 +327,7 @@ TEST_CASE("client unregister entries", "[DNS]") { std::vector> clients; clients.emplace_back(std::make_unique(broker.context, 20ms, "dnsTestClient")); - clients.emplace_back(std::make_unique(opencmw::client::DefaultContentTypeHeader(MIME::BINARY))); + clients.emplace_back(std::make_unique(client::DefaultContentTypeHeader(MIME::BINARY))); client::ClientContext clientContext{ std::move(clients) }; DnsClient restClient{ clientContext, URI<>{ "http://localhost:8080/dns" } }; restClient.registerSignals({ entry_a, entry_b, entry_c });