Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions backend/websocket/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"slices"
"sync"

"github.com/charmbracelet/log"
"github.com/gorilla/websocket"
Expand Down Expand Up @@ -68,7 +69,18 @@ func (ws *WebsocketServer) handleMessages(conn *websocket.Conn) {
Event: event.EventType(msg.Event),
Data: msg.Data,
})
ws.SendMessage(conn, response)

// Find the connectionInfo for this connection
ws.mutex.RLock()
addr := conn.RemoteAddr().String()
connInfo, ok := ws.connections[addr]
ws.mutex.RUnlock()

if ok {
ws.SendMessage(connInfo, response)
} else {
log.Error("Connection not found in connections map during message handling", "addr", addr)
}
}
}

Expand All @@ -77,21 +89,29 @@ func (ws *WebsocketServer) AddConnection(conn *websocket.Conn) {
ws.mutex.Lock()
defer ws.mutex.Unlock()

addr := conn.RemoteAddr().String()

// close connection if remote addr tries to connect again
if connection, ok := ws.connections[conn.RemoteAddr().String()]; ok {
ws.RemoveConnection(connection)
_ = connection.Close()
if connInfo, ok := ws.connections[addr]; ok {
// Remove the connection directly since we already hold the lock
delete(ws.connections, addr)
delete(ws.dataListeners, addr)
_ = connInfo.conn.Close()
}

ws.connections[conn.RemoteAddr().String()] = conn
ws.connections[addr] = &connectionInfo{
conn: conn,
writeMux: sync.Mutex{},
}
}

// RemoveConnection removes a WebSocket connection
func (ws *WebsocketServer) RemoveConnection(conn *websocket.Conn) {
ws.mutex.Lock()
defer ws.mutex.Unlock()
delete(ws.connections, conn.RemoteAddr().String())
delete(ws.dataListeners, conn.RemoteAddr().String())
addr := conn.RemoteAddr().String()
delete(ws.connections, addr)
delete(ws.dataListeners, addr)
}

type RegisterResponse int
Expand Down Expand Up @@ -144,16 +164,16 @@ func (ws *WebsocketServer) BroadcastModuleUpdate(module types.Module, addr *stri
}

if addr != nil {
if conn, ok := ws.connections[*addr]; ok {
if connInfo, ok := ws.connections[*addr]; ok {
log.Debug("WS: Broadcasting module update to connection", "addr", *addr, "module", module.Name)
ws.SendMessage(conn, response)
ws.SendMessage(connInfo, response)
} else {
for remote_addr, conn := range ws.connections {
for remote_addr, connInfo := range ws.connections {
modules, ok := ws.dataListeners[remote_addr]

if ok && slices.Contains(modules, module.Name) {
log.Debug("WS: Broadcasting module update to listener", "addr", remote_addr, "module", module.Name)
ws.SendMessage(conn, response)
ws.SendMessage(connInfo, response)
}

}
Expand Down
36 changes: 30 additions & 6 deletions backend/websocket/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,29 @@ import (
"github.com/timmo001/system-bridge/event"
)

func (ws *WebsocketServer) SendMessage(conn *websocket.Conn, message event.MessageResponse) {
func (ws *WebsocketServer) SendMessage(connInfo *connectionInfo, message event.MessageResponse) {
log.Debug("Sending message to connection", "response", message)

if err := conn.WriteJSON(message); err != nil {
// Use per-connection mutex to prevent concurrent writes
connInfo.writeMux.Lock()
defer connInfo.writeMux.Unlock()

if err := connInfo.conn.WriteJSON(message); err != nil {
log.Error("Failed to send response:", err)
// If there's an error, remove the connection
if closeErr := conn.Close(); closeErr != nil {
// If there's an error, close the connection
if closeErr := connInfo.conn.Close(); closeErr != nil {
log.Error("Error closing connection:", closeErr)
}
delete(ws.connections, conn.RemoteAddr().String())
// Remove from connections and dataListeners if and only if the pointer matches
go func(addr string, failedConn *websocket.Conn) {
ws.mutex.Lock()
defer ws.mutex.Unlock()
connInfo, ok := ws.connections[addr]
if ok && connInfo.conn == failedConn {
delete(ws.connections, addr)
delete(ws.dataListeners, addr)
}
}(connInfo.conn.RemoteAddr().String(), connInfo.conn)
}
}

Expand All @@ -27,5 +40,16 @@ func (ws *WebsocketServer) SendError(conn *websocket.Conn, req WebSocketRequest,
Data: map[string]string{},
Message: message,
}
ws.SendMessage(conn, response)

// Find the connectionInfo for this connection
ws.mutex.RLock()
addr := conn.RemoteAddr().String()
connInfo, ok := ws.connections[addr]
ws.mutex.RUnlock()

if ok {
ws.SendMessage(connInfo, response)
} else {
log.Error("Connection not found in connections map", "addr", addr)
}
}
10 changes: 8 additions & 2 deletions backend/websocket/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,16 @@ type WebSocketRequest struct {
Token string `json:"token" mapstructure:"token"`
}

// connectionInfo holds connection data with write synchronization
type connectionInfo struct {
conn *websocket.Conn
writeMux sync.Mutex
}

type WebsocketServer struct {
token string
upgrader websocket.Upgrader
connections map[string]*websocket.Conn
connections map[string]*connectionInfo
dataListeners map[string][]types.ModuleName
mutex sync.RWMutex
dataStore *data.DataStore
Expand All @@ -36,7 +42,7 @@ type WebsocketServer struct {
func NewWebsocketServer(settings *settings.Settings, dataStore *data.DataStore, eventRouter *event.MessageRouter) *WebsocketServer {
ws := &WebsocketServer{
token: settings.API.Token,
connections: make(map[string]*websocket.Conn),
connections: make(map[string]*connectionInfo),
dataListeners: make(map[string][]types.ModuleName),
dataStore: dataStore,
EventRouter: eventRouter,
Expand Down
Loading
Loading