diff --git a/definitions/net/client.luau b/definitions/net/client.luau index f2a80d28d..f9c74589d 100644 --- a/definitions/net/client.luau +++ b/definitions/net/client.luau @@ -18,4 +18,21 @@ function client.request(url: string, metadata: Metadata?): Response error("not implemented") end +export type WebSocketOptions = { + headers: { [string]: string }?, + onopen: (() -> ())?, + onmessage: ((message: string | buffer) -> ())?, + onclose: ((code: number, reason: string) -> ())?, + onerror: ((error: string) -> ())?, +} + +export type WebSocket = { + send: (self: WebSocket, data: string | buffer) -> (), + close: (self: WebSocket) -> (), +} + +function client.websocket(url: string, options: WebSocketOptions?): WebSocket + error("not implemented") +end + return client diff --git a/definitions/net/init.luau b/definitions/net/init.luau index e41ff84e1..f6d32ed27 100644 --- a/definitions/net/init.luau +++ b/definitions/net/init.luau @@ -5,11 +5,15 @@ local net = {} export type Metadata = client.Metadata export type Response = client.Response +export type WebSocketOptions = client.WebSocketOptions +export type WebSocket = client.WebSocket export type ReceivedRequest = server.ReceivedRequest export type ServerResponse = server.ServerResponse export type Handler = server.Handler export type Configuration = server.Configuration export type Server = server.Server +export type ServerWebSocket = server.ServerWebSocket +export type WebSocketHandlers = server.WebSocketHandlers net.client = client net.server = server diff --git a/definitions/net/server.luau b/definitions/net/server.luau index af69f65f4..51b33a600 100644 --- a/definitions/net/server.luau +++ b/definitions/net/server.luau @@ -15,20 +15,34 @@ export type ServerResponse = string | { headers: { [string]: string }?, } -export type Handler = (request: ReceivedRequest) -> ServerResponse +export type ServerWebSocket = { + send: (self: ServerWebSocket, data: string | buffer) -> number, + close: (self: ServerWebSocket, code: number?, message: string?) -> (), +} -export type Configuration = { - hostname: string?, - port: number?, - reuseport: boolean?, - tls: { certfilename: string, keyfilename: string, passphrase: string?, cafilename: string? }?, - handler: Handler, +export type WebSocketHandlers = { + open: ((ws: ServerWebSocket) -> ())?, + message: ((ws: ServerWebSocket, message: string | buffer) -> ())?, + close: ((ws: ServerWebSocket, code: number, message: string) -> ())?, + drain: ((ws: ServerWebSocket) -> ())?, } export type Server = { hostname: string, port: number, close: () -> (), + upgrade: (self: Server, req: ReceivedRequest) -> boolean, +} + +export type Handler = (request: ReceivedRequest, server: Server) -> ServerResponse? + +export type Configuration = { + hostname: string?, + port: number?, + reuseport: boolean?, + tls: { certfilename: string, keyfilename: string, passphrase: string?, cafilename: string? }?, + handler: Handler?, + websocket: WebSocketHandlers?, } function server.serve(config: Handler | Configuration): Server diff --git a/examples/serve_websocket.luau b/examples/serve_websocket.luau new file mode 100644 index 000000000..543ccaa53 --- /dev/null +++ b/examples/serve_websocket.luau @@ -0,0 +1,43 @@ +local server = require("@lute/net/server") + +print("starting server on ws://127.0.0.1:3000") + +local instance = server.serve({ + hostname = "127.0.0.1", + port = 3000, + handler = function(req: server.ReceivedRequest, instance: server.Server): server.ServerResponse? + if instance:upgrade(req) then + return nil + end + + return { + status = 200, + body = "Hello over HTTP. Try websocket upgrade.", + } + end, + websocket = { + open = function(ws: server.ServerWebSocket) + print("ws open") + ws:send("welcome") + end, + message = function(ws, message) + if type(message) == "buffer" then + print("ws binary message len:", buffer.len(message)) + else + print("ws message:", message) + if message == "close" then + ws:close() + end + end + ws:send(message) + end, + close = function(_ws, code, message) + print("ws close:", code, message) + end, + drain = function(_ws) + print("ws drain") + end, + }, +}) + +print(`listening on {instance.hostname}:{instance.port}`) diff --git a/examples/websocket_echo.luau b/examples/websocket_echo.luau new file mode 100644 index 000000000..0fe4c25c5 --- /dev/null +++ b/examples/websocket_echo.luau @@ -0,0 +1,37 @@ +local client = require("@lute/net/client") + +print("connecting to echo server...") + +local _ws: client.WebSocket? +_ws = client.websocket("wss://echo.websocket.org", { + onopen = function() + print("websocket opened") + if _ws then + _ws:send("hello from lute over websockets") + _ws:send(buffer.create(4)) + _ws:send("close") + end + end, + onmessage = function(message) + if type(message) == "buffer" then + print("received binary message len:", buffer.len(message)) + return + end + + print("received:", message) + if message == "close" then + print("closing...") + if _ws then + _ws:close() + end + end + end, + onclose = function() + print("websocket closed") + end, + onerror = function(err) + print("websocket error:", err) + end, +}) + +print("connected, waiting for messages...") diff --git a/lute/net/CMakeLists.txt b/lute/net/CMakeLists.txt index 06d40a0d4..a6037e5e2 100644 --- a/lute/net/CMakeLists.txt +++ b/lute/net/CMakeLists.txt @@ -4,6 +4,7 @@ target_sources(Lute.Net PRIVATE include/lute/net.h src/client.cpp + src/client_websocket.cpp src/net.cpp src/server.cpp ) diff --git a/lute/net/src/client.cpp b/lute/net/src/client.cpp index ad96c9c26..7c3551bcb 100644 --- a/lute/net/src/client.cpp +++ b/lute/net/src/client.cpp @@ -2,6 +2,7 @@ #include "lute/common.h" #include "lute/runtime.h" +#include "lute/userdatas.h" #include "Luau/DenseHash.h" @@ -10,14 +11,22 @@ #include "curl/curl.h" +#include #include #include #include #include -namespace +namespace net::client { +struct WebSocketHandle; +int websocket(lua_State* L); +int ws_send(lua_State* L); +int ws_close(lua_State* L); +} +namespace +{ struct CurlHolder { CurlHolder() @@ -37,6 +46,48 @@ static CurlHolder& globalCurlInit() return holder; } +static void initializeNetClient(lua_State* L) +{ + luaL_newmetatable(L, "WebSocketHandle"); + + lua_pushcfunction( + L, + [](lua_State* L) + { + const char* index = luaL_checkstring(L, -1); + + if (strcmp(index, "send") == 0) + { + lua_pushcfunction(L, net::client::ws_send, "WebSocketHandle.send"); + return 1; + } + + if (strcmp(index, "close") == 0) + { + lua_pushcfunction(L, net::client::ws_close, "WebSocketHandle.close"); + return 1; + } + + return 0; + }, + "WebSocketHandle.__index" + ); + lua_setfield(L, -2, "__index"); + + lua_pushstring(L, "WebSocketHandle"); + lua_setfield(L, -2, "__type"); + + lua_setuserdatadtor( + L, + kWebSocketHandleTag, + [](lua_State*, void* ud) + { + std::destroy_at(static_cast*>(ud)); + } + ); + + lua_setuserdatametatable(L, kWebSocketHandleTag); +} } // namespace namespace net::client @@ -197,7 +248,7 @@ int request(lua_State* L) // TODO: add cancellations token->runtime->runInWorkQueue( - [=] + [url = std::move(url), method = std::move(method), body = std::move(body), headers = std::move(headers), token] { CurlResponse resp = requestData(url, method, body, headers); if (!resp.error.empty()) @@ -248,12 +299,14 @@ const char* const NetClient::properties[] = {nullptr}; const luaL_Reg NetClient::lib[] = { {"request", net::client::request}, + {"websocket", net::client::websocket}, {nullptr, nullptr}, }; int NetClient::pushLibrary(lua_State* L) { globalCurlInit(); + initializeNetClient(L); lua_createtable(L, 0, std::size(NetClient::lib)); diff --git a/lute/net/src/client_websocket.cpp b/lute/net/src/client_websocket.cpp new file mode 100644 index 000000000..6846ca6ec --- /dev/null +++ b/lute/net/src/client_websocket.cpp @@ -0,0 +1,792 @@ +#include "lute/runtime.h" +#include "lute/userdatas.h" + +#include "Luau/VecDeque.h" + +#include "lua.h" +#include "lualib.h" + +#include "curl/curl.h" +#include "curl/websockets.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace net::client +{ +struct WebSocketConnection; +struct WebSocketHandle; + +struct WebSocketPayload +{ + const char* data = nullptr; + size_t length = 0; + bool binary = false; +}; + +struct PendingSend +{ + std::string payload; + bool binary = false; + size_t offset = 0; +}; + +struct WebSocketPollState +{ + uv_poll_t handle{}; + std::shared_ptr owner; +}; + +static WebSocketPayload extractWebSocketPayload(lua_State* L, int index) +{ + if (lua_isstring(L, index)) + { + size_t length = 0; + const char* data = lua_tolstring(L, index, &length); + return {data, length, false}; + } + + if (lua_isbuffer(L, index)) + { + size_t length = 0; + void* data = lua_tobuffer(L, index, &length); + return {static_cast(data), length, true}; + } + + luaL_typeerrorL(L, index, "string or buffer"); + return {}; +} + +static void clearWebSocketPollState(WebSocketPollState*& pollState, int& activePollEvents) +{ + if (!pollState) + return; + + WebSocketPollState* state = pollState; + pollState = nullptr; + activePollEvents = 0; + + uv_poll_stop(&state->handle); + uv_close( + reinterpret_cast(&state->handle), + [](uv_handle_t* handle) + { + delete static_cast(handle->data); + } + ); +} + +static void pushWebSocketMessageToLua(lua_State* L, const std::string& message, bool binary) +{ + if (binary) + { + void* buffer = lua_newbuffer(L, message.size()); + if (!message.empty()) + memcpy(buffer, message.data(), message.size()); + return; + } + + lua_pushlstring(L, message.data(), message.size()); +} + +static std::pair parseClosePayload(const std::string& payload) +{ + int closeCode = 1000; + std::string closeReason; + + if (payload.size() >= 2) + { + const unsigned char* bytes = reinterpret_cast(payload.data()); + closeCode = int((bytes[0] << 8) | bytes[1]); + if (payload.size() > 2) + closeReason.assign(payload.data() + 2, payload.size() - 2); + } + + return {closeCode, std::move(closeReason)}; +} + +struct WebSocketConnection : std::enable_shared_from_this +{ + CURL* curl = nullptr; + curl_slist* headerList = nullptr; + WebSocketPollState* pollState = nullptr; + std::mutex curlMutex; + int activePollEvents = 0; + std::vector recvBuffer = std::vector(16 * 1024); + std::string currentMessage; + bool currentBinary = false; + bool hasCurrentMessage = false; + std::string closePayload; + Luau::VecDeque pendingSends; + std::atomic isClosed{false}; + std::weak_ptr owner; + std::shared_ptr keepAliveHandle; + + void setOwner(const std::shared_ptr& handle) + { + owner = handle; + keepAliveHandle = handle; + } + + void releaseKeepAliveHandle() + { + keepAliveHandle.reset(); + } + + bool closed() const + { + return isClosed.load(); + } + + bool startPolling(Runtime* runtime, curl_socket_t socket) + { + if (!runtime || socket == CURL_SOCKET_BAD) + return false; + + auto* state = new WebSocketPollState(); + state->owner = shared_from_this(); + state->handle.data = state; + + int initResult = uv_poll_init_socket(runtime->getEventLoop(), &state->handle, socket); + if (initResult != 0) + { + delete state; + return false; + } + + pollState = state; + + if (!updatePollingInterest()) + { + closeTransport(false); + return false; + } + + return true; + } + + bool applyPollingEvents(int events) + { + if (!pollState) + return false; + + if (activePollEvents == events) + return true; + + int startResult = uv_poll_start( + &pollState->handle, + events, + [](uv_poll_t* handle, int status, int events) + { + auto* state = static_cast(handle->data); + if (!state || !state->owner) + return; + + state->owner->handlePollEvent(status, events); + } + ); + + if (startResult != 0) + return false; + + activePollEvents = events; + return true; + } + + bool updatePollingInterest() + { + int events = UV_READABLE; + if (!pendingSends.empty()) + events |= UV_WRITABLE; + + return applyPollingEvents(events); + } + + void handlePollEvent(int status, int events) + { + if (status < 0) + { + closeWithError(uv_strerror(status)); + return; + } + + if (events & UV_READABLE) + processIncoming(); + + if (!closed() && (events & UV_WRITABLE)) + flushOutgoing(); + } + + void enqueueSend(std::string payload, bool binary) + { + if (closed()) + return; + + pendingSends.push_back({std::move(payload), binary, 0}); + flushOutgoing(); + } + + bool closeTransport(bool sendCloseFrame) + { + bool expected = false; + if (!isClosed.compare_exchange_strong(expected, true)) + return false; + + clearWebSocketPollState(pollState, activePollEvents); + pendingSends.clear(); + currentMessage.clear(); + hasCurrentMessage = false; + closePayload.clear(); + + { + std::lock_guard lock(curlMutex); + if (curl) + { + if (sendCloseFrame) + { + size_t sent = 0; + (void)curl_ws_send(curl, "", 0, &sent, 0, CURLWS_CLOSE); + } + + curl_easy_cleanup(curl); + curl = nullptr; + } + + if (headerList) + { + curl_slist_free_all(headerList); + headerList = nullptr; + } + } + + return true; + } + + void closeWithCode(int closeCode = 1000, std::string closeReason = "", bool sendCloseFrame = true) + { + if (!closeTransport(sendCloseFrame)) + return; + + notifyClose(closeCode, std::move(closeReason)); + } + + void closeWithError(std::string error) + { + if (!closeTransport(false)) + return; + + notifyError(std::move(error)); + notifyClose(1006, ""); + } + + void processIncoming() + { + while (!closed()) + { + size_t receivedLength = 0; + const curl_ws_frame* meta = nullptr; + CURLcode result = CURLE_OK; + + { + std::lock_guard lock(curlMutex); + if (!curl) + return; + + result = curl_ws_recv(curl, recvBuffer.data(), recvBuffer.size(), &receivedLength, &meta); + } + + if (closed()) + return; + + if (result == CURLE_AGAIN) + return; + + if (result != CURLE_OK) + { + closeWithError(curl_easy_strerror(result)); + return; + } + + if (!meta) + continue; + + if (meta->flags & CURLWS_CLOSE) + { + closePayload.append(recvBuffer.data(), receivedLength); + + if (meta->bytesleft != 0) + continue; + + auto [closeCode, closeReason] = parseClosePayload(closePayload); + closeWithCode(closeCode, std::move(closeReason)); + return; + } + + if (meta->flags & CURLWS_PING) + { + size_t sent = 0; + std::lock_guard lock(curlMutex); + if (curl) + (void)curl_ws_send(curl, recvBuffer.data(), receivedLength, &sent, 0, CURLWS_PONG); + continue; + } + + if (meta->flags & (CURLWS_TEXT | CURLWS_BINARY | CURLWS_CONT)) + { + if (meta->flags & (CURLWS_TEXT | CURLWS_BINARY)) + { + bool binary = (meta->flags & CURLWS_BINARY) != 0; + if (!hasCurrentMessage) + { + currentMessage.clear(); + currentBinary = binary; + hasCurrentMessage = true; + } + else if (currentBinary != binary) + { + closeWithError("websocket received mixed message types"); + return; + } + } + else if (!hasCurrentMessage) + { + currentMessage.clear(); + currentBinary = false; + hasCurrentMessage = true; + } + + currentMessage.append(recvBuffer.data(), receivedLength); + + if (meta->bytesleft == 0 && !(meta->flags & CURLWS_CONT)) + { + std::string message = std::move(currentMessage); + bool binary = currentBinary; + hasCurrentMessage = false; + + notifyMessage(std::move(message), binary); + } + } + } + } + + void flushOutgoing() + { + while (!pendingSends.empty() && !closed()) + { + PendingSend& pending = pendingSends.front(); + size_t sent = 0; + CURLcode result = CURLE_OK; + + { + std::lock_guard lock(curlMutex); + if (!curl) + { + pendingSends.clear(); + return; + } + + size_t remaining = pending.payload.size() - pending.offset; + const char* data = + remaining > 0 ? pending.payload.data() + pending.offset : pending.payload.data(); + + result = curl_ws_send(curl, data, remaining, &sent, 0, pending.binary ? CURLWS_BINARY : CURLWS_TEXT); + } + + if (result == CURLE_AGAIN) + break; + + if (result != CURLE_OK) + { + closeWithError("websocket send failed: " + std::string(curl_easy_strerror(result))); + return; + } + + pending.offset += sent; + + if (pending.payload.empty() || pending.offset >= pending.payload.size()) + { + pendingSends.pop_front(); + continue; + } + + if (sent == 0) + break; + } + + if (!closed() && !updatePollingInterest()) + closeWithError("failed to update websocket polling"); + } + + ~WebSocketConnection() + { + closeTransport(false); + } + + void notifyMessage(std::string message, bool binary); + void notifyClose(int closeCode, std::string closeReason); + void notifyError(std::string error); +}; + +struct WebSocketHandle : std::enable_shared_from_this +{ + Runtime* runtime = nullptr; + std::shared_ptr onOpenRef; + std::shared_ptr onMessageRef; + std::shared_ptr onCloseRef; + std::shared_ptr onErrorRef; + std::weak_ptr connection; + std::atomic hasScheduledClose{false}; + bool isActive = false; + + std::shared_ptr lockConnection() const + { + return connection.lock(); + } + + void attachConnection(const std::shared_ptr& newConnection) + { + connection = newConnection; + newConnection->setOwner(shared_from_this()); + } + + void activate() + { + isActive = true; + } + + bool closed() const + { + auto lockedConnection = lockConnection(); + return !lockedConnection || lockedConnection->closed(); + } + + void scheduleCallback(const std::shared_ptr& callback, std::function argPusher) + { + if (!isActive || !callback || !runtime) + return; + + runtime->scheduleLuauCallback(callback, std::move(argPusher)); + } + + void scheduleCloseCallback(int closeCode = 1000, std::string closeReason = "") + { + bool expected = false; + if (!isActive || !hasScheduledClose.compare_exchange_strong(expected, true)) + return; + + scheduleCallback( + onCloseRef, + [closeCode, closeReason = std::move(closeReason)](lua_State* L) + { + lua_pushinteger(L, closeCode); + lua_pushlstring(L, closeReason.data(), closeReason.size()); + return 2; + } + ); + } + + void handleMessage(std::string message, bool binary) + { + scheduleCallback( + onMessageRef, + [message = std::move(message), binary](lua_State* L) + { + pushWebSocketMessageToLua(L, message, binary); + return 1; + } + ); + } + + void handleClose(int closeCode, std::string closeReason) + { + scheduleCloseCallback(closeCode, std::move(closeReason)); + releaseConnectionKeepAlive(); + } + + void releaseConnectionKeepAlive() + { + if (auto lockedConnection = lockConnection()) + lockedConnection->releaseKeepAliveHandle(); + } + + void handleError(std::string error) + { + scheduleErrorCallback(std::move(error)); + scheduleCloseCallback(1006); + releaseConnectionKeepAlive(); + } + + void scheduleErrorCallback(std::string error) + { + scheduleCallback( + onErrorRef, + [error = std::move(error)](lua_State* L) + { + lua_pushlstring(L, error.data(), error.size()); + return 1; + } + ); + } + + void notifyOpen() + { + scheduleCallback( + onOpenRef, + [](lua_State*) + { + return 0; + } + ); + } + + void send(std::string payload, bool binary) + { + if (auto lockedConnection = lockConnection()) + lockedConnection->enqueueSend(std::move(payload), binary); + } + + void close() + { + closeWithCode(); + } + + void closeWithCode(int closeCode = 1000, std::string closeReason = "", bool sendCloseFrame = true) + { + if (!isActive) + { + if (auto lockedConnection = lockConnection()) + { + lockedConnection->closeTransport(false); + lockedConnection->releaseKeepAliveHandle(); + } + return; + } + + scheduleCloseCallback(closeCode, std::move(closeReason)); + + if (auto lockedConnection = lockConnection()) + { + lockedConnection->closeTransport(sendCloseFrame); + lockedConnection->releaseKeepAliveHandle(); + } + + releaseConnectionKeepAlive(); + } + + ~WebSocketHandle() + { + close(); + + onOpenRef.reset(); + onMessageRef.reset(); + onCloseRef.reset(); + onErrorRef.reset(); + } +}; + +void WebSocketConnection::notifyMessage(std::string message, bool binary) +{ + if (auto handle = owner.lock()) + handle->handleMessage(std::move(message), binary); +} + +void WebSocketConnection::notifyClose(int closeCode, std::string closeReason) +{ + if (auto handle = owner.lock()) + handle->handleClose(closeCode, std::move(closeReason)); +} + +void WebSocketConnection::notifyError(std::string error) +{ + if (auto handle = owner.lock()) + handle->handleError(std::move(error)); +} + +static std::shared_ptr* getWebSocketHandle(lua_State* L, int index) +{ + return static_cast*>(lua_touserdatatagged(L, index, kWebSocketHandleTag)); +} + +int ws_send(lua_State* L) +{ + if (lua_gettop(L) != 2) + luaL_errorL(L, "websocket send expects exactly 1 payload argument"); + + luaL_checktype(L, 1, LUA_TUSERDATA); + auto* handleStorage = getWebSocketHandle(L, 1); + if (!handleStorage || !(*handleStorage) || (*handleStorage)->closed()) + luaL_errorL(L, "Invalid or closed websocket"); + + WebSocketPayload payload = extractWebSocketPayload(L, 2); + (*handleStorage)->send(std::string(payload.data, payload.length), payload.binary); + return 0; +} + +int ws_close(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TUSERDATA); + auto* handleStorage = getWebSocketHandle(L, 1); + if (!handleStorage || !(*handleStorage)) + luaL_errorL(L, "Invalid websocket"); + + (*handleStorage)->close(); + return 0; +} + +int websocket(lua_State* L) +{ + std::string url = luaL_checkstring(L, 1); + std::vector> headers; + std::shared_ptr onOpenRef; + std::shared_ptr onMessageRef; + std::shared_ptr onCloseRef; + std::shared_ptr onErrorRef; + + if (lua_istable(L, 2)) + { + lua_getfield(L, 2, "headers"); + if (lua_istable(L, -1)) + { + lua_pushnil(L); + while (lua_next(L, -2)) + { + if (lua_isstring(L, -2) && lua_isstring(L, -1)) + headers.emplace_back(lua_tostring(L, -2), lua_tostring(L, -1)); + + lua_pop(L, 1); + } + } + lua_pop(L, 1); + + lua_getfield(L, 2, "onopen"); + if (lua_isfunction(L, -1)) + onOpenRef = std::make_shared(L, -1); + lua_pop(L, 1); + + lua_getfield(L, 2, "onmessage"); + if (lua_isfunction(L, -1)) + onMessageRef = std::make_shared(L, -1); + lua_pop(L, 1); + + lua_getfield(L, 2, "onclose"); + if (lua_isfunction(L, -1)) + onCloseRef = std::make_shared(L, -1); + lua_pop(L, 1); + + lua_getfield(L, 2, "onerror"); + if (lua_isfunction(L, -1)) + onErrorRef = std::make_shared(L, -1); + lua_pop(L, 1); + } + + auto token = getResumeToken(L); + Runtime* runtime = token->runtime; + + runtime->runInWorkQueue( + [token, + runtime, + url = std::move(url), + headers = std::move(headers), + onOpenRef = std::move(onOpenRef), + onMessageRef = std::move(onMessageRef), + onCloseRef = std::move(onCloseRef), + onErrorRef = std::move(onErrorRef)]() mutable + { + CURL* curl = curl_easy_init(); + if (!curl) + { + token->fail("failed to initialize websocket"); + return; + } + + curl_slist* headerList = nullptr; + + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); + curl_easy_setopt(curl, CURLOPT_CONNECT_ONLY, 2L); + + for (const auto& headerPair : headers) + { + std::string headerString = headerPair.first + ": " + headerPair.second; + headerList = curl_slist_append(headerList, headerString.c_str()); + } + + if (headerList) + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headerList); + + CURLcode result = curl_easy_perform(curl); + if (result != CURLE_OK) + { + std::string error = curl_easy_strerror(result); + if (headerList) + curl_slist_free_all(headerList); + curl_easy_cleanup(curl); + token->fail("websocket connect failed: " + error); + return; + } + + curl_socket_t socket = CURL_SOCKET_BAD; + if (curl_easy_getinfo(curl, CURLINFO_ACTIVESOCKET, &socket) != CURLE_OK || socket == CURL_SOCKET_BAD) + { + if (headerList) + curl_slist_free_all(headerList); + curl_easy_cleanup(curl); + token->fail("failed to get websocket socket"); + return; + } + + token->complete( + [runtime, + curl, + socket, + headerList, + onOpenRef = std::move(onOpenRef), + onMessageRef = std::move(onMessageRef), + onCloseRef = std::move(onCloseRef), + onErrorRef = std::move(onErrorRef)](lua_State* L) mutable + { + auto connection = std::make_shared(); + connection->curl = curl; + connection->headerList = headerList; + + auto handle = std::make_shared(); + handle->runtime = runtime; + handle->onOpenRef = std::move(onOpenRef); + handle->onMessageRef = std::move(onMessageRef); + handle->onCloseRef = std::move(onCloseRef); + handle->onErrorRef = std::move(onErrorRef); + handle->attachConnection(connection); + + if (!connection->startPolling(runtime, socket)) + { + connection->closeTransport(false); + connection->releaseKeepAliveHandle(); + luaL_errorL(L, "failed to initialize websocket polling"); + } + + handle->activate(); + + auto* storage = + new (static_cast*>(lua_newuserdatataggedwithmetatable( + L, + sizeof(std::shared_ptr), + kWebSocketHandleTag + ))) std::shared_ptr(handle); + (void)storage; + + handle->notifyOpen(); + return 1; + } + ); + } + ); + + return lua_yield(L, 0); +} +} // namespace net::client diff --git a/lute/net/src/server.cpp b/lute/net/src/server.cpp index 8cee07e32..cc0696f38 100644 --- a/lute/net/src/server.cpp +++ b/lute/net/src/server.cpp @@ -2,6 +2,7 @@ #include "lute/common.h" #include "lute/runtime.h" +#include "lute/userdatas.h" #include "Luau/DenseHash.h" #include "Luau/Variant.h" @@ -9,11 +10,12 @@ #include "lua.h" #include "lualib.h" -#include "uv.h" - #include +#include #include +#include #include +#include #include #include #include @@ -28,10 +30,19 @@ namespace net::server using uWSApp = Luau::Variant, std::unique_ptr>; +struct WebSocketPayload +{ + const char* data = nullptr; + size_t length = 0; + bool binary = false; +}; + static const int kEmptyServerKey = 0; static Luau::DenseHashMap serverInstances(kEmptyServerKey); static Luau::DenseHashMap> serverStates(kEmptyServerKey); static int nextServerId = 1; +static int kRequestUpgradeKey = 0; +static constexpr unsigned int kWebSocketMaxPayloadLength = 16 * 1024 * 1024; struct ServerLoopState { @@ -40,11 +51,89 @@ struct ServerLoopState bool running = true; std::function loopFunction; std::shared_ptr handlerRef; + std::shared_ptr serverRef; + std::shared_ptr wsOpenRef; + std::shared_ptr wsMessageRef; + std::shared_ptr wsCloseRef; + std::shared_ptr wsDrainRef; + bool hasWebSocket = false; std::string hostname; int port; bool reusePort = false; }; +template +struct PerSocketData; + +struct ServerWebSocketHandle +{ + void* wsPtr = nullptr; + std::atomic closed{false}; + std::shared_ptr userdataRef; + int (*sendFn)(void* wsPtr, std::string_view data, bool binary) = nullptr; + void (*closeFn)(void* wsPtr, uint16_t code, std::string_view message) = nullptr; +}; + +static WebSocketPayload extractWebSocketPayload(lua_State* L, int index) +{ + if (lua_isstring(L, index)) + { + size_t length = 0; + const char* data = lua_tolstring(L, index, &length); + return {data, length, false}; + } + + if (lua_isbuffer(L, index)) + { + size_t length = 0; + void* data = lua_tobuffer(L, index, &length); + return {static_cast(data), length, true}; + } + + luaL_typeerrorL(L, index, "string or buffer"); +} + +template +struct PerSocketData +{ + std::shared_ptr handle; +}; + +struct RequestRouteData +{ + std::string method; + std::string path; + std::string query; +}; + +template +static RequestRouteData extractRequestRouteData(ReqT* req) +{ + RequestRouteData route; + route.method = std::string(req->getMethod()); + std::transform( + route.method.begin(), + route.method.end(), + route.method.begin(), + [](unsigned char ch) + { + return char(std::toupper(ch)); + } + ); + + std::string_view url = req->getFullUrl(); + size_t queryPos = url.find('?'); + if (queryPos == std::string::npos) + { + route.path.assign(url.data(), url.size()); + return route; + } + + route.path.assign(url.data(), queryPos); + route.query.assign(url.data() + queryPos, url.size() - queryPos); + return route; +} + static void parseQuery(const std::string_view& query, lua_State* L) { lua_createtable(L, 0, 0); @@ -74,7 +163,8 @@ static void parseQuery(const std::string_view& query, lua_State* L) } } -static void parseHeaders(auto* req, lua_State* L) +template +static void parseHeaders(ReqT* req, lua_State* L) { lua_createtable(L, 0, 0); for (const auto& header : *req) @@ -172,33 +262,229 @@ static void handleResponse(auto* res, lua_State* L, int responseIndex) res->end(body); } -static void processRequest( +static void resumeWith( std::shared_ptr state, - auto* res, - auto* req, - const std::string& method, - const std::string_view& path, - const std::string_view& query, - const std::string_view& body + const std::shared_ptr& callback, + std::function argPusher ) { - lua_State* L = lua_newthread(state->runtime->GL); + if (!callback) + return; + + state->runtime->scheduleLuauCallback(callback, std::move(argPusher)); +} + +template +static bool performWebSocketUpgrade( + uWS::HttpResponse* res, + uWS::HttpRequest* req, + us_socket_context_t* context +) +{ + std::string_view key = req->getHeader("sec-websocket-key"); + std::string_view protocol = req->getHeader("sec-websocket-protocol"); + std::string_view extensions = req->getHeader("sec-websocket-extensions"); + + if (key.empty()) + return false; + + PerSocketData userData; + res->template upgrade>(std::move(userData), key, protocol, extensions, context); + return true; +} + +static int server_upgrade_noop(lua_State* L) +{ + lua_pushboolean(L, 0); + return 1; +} + +static int server_upgrade(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TTABLE); + luaL_checktype(L, 2, LUA_TTABLE); + + if (!lua_getmetatable(L, 2)) + { + lua_pushboolean(L, 0); + return 1; + } + + lua_pushlightuserdata(L, &kRequestUpgradeKey); + lua_rawget(L, -2); + if (!lua_isfunction(L, -1)) + { + lua_pop(L, 2); + lua_pushboolean(L, 0); + return 1; + } + + lua_call(L, 0, 1); + lua_remove(L, -2); + return 1; +} + +template +static int server_upgrade_do(lua_State* L) +{ + auto* res = static_cast*>(lua_touserdata(L, lua_upvalueindex(1))); + auto* req = static_cast(lua_touserdata(L, lua_upvalueindex(2))); + auto* context = static_cast(lua_touserdata(L, lua_upvalueindex(3))); + auto* upgradedPtr = static_cast(lua_touserdata(L, lua_upvalueindex(4))); + + if (!res || !req || !context || !upgradedPtr) + { + lua_pushboolean(L, 0); + return 1; + } + + bool upgraded = performWebSocketUpgrade(res, req, context); + *upgradedPtr = upgraded; + lua_pushboolean(L, upgraded); + return 1; +} + +template +static int wsSendImpl(void* wsPtr, std::string_view data, bool binary) +{ + auto* ws = static_cast>*>(wsPtr); + auto status = ws->send(data, binary ? uWS::OpCode::BINARY : uWS::OpCode::TEXT); + + if (status == decltype(status)::BACKPRESSURE) + return -1; + + if (status == decltype(status)::DROPPED) + return 0; + + return int(data.size() > 0 ? data.size() : 1); +} + +template +static void wsCloseImpl(void* wsPtr, uint16_t code, std::string_view message) +{ + auto* ws = static_cast>*>(wsPtr); + ws->end(int(code), message); +} + +static int server_ws_send(lua_State* L) +{ + if (lua_gettop(L) != 2) + luaL_errorL(L, "websocket send expects exactly 1 payload argument"); + + luaL_checktype(L, 1, LUA_TUSERDATA); + auto* handlePtr = + static_cast*>(lua_touserdatatagged(L, 1, kServerWebSocketHandleTag)); + if (!handlePtr || !(*handlePtr) || (*handlePtr)->closed.load()) + { + lua_pushinteger(L, 0); + return 1; + } + + WebSocketPayload payload = extractWebSocketPayload(L, 2); + + int result = 0; + if (!(*handlePtr)->closed.load() && (*handlePtr)->wsPtr && (*handlePtr)->sendFn) + result = (*handlePtr)->sendFn((*handlePtr)->wsPtr, std::string_view(payload.data, payload.length), payload.binary); + + lua_pushinteger(L, result); + return 1; +} + +static int server_ws_close(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TUSERDATA); + auto* handlePtr = + static_cast*>(lua_touserdatatagged(L, 1, kServerWebSocketHandleTag)); + if (!handlePtr || !(*handlePtr) || (*handlePtr)->closed.load() || !(*handlePtr)->wsPtr) + return 0; + + int code = 1000; + if (!lua_isnoneornil(L, 2)) + { + code = int(luaL_checkinteger(L, 2)); + if (code < 0 || code > std::numeric_limits::max()) + luaL_errorL(L, "invalid websocket close code %d", code); + } + + std::string message; + if (!lua_isnoneornil(L, 3)) + { + size_t messageLength = 0; + const char* messageData = luaL_checklstring(L, 3, &messageLength); + message.assign(messageData, messageLength); + } + + if (!(*handlePtr)->closed.load() && (*handlePtr)->wsPtr && (*handlePtr)->closeFn) + { + (*handlePtr)->closeFn((*handlePtr)->wsPtr, uint16_t(code), message); + } + + return 0; +} + +static void pushServerWebSocket( + lua_State* L, + const std::shared_ptr& handle, + const std::shared_ptr& retainedRef = nullptr +) +{ + if (retainedRef) + { + retainedRef->push(L); + return; + } + + if (handle->userdataRef) + { + handle->userdataRef->push(L); + return; + } + + auto* storage = + new (static_cast*>(lua_newuserdatataggedwithmetatable( + L, + sizeof(std::shared_ptr), + kServerWebSocketHandleTag + ))) std::shared_ptr(handle); + (void)storage; + handle->userdataRef = std::make_shared(L, -1); +} + +static lua_State* createHandlerThread(Runtime* runtime) +{ + LUTE_ASSERT(runtime); + + lua_State* L = lua_newthread(runtime->GL); luaL_sandboxthread(L); - std::shared_ptr threadRef = getRefForThread(L); - lua_pop(state->runtime->GL, 1); + lua_checkstack(L, 64); + getRefForThread(L); + lua_pop(runtime->GL, 1); + return L; +} +template +static void pushRequestTable( + lua_State* L, + ReqT* req, + const RequestRouteData& route, + std::string_view body, + lua_CFunction upgradeFn, + int nUpvalues, + PushUpvalues pushUpvalues +) +{ lua_createtable(L, 0, 5); lua_pushstring(L, "method"); - lua_pushstring(L, method.c_str()); + lua_pushstring(L, route.method.c_str()); lua_settable(L, -3); lua_pushstring(L, "path"); - lua_pushlstring(L, path.data(), path.size()); + lua_pushlstring(L, route.path.data(), route.path.size()); lua_settable(L, -3); lua_pushstring(L, "query"); - parseQuery(query, L); + parseQuery(route.query, L); lua_settable(L, -3); lua_pushstring(L, "headers"); @@ -209,12 +495,114 @@ static void processRequest( lua_pushlstring(L, body.data(), body.size()); lua_settable(L, -3); + int requestIndex = lua_absindex(L, -1); + lua_createtable(L, 0, 1); + lua_pushlightuserdata(L, &kRequestUpgradeKey); + pushUpvalues(L); + lua_pushcclosure(L, upgradeFn, "request.upgrade", nUpvalues); + lua_rawset(L, -3); + lua_setmetatable(L, requestIndex); +} + +template +static void pushServerTable( + lua_State* L, + const std::shared_ptr& serverRef, + lua_CFunction upgradeFn, + int nUpvalues, + PushUpvalues pushUpvalues +) +{ + if (serverRef) + serverRef->push(L); + else + lua_newtable(L); + + int serverBaseIndex = lua_absindex(L, -1); + + lua_createtable(L, 0, 1); + lua_createtable(L, 0, 1); + lua_pushvalue(L, serverBaseIndex); + lua_setfield(L, -2, "__index"); + lua_setmetatable(L, -2); + + pushUpvalues(L); + lua_pushcclosure(L, upgradeFn, "server.upgrade", nUpvalues); + lua_setfield(L, -2, "upgrade"); + + lua_remove(L, serverBaseIndex); +} + +template +static lua_State* prepareHttpHandlerThread( + const std::shared_ptr& state, + ReqT* req, + const RequestRouteData& route, + std::string_view body +) +{ + LUTE_ASSERT(state); + LUTE_ASSERT(state->runtime); + LUTE_ASSERT(state->handlerRef); + + lua_State* L = createHandlerThread(state->runtime); + + // `lua_resume(L, nullptr, 2)` expects the stack shape `[handler, request, server]`. state->handlerRef->push(L); + pushRequestTable(L, req, route, body, server_upgrade_noop, 0, [](lua_State*) {}); + pushServerTable(L, state->serverRef, server_upgrade_noop, 0, [](lua_State*) {}); + return L; +} - lua_pushvalue(L, -2); - lua_remove(L, -3); +template +static lua_State* prepareUpgradeHandlerThread( + const std::shared_ptr& state, + uWS::HttpResponse* res, + uWS::HttpRequest* req, + us_socket_context_t* context, + const RequestRouteData& route, + bool& upgraded +) +{ + LUTE_ASSERT(state); + LUTE_ASSERT(state->runtime); + LUTE_ASSERT(state->handlerRef); - int status = lua_resume(L, nullptr, 1); + auto pushUpgradeUpvalues = [res, req, context, &upgraded](lua_State* L) + { + lua_pushlightuserdata(L, res); + lua_pushlightuserdata(L, req); + lua_pushlightuserdata(L, context); + lua_pushlightuserdata(L, &upgraded); + }; + + lua_State* L = createHandlerThread(state->runtime); + + // `lua_resume(L, nullptr, 2)` expects the stack shape `[handler, request, server]`. + state->handlerRef->push(L); + pushRequestTable(L, req, route, std::string_view(""), server_upgrade_do, 4, pushUpgradeUpvalues); + pushServerTable(L, state->serverRef, server_upgrade_do, 4, pushUpgradeUpvalues); + return L; +} + +template +static void processRequest( + const std::shared_ptr& state, + ResT* res, + ReqT* req, + const RequestRouteData& route, + std::string_view body +) +{ + if (!state->handlerRef) + { + res->writeStatus("404 Not Found"); + res->end("No handler configured"); + return; + } + + lua_State* L = prepareHttpHandlerThread(state, req, route, body); + int status = lua_resume(L, nullptr, 2); if (status != LUA_OK && status != LUA_YIELD) { std::string error = lua_tostring(L, -1); @@ -226,28 +614,163 @@ static void processRequest( } handleResponse(res, L, -1); - lua_pop(L, 1); } -static void setupAppAndListen(auto* app, std::shared_ptr state, bool& success) +template +static void installWebSocketRoutes(AppT* app, const std::shared_ptr& state) +{ + if (!state->hasWebSocket) + return; + + typename uWS::TemplatedApp::template WebSocketBehavior> behavior{}; + behavior.maxPayloadLength = kWebSocketMaxPayloadLength; + behavior.upgrade = + [state](auto* res, auto* req, auto* context) + { + if (!state->handlerRef) + { + if (!performWebSocketUpgrade(res, req, context)) + { + res->writeStatus("426 Upgrade Required"); + res->end("WebSocket upgrade required"); + } + return; + } + + RequestRouteData route = extractRequestRouteData(req); + bool upgraded = false; + lua_State* L = prepareUpgradeHandlerThread(state, res, req, context, route, upgraded); + int status = lua_resume(L, nullptr, 2); + + if (status == LUA_YIELD) + { + lua_resetthread(L); + + if (!upgraded) + { + res->writeStatus("500 Internal Server Error"); + res->end("upgrade handler cannot yield"); + } + return; + } + + if (status != LUA_OK) + { + std::string error = lua_isstring(L, -1) ? lua_tostring(L, -1) : "Server error"; + if (!upgraded) + { + res->writeStatus("500 Internal Server Error"); + res->end("Server error: " + error); + } + lua_pop(L, 1); + return; + } + + if (!upgraded) + handleResponse(res, L, -1); + + lua_pop(L, 1); + }; + behavior.open = + [state](auto* ws) + { + auto* data = ws->getUserData(); + data->handle = std::make_shared(); + data->handle->wsPtr = ws; + data->handle->sendFn = &wsSendImpl; + data->handle->closeFn = &wsCloseImpl; + + resumeWith( + state, + state->wsOpenRef, + [handle = data->handle](lua_State* L) + { + pushServerWebSocket(L, handle); + return 1; + } + ); + }; + behavior.message = + [state](auto* ws, std::string_view message, uWS::OpCode opCode) + { + auto handle = ws->getUserData()->handle; + std::string payload(message.data(), message.size()); + bool binary = (opCode == uWS::OpCode::BINARY); + + resumeWith( + state, + state->wsMessageRef, + [handle, payload = std::move(payload), binary](lua_State* L) + { + pushServerWebSocket(L, handle); + if (binary) + { + void* buf = lua_newbuffer(L, payload.size()); + if (!payload.empty()) + memcpy(buf, payload.data(), payload.size()); + } + else + { + lua_pushlstring(L, payload.data(), payload.size()); + } + return 2; + } + ); + }; + behavior.drain = + [state](auto* ws) + { + auto handle = ws->getUserData()->handle; + + resumeWith( + state, + state->wsDrainRef, + [handle](lua_State* L) + { + pushServerWebSocket(L, handle); + return 1; + } + ); + }; + behavior.close = + [state](auto* ws, int code, std::string_view message) + { + auto handle = ws->getUserData()->handle; + std::shared_ptr userdataRef; + if (handle) + { + handle->closed.store(true); + handle->wsPtr = nullptr; + userdataRef = std::move(handle->userdataRef); + } + + std::string payload(message.data(), message.size()); + + resumeWith( + state, + state->wsCloseRef, + [handle, userdataRef = std::move(userdataRef), code, payload = std::move(payload)](lua_State* L) + { + pushServerWebSocket(L, handle, userdataRef); + lua_pushinteger(L, code); + lua_pushlstring(L, payload.data(), payload.size()); + return 3; + } + ); + }; + + app->template ws>("/*", std::move(behavior)); +} + +template +static void installHttpRoutes(AppT* app, const std::shared_ptr& state) { app->any( "/*", [state](auto* res, auto* req) { - std::string method = std::string(req->getMethod()); - std::transform(method.begin(), method.end(), method.begin(), ::toupper); - std::string_view url = req->getFullUrl(); - std::string_view path = url; - - size_t queryPos = url.find('?'); - std::string query; - if (queryPos != std::string::npos) - { - path = std::string_view(url.data(), queryPos); - query = std::string_view(url.data() + queryPos, url.size() - queryPos); - } + RequestRouteData route = extractRequestRouteData(req); res->onAborted( []() @@ -258,18 +781,21 @@ static void setupAppAndListen(auto* app, std::shared_ptr state, std::unique_ptr bodyBuffer; res->onData( - [state, res, req, method, path, query, bodyBuffer = std::move(bodyBuffer)](std::string_view data, bool last) mutable + [state, res, req, route = std::move(route), bodyBuffer = std::move(bodyBuffer)]( + std::string_view data, + bool last + ) mutable { if (last) { if (bodyBuffer.get()) { bodyBuffer->append(data); - processRequest(state, res, req, method, path, query, *bodyBuffer); + processRequest(state, res, req, route, *bodyBuffer); } else { - processRequest(state, res, req, method, path, query, data); + processRequest(state, res, req, route, data); } } else @@ -287,16 +813,22 @@ static void setupAppAndListen(auto* app, std::shared_ptr state, ); } ); +} +template +static void listenApp(AppT* app, const std::shared_ptr& state, bool& success) +{ int options = state->reusePort ? LIBUS_LISTEN_DEFAULT : LIBUS_LISTEN_EXCLUSIVE_PORT; app->listen( state->hostname, state->port, options, - [&success](auto* listen_socket) + [state, &success](auto* listen_socket) { success = (listen_socket != nullptr); + if (listen_socket) + state->port = us_socket_local_port(SSL, (struct us_socket_t*)listen_socket); } ); } @@ -339,7 +871,8 @@ int serve(lua_State* L) int port = 3000; bool reusePort = false; std::optional tlsOptions; - int handlerIndex = 1; + int handlerIndex = 0; + int websocketIndex = 0; if (lua_istable(L, 1)) { @@ -404,20 +937,40 @@ int serve(lua_State* L) lua_pop(L, 1); lua_getfield(L, 1, "handler"); - if (!lua_isfunction(L, -1)) + if (lua_isfunction(L, -1)) + { + handlerIndex = lua_gettop(L); + } + else { lua_pop(L, 1); - luaL_errorL(L, "handler function is required in config table"); + } + + lua_getfield(L, 1, "websocket"); + if (lua_istable(L, -1)) + { + websocketIndex = lua_gettop(L); + } + else + { + lua_pop(L, 1); + } + + if (handlerIndex == 0 && websocketIndex == 0) + { + luaL_errorL(L, "config table requires a handler function, websocket config, or both"); return 0; } - lua_insert(L, -1); - handlerIndex = lua_gettop(L); } else if (!lua_isfunction(L, 1)) { luaL_errorL(L, "serve requires a handler function or config table"); return 0; } + else + { + handlerIndex = 1; + } Runtime* runtime = getRuntime(L); @@ -429,9 +982,37 @@ int serve(lua_State* L) state->port = port; state->reusePort = reusePort; - lua_pushvalue(L, handlerIndex); - state->handlerRef = std::make_shared(L, -1); - lua_pop(L, 1); + if (handlerIndex != 0) + { + lua_pushvalue(L, handlerIndex); + state->handlerRef = std::make_shared(L, -1); + lua_pop(L, 1); + } + + if (websocketIndex != 0) + { + state->hasWebSocket = true; + + lua_getfield(L, websocketIndex, "open"); + if (lua_isfunction(L, -1)) + state->wsOpenRef = std::make_shared(L, -1); + lua_pop(L, 1); + + lua_getfield(L, websocketIndex, "message"); + if (lua_isfunction(L, -1)) + state->wsMessageRef = std::make_shared(L, -1); + lua_pop(L, 1); + + lua_getfield(L, websocketIndex, "close"); + if (lua_isfunction(L, -1)) + state->wsCloseRef = std::make_shared(L, -1); + lua_pop(L, 1); + + lua_getfield(L, websocketIndex, "drain"); + if (lua_isfunction(L, -1)) + state->wsDrainRef = std::make_shared(L, -1); + lua_pop(L, 1); + } uWSApp app; bool success = false; @@ -440,14 +1021,18 @@ int serve(lua_State* L) { auto ssl_app = std::make_unique(*tlsOptions); state->app = ssl_app.get(); - setupAppAndListen(ssl_app.get(), state, success); + installWebSocketRoutes(ssl_app.get(), state); + installHttpRoutes(ssl_app.get(), state); + listenApp(ssl_app.get(), state, success); app = std::move(ssl_app); } else { auto plain_app = std::make_unique(); state->app = plain_app.get(); - setupAppAndListen(plain_app.get(), state, success); + installWebSocketRoutes(plain_app.get(), state); + installHttpRoutes(plain_app.get(), state); + listenApp(plain_app.get(), state, success); app = std::move(plain_app); } @@ -460,14 +1045,14 @@ int serve(lua_State* L) serverInstances[serverId] = std::move(app); serverStates[serverId] = state; - lua_createtable(L, 0, 3); + lua_createtable(L, 0, 4); lua_pushstring(L, "hostname"); lua_pushstring(L, hostname.c_str()); lua_settable(L, -3); lua_pushstring(L, "port"); - lua_pushinteger(L, port); + lua_pushinteger(L, state->port); lua_settable(L, -3); lua_pushstring(L, "close"); @@ -487,11 +1072,60 @@ int serve(lua_State* L) ); lua_settable(L, -3); + lua_pushstring(L, "upgrade"); + lua_pushcfunction(L, server_upgrade, "server_upgrade"); + lua_settable(L, -3); + + state->serverRef = std::make_shared(L, -1); + return 1; } } // namespace net::server +static void initializeNetServer(lua_State* L) +{ + luaL_newmetatable(L, "ServerWebSocketHandle"); + + lua_pushcfunction( + L, + [](lua_State* L) + { + const char* index = luaL_checkstring(L, -1); + + if (strcmp(index, "send") == 0) + { + lua_pushcfunction(L, net::server::server_ws_send, "ServerWebSocketHandle.send"); + return 1; + } + + if (strcmp(index, "close") == 0) + { + lua_pushcfunction(L, net::server::server_ws_close, "ServerWebSocketHandle.close"); + return 1; + } + + return 0; + }, + "ServerWebSocketHandle.__index" + ); + lua_setfield(L, -2, "__index"); + + lua_pushstring(L, "ServerWebSocketHandle"); + lua_setfield(L, -2, "__type"); + + lua_setuserdatadtor( + L, + kServerWebSocketHandleTag, + [](lua_State*, void* ud) + { + std::destroy_at(static_cast*>(ud)); + } + ); + + lua_setuserdatametatable(L, kServerWebSocketHandleTag); +} + const char* const NetServer::properties[] = {nullptr}; const luaL_Reg NetServer::lib[] = { @@ -501,6 +1135,8 @@ const luaL_Reg NetServer::lib[] = { int NetServer::pushLibrary(lua_State* L) { + initializeNetServer(L); + lua_createtable(L, 0, std::size(NetServer::lib)); for (auto& [name, func] : NetServer::lib) diff --git a/lute/runtime/include/lute/userdatas.h b/lute/runtime/include/lute/userdatas.h index 5cc6153d9..7ae9455c1 100644 --- a/lute/runtime/include/lute/userdatas.h +++ b/lute/runtime/include/lute/userdatas.h @@ -5,3 +5,5 @@ constexpr int kDurationTag = 127; constexpr int kInstantTag = 126; constexpr int kWatchHandleTag = 125; constexpr int kHashFunctionTag = 124; +constexpr int kWebSocketHandleTag = 123; +constexpr int kServerWebSocketHandleTag = 122; diff --git a/tests/lute/net.test.luau b/tests/lute/net.test.luau new file mode 100644 index 000000000..921a90260 --- /dev/null +++ b/tests/lute/net.test.luau @@ -0,0 +1,270 @@ +local test = require("@std/test") + +local client = require("@lute/net/client") +local server = require("@lute/net/server") +local task = require("@lute/task") + +test.suite("LuteNetSuite", function(suite) + suite:case("serveWebsocketAndConnect", function(assert) + local receivedEcho = false + local clientClosed = false + local clientCloseCode: number? = nil + local clientCloseReason: string? = nil + local serverCloseCode: number? = nil + local serverCloseMessage: string? = nil + local firstServerSendResult: number? = nil + + local instance = server.serve({ + handler = function(req: server.ReceivedRequest, current: server.Server): server.ServerResponse? + if current:upgrade(req) then + return nil + end + + return { + status = 200, + body = "no upgrade", + } + end, + websocket = { + message = function(ws, message) + local sendResult = ws:send(message) + if firstServerSendResult == nil then + firstServerSendResult = sendResult + end + if message == "bye" then + ws:close(4001, "done") + end + end, + close = function(_ws, code, message) + serverCloseCode = code + serverCloseMessage = message + end, + }, + }) + + local _ws: client.WebSocket? + _ws = client.websocket(`ws://{instance.hostname}:{instance.port}`, { + onopen = function() + if _ws then + _ws:send("hello") + end + end, + onmessage = function(message) + if type(message) == "string" and message == "hello" then + receivedEcho = true + if _ws then + _ws:send("bye") + end + end + end, + onclose = function(code, reason) + clientClosed = true + clientCloseCode = code + clientCloseReason = reason + end, + onerror = function(err: string) + error(err) + end, + }) + + local start = os.clock() + while (not receivedEcho or not clientClosed or serverCloseCode == nil) and os.clock() - start < 5 do + task.deferSelf() + end + + assert.eq(true, receivedEcho) + assert.eq(true, clientClosed) + if clientCloseCode == nil or clientCloseReason == nil then + error("client close callback did not capture close details") + end + if firstServerSendResult == nil then + error("server send result was not captured") + end + if serverCloseCode == nil or serverCloseMessage == nil then + error("websocket close callback did not run") + end + assert.eq(true, firstServerSendResult > 0) + assert.eq(4001, clientCloseCode) + assert.eq("done", clientCloseReason) + assert.eq(4001, serverCloseCode) + assert.eq("done", serverCloseMessage) + + instance.close() + end) + + suite:case("serveBinaryWebsocketAndConnect", function(assert) + local serverSawBinary = false + local clientSawBinary = false + local firstServerSendResult: number? = nil + local payload = buffer.create(4) + buffer.writeu8(payload, 0, 1) + buffer.writeu8(payload, 1, 2) + buffer.writeu8(payload, 2, 3) + buffer.writeu8(payload, 3, 4) + + local instance = server.serve({ + port = 0, + handler = function(req: server.ReceivedRequest, current: server.Server): server.ServerResponse? + if current:upgrade(req) then + return nil + end + + return { + status = 200, + body = "no upgrade", + } + end, + websocket = { + message = function(ws, message) + assert.eq("buffer", type(message)) + serverSawBinary = true + + local sendResult = ws:send(message) + if firstServerSendResult == nil then + firstServerSendResult = sendResult + end + end, + }, + }) + + local closed = false + local _ws: client.WebSocket? + _ws = client.websocket(`ws://{instance.hostname}:{instance.port}`, { + onopen = function() + if _ws then + _ws:send(payload) + end + end, + onmessage = function(message) + if type(message) ~= "buffer" then + error("expected binary websocket message") + end + clientSawBinary = true + local messageBuffer: buffer = message + assert.eq(payload, messageBuffer) + + if _ws then + _ws:close() + end + end, + onclose = function() + closed = true + end, + onerror = function(err: string) + error(err) + end, + }) + + local start = os.clock() + while (not serverSawBinary or not clientSawBinary or not closed) and os.clock() - start < 5 do + task.deferSelf() + end + + assert.eq(true, serverSawBinary) + assert.eq(true, clientSawBinary) + assert.eq(true, closed) + if firstServerSendResult == nil then + error("binary server send result was not captured") + end + assert.eq(true, firstServerSendResult > 0) + + instance.close() + end) + + suite:case("clientWebsocketCallbacksSurviveGc", function(assert) + local opened = false + + local instance = server.serve({ + handler = function(req: server.ReceivedRequest, current: server.Server): server.ServerResponse? + if current:upgrade(req) then + return nil + end + + return { + status = 200, + body = "no upgrade", + } + end, + websocket = {}, + }) + + local _ws: client.WebSocket? + _ws = client.websocket(`ws://{instance.hostname}:{instance.port}`, { + onopen = function() + opened = true + end, + }) + + local start = os.clock() + while os.clock() - start < 0.2 do + end + + _ws = nil + + do + local junk = {} + for i = 1, 20000 do + junk[i] = string.rep("x", 1024) + end + end + + start = os.clock() + while not opened and os.clock() - start < 2 do + task.deferSelf() + end + + assert.eq(true, opened) + instance.close() + end) + + suite:case("clientInitiatedCloseRunsOnclose", function(assert) + local closed = false + local closeCode: number? = nil + local closeReason: string? = nil + + local instance = server.serve({ + handler = function(req: server.ReceivedRequest, current: server.Server): server.ServerResponse? + if current:upgrade(req) then + return nil + end + + return { + status = 200, + body = "no upgrade", + } + end, + websocket = {}, + }) + + local _ws: client.WebSocket? + _ws = client.websocket(`ws://{instance.hostname}:{instance.port}`, { + onopen = function() + if _ws then + _ws:close() + end + end, + onclose = function(code, reason) + closed = true + closeCode = code + closeReason = reason + end, + onerror = function(err: string) + error(err) + end, + }) + + local start = os.clock() + while not closed and os.clock() - start < 2 do + task.deferSelf() + end + + assert.eq(true, closed) + if closeCode == nil or closeReason == nil then + error("client close callback did not capture close details") + end + assert.eq(1000, closeCode) + assert.eq("", closeReason) + + instance.close() + end) +end)