diff --git a/src/debug/HyprCtl.cpp b/src/debug/HyprCtl.cpp index 53bac0d8b1d..89533b3b5cd 100644 --- a/src/debug/HyprCtl.cpp +++ b/src/debug/HyprCtl.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -2250,6 +2251,8 @@ static bool isFollowUpRollingLogRequest(const std::string& request) { } static int hyprCtlFDTick(int fd, uint32_t mask, void* data) { + constexpr size_t MAX_REQUEST_SIZE = 64 * 1024; + if (mask & WL_EVENT_ERROR || mask & WL_EVENT_HANGUP) return 0; @@ -2259,7 +2262,13 @@ static int hyprCtlFDTick(int fd, uint32_t mask, void* data) { sockaddr_in clientAddress; socklen_t clientSize = sizeof(clientAddress); - const auto ACCEPTEDCONNECTION = accept4(g_pHyprCtl->m_socketFD.get(), rc(&clientAddress), &clientSize, SOCK_CLOEXEC); + const auto ACCEPTEDCONNECTION = accept4(g_pHyprCtl->m_socketFD.get(), rc(&clientAddress), &clientSize, SOCK_CLOEXEC | SOCK_NONBLOCK); + + if (ACCEPTEDCONNECTION < 0) { + if (errno != EAGAIN && errno != EWOULDBLOCK) + Debug::log(ERR, "Hyprctl: failed to accept connection: {}", strerror(errno)); + return 0; + } std::array readBuffer; @@ -2273,31 +2282,28 @@ static int hyprCtlFDTick(int fd, uint32_t mask, void* data) { Debug::log(LOG, "Hyprctl: new connection from pid {}", creds.CRED_PID); } - // - pollfd pollfds[1] = { - { - .fd = ACCEPTEDCONNECTION, - .events = POLLIN, - }, - }; - - int ret = poll(pollfds, 1, 5000); - - if (ret <= 0) { - close(ACCEPTEDCONNECTION); - return 0; - } - std::string request; - while (true) { + while (request.size() < MAX_REQUEST_SIZE) { readBuffer.fill(0); - auto messageSize = read(ACCEPTEDCONNECTION, readBuffer.data(), 1023); - if (messageSize < 1) - break; - std::string recvd = readBuffer.data(); - request += recvd; - if (messageSize < 1023) + auto messageSize = read(ACCEPTEDCONNECTION, readBuffer.data(), readBuffer.size()); + if (messageSize < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) + break; + Debug::log(ERR, "Hyprctl: read error: {}", strerror(errno)); + close(ACCEPTEDCONNECTION); + return 0; + } + if (messageSize == 0) break; + request.append(readBuffer.data(), rc(messageSize)); + if (messageSize < static_cast(readBuffer.size())) + break; // drained + } + + if (request.size() >= MAX_REQUEST_SIZE) { + Debug::log(ERR, "Hyprctl: request exceeded {} bytes, closing", MAX_REQUEST_SIZE); + close(ACCEPTEDCONNECTION); + return 0; } std::string reply = ""; @@ -2342,7 +2348,7 @@ static int hyprCtlFDTick(int fd, uint32_t mask, void* data) { } void CHyprCtl::startHyprCtlSocket() { - m_socketFD = CFileDescriptor{socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0)}; + m_socketFD = CFileDescriptor{socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC | SOCK_NONBLOCK, 0)}; if (!m_socketFD.isValid()) { Debug::log(ERR, "Couldn't start the Hyprland Socket. (1) IPC will not work.");