Skip to content

enh: add disableClientMask option for WebSocket payload masking and optimize mask calculation #985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions conn.go
Original file line number Diff line number Diff line change
@@ -252,6 +252,8 @@ type Conn struct {
writer io.WriteCloser // the current writer returned to the application
isWriting bool // for best-effort concurrent write detection

disableClientMask bool

writeErrMu sync.Mutex
writeErr error

@@ -315,6 +317,7 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int,
writeBufSize: writeBufferSize,
enableWriteCompression: true,
compressionLevel: defaultCompressionLevel,
disableClientMask: false,
}
c.SetCloseHandler(nil)
c.SetPingHandler(nil)
@@ -432,7 +435,12 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
if c.isServer {
buf = append(buf, data...)
} else {
key := newMaskKey()
var key [4]byte
if c.disableClientMask {
key = [4]byte{0, 0, 0, 0}
} else {
key = newMaskKey()
}
buf = append(buf, key[:]...)
buf = append(buf, data...)
maskBytes(key, 0, buf[6:])
@@ -610,7 +618,12 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
}

if !c.isServer {
key := newMaskKey()
var key [4]byte
if c.disableClientMask {
key = [4]byte{0, 0, 0, 0}
} else {
key = newMaskKey()
}
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
if len(extra) > 0 {
@@ -743,9 +756,10 @@ func (w *messageWriter) Close() error {
// WritePreparedMessage writes prepared message into connection.
func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error {
frameType, frameData, err := pm.frame(prepareKey{
isServer: c.isServer,
compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType),
compressionLevel: c.compressionLevel,
isServer: c.isServer,
compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType),
compressionLevel: c.compressionLevel,
disableClientMask: c.disableClientMask,
})
if err != nil {
return err
@@ -1230,6 +1244,20 @@ func (c *Conn) SetCompressionLevel(level int) error {
return nil
}

// SetDisableClientMask configures WebSocket payload masking behavior for client-mode frames.
// When enabled (true), implements protocol-allowed optimization
// by generating zero-value mask keys ([4]byte{0,0,0,0}), effectively omitting XOR operations
// while maintaining formal protocol compliance.
//
// Security Advisory:
Copy link

@ghost ghost Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This description of the security considerations is wrong. Secure transport does not obviate the need for masking.

Refer readers to the RFC sections 5.3 and 10.3 instead of what's written in the PR.

// - Safe to enable ONLY when using secure transport layers (TLS 1.2+/SSL)
// - May expose vulnerabilities to network intermediaries when unprotected
//
// Default: false (masking enabled) - Maintains protocol compliance for plaintext connections
func (c *Conn) SetDisableClientMask(value bool) {
c.disableClientMask = value
}

// FormatCloseMessage formats closeCode and text as a WebSocket close message.
// An empty message is returned for code CloseNoStatusReceived.
func FormatCloseMessage(closeCode int, text string) []byte {
24 changes: 19 additions & 5 deletions mask.go
Original file line number Diff line number Diff line change
@@ -13,14 +13,17 @@ const wordSize = int(unsafe.Sizeof(uintptr(0)))

func maskBytes(key [4]byte, pos int, b []byte) int {
// Mask one byte at a time for small buffers.
if len(b) < 2*wordSize {
if len(b) <= 2*wordSize {
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}

if key == [4]byte{} {
return (pos + len(b)) & 3
}
// Mask one byte at a time to word boundary.
if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 {
n = wordSize - n
@@ -32,16 +35,27 @@ func maskBytes(key [4]byte, pos int, b []byte) int {
}

// Create aligned word size key.
var kw uintptr
var k [wordSize]byte
for i := range k {
k[i] = key[(pos+i)&3]
if wordSize == 8 {
k[0] = key[(pos+0)&3]
k[1] = key[(pos+1)&3]
k[2] = key[(pos+2)&3]
k[3] = key[(pos+3)&3]
kw = *(*uintptr)(unsafe.Pointer(&k))
kw = (kw << 32) | kw
} else {
for i := range k {
k[i] = key[(pos+i)&3]
}
kw = *(*uintptr)(unsafe.Pointer(&k))
}
kw := *(*uintptr)(unsafe.Pointer(&k))

// Mask one word at a time.
n := (len(b) / wordSize) * wordSize
p0 := unsafe.Pointer(&b[0])
for i := 0; i < n; i += wordSize {
*(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw
*(*uintptr)(unsafe.Pointer(uintptr(p0) + uintptr(i))) ^= kw
}

// Mask one byte at a time for remaining bytes.
81 changes: 80 additions & 1 deletion mask_test.go
Original file line number Diff line number Diff line change
@@ -7,8 +7,11 @@
package websocket

import (
"bytes"
"fmt"
"math/rand"
"testing"
"unsafe"
)

func maskBytesByByte(key [4]byte, pos int, b []byte) int {
@@ -28,6 +31,49 @@ func notzero(b []byte) int {
return -1
}

func maskBytesV1(key [4]byte, pos int, b []byte) int {
// Mask one byte at a time for small buffers.
if len(b) < 2*wordSize {
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}

// Mask one byte at a time to word boundary.
if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 {
n = wordSize - n
for i := range b[:n] {
b[i] ^= key[pos&3]
pos++
}
b = b[n:]
}

// Create aligned word size key.
var k [wordSize]byte
for i := range k {
k[i] = key[(pos+i)&3]
}
kw := *(*uintptr)(unsafe.Pointer(&k))

// Mask one word at a time.
n := (len(b) / wordSize) * wordSize
for i := 0; i < n; i += wordSize {
*(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw
}

// Mask one byte at a time for remaining bytes.
b = b[n:]
for i := range b {
b[i] ^= key[pos&3]
pos++
}

return pos & 3
}

func TestMaskBytes(t *testing.T) {
key := [4]byte{1, 2, 3, 4}
for size := 1; size <= 1024; size++ {
@@ -44,8 +90,39 @@ func TestMaskBytes(t *testing.T) {
}
}

func TestMaskBytesWithRandomMessage(t *testing.T) {
keys := [][4]byte{
{1, 2, 3, 4},
{0, 0, 0, 0},
}
for _, key := range keys {
for size := 1; size <= 1024; size++ {
for align := 0; align < wordSize; align++ {
for pos := 0; pos < 4; pos++ {
byteMessage := make([]byte, size+align)[align:]
for i := 0; i < len(byteMessage); i++ {
byteMessage[i] = uint8(rand.Uint32())
}
byteMessageCopy := make([]byte, len(byteMessage))
copy(byteMessageCopy, byteMessage)
posBytes := maskBytes(key, pos, byteMessage)
posBytesByByte := maskBytesByByte(key, pos, byteMessageCopy)
if posBytes != posBytesByByte {
t.Errorf("keys:%v, size:%d, align:%d, pos:%d", key, size, align, pos)
return
}
if !bytes.Equal(byteMessage, byteMessageCopy) {
t.Errorf("keys:%v, size:%d, align:%d, pos:%d", key, size, align, pos)
return
}
}
}
}
}
}

func BenchmarkMaskBytes(b *testing.B) {
for _, size := range []int{2, 4, 8, 16, 32, 512, 1024} {
for _, size := range []int{2, 4, 8, 16, 32, 512, 1024, 1048576} {
b.Run(fmt.Sprintf("size-%d", size), func(b *testing.B) {
for _, align := range []int{wordSize / 2} {
b.Run(fmt.Sprintf("align-%d", align), func(b *testing.B) {
@@ -54,11 +131,13 @@ func BenchmarkMaskBytes(b *testing.B) {
fn func(key [4]byte, pos int, b []byte) int
}{
{"byte", maskBytesByByte},
{"wordV1", maskBytesV1},
{"word", maskBytes},
} {
b.Run(fn.name, func(b *testing.B) {
key := newMaskKey()
data := make([]byte, size+align)[align:]
b.ResetTimer()
for i := 0; i < b.N; i++ {
fn.fn(key, 0, data)
}
8 changes: 5 additions & 3 deletions prepared.go
Original file line number Diff line number Diff line change
@@ -25,9 +25,10 @@ type PreparedMessage struct {

// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage.
type prepareKey struct {
isServer bool
compress bool
compressionLevel int
isServer bool
compress bool
compressionLevel int
disableClientMask bool
}

// preparedFrame contains data in wire representation.
@@ -83,6 +84,7 @@ func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) {
compressionLevel: key.compressionLevel,
enableWriteCompression: true,
writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize),
disableClientMask: key.disableClientMask,
}
if key.compress {
c.newCompressionWriter = compressNoContextTakeover