Skip to content

[WIP] node-api: fix data race and use-after-free in napi_threadsafe_function #55877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 98 additions & 52 deletions src/node_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ inline napi_env NewEnv(v8::Local<v8::Context> context,
return result;
}

class ThreadSafeFunction : public node::AsyncResource {
class ThreadSafeFunction {
public:
ThreadSafeFunction(v8::Local<v8::Function> func,
v8::Local<v8::Object> resource,
Expand All @@ -208,11 +208,12 @@ class ThreadSafeFunction : public node::AsyncResource {
void* finalize_data_,
napi_finalize finalize_cb_,
napi_threadsafe_function_call_js call_js_cb_)
: AsyncResource(env_->isolate,
resource,
*v8::String::Utf8Value(env_->isolate, name)),
: async_resource(std::in_place,
env_->isolate,
resource,
*v8::String::Utf8Value(env_->isolate, name)),
thread_count(thread_count_),
is_closing(false),
state(OPEN),
dispatch_state(kDispatchIdle),
context(context_),
max_queue_size(max_queue_size_),
Expand All @@ -226,76 +227,100 @@ class ThreadSafeFunction : public node::AsyncResource {
env->Ref();
}

~ThreadSafeFunction() override {
node::RemoveEnvironmentCleanupHook(env->isolate, Cleanup, this);
env->Unref();
}
~ThreadSafeFunction() { ReleaseResources(); }

// These methods can be called from any thread.

napi_status Push(void* data, napi_threadsafe_function_call_mode mode) {
node::Mutex::ScopedLock lock(this->mutex);
{
node::Mutex::ScopedLock lock(this->mutex);

while (queue.size() >= max_queue_size && max_queue_size > 0 &&
!is_closing) {
if (mode == napi_tsfn_nonblocking) {
return napi_queue_full;
while (queue.size() >= max_queue_size && max_queue_size > 0 &&
state == OPEN) {
if (mode == napi_tsfn_nonblocking) {
return napi_queue_full;
}
cond->Wait(lock);
}
cond->Wait(lock);
}

if (is_closing) {
if (state == OPEN) {
queue.push(data);
Send();
return napi_ok;
}
if (thread_count == 0) {
return napi_invalid_arg;
} else {
thread_count--;
}
thread_count--;
if (!(state == CLOSED && thread_count == 0)) {
return napi_closing;
}
} else {
queue.push(data);
Send();
return napi_ok;
}
// Make sure to release lock before destroying
delete this;
return napi_closing;
}

napi_status Acquire() {
node::Mutex::ScopedLock lock(this->mutex);

if (is_closing) {
return napi_closing;
}
if (state == OPEN) {
thread_count++;

thread_count++;
return napi_ok;
}

return napi_ok;
return napi_closing;
}

napi_status Release(napi_threadsafe_function_release_mode mode) {
node::Mutex::ScopedLock lock(this->mutex);
{
node::Mutex::ScopedLock lock(this->mutex);

if (thread_count == 0) {
return napi_invalid_arg;
}
if (thread_count == 0) {
return napi_invalid_arg;
}

thread_count--;
thread_count--;

if (thread_count == 0 || mode == napi_tsfn_abort) {
if (!is_closing) {
is_closing = (mode == napi_tsfn_abort);
if (is_closing && max_queue_size > 0) {
cond->Signal(lock);
if (thread_count == 0 || mode == napi_tsfn_abort) {
if (state == OPEN) {
if (mode == napi_tsfn_abort) {
state = CLOSING;
}
if (state == CLOSING && max_queue_size > 0) {
cond->Signal(lock);
}
Send();
}
Send();
}
}

if (!(state == CLOSED && thread_count == 0)) {
return napi_ok;
}
}
// Make sure to release lock before destroying
delete this;
return napi_ok;
}

void EmptyQueueAndDelete() {
for (; !queue.empty(); queue.pop()) {
call_js_cb(nullptr, nullptr, context, queue.front());
void EmptyQueueAndMaybeDelete() {
{
node::Mutex::ScopedLock lock(this->mutex);
for (; !queue.empty(); queue.pop()) {
call_js_cb(nullptr, nullptr, context, queue.front());
}
if (thread_count > 0) {
// At this point this TSFN is effectively done, but we need to keep
// it alive for other threads that still have pointers to it until
// they release them.
// But we already release all the resources that we can at this point
queue = {};
ReleaseResources();
return;
}
}
// Make sure to release lock before destroying
delete this;
}

Expand Down Expand Up @@ -347,6 +372,16 @@ class ThreadSafeFunction : public node::AsyncResource {
inline void* Context() { return context; }

protected:
void ReleaseResources() {
if (state != CLOSED) {
state = CLOSED;
ref.Reset();
node::RemoveEnvironmentCleanupHook(env->isolate, Cleanup, this);
env->Unref();
async_resource.reset();
}
}

void Dispatch() {
bool has_more = true;

Expand Down Expand Up @@ -375,9 +410,7 @@ class ThreadSafeFunction : public node::AsyncResource {

{
node::Mutex::ScopedLock lock(this->mutex);
if (is_closing) {
CloseHandlesAndMaybeDelete();
} else {
if (state == OPEN) {
size_t size = queue.size();
if (size > 0) {
data = queue.front();
Expand All @@ -391,7 +424,7 @@ class ThreadSafeFunction : public node::AsyncResource {

if (size == 0) {
if (thread_count == 0) {
is_closing = true;
state = CLOSING;
if (max_queue_size > 0) {
cond->Signal(lock);
}
Expand All @@ -400,12 +433,14 @@ class ThreadSafeFunction : public node::AsyncResource {
} else {
has_more = true;
}
} else {
CloseHandlesAndMaybeDelete();
}
}

if (popped_value) {
v8::HandleScope scope(env->isolate);
CallbackScope cb_scope(this);
AsyncResource::CallbackScope cb_scope(&*async_resource);
napi_value js_callback = nullptr;
if (!ref.IsEmpty()) {
v8::Local<v8::Function> js_cb =
Expand All @@ -422,17 +457,17 @@ class ThreadSafeFunction : public node::AsyncResource {
void Finalize() {
v8::HandleScope scope(env->isolate);
if (finalize_cb) {
CallbackScope cb_scope(this);
AsyncResource::CallbackScope cb_scope(&*async_resource);
env->CallFinalizer<false>(finalize_cb, finalize_data, context);
}
EmptyQueueAndDelete();
EmptyQueueAndMaybeDelete();
}

void CloseHandlesAndMaybeDelete(bool set_closing = false) {
v8::HandleScope scope(env->isolate);
if (set_closing) {
node::Mutex::ScopedLock lock(this->mutex);
is_closing = true;
state = CLOSING;
if (max_queue_size > 0) {
cond->Signal(lock);
}
Expand Down Expand Up @@ -497,19 +532,30 @@ class ThreadSafeFunction : public node::AsyncResource {
}

private:
// Needed because node::AsyncResource::CallbackScope is protected
class AsyncResource : public node::AsyncResource {
public:
using node::AsyncResource::AsyncResource;
using node::AsyncResource::CallbackScope;
};

enum State : unsigned char { OPEN, CLOSING, CLOSED };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow these guidelines for the enum names: https://google.github.io/styleguide/cppguide.html#Enumerator_Names
The Node.js guidelines https://github.com/nodejs/node/blob/main/doc/contributing/cpp-style-guide.md are based on the Google C++ Style Guide.


static const unsigned char kDispatchIdle = 0;
static const unsigned char kDispatchRunning = 1 << 0;
static const unsigned char kDispatchPending = 1 << 1;

static const unsigned int kMaxIterationCount = 1000;

std::optional<AsyncResource> async_resource;

// These are variables protected by the mutex.
node::Mutex mutex;
std::unique_ptr<node::ConditionVariable> cond;
std::queue<void*> queue;
uv_async_t async;
size_t thread_count;
bool is_closing;
State state;
std::atomic_uchar dispatch_state;

// These are variables set once, upon creation, and then never again, which
Expand Down
80 changes: 80 additions & 0 deletions test/node-api/test_threadsafe_function_shutdown/binding.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#include <js_native_api.h>
#include <node_api.h>
#include <node_api_types.h>

#include <cstdio>
#include <cstdlib>
#include <memory>
#include <mutex>
#include <shared_mutex>
#include <thread>
#include <type_traits>
#include <utility>
#include <vector>

template <typename R, auto func, typename... Args>
inline auto call(const char *name, Args &&...args) -> R {
napi_status status;
if constexpr (std::is_same_v<R, void>) {
status = func(std::forward<Args>(args)...);
if (status == napi_ok) {
return;
}
} else {
R ret;
status = func(std::forward<Args>(args)..., &ret);
if (status == napi_ok) {
return ret;
}
}
std::fprintf(stderr, "%s: %d\n", name, status);
std::abort();
}

#define NAPI_CALL(ret_type, func, ...) \
call<ret_type, func>(#func, ##__VA_ARGS__)

void thread_func(napi_threadsafe_function tsfn) {
fprintf(stderr, "thread_func: starting\n");
auto status =
napi_call_threadsafe_function(tsfn, nullptr, napi_tsfn_blocking);
while (status == napi_ok) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
status = napi_call_threadsafe_function(tsfn, nullptr, napi_tsfn_blocking);
}
fprintf(stderr, "thread_func: Got status %d, exiting...\n", status);
}

void tsfn_callback(napi_env env, napi_value js_cb, void *ctx, void *data) {
if (env == nullptr) {
fprintf(stderr, "tsfn_callback: env=%p\n", env);
}
}

void tsfn_finalize(napi_env env, void *finalize_data, void *finalize_hint) {
fprintf(stderr, "tsfn_finalize: env=%p\n", env);
}

std::vector<std::jthread> threads;

auto run(napi_env env, napi_callback_info info) -> napi_value {
auto global = NAPI_CALL(napi_value, napi_get_global, env);
auto undefined = NAPI_CALL(napi_value, napi_get_undefined, env);
auto n_threads = 32;
auto tsfn =
NAPI_CALL(napi_threadsafe_function, napi_create_threadsafe_function, env,
nullptr, global, undefined, 0, n_threads, nullptr,
tsfn_finalize, nullptr, tsfn_callback);
for (auto i = 0; i < n_threads; ++i) {
threads.emplace_back([tsfn] { thread_func(tsfn); });
}
NAPI_CALL(void, napi_unref_threadsafe_function, env, tsfn);
return NAPI_CALL(napi_value, napi_get_undefined, env);
}

napi_value init(napi_env env, napi_value exports) {
return NAPI_CALL(napi_value, napi_create_function, env, nullptr, 0,
run, nullptr);
}

NAPI_MODULE(NODE_GYP_MODULE_NAME, init)
11 changes: 11 additions & 0 deletions test/node-api/test_threadsafe_function_shutdown/binding.gyp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"targets": [
{
"target_name": "binding",
"sources": ["binding.cc"],
"cflags_cc": ["--std=c++20"],
'cflags!': [ '-fno-exceptions', '-fno-rtti' ],
'cflags_cc!': [ '-fno-exceptions', '-fno-rtti' ],
}
]
}
17 changes: 17 additions & 0 deletions test/node-api/test_threadsafe_function_shutdown/test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
'use strict';

const common = require('../../common');
const process = require('process')

Check failure on line 4 in test/node-api/test_threadsafe_function_shutdown/test.js

View workflow job for this annotation

GitHub Actions / lint-js-and-md

Missing semicolon
const assert = require('assert');
const { fork } = require('child_process');
const binding = require(`./build/${common.buildType}/binding`);

if (process.argv[2] === 'child') {
binding();
setTimeout(() => {}, 100);
} else {
const child = fork(__filename, ['child']);
child.on('close', (code) => {
assert.strictEqual(code, 0);
});
}
Loading