From 224407d86914a359306031824053f6a2ba5aeb0e Mon Sep 17 00:00:00 2001 From: apple1417 Date: Fri, 20 Jun 2025 14:12:42 +1200 Subject: [PATCH] make `DLLSafeCallback` safer idk why I've left it this awkward to use for so long, now acts more like a normal c++ type also worked out how to create this template based off of a `std::function`, so that we don't need to put the dll safe versions in the header --- src/unrealsdk/commands.cpp | 24 +++---- src/unrealsdk/commands.h | 3 +- src/unrealsdk/hook_manager.cpp | 51 +++++++------- src/unrealsdk/hook_manager.h | 3 +- src/unrealsdk/utils.h | 121 +++++++++++++++++++++++---------- 5 files changed, 121 insertions(+), 81 deletions(-) diff --git a/src/unrealsdk/commands.cpp b/src/unrealsdk/commands.cpp index c0c24b1..11518a9 100644 --- a/src/unrealsdk/commands.cpp +++ b/src/unrealsdk/commands.cpp @@ -6,10 +6,10 @@ namespace unrealsdk::commands { namespace { -#ifndef UNREALSDK_IMPORTING - -utils::StringViewMap commands{}; +using DLLSafeCallback = utils::DLLSafeCallback; +#ifndef UNREALSDK_IMPORTING +utils::StringViewMap commands{}; #endif } // namespace @@ -18,10 +18,10 @@ utils::StringViewMap commands{}; const std::wstring NEXT_LINE{}; #ifdef UNREALSDK_SHARED -UNREALSDK_CAPI(bool, add_command, const wchar_t* cmd, size_t size, DLLSafeCallback* callback); +UNREALSDK_CAPI(bool, add_command, const wchar_t* cmd, size_t size, DLLSafeCallback&& callback); #endif #ifndef UNREALSDK_IMPORTING -UNREALSDK_CAPI(bool, add_command, const wchar_t* cmd, size_t size, DLLSafeCallback* callback) { +UNREALSDK_CAPI(bool, add_command, const wchar_t* cmd, size_t size, DLLSafeCallback&& callback) { std::wstring lower_cmd(size, '\0'); std::transform(cmd, cmd + size, lower_cmd.begin(), &std::towlower); @@ -29,14 +29,14 @@ UNREALSDK_CAPI(bool, add_command, const wchar_t* cmd, size_t size, DLLSafeCallba return false; } - commands.emplace(std::move(lower_cmd), callback); + commands.emplace(std::move(lower_cmd), std::move(callback)); return true; } #endif bool add_command(std::wstring_view cmd, const Callback& callback) { // NOLINTNEXTLINE(cppcoreguidelines-owning-memory) - return UNREALSDK_MANGLE(add_command)(cmd.data(), cmd.size(), new DLLSafeCallback(callback)); + return UNREALSDK_MANGLE(add_command)(cmd.data(), cmd.size(), {callback}); } #ifdef UNREALSDK_SHARED @@ -67,9 +67,7 @@ UNREALSDK_CAPI(bool, remove_command, const wchar_t* cmd, size_t size) { return false; } - iter->second->destroy(); commands.erase(iter); - return true; } #endif @@ -101,12 +99,10 @@ bool is_command_valid(std::wstring_view line, bool direct_user_input) { void run_command(std::wstring_view line) { auto iter = commands.find(NEXT_LINE); if (iter != commands.end()) { - auto callback = iter->second; + auto callback = std::move(iter->second); commands.erase(iter); - callback->operator()(line.data(), line.size(), 0); - - callback->destroy(); + callback(line.data(), line.size(), 0); return; } @@ -122,7 +118,7 @@ void run_command(std::wstring_view line) { std::wstring cmd(cmd_end - non_space, '\0'); std::transform(non_space, cmd_end, cmd.begin(), &std::towlower); - commands.at(cmd)->operator()(line.data(), line.size(), cmd_end - line.begin()); + commands.at(cmd)(line.data(), line.size(), cmd_end - line.begin()); } #endif diff --git a/src/unrealsdk/commands.h b/src/unrealsdk/commands.h index 5393577..d514211 100644 --- a/src/unrealsdk/commands.h +++ b/src/unrealsdk/commands.h @@ -36,8 +36,7 @@ extern const std::wstring NEXT_LINE; * points to the first whitespace char after the command (or off the end of the * string if there was none). 0 in the case of a `NEXT_LINE` match. */ -using DLLSafeCallback = utils::DLLSafeCallback; -using Callback = DLLSafeCallback::InnerFunc; +using Callback = std::function; /** * @brief Adds a custom console command. diff --git a/src/unrealsdk/hook_manager.cpp b/src/unrealsdk/hook_manager.cpp index 46577ad..8266915 100644 --- a/src/unrealsdk/hook_manager.cpp +++ b/src/unrealsdk/hook_manager.cpp @@ -12,9 +12,13 @@ using namespace unrealsdk::unreal; +namespace unrealsdk::hook_manager { + +using DLLSafeCallback = utils::DLLSafeCallback; + #pragma region Implementation #ifndef UNREALSDK_IMPORTING -namespace unrealsdk::hook_manager::impl { +namespace impl { /* The fact that hooks run arbitrary user provided callbacks means our data structure can get modified @@ -81,7 +85,7 @@ struct Node { std::wstring full_name; Type type; std::wstring identifier; - DLLSafeCallback* callback; + DLLSafeCallback callback; // Using shared pointers because it's easy // Since we use std::make_shared, we're not really wasting allocations, but as a future @@ -96,22 +100,12 @@ struct Node { std::wstring_view full_name, Type type, std::wstring_view identifier, - DLLSafeCallback* callback) + DLLSafeCallback&& callback) : fname(fname), full_name(full_name), type(type), identifier(identifier), - callback(callback) {} - Node(const Node&) = default; - Node(Node&&) noexcept = default; - Node& operator=(const Node&) = default; - Node& operator=(Node&&) noexcept = default; - ~Node() { - if (this->callback != nullptr) { - this->callback->destroy(); - this->callback = nullptr; - } - } + callback(std::move(callback)) {} }; namespace { @@ -181,7 +175,7 @@ FName extract_func_obj_name(std::wstring_view func) { bool add_hook(std::wstring_view func, Type type, std::wstring_view identifier, - DLLSafeCallback* callback) { + DLLSafeCallback&& callback) { auto fname = extract_func_obj_name(func); auto hash_idx = get_table_index(fname); @@ -189,7 +183,7 @@ bool add_hook(std::wstring_view func, if (node == nullptr) { // This function isn't in the hash table, can just add directly. hooks_hash_table.at(hash_idx) = - std::make_shared(fname, func, type, identifier, callback); + std::make_shared(fname, func, type, identifier, std::move(callback)); return true; } @@ -197,7 +191,8 @@ bool add_hook(std::wstring_view func, while (node->fname != fname) { if (node->next_collision == nullptr) { // We found a collision, but nothing matched our name, so add it to the end - node->next_collision = std::make_shared(fname, func, type, identifier, callback); + node->next_collision = + std::make_shared(fname, func, type, identifier, std::move(callback)); return true; } node = node->next_collision; @@ -207,7 +202,8 @@ bool add_hook(std::wstring_view func, while (node->full_name != func) { if (node->next_function == nullptr) { // We found another function with the same fname, but nothing matches the full name - node->next_function = std::make_shared(fname, func, type, identifier, callback); + node->next_function = + std::make_shared(fname, func, type, identifier, std::move(callback)); return true; } node = node->next_function; @@ -217,7 +213,8 @@ bool add_hook(std::wstring_view func, while (node->type != type) { if (node->next_type == nullptr) { // We found the right function, but it doesn't have any hooks of this type yet - node->next_type = std::make_shared(fname, func, type, identifier, callback); + node->next_type = + std::make_shared(fname, func, type, identifier, std::move(callback)); return true; } node = node->next_type; @@ -228,7 +225,7 @@ bool add_hook(std::wstring_view func, if (node->next_in_collection == nullptr) { // Didn't find a matching identifier, add our new hook at the end node->next_in_collection = - std::make_shared(fname, func, type, identifier, callback); + std::make_shared(fname, func, type, identifier, std::move(callback)); } node = node->next_in_collection; } @@ -503,7 +500,7 @@ bool run_hooks_of_type(std::shared_ptr node, Type type, Details& hook) { bool ret = false; for (; node != nullptr; node = node->next_in_collection) { try { - ret |= node->callback->operator()(hook); + ret |= node->callback(hook); } catch (const std::exception& ex) { LOG(ERROR, "An exception occurred during hook processing"); LOG(ERROR, L"Function: {}", hook.func.func->get_path_name()); @@ -514,14 +511,13 @@ bool run_hooks_of_type(std::shared_ptr node, Type type, Details& hook) { return ret; } -} // namespace unrealsdk::hook_manager::impl +} // namespace impl #endif #pragma endregion // ================================================================================================= #pragma region Public Interface -namespace unrealsdk::hook_manager { #ifdef UNREALSDK_SHARED UNREALSDK_CAPI(void, log_all_calls, bool should_log); @@ -555,7 +551,7 @@ UNREALSDK_CAPI(bool, Type type, const wchar_t* identifier, size_t identifier_size, - DLLSafeCallback* callback); + DLLSafeCallback&& callback); #endif #ifndef UNREALSDK_IMPORTING UNREALSDK_CAPI(bool, @@ -565,8 +561,9 @@ UNREALSDK_CAPI(bool, Type type, const wchar_t* identifier, size_t identifier_size, - DLLSafeCallback* callback) { - return impl::add_hook({func, func_size}, type, {identifier, identifier_size}, callback); + DLLSafeCallback&& callback) { + return impl::add_hook({func, func_size}, type, {identifier, identifier_size}, + std::move(callback)); } #endif @@ -576,7 +573,7 @@ bool add_hook(std::wstring_view func, const Callback& callback) { // NOLINTBEGIN(cppcoreguidelines-owning-memory) return UNREALSDK_MANGLE(add_hook)(func.data(), func.size(), type, identifier.data(), - identifier.size(), new DLLSafeCallback(callback)); + identifier.size(), {callback}); // NOLINTEND(cppcoreguidelines-owning-memory) } diff --git a/src/unrealsdk/hook_manager.h b/src/unrealsdk/hook_manager.h index a45fa67..7cce2f6 100644 --- a/src/unrealsdk/hook_manager.h +++ b/src/unrealsdk/hook_manager.h @@ -55,8 +55,7 @@ struct Details { * will not be run. * In post-hooks: ignored. */ -using DLLSafeCallback = utils::DLLSafeCallback; -using Callback = DLLSafeCallback::InnerFunc; +using Callback = std::function; /** * @brief Toggles logging all unreal function calls. Best used in short bursts for debugging. diff --git a/src/unrealsdk/utils.h b/src/unrealsdk/utils.h index 1e0b531..0ca275d 100644 --- a/src/unrealsdk/utils.h +++ b/src/unrealsdk/utils.h @@ -68,59 +68,103 @@ struct IteratorProxy { /** * @brief A class used to safely pass callbacks, which may be lambdas or other complex callables, * across dll boundaries safely. + * @note This is used to address the fact that std::function may differ between STL implementations. * @note You should never have to use this directly, the wrappers should convert everything for you. * - * @warning When using this type to implement a wrapper, it must be handled with great care. It - * relies on two dangerous rules to work properly: - * - It must never be stored or passed by value, only ever by pointer. - * - Before destroying the pointer, you must manually call the `destroy` method. - * + * @tparam F The std::function type this callback is based on. * @tparam R The return type. May be void. * @tparam As The argument types. */ +template +struct DLLSafeCallback; template -// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) -struct DLLSafeCallback { - /// The inner type of the function this callback runs. - using InnerFunc = std::function; - +struct DLLSafeCallback> { private: - // We can't trust an actual virtual function table to have a consistent layout - e.g. clang - // uses two destructors (freeing and non-freeing), MSVC only uses one (with a bool arg) - // Instead, create our own manually. - - // Note MSVC relies on `noexcept(false)` to allow exceptions to cross dll boundaries. - struct PseudoVFTable { - void (*destroy)(DLLSafeCallback* self) noexcept(false); - R (*call)(DLLSafeCallback* self, As... args) noexcept(false); + /* + This Inner type is what does most of the work. It has some very non-standard semantics, so the + outer DLLSafeCallback just safely wraps a pointer around it. + + On constructing a new callback, we'll construct a new Inner struct in the current DLL - using + the current std::function layout and current allocator. We then rely on "virtual" functions + whenever we interact with it, to make sure it's always code from the original DLL. This is + critical because the inner struct's layout can be completely different between implementations, + we can't even rely on it being a consistent size. + + One complication is we can't trust an actual virtual function table to have a consistent layout. + For example, clang uses two destructors (freeing and non-freeing), while MSVC only uses one + (with a bool arg). We also need to make sure we avoid devirtualization, especially with LTO. To + handle these, we have to create our own vftable manually. + */ + struct Inner { + public: + struct PseudoVFTable { + // MSVC relies on `noexcept(false)` to allow exceptions to cross dll boundaries. + void (*destroy)(Inner* self) noexcept(false); + R (*call)(Inner* self, As... args) noexcept(false); + }; + + private: + // NOLINTNEXTLINE(cppcoreguidelines-owning-memory) + static void destroy(Inner* self) noexcept(false) { delete self; } + static R call(Inner* self, As... args) noexcept(false) { + return self->func(std::forward(args)...); + } + + static const constexpr PseudoVFTable DEFAULT_VFTABLE = { + &Inner::destroy, + &Inner::call, + }; + + public: + // Volatile means assume can be changed externally - e.g. because it was assigned by a + // different dll. Otherwise optimization might try devirtualize it. + volatile const PseudoVFTable* vftable = &DEFAULT_VFTABLE; + + private: + // Must be after the vftable, since we can't rely on it's size. + std::function func; + + public: + Inner(std::function func) : func(func) {} }; - // NOLINTNEXTLINE(cppcoreguidelines-owning-memory) - static void destroy(DLLSafeCallback* self) noexcept(false) { delete self; } - static R call(DLLSafeCallback* self, As... args) noexcept(false) { return self->func(args...); } - - static const constexpr PseudoVFTable DEFAULT_VFTABLE = { - &DLLSafeCallback::destroy, - &DLLSafeCallback::call, - }; - // Use volatile to prevent optimization from calling our copies of the function, it has to go - // through the vftable - volatile const PseudoVFTable* vftable = &DEFAULT_VFTABLE; - - InnerFunc func; + Inner* inner; public: /** - * @brief Constructs a new safe callback from a function. + * @brief Constructs a new dll safe callback. * * @param func The function to wrap. */ - DLLSafeCallback(InnerFunc func) : func(func) {} + DLLSafeCallback(std::function func) : inner(new Inner(std::move(func))) {} + DLLSafeCallback(DLLSafeCallback&& other) noexcept + : inner(std::exchange(other.inner, nullptr)) {} /** - * @brief Destroys the callback. + * @brief Assigns to the dll safe callback. + * + * @param other The other callback to assign from. + * @return A reference to this callback. */ - void destroy(void) { this->vftable->destroy(this); } + DLLSafeCallback& operator=(DLLSafeCallback&& other) noexcept { + std::swap(this->inner, other.inner); + } + + /** + * @brief Destroys the dll safe callback. + */ + ~DLLSafeCallback() { + if (this->inner != nullptr) { + this->inner->vftable->destroy(this->inner); + // The original dll has safely destroyed the inner struct, so we're free to "leak" it + this->inner = nullptr; + } + } + + // No copy construction/assignment. Haven't really needed it yet, and would require implementing + // some form of reference counting. + DLLSafeCallback(const DLLSafeCallback&) = delete; + DLLSafeCallback& operator=(const DLLSafeCallback& other) = delete; /** * @brief Runs the callback. @@ -128,7 +172,12 @@ struct DLLSafeCallback { * @param args The callback args. * @return The return value of the callback */ - R operator()(As... args) { return this->vftable->call(this, args...); } + R operator()(As... args) { + if (this->inner == nullptr) { + throw std::runtime_error("tried to run a null callback!"); + } + return this->inner->vftable->call(this->inner, std::forward(args)...); + } }; /**