diff --git a/docs/advanced-guide/websocket/page.md b/docs/advanced-guide/websocket/page.md index 137eeeb87..4576a6beb 100644 --- a/docs/advanced-guide/websocket/page.md +++ b/docs/advanced-guide/websocket/page.md @@ -51,6 +51,10 @@ GoFr allows us to customize the WebSocket upgrader with several options. We can - `CheckOrigin (WithCheckOrigin)`: Sets a custom origin check function. - `Compression (WithCompression)`: Enables compression. +## Writing Messages + +GoFr provides the `WriteMessageToSocket` method to send messages to the underlying websocket connection in a thread-safe way. The data parameter can be a string, []byte, or any struct that can be marshaled to JSON. + ## Example: We can configure the Upgrader by creating a chain of option functions provided by GoFr. diff --git a/pkg/gofr/websocket/websocket.go b/pkg/gofr/websocket/websocket.go index 0f6c17f14..e91ef69af 100644 --- a/pkg/gofr/websocket/websocket.go +++ b/pkg/gofr/websocket/websocket.go @@ -19,6 +19,9 @@ const WSConnectionKey WSKey = "ws-connection-key" // Connection is a wrapper for gorilla websocket connection. type Connection struct { *websocket.Conn + + // Mutex to prevent race conditions on write operations + writeMutex sync.Mutex } // ErrorConnection is the connection error that occurs when webscoket connection cannot be established. @@ -76,6 +79,16 @@ func (w *Connection) Bind(v any) error { return nil } +// WriteMessage writes the data on the underlying ws connection. +// +// This method is thread-safe and be called concurrently with WriteJSON. +func (w *Connection) WriteMessage(messageType int, data []byte) error { + w.writeMutex.Lock() + defer w.writeMutex.Unlock() + + return w.Conn.WriteMessage(messageType, data) +} + func (*Connection) HostName() string { return "" // Not applicable for WebSocket, can be implemented if needed } diff --git a/pkg/gofr/websocket/websocket_test.go b/pkg/gofr/websocket/websocket_test.go index 602f2ce8b..5e663365f 100644 --- a/pkg/gofr/websocket/websocket_test.go +++ b/pkg/gofr/websocket/websocket_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "os" + "sync" "testing" "time" @@ -150,3 +151,40 @@ func dereference(v any) any { return v } } + +func TestConcurrentWriteMessageCalls(t *testing.T) { + upgrader := websocket.Upgrader{} + + const message = "this is a test message" + + loop := 10 + workers := 10 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + assert.NoError(t, err) + defer conn.Close() + + wc := &Connection{Conn: conn} + + wg := sync.WaitGroup{} + + for range loop { + for range workers { + wg.Add(1) + + go func() { + defer wg.Done() + + if err := wc.WriteMessage(websocket.TextMessage, []byte(message)); err != nil { + t.Errorf("concurrently wc.WriteMessage() returned %v", err) + } + }() + } + } + + wg.Wait() + })) + + server.Close() +}