diff --git a/src/AP_WS_Connection.cpp b/src/AP_WS_Connection.cpp index bf1979ee..d682fbd4 100644 --- a/src/AP_WS_Connection.cpp +++ b/src/AP_WS_Connection.cpp @@ -561,14 +561,14 @@ namespace OpenWifi { void AP_WS_Connection::OnSocketShutdown( [[maybe_unused]] const Poco::AutoPtr &pNf) { poco_trace(Logger_, fmt::format("SOCKET-SHUTDOWN({}): Closing.", CId_)); -// std::lock_guard G(ConnectionMutex_); + std::lock_guard G(ConnectionMutex_); return EndConnection(); } void AP_WS_Connection::OnSocketError( [[maybe_unused]] const Poco::AutoPtr &pNf) { poco_trace(Logger_, fmt::format("SOCKET-ERROR({}): Closing.", CId_)); -// std::lock_guard G(ConnectionMutex_); + std::lock_guard G(ConnectionMutex_); return EndConnection(); } @@ -652,9 +652,10 @@ namespace OpenWifi { case Poco::Net::WebSocket::FRAME_OP_TEXT: { poco_trace(Logger_, - fmt::format("FRAME({}): Frame received (length={}, flags={}). Msg={}", - CId_, IncomingSize, flags, IncomingFrame.begin())); + fmt::format("FRAME({}): Frame received (length={}, flags={}). Msg={}", + CId_, IncomingSize, flags, IncomingFrame.begin())); + Poco::JSON::Parser parser; auto ParsedMessage = parser.parse(IncomingFrame.begin()); auto IncomingJSON = ParsedMessage.extract(); diff --git a/src/AP_WS_Server.cpp b/src/AP_WS_Server.cpp index d4fa73dc..796472b8 100644 --- a/src/AP_WS_Server.cpp +++ b/src/AP_WS_Server.cpp @@ -57,8 +57,9 @@ namespace OpenWifi { if (request.find("Upgrade") != request.end() && Poco::icompare(request["Upgrade"], "websocket") == 0) { Utils::SetThreadName("ws:conn-init"); - session_id_++; - return new AP_WS_RequestHandler(Logger_, session_id_); + //session_id_++; + auto new_session_id = session_id_.fetch_add(1, std::memory_order_seq_cst) + 1; + return new AP_WS_RequestHandler(Logger_, new_session_id); } else { return nullptr; } @@ -514,10 +515,27 @@ namespace OpenWifi { Connection = SessionHint->second; Sessions_[sessionHash].erase(SessionHint); } - - auto deviceHash = MACHash::Hash(SerialNumber); - std::lock_guard DeviceLock(SerialNumbersMutex_[deviceHash]); - SerialNumbers_[deviceHash][SerialNumber] = Connection; + std::atomic_bool duplicate_session = false; + { + auto deviceHash = MACHash::Hash(SerialNumber); + std::lock_guard DeviceLock(SerialNumbersMutex_[deviceHash]); + auto DeviceHint = SerialNumbers_[deviceHash].find(SerialNumber); + if (DeviceHint == SerialNumbers_[deviceHash].end()) { + // No duplicate connection go ahead and add new connection + SerialNumbers_[deviceHash][SerialNumber] = Connection; + } + else { + // Mark a duplicate session + duplicate_session = true; + poco_information(Logger(), fmt::format("[session ID: {}] Found a duplicate connection for device serial: {}", session_id, Utils::IntToSerialNumber(SerialNumber))); + } + } + if (duplicate_session.load()){ + // This is only called if we have a duplicate session + // We remove the new incoming session that we just added a few lines above, forcing the destructor for this new session while not impacting the pointers to the old session. + std::lock_guard SessionLock(SessionMutex_[sessionHash]); + Sessions_[sessionHash].erase(session_id); + } } bool AP_WS_Server::EndSession(uint64_t session_id, uint64_t SerialNumber) {