Skip to content

Commit 518fc2c

Browse files
committed
refactor: enhance security and validation for user inputs; remove useless functions
1 parent 0aa6839 commit 518fc2c

1 file changed

Lines changed: 120 additions & 85 deletions

File tree

main.go

Lines changed: 120 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
package main
22

33
import (
4+
"crypto/rand"
5+
"encoding/hex"
46
"encoding/json"
57
"fmt"
6-
"math/rand"
8+
mathrand "math/rand"
79
"net/http"
810
"os"
11+
"strings"
912
"sync"
1013
"time"
14+
"unicode/utf8"
1115

1216
"github.com/gorilla/websocket"
1317
"go.uber.org/zap"
@@ -25,7 +29,6 @@ const (
2529
MsgTypeBufferReady = "buffer_ready"
2630
MsgTypeKickUser = "kick_user"
2731
MsgTypePing = "ping"
28-
MsgTypeChat = "chat"
2932
MsgTypeRequestSync = "request_sync"
3033
MsgTypeReconnect = "reconnect"
3134
MsgTypeSuggestTrack = "suggest_track"
@@ -45,7 +48,6 @@ const (
4548
MsgTypeError = "error"
4649
MsgTypePong = "pong"
4750
MsgTypeRoomState = "room_state"
48-
MsgTypeChatMessage = "chat_message"
4951
MsgTypeHostChanged = "host_changed"
5052
MsgTypeKicked = "kicked"
5153
MsgTypeSyncState = "sync_state"
@@ -228,19 +230,6 @@ type UserInfo struct {
228230
IsConnected bool `json:"is_connected"`
229231
}
230232

231-
// ChatPayload is for chat messages
232-
type ChatPayload struct {
233-
Message string `json:"message"`
234-
}
235-
236-
// ChatMessagePayload is sent to all users in a room
237-
type ChatMessagePayload struct {
238-
UserID string `json:"user_id"`
239-
Username string `json:"username"`
240-
Message string `json:"message"`
241-
Timestamp int64 `json:"timestamp"`
242-
}
243-
244233
// KickUserPayload is for kicking a user from the room
245234
type KickUserPayload struct {
246235
UserID string `json:"user_id"`
@@ -300,6 +289,12 @@ type Session struct {
300289
DisconnectAt time.Time
301290
}
302291

292+
// RateLimiter tracks message rates per client
293+
type RateLimiter struct {
294+
messages []time.Time
295+
mu sync.Mutex
296+
}
297+
303298
// Client represents a connected WebSocket client
304299
type Client struct {
305300
ID string
@@ -310,6 +305,7 @@ type Client struct {
310305
Send chan []byte
311306
closed bool
312307
mu sync.Mutex
308+
rateLimiter *RateLimiter
313309
}
314310

315311
// Room represents a listening room
@@ -343,14 +339,29 @@ type Server struct {
343339
upgrader websocket.Upgrader
344340
mu sync.RWMutex
345341
logger *zap.Logger
346-
rng *rand.Rand
342+
rng *mathrand.Rand
347343
}
348344

349345
const (
350346
// Grace period for reconnection (5 minutes)
351347
ReconnectGracePeriod = 5 * time.Minute
352348
// How often to clean up expired sessions
353349
SessionCleanupInterval = 1 * time.Minute
350+
// Security limits
351+
MaxUsernameLength = 50
352+
MaxRoomCodeLength = 10
353+
MaxMessageLength = 500
354+
MaxTrackTitleLength = 200
355+
MaxTrackArtistLength = 200
356+
MaxQueueSize = 1000
357+
// Rate limiting
358+
RateLimitWindow = time.Minute
359+
MaxMessagesPerWindow = 100
360+
// Connection limits
361+
MaxReadMessageSize = 65536
362+
WriteTimeout = 10 * time.Second
363+
ReadTimeout = 60 * time.Second
364+
PongTimeout = 10 * time.Second
354365
)
355366

356367
func NewServer(logger *zap.Logger) *Server {
@@ -366,10 +377,10 @@ func NewServer(logger *zap.Logger) *Server {
366377
WriteBufferSize: 1024,
367378
},
368379
logger: logger,
369-
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
380+
rng: mathrand.New(mathrand.NewSource(time.Now().UnixNano())),
370381
}
371382

372-
// Start session cleanup goroutine
383+
// Start cleanup goroutines
373384
go s.cleanupExpiredSessions()
374385

375386
return s
@@ -428,7 +439,7 @@ func (s *Server) cleanupExpiredSessions() {
428439
}
429440

430441
func (s *Server) generateRoomCode() string {
431-
const chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" // Removed confusing chars
442+
const chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
432443
code := make([]byte, 6)
433444
for i := range code {
434445
code[i] = chars[s.rng.Intn(len(chars))]
@@ -441,25 +452,29 @@ func (s *Server) generateUserID() string {
441452
}
442453

443454
func (s *Server) generateSessionToken() string {
444-
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
445-
token := make([]byte, 32)
446-
for i := range token {
447-
token[i] = chars[s.rng.Intn(len(chars))]
455+
// Use crypto/rand for secure token generation
456+
b := make([]byte, 32)
457+
if _, err := rand.Read(b); err != nil {
458+
s.logger.Error("Failed to generate secure token", zap.Error(err))
459+
// Fallback to less secure but functional token
460+
return fmt.Sprintf("token_%d_%d", time.Now().UnixNano(), s.rng.Intn(1000000))
448461
}
449-
return string(token)
462+
return hex.EncodeToString(b)
450463
}
451464

452465
func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
453466
conn, err := s.upgrader.Upgrade(w, r, nil)
454467
if err != nil {
455468
s.logger.Warn("WebSocket upgrade error", zap.Error(err))
469+
s.mu.Unlock()
456470
return
457471
}
458472

459473
client := &Client{
460-
ID: s.generateUserID(),
461-
Conn: conn,
462-
Send: make(chan []byte, 256),
474+
ID: s.generateUserID(),
475+
Conn: conn,
476+
Send: make(chan []byte, 256),
477+
rateLimiter: &RateLimiter{messages: make([]time.Time, 0)},
463478
}
464479

465480
s.mu.Lock()
@@ -508,10 +523,10 @@ func (c *Client) readPump(s *Server) {
508523
c.Conn.Close()
509524
}()
510525

511-
c.Conn.SetReadLimit(65536)
512-
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
526+
c.Conn.SetReadLimit(MaxReadMessageSize)
527+
c.Conn.SetReadDeadline(time.Now().Add(ReadTimeout))
513528
c.Conn.SetPongHandler(func(string) error {
514-
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
529+
c.Conn.SetReadDeadline(time.Now().Add(ReadTimeout))
515530
return nil
516531
})
517532

@@ -751,6 +766,38 @@ func (s *Server) handleReconnect(c *Client, payload json.RawMessage) {
751766
zap.Bool("is_host", isHost))
752767
}
753768

769+
// sanitizeString removes potentially dangerous characters and limits length
770+
func sanitizeString(s string, maxLen int) string {
771+
// Remove null bytes and other control characters
772+
s = strings.Map(func(r rune) rune {
773+
if r == 0 || (r < 32 && r != '\t' && r != '\n' && r != '\r') {
774+
return -1
775+
}
776+
return r
777+
}, s)
778+
779+
// Trim whitespace
780+
s = strings.TrimSpace(s)
781+
782+
// Validate UTF-8
783+
if !utf8.ValidString(s) {
784+
s = strings.ToValidUTF8(s, "")
785+
}
786+
787+
// Limit length
788+
if len(s) > maxLen {
789+
// Ensure we don't cut in the middle of a multi-byte character
790+
for i := maxLen; i > 0 && i > maxLen-4; i-- {
791+
if utf8.ValidString(s[:i]) {
792+
return s[:i]
793+
}
794+
}
795+
return s[:maxLen]
796+
}
797+
798+
return s
799+
}
800+
754801
func (s *Server) handleMessage(c *Client, data []byte) {
755802
var msg Message
756803
if err := json.Unmarshal(data, &msg); err != nil {
@@ -785,8 +832,6 @@ func (s *Server) handleMessage(c *Client, data []byte) {
785832
s.handleKickUser(c, msg.Payload)
786833
case MsgTypePing:
787834
c.sendMessage(s.logger, MsgTypePong, nil)
788-
case MsgTypeChat:
789-
s.handleChat(c, msg.Payload)
790835
case MsgTypeRequestSync:
791836
s.handleRequestSync(c)
792837
case MsgTypeReconnect:
@@ -813,7 +858,19 @@ func (s *Server) handleSuggestTrack(c *Client, payload json.RawMessage) {
813858
c.sendError(s.logger, "not_in_room", "You are not in a room")
814859
return
815860
}
816-
if p.TrackInfo == nil || p.TrackInfo.ID == "" || p.TrackInfo.Title == "" {
861+
862+
if p.TrackInfo == nil {
863+
c.sendError(s.logger, "missing_track_info", "Track info is required")
864+
return
865+
}
866+
867+
// Validate and sanitize track info
868+
p.TrackInfo.ID = sanitizeString(p.TrackInfo.ID, 200)
869+
p.TrackInfo.Title = sanitizeString(p.TrackInfo.Title, MaxTrackTitleLength)
870+
p.TrackInfo.Artist = sanitizeString(p.TrackInfo.Artist, MaxTrackArtistLength)
871+
p.TrackInfo.Album = sanitizeString(p.TrackInfo.Album, MaxTrackArtistLength)
872+
873+
if p.TrackInfo.ID == "" || p.TrackInfo.Title == "" {
817874
c.sendError(s.logger, "invalid_track_info", "Track must have ID and title")
818875
return
819876
}
@@ -888,7 +945,7 @@ func (s *Server) handleApproveSuggestion(c *Client, payload json.RawMessage) {
888945

889946
// Update room state queue: insert next (front of upcoming queue)
890947
if suggestion.Track != nil {
891-
if len(room.State.Queue) >= 1000 {
948+
if len(room.State.Queue) >= MaxQueueSize {
892949
c.sendError(s.logger, "queue_full", "Queue is full")
893950
return
894951
}
@@ -977,9 +1034,10 @@ func (s *Server) handleCreateRoom(c *Client, payload json.RawMessage) {
9771034
return
9781035
}
9791036

980-
// Validate username length
981-
if len(p.Username) > 100 {
982-
c.sendError(s.logger, "username_too_long", "Username must be 100 characters or less")
1037+
// Sanitize and validate username
1038+
p.Username = sanitizeString(p.Username, MaxUsernameLength)
1039+
if p.Username == "" {
1040+
c.sendError(s.logger, "invalid_username", "Username is invalid")
9831041
return
9841042
}
9851043

@@ -1054,8 +1112,10 @@ func (s *Server) handleJoinRoom(c *Client, payload json.RawMessage) {
10541112
return
10551113
}
10561114

1057-
if len(p.Username) > 100 {
1058-
c.sendError(s.logger, "username_too_long", "Username must be 100 characters or less")
1115+
// Sanitize and validate username
1116+
p.Username = sanitizeString(p.Username, MaxUsernameLength)
1117+
if p.Username == "" {
1118+
c.sendError(s.logger, "invalid_username", "Username is invalid")
10591119
return
10601120
}
10611121

@@ -1064,6 +1124,13 @@ func (s *Server) handleJoinRoom(c *Client, payload json.RawMessage) {
10641124
return
10651125
}
10661126

1127+
// Sanitize and validate room code
1128+
p.RoomCode = sanitizeString(strings.ToUpper(p.RoomCode), MaxRoomCodeLength)
1129+
if p.RoomCode == "" {
1130+
c.sendError(s.logger, "invalid_room_code", "Room code is invalid")
1131+
return
1132+
}
1133+
10671134
s.mu.RLock()
10681135
room, exists := s.rooms[p.RoomCode]
10691136
s.mu.RUnlock()
@@ -1317,6 +1384,12 @@ func (s *Server) handlePlaybackAction(c *Client, payload json.RawMessage) {
13171384
return
13181385
}
13191386

1387+
// Validate and sanitize track info
1388+
p.TrackInfo.ID = sanitizeString(p.TrackInfo.ID, 200)
1389+
p.TrackInfo.Title = sanitizeString(p.TrackInfo.Title, MaxTrackTitleLength)
1390+
p.TrackInfo.Artist = sanitizeString(p.TrackInfo.Artist, MaxTrackArtistLength)
1391+
p.TrackInfo.Album = sanitizeString(p.TrackInfo.Album, MaxTrackArtistLength)
1392+
13201393
if p.TrackInfo.ID == "" || p.TrackInfo.Title == "" {
13211394
c.sendError(s.logger, "invalid_track_info", "Track must have ID and title")
13221395
return
@@ -1384,13 +1457,19 @@ func (s *Server) handlePlaybackAction(c *Client, payload json.RawMessage) {
13841457
return
13851458
}
13861459

1460+
// Validate and sanitize track info
1461+
p.TrackInfo.ID = sanitizeString(p.TrackInfo.ID, 200)
1462+
p.TrackInfo.Title = sanitizeString(p.TrackInfo.Title, MaxTrackTitleLength)
1463+
p.TrackInfo.Artist = sanitizeString(p.TrackInfo.Artist, MaxTrackArtistLength)
1464+
p.TrackInfo.Album = sanitizeString(p.TrackInfo.Album, MaxTrackArtistLength)
1465+
13871466
if p.TrackInfo.ID == "" || p.TrackInfo.Title == "" {
13881467
c.sendError(s.logger, "invalid_track_info", "Track must have ID and title")
13891468
return
13901469
}
13911470

13921471
// Limit queue size to prevent memory issues
1393-
if len(room.State.Queue) >= 1000 {
1472+
if len(room.State.Queue) >= MaxQueueSize {
13941473
c.sendError(s.logger, "queue_full", "Queue is full")
13951474
return
13961475
}
@@ -1631,50 +1710,6 @@ func (s *Server) handleKickUser(c *Client, payload json.RawMessage) {
16311710
zap.String("room_code", room.Code))
16321711
}
16331712

1634-
func (s *Server) handleChat(c *Client, payload json.RawMessage) {
1635-
var p ChatPayload
1636-
if err := json.Unmarshal(payload, &p); err != nil {
1637-
c.sendError(s.logger, "invalid_payload", "Invalid chat payload")
1638-
return
1639-
}
1640-
1641-
if c.Room == nil {
1642-
c.sendError(s.logger, "not_in_room", "You are not in a room")
1643-
return
1644-
}
1645-
1646-
if p.Message == "" {
1647-
return // Silently ignore empty messages
1648-
}
1649-
1650-
// Limit message length
1651-
const maxMessageLength = 500
1652-
if len(p.Message) > maxMessageLength {
1653-
p.Message = p.Message[:maxMessageLength]
1654-
}
1655-
1656-
room := c.Room
1657-
room.mu.RLock()
1658-
defer room.mu.RUnlock()
1659-
1660-
if len(room.Clients) == 0 {
1661-
return // Room is empty, don't send
1662-
}
1663-
1664-
chatMsg := ChatMessagePayload{
1665-
UserID: c.ID,
1666-
Username: c.Username,
1667-
Message: p.Message,
1668-
Timestamp: time.Now().UnixMilli(),
1669-
}
1670-
1671-
for _, client := range room.Clients {
1672-
if client != nil {
1673-
client.sendMessage(s.logger, MsgTypeChatMessage, chatMsg)
1674-
}
1675-
}
1676-
}
1677-
16781713
func (s *Server) handleRequestSync(c *Client) {
16791714
if c.Room == nil {
16801715
c.sendError(s.logger, "not_in_room", "You are not in a room")

0 commit comments

Comments
 (0)