diff --git a/src/node_api.cc b/src/node_api.cc index 1638d096969826..15d2cbd1e3c876 100644 --- a/src/node_api.cc +++ b/src/node_api.cc @@ -196,7 +196,7 @@ inline napi_env NewEnv(v8::Local context, return result; } -class ThreadSafeFunction : public node::AsyncResource { +class ThreadSafeFunction { public: ThreadSafeFunction(v8::Local func, v8::Local resource, @@ -208,11 +208,12 @@ class ThreadSafeFunction : public node::AsyncResource { void* finalize_data_, napi_finalize finalize_cb_, napi_threadsafe_function_call_js call_js_cb_) - : AsyncResource(env_->isolate, - resource, - *v8::String::Utf8Value(env_->isolate, name)), + : async_resource(std::in_place, + env_->isolate, + resource, + *v8::String::Utf8Value(env_->isolate, name)), thread_count(thread_count_), - is_closing(false), + state(OPEN), dispatch_state(kDispatchIdle), context(context_), max_queue_size(max_queue_size_), @@ -226,76 +227,100 @@ class ThreadSafeFunction : public node::AsyncResource { env->Ref(); } - ~ThreadSafeFunction() override { - node::RemoveEnvironmentCleanupHook(env->isolate, Cleanup, this); - env->Unref(); - } + ~ThreadSafeFunction() { ReleaseResources(); } // These methods can be called from any thread. napi_status Push(void* data, napi_threadsafe_function_call_mode mode) { - node::Mutex::ScopedLock lock(this->mutex); + { + node::Mutex::ScopedLock lock(this->mutex); - while (queue.size() >= max_queue_size && max_queue_size > 0 && - !is_closing) { - if (mode == napi_tsfn_nonblocking) { - return napi_queue_full; + while (queue.size() >= max_queue_size && max_queue_size > 0 && + state == OPEN) { + if (mode == napi_tsfn_nonblocking) { + return napi_queue_full; + } + cond->Wait(lock); } - cond->Wait(lock); - } - if (is_closing) { + if (state == OPEN) { + queue.push(data); + Send(); + return napi_ok; + } if (thread_count == 0) { return napi_invalid_arg; - } else { - thread_count--; + } + thread_count--; + if (!(state == CLOSED && thread_count == 0)) { return napi_closing; } - } else { - queue.push(data); - Send(); - return napi_ok; } + // Make sure to release lock before destroying + delete this; + return napi_closing; } napi_status Acquire() { node::Mutex::ScopedLock lock(this->mutex); - if (is_closing) { - return napi_closing; - } + if (state == OPEN) { + thread_count++; - thread_count++; + return napi_ok; + } - return napi_ok; + return napi_closing; } napi_status Release(napi_threadsafe_function_release_mode mode) { - node::Mutex::ScopedLock lock(this->mutex); + { + node::Mutex::ScopedLock lock(this->mutex); - if (thread_count == 0) { - return napi_invalid_arg; - } + if (thread_count == 0) { + return napi_invalid_arg; + } - thread_count--; + thread_count--; - if (thread_count == 0 || mode == napi_tsfn_abort) { - if (!is_closing) { - is_closing = (mode == napi_tsfn_abort); - if (is_closing && max_queue_size > 0) { - cond->Signal(lock); + if (thread_count == 0 || mode == napi_tsfn_abort) { + if (state == OPEN) { + if (mode == napi_tsfn_abort) { + state = CLOSING; + } + if (state == CLOSING && max_queue_size > 0) { + cond->Signal(lock); + } + Send(); } - Send(); } - } + if (!(state == CLOSED && thread_count == 0)) { + return napi_ok; + } + } + // Make sure to release lock before destroying + delete this; return napi_ok; } - void EmptyQueueAndDelete() { - for (; !queue.empty(); queue.pop()) { - call_js_cb(nullptr, nullptr, context, queue.front()); + void EmptyQueueAndMaybeDelete() { + { + node::Mutex::ScopedLock lock(this->mutex); + for (; !queue.empty(); queue.pop()) { + call_js_cb(nullptr, nullptr, context, queue.front()); + } + if (thread_count > 0) { + // At this point this TSFN is effectively done, but we need to keep + // it alive for other threads that still have pointers to it until + // they release them. + // But we already release all the resources that we can at this point + queue = {}; + ReleaseResources(); + return; + } } + // Make sure to release lock before destroying delete this; } @@ -347,6 +372,16 @@ class ThreadSafeFunction : public node::AsyncResource { inline void* Context() { return context; } protected: + void ReleaseResources() { + if (state != CLOSED) { + state = CLOSED; + ref.Reset(); + node::RemoveEnvironmentCleanupHook(env->isolate, Cleanup, this); + env->Unref(); + async_resource.reset(); + } + } + void Dispatch() { bool has_more = true; @@ -375,9 +410,7 @@ class ThreadSafeFunction : public node::AsyncResource { { node::Mutex::ScopedLock lock(this->mutex); - if (is_closing) { - CloseHandlesAndMaybeDelete(); - } else { + if (state == OPEN) { size_t size = queue.size(); if (size > 0) { data = queue.front(); @@ -391,7 +424,7 @@ class ThreadSafeFunction : public node::AsyncResource { if (size == 0) { if (thread_count == 0) { - is_closing = true; + state = CLOSING; if (max_queue_size > 0) { cond->Signal(lock); } @@ -400,12 +433,14 @@ class ThreadSafeFunction : public node::AsyncResource { } else { has_more = true; } + } else { + CloseHandlesAndMaybeDelete(); } } if (popped_value) { v8::HandleScope scope(env->isolate); - CallbackScope cb_scope(this); + AsyncResource::CallbackScope cb_scope(&*async_resource); napi_value js_callback = nullptr; if (!ref.IsEmpty()) { v8::Local js_cb = @@ -422,17 +457,17 @@ class ThreadSafeFunction : public node::AsyncResource { void Finalize() { v8::HandleScope scope(env->isolate); if (finalize_cb) { - CallbackScope cb_scope(this); + AsyncResource::CallbackScope cb_scope(&*async_resource); env->CallFinalizer(finalize_cb, finalize_data, context); } - EmptyQueueAndDelete(); + EmptyQueueAndMaybeDelete(); } void CloseHandlesAndMaybeDelete(bool set_closing = false) { v8::HandleScope scope(env->isolate); if (set_closing) { node::Mutex::ScopedLock lock(this->mutex); - is_closing = true; + state = CLOSING; if (max_queue_size > 0) { cond->Signal(lock); } @@ -497,19 +532,30 @@ class ThreadSafeFunction : public node::AsyncResource { } private: + // Needed because node::AsyncResource::CallbackScope is protected + class AsyncResource : public node::AsyncResource { + public: + using node::AsyncResource::AsyncResource; + using node::AsyncResource::CallbackScope; + }; + + enum State : unsigned char { OPEN, CLOSING, CLOSED }; + static const unsigned char kDispatchIdle = 0; static const unsigned char kDispatchRunning = 1 << 0; static const unsigned char kDispatchPending = 1 << 1; static const unsigned int kMaxIterationCount = 1000; + std::optional async_resource; + // These are variables protected by the mutex. node::Mutex mutex; std::unique_ptr cond; std::queue queue; uv_async_t async; size_t thread_count; - bool is_closing; + State state; std::atomic_uchar dispatch_state; // These are variables set once, upon creation, and then never again, which diff --git a/test/node-api/test_threadsafe_function_shutdown/binding.cc b/test/node-api/test_threadsafe_function_shutdown/binding.cc new file mode 100644 index 00000000000000..5b2541cb5dbcbc --- /dev/null +++ b/test/node-api/test_threadsafe_function_shutdown/binding.cc @@ -0,0 +1,80 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template +inline auto call(const char *name, Args &&...args) -> R { + napi_status status; + if constexpr (std::is_same_v) { + status = func(std::forward(args)...); + if (status == napi_ok) { + return; + } + } else { + R ret; + status = func(std::forward(args)..., &ret); + if (status == napi_ok) { + return ret; + } + } + std::fprintf(stderr, "%s: %d\n", name, status); + std::abort(); +} + +#define NAPI_CALL(ret_type, func, ...) \ + call(#func, ##__VA_ARGS__) + +void thread_func(napi_threadsafe_function tsfn) { + fprintf(stderr, "thread_func: starting\n"); + auto status = + napi_call_threadsafe_function(tsfn, nullptr, napi_tsfn_blocking); + while (status == napi_ok) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + status = napi_call_threadsafe_function(tsfn, nullptr, napi_tsfn_blocking); + } + fprintf(stderr, "thread_func: Got status %d, exiting...\n", status); +} + +void tsfn_callback(napi_env env, napi_value js_cb, void *ctx, void *data) { + if (env == nullptr) { + fprintf(stderr, "tsfn_callback: env=%p\n", env); + } +} + +void tsfn_finalize(napi_env env, void *finalize_data, void *finalize_hint) { + fprintf(stderr, "tsfn_finalize: env=%p\n", env); +} + +std::vector threads; + +auto run(napi_env env, napi_callback_info info) -> napi_value { + auto global = NAPI_CALL(napi_value, napi_get_global, env); + auto undefined = NAPI_CALL(napi_value, napi_get_undefined, env); + auto n_threads = 32; + auto tsfn = + NAPI_CALL(napi_threadsafe_function, napi_create_threadsafe_function, env, + nullptr, global, undefined, 0, n_threads, nullptr, + tsfn_finalize, nullptr, tsfn_callback); + for (auto i = 0; i < n_threads; ++i) { + threads.emplace_back([tsfn] { thread_func(tsfn); }); + } + NAPI_CALL(void, napi_unref_threadsafe_function, env, tsfn); + return NAPI_CALL(napi_value, napi_get_undefined, env); +} + +napi_value init(napi_env env, napi_value exports) { + return NAPI_CALL(napi_value, napi_create_function, env, nullptr, 0, + run, nullptr); +} + +NAPI_MODULE(NODE_GYP_MODULE_NAME, init) diff --git a/test/node-api/test_threadsafe_function_shutdown/binding.gyp b/test/node-api/test_threadsafe_function_shutdown/binding.gyp new file mode 100644 index 00000000000000..eb08b447a94a86 --- /dev/null +++ b/test/node-api/test_threadsafe_function_shutdown/binding.gyp @@ -0,0 +1,11 @@ +{ + "targets": [ + { + "target_name": "binding", + "sources": ["binding.cc"], + "cflags_cc": ["--std=c++20"], + 'cflags!': [ '-fno-exceptions', '-fno-rtti' ], + 'cflags_cc!': [ '-fno-exceptions', '-fno-rtti' ], + } + ] +} diff --git a/test/node-api/test_threadsafe_function_shutdown/test.js b/test/node-api/test_threadsafe_function_shutdown/test.js new file mode 100644 index 00000000000000..e1777dd1e6a292 --- /dev/null +++ b/test/node-api/test_threadsafe_function_shutdown/test.js @@ -0,0 +1,17 @@ +'use strict'; + +const common = require('../../common'); +const process = require('process') +const assert = require('assert'); +const { fork } = require('child_process'); +const binding = require(`./build/${common.buildType}/binding`); + +if (process.argv[2] === 'child') { + binding(); + setTimeout(() => {}, 100); +} else { + const child = fork(__filename, ['child']); + child.on('close', (code) => { + assert.strictEqual(code, 0); + }); +}