11package main
22
33import (
4+ "crypto/rand"
5+ "encoding/hex"
46 "encoding/json"
57 "fmt"
6- "math/rand"
8+ mathrand "math/rand"
79 "net/http"
810 "os"
11+ "strings"
912 "sync"
1013 "time"
14+ "unicode/utf8"
1115
1216 "github.com/gorilla/websocket"
1317 "go.uber.org/zap"
@@ -25,7 +29,6 @@ const (
2529 MsgTypeBufferReady = "buffer_ready"
2630 MsgTypeKickUser = "kick_user"
2731 MsgTypePing = "ping"
28- MsgTypeChat = "chat"
2932 MsgTypeRequestSync = "request_sync"
3033 MsgTypeReconnect = "reconnect"
3134 MsgTypeSuggestTrack = "suggest_track"
@@ -45,7 +48,6 @@ const (
4548 MsgTypeError = "error"
4649 MsgTypePong = "pong"
4750 MsgTypeRoomState = "room_state"
48- MsgTypeChatMessage = "chat_message"
4951 MsgTypeHostChanged = "host_changed"
5052 MsgTypeKicked = "kicked"
5153 MsgTypeSyncState = "sync_state"
@@ -228,19 +230,6 @@ type UserInfo struct {
228230 IsConnected bool `json:"is_connected"`
229231}
230232
231- // ChatPayload is for chat messages
232- type ChatPayload struct {
233- Message string `json:"message"`
234- }
235-
236- // ChatMessagePayload is sent to all users in a room
237- type ChatMessagePayload struct {
238- UserID string `json:"user_id"`
239- Username string `json:"username"`
240- Message string `json:"message"`
241- Timestamp int64 `json:"timestamp"`
242- }
243-
244233// KickUserPayload is for kicking a user from the room
245234type KickUserPayload struct {
246235 UserID string `json:"user_id"`
@@ -300,6 +289,12 @@ type Session struct {
300289 DisconnectAt time.Time
301290}
302291
292+ // RateLimiter tracks message rates per client
293+ type RateLimiter struct {
294+ messages []time.Time
295+ mu sync.Mutex
296+ }
297+
303298// Client represents a connected WebSocket client
304299type Client struct {
305300 ID string
@@ -310,6 +305,7 @@ type Client struct {
310305 Send chan []byte
311306 closed bool
312307 mu sync.Mutex
308+ rateLimiter * RateLimiter
313309}
314310
315311// Room represents a listening room
@@ -343,14 +339,29 @@ type Server struct {
343339 upgrader websocket.Upgrader
344340 mu sync.RWMutex
345341 logger * zap.Logger
346- rng * rand .Rand
342+ rng * mathrand .Rand
347343}
348344
349345const (
350346 // Grace period for reconnection (5 minutes)
351347 ReconnectGracePeriod = 5 * time .Minute
352348 // How often to clean up expired sessions
353349 SessionCleanupInterval = 1 * time .Minute
350+ // Security limits
351+ MaxUsernameLength = 50
352+ MaxRoomCodeLength = 10
353+ MaxMessageLength = 500
354+ MaxTrackTitleLength = 200
355+ MaxTrackArtistLength = 200
356+ MaxQueueSize = 1000
357+ // Rate limiting
358+ RateLimitWindow = time .Minute
359+ MaxMessagesPerWindow = 100
360+ // Connection limits
361+ MaxReadMessageSize = 65536
362+ WriteTimeout = 10 * time .Second
363+ ReadTimeout = 60 * time .Second
364+ PongTimeout = 10 * time .Second
354365)
355366
356367func NewServer (logger * zap.Logger ) * Server {
@@ -366,10 +377,10 @@ func NewServer(logger *zap.Logger) *Server {
366377 WriteBufferSize : 1024 ,
367378 },
368379 logger : logger ,
369- rng : rand .New (rand .NewSource (time .Now ().UnixNano ())),
380+ rng : mathrand .New (mathrand .NewSource (time .Now ().UnixNano ())),
370381 }
371382
372- // Start session cleanup goroutine
383+ // Start cleanup goroutines
373384 go s .cleanupExpiredSessions ()
374385
375386 return s
@@ -428,7 +439,7 @@ func (s *Server) cleanupExpiredSessions() {
428439}
429440
430441func (s * Server ) generateRoomCode () string {
431- const chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" // Removed confusing chars
442+ const chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
432443 code := make ([]byte , 6 )
433444 for i := range code {
434445 code [i ] = chars [s .rng .Intn (len (chars ))]
@@ -441,25 +452,29 @@ func (s *Server) generateUserID() string {
441452}
442453
443454func (s * Server ) generateSessionToken () string {
444- const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
445- token := make ([]byte , 32 )
446- for i := range token {
447- token [i ] = chars [s .rng .Intn (len (chars ))]
455+ // Use crypto/rand for secure token generation
456+ b := make ([]byte , 32 )
457+ if _ , err := rand .Read (b ); err != nil {
458+ s .logger .Error ("Failed to generate secure token" , zap .Error (err ))
459+ // Fallback to less secure but functional token
460+ return fmt .Sprintf ("token_%d_%d" , time .Now ().UnixNano (), s .rng .Intn (1000000 ))
448461 }
449- return string ( token )
462+ return hex . EncodeToString ( b )
450463}
451464
452465func (s * Server ) handleWebSocket (w http.ResponseWriter , r * http.Request ) {
453466 conn , err := s .upgrader .Upgrade (w , r , nil )
454467 if err != nil {
455468 s .logger .Warn ("WebSocket upgrade error" , zap .Error (err ))
469+ s .mu .Unlock ()
456470 return
457471 }
458472
459473 client := & Client {
460- ID : s .generateUserID (),
461- Conn : conn ,
462- Send : make (chan []byte , 256 ),
474+ ID : s .generateUserID (),
475+ Conn : conn ,
476+ Send : make (chan []byte , 256 ),
477+ rateLimiter : & RateLimiter {messages : make ([]time.Time , 0 )},
463478 }
464479
465480 s .mu .Lock ()
@@ -508,10 +523,10 @@ func (c *Client) readPump(s *Server) {
508523 c .Conn .Close ()
509524 }()
510525
511- c .Conn .SetReadLimit (65536 )
512- c .Conn .SetReadDeadline (time .Now ().Add (60 * time . Second ))
526+ c .Conn .SetReadLimit (MaxReadMessageSize )
527+ c .Conn .SetReadDeadline (time .Now ().Add (ReadTimeout ))
513528 c .Conn .SetPongHandler (func (string ) error {
514- c .Conn .SetReadDeadline (time .Now ().Add (60 * time . Second ))
529+ c .Conn .SetReadDeadline (time .Now ().Add (ReadTimeout ))
515530 return nil
516531 })
517532
@@ -751,6 +766,38 @@ func (s *Server) handleReconnect(c *Client, payload json.RawMessage) {
751766 zap .Bool ("is_host" , isHost ))
752767}
753768
769+ // sanitizeString removes potentially dangerous characters and limits length
770+ func sanitizeString (s string , maxLen int ) string {
771+ // Remove null bytes and other control characters
772+ s = strings .Map (func (r rune ) rune {
773+ if r == 0 || (r < 32 && r != '\t' && r != '\n' && r != '\r' ) {
774+ return - 1
775+ }
776+ return r
777+ }, s )
778+
779+ // Trim whitespace
780+ s = strings .TrimSpace (s )
781+
782+ // Validate UTF-8
783+ if ! utf8 .ValidString (s ) {
784+ s = strings .ToValidUTF8 (s , "" )
785+ }
786+
787+ // Limit length
788+ if len (s ) > maxLen {
789+ // Ensure we don't cut in the middle of a multi-byte character
790+ for i := maxLen ; i > 0 && i > maxLen - 4 ; i -- {
791+ if utf8 .ValidString (s [:i ]) {
792+ return s [:i ]
793+ }
794+ }
795+ return s [:maxLen ]
796+ }
797+
798+ return s
799+ }
800+
754801func (s * Server ) handleMessage (c * Client , data []byte ) {
755802 var msg Message
756803 if err := json .Unmarshal (data , & msg ); err != nil {
@@ -785,8 +832,6 @@ func (s *Server) handleMessage(c *Client, data []byte) {
785832 s .handleKickUser (c , msg .Payload )
786833 case MsgTypePing :
787834 c .sendMessage (s .logger , MsgTypePong , nil )
788- case MsgTypeChat :
789- s .handleChat (c , msg .Payload )
790835 case MsgTypeRequestSync :
791836 s .handleRequestSync (c )
792837 case MsgTypeReconnect :
@@ -813,7 +858,19 @@ func (s *Server) handleSuggestTrack(c *Client, payload json.RawMessage) {
813858 c .sendError (s .logger , "not_in_room" , "You are not in a room" )
814859 return
815860 }
816- if p .TrackInfo == nil || p .TrackInfo .ID == "" || p .TrackInfo .Title == "" {
861+
862+ if p .TrackInfo == nil {
863+ c .sendError (s .logger , "missing_track_info" , "Track info is required" )
864+ return
865+ }
866+
867+ // Validate and sanitize track info
868+ p .TrackInfo .ID = sanitizeString (p .TrackInfo .ID , 200 )
869+ p .TrackInfo .Title = sanitizeString (p .TrackInfo .Title , MaxTrackTitleLength )
870+ p .TrackInfo .Artist = sanitizeString (p .TrackInfo .Artist , MaxTrackArtistLength )
871+ p .TrackInfo .Album = sanitizeString (p .TrackInfo .Album , MaxTrackArtistLength )
872+
873+ if p .TrackInfo .ID == "" || p .TrackInfo .Title == "" {
817874 c .sendError (s .logger , "invalid_track_info" , "Track must have ID and title" )
818875 return
819876 }
@@ -888,7 +945,7 @@ func (s *Server) handleApproveSuggestion(c *Client, payload json.RawMessage) {
888945
889946 // Update room state queue: insert next (front of upcoming queue)
890947 if suggestion .Track != nil {
891- if len (room .State .Queue ) >= 1000 {
948+ if len (room .State .Queue ) >= MaxQueueSize {
892949 c .sendError (s .logger , "queue_full" , "Queue is full" )
893950 return
894951 }
@@ -977,9 +1034,10 @@ func (s *Server) handleCreateRoom(c *Client, payload json.RawMessage) {
9771034 return
9781035 }
9791036
980- // Validate username length
981- if len (p .Username ) > 100 {
982- c .sendError (s .logger , "username_too_long" , "Username must be 100 characters or less" )
1037+ // Sanitize and validate username
1038+ p .Username = sanitizeString (p .Username , MaxUsernameLength )
1039+ if p .Username == "" {
1040+ c .sendError (s .logger , "invalid_username" , "Username is invalid" )
9831041 return
9841042 }
9851043
@@ -1054,8 +1112,10 @@ func (s *Server) handleJoinRoom(c *Client, payload json.RawMessage) {
10541112 return
10551113 }
10561114
1057- if len (p .Username ) > 100 {
1058- c .sendError (s .logger , "username_too_long" , "Username must be 100 characters or less" )
1115+ // Sanitize and validate username
1116+ p .Username = sanitizeString (p .Username , MaxUsernameLength )
1117+ if p .Username == "" {
1118+ c .sendError (s .logger , "invalid_username" , "Username is invalid" )
10591119 return
10601120 }
10611121
@@ -1064,6 +1124,13 @@ func (s *Server) handleJoinRoom(c *Client, payload json.RawMessage) {
10641124 return
10651125 }
10661126
1127+ // Sanitize and validate room code
1128+ p .RoomCode = sanitizeString (strings .ToUpper (p .RoomCode ), MaxRoomCodeLength )
1129+ if p .RoomCode == "" {
1130+ c .sendError (s .logger , "invalid_room_code" , "Room code is invalid" )
1131+ return
1132+ }
1133+
10671134 s .mu .RLock ()
10681135 room , exists := s .rooms [p .RoomCode ]
10691136 s .mu .RUnlock ()
@@ -1317,6 +1384,12 @@ func (s *Server) handlePlaybackAction(c *Client, payload json.RawMessage) {
13171384 return
13181385 }
13191386
1387+ // Validate and sanitize track info
1388+ p .TrackInfo .ID = sanitizeString (p .TrackInfo .ID , 200 )
1389+ p .TrackInfo .Title = sanitizeString (p .TrackInfo .Title , MaxTrackTitleLength )
1390+ p .TrackInfo .Artist = sanitizeString (p .TrackInfo .Artist , MaxTrackArtistLength )
1391+ p .TrackInfo .Album = sanitizeString (p .TrackInfo .Album , MaxTrackArtistLength )
1392+
13201393 if p .TrackInfo .ID == "" || p .TrackInfo .Title == "" {
13211394 c .sendError (s .logger , "invalid_track_info" , "Track must have ID and title" )
13221395 return
@@ -1384,13 +1457,19 @@ func (s *Server) handlePlaybackAction(c *Client, payload json.RawMessage) {
13841457 return
13851458 }
13861459
1460+ // Validate and sanitize track info
1461+ p .TrackInfo .ID = sanitizeString (p .TrackInfo .ID , 200 )
1462+ p .TrackInfo .Title = sanitizeString (p .TrackInfo .Title , MaxTrackTitleLength )
1463+ p .TrackInfo .Artist = sanitizeString (p .TrackInfo .Artist , MaxTrackArtistLength )
1464+ p .TrackInfo .Album = sanitizeString (p .TrackInfo .Album , MaxTrackArtistLength )
1465+
13871466 if p .TrackInfo .ID == "" || p .TrackInfo .Title == "" {
13881467 c .sendError (s .logger , "invalid_track_info" , "Track must have ID and title" )
13891468 return
13901469 }
13911470
13921471 // Limit queue size to prevent memory issues
1393- if len (room .State .Queue ) >= 1000 {
1472+ if len (room .State .Queue ) >= MaxQueueSize {
13941473 c .sendError (s .logger , "queue_full" , "Queue is full" )
13951474 return
13961475 }
@@ -1631,50 +1710,6 @@ func (s *Server) handleKickUser(c *Client, payload json.RawMessage) {
16311710 zap .String ("room_code" , room .Code ))
16321711}
16331712
1634- func (s * Server ) handleChat (c * Client , payload json.RawMessage ) {
1635- var p ChatPayload
1636- if err := json .Unmarshal (payload , & p ); err != nil {
1637- c .sendError (s .logger , "invalid_payload" , "Invalid chat payload" )
1638- return
1639- }
1640-
1641- if c .Room == nil {
1642- c .sendError (s .logger , "not_in_room" , "You are not in a room" )
1643- return
1644- }
1645-
1646- if p .Message == "" {
1647- return // Silently ignore empty messages
1648- }
1649-
1650- // Limit message length
1651- const maxMessageLength = 500
1652- if len (p .Message ) > maxMessageLength {
1653- p .Message = p .Message [:maxMessageLength ]
1654- }
1655-
1656- room := c .Room
1657- room .mu .RLock ()
1658- defer room .mu .RUnlock ()
1659-
1660- if len (room .Clients ) == 0 {
1661- return // Room is empty, don't send
1662- }
1663-
1664- chatMsg := ChatMessagePayload {
1665- UserID : c .ID ,
1666- Username : c .Username ,
1667- Message : p .Message ,
1668- Timestamp : time .Now ().UnixMilli (),
1669- }
1670-
1671- for _ , client := range room .Clients {
1672- if client != nil {
1673- client .sendMessage (s .logger , MsgTypeChatMessage , chatMsg )
1674- }
1675- }
1676- }
1677-
16781713func (s * Server ) handleRequestSync (c * Client ) {
16791714 if c .Room == nil {
16801715 c .sendError (s .logger , "not_in_room" , "You are not in a room" )
0 commit comments