Skip to content

Commit 54778ff

Browse files
committed
fix: reject oversized string/list/map sizes during decode
Decoded string/[]byte/list/map lengths that exceed the remaining buffer can not possibly be valid, so allocating for them risks huge allocations or OOM on corrupted data. Detect these before allocating and return a thrift SIZE_LIMIT protocol exception, distinct from io.ErrShortBuffer. Also guard the size-header reads themselves: a buffer too short to even hold the length header now returns io.ErrShortBuffer instead of panicking with an index out of range.
1 parent 88e3cae commit 54778ff

3 files changed

Lines changed: 126 additions & 15 deletions

File tree

internal/reflect/decoder.go

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,28 @@ func decodeFixedSizeTypes(t ttype, b []byte, p unsafe.Pointer) int {
161161
}
162162
}
163163

164+
// minWireSize is the minimum number of bytes a value of a given wire type takes
165+
// in Thrift binary encoding. It is used while decoding to reject corrupted
166+
// container/string lengths that can not possibly fit in the remaining buffer,
167+
// before allocating memory for them.
168+
var minWireSize = [256]int8{
169+
tBOOL: 1,
170+
tBYTE: 1,
171+
tI16: 2,
172+
tI32: 4,
173+
tI64: 8,
174+
tDOUBLE: 8,
175+
tSTRING: 4, // length header only, content may be empty
176+
tSTRUCT: 1, // tSTOP only
177+
tMAP: 6, // header only, may hold zero entries
178+
tSET: 5, // header only, may hold zero elements
179+
tLIST: 5, // header only, may hold zero elements
180+
}
181+
164182
func decodeStringNoCopy(t *tType, b []byte, p unsafe.Pointer) (i int, err error) {
183+
if len(b) < strHeaderLen {
184+
return 0, io.ErrShortBuffer
185+
}
165186
l := int(int32(binary.BigEndian.Uint32(b)))
166187
if l < 0 {
167188
err = errNegativeSize
@@ -177,8 +198,8 @@ func decodeStringNoCopy(t *tType, b []byte, p unsafe.Pointer) (i int, err error)
177198
return
178199
}
179200

180-
if i+l-1 >= len(b) {
181-
return i, io.ErrShortBuffer
201+
if l > len(b)-i {
202+
return i, newSizeExceedsBufferException(l, len(b)-i)
182203
}
183204

184205
if t.Tag == defs.T_binary {
@@ -199,6 +220,9 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int
199220
}
200221
switch t.T {
201222
case tSTRING:
223+
if len(b) < strHeaderLen {
224+
return 0, io.ErrShortBuffer
225+
}
202226
l := int(int32(binary.BigEndian.Uint32(b)))
203227
if l < 0 {
204228
return 0, errNegativeSize
@@ -213,8 +237,8 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int
213237
return i, nil
214238
}
215239

216-
if i+l-1 >= len(b) {
217-
return i, io.ErrShortBuffer
240+
if l > len(b)-i {
241+
return i, newSizeExceedsBufferException(l, len(b)-i)
218242
}
219243

220244
x := d.Malloc(l, 1, 0)
@@ -229,6 +253,9 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int
229253

230254
case tMAP:
231255
// map header
256+
if len(b) < mapHeaderLen {
257+
return 0, io.ErrShortBuffer
258+
}
232259
t0, t1, l := ttype(b[0]), ttype(b[1]), int(int32(binary.BigEndian.Uint32(b[2:])))
233260
if l < 0 {
234261
return 0, errNegativeSize
@@ -241,6 +268,13 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int
241268
return 0, newTypeMismatchKV(kt.WT, vt.WT, t0, t1)
242269
}
243270

271+
// reject corrupted lengths before allocating the map: every entry needs
272+
// at least minWireSize[key]+minWireSize[value] bytes, so l entries can
273+
// not fit if they exceed the remaining buffer. likely data is broken.
274+
if remain := len(b) - mapHeaderLen; l > remain/(int(minWireSize[kt.WT])+int(minWireSize[vt.WT])) {
275+
return mapHeaderLen, newSizeExceedsBufferException(l, remain)
276+
}
277+
244278
// decode map
245279

246280
// tmp vars
@@ -318,6 +352,9 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int
318352

319353
case tLIST, tSET: // NOTE: for tSET, it may be map in the future
320354
// list header
355+
if len(b) < listHeaderLen {
356+
return 0, io.ErrShortBuffer
357+
}
321358
tp, l := ttype(b[0]), int(int32(binary.BigEndian.Uint32(b[1:])))
322359
if l < 0 {
323360
return 0, errNegativeSize
@@ -336,6 +373,14 @@ func (d *tDecoder) decodeType(t *tType, b []byte, p unsafe.Pointer, maxdepth int
336373
h.Zero()
337374
return i, nil
338375
}
376+
377+
// reject corrupted lengths before allocating the slice: every element
378+
// needs at least minWireSize[et] bytes, so l elements can not fit if
379+
// they exceed the remaining buffer. likely the data is broken.
380+
if remain := len(b) - i; l > remain/int(minWireSize[et.WT]) {
381+
return i, newSizeExceedsBufferException(l, remain)
382+
}
383+
339384
x := d.Malloc(l*et.Size, et.Align, et.MallocAbiType) // malloc for slice. make([]Type, l, l)
340385
h.Data = x
341386
h.Len = l

internal/reflect/decoder_test.go

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package reflect
1818

1919
import (
2020
"bytes"
21+
"errors"
2122
"io"
2223
"math"
2324
"math/rand"
@@ -456,21 +457,22 @@ func TestDecodeStringShortBuffer(t *testing.T) {
456457
var result string
457458
ptr := unsafe.Pointer(&result)
458459

459-
// io.ErrShortBuffer: decodeType
460-
data := []byte{0x00, 0x00, 0x00, 0x05, 'h', 'e', 'l', 'l'} // length=5, only 4 bytes
461-
n, err := decoder.decodeType(typ, data, ptr, 1)
462-
assert.True(t, err == io.ErrShortBuffer)
463-
assert.True(t, n <= len(data))
460+
// header truncated: fewer than strHeaderLen bytes to even read the length.
461+
// this is a genuine short buffer, not corrupted data.
462+
for _, data := range [][]byte{nil, {0x00}, {0x00, 0x00, 0x00}} {
463+
n, err := decoder.decodeType(typ, data, ptr, 1)
464+
assert.True(t, err == io.ErrShortBuffer)
465+
assert.True(t, n <= len(data))
464466

465-
// io.ErrShortBuffer: decodeStringNoCopy
466-
n, err = decodeStringNoCopy(typ, data, ptr)
467-
assert.True(t, err == io.ErrShortBuffer)
468-
assert.True(t, n <= len(data))
467+
n, err = decodeStringNoCopy(typ, data, ptr)
468+
assert.True(t, err == io.ErrShortBuffer)
469+
assert.True(t, n <= len(data))
470+
}
469471

470472
// Normal case: decodeType
471-
data = []byte{0x00, 0x00, 0x00, 0x05, 'h', 'e', 'l', 'l', 'o'}
473+
data := []byte{0x00, 0x00, 0x00, 0x05, 'h', 'e', 'l', 'l', 'o'}
472474
result = ""
473-
n, err = decoder.decodeType(typ, data, ptr, 1)
475+
n, err := decoder.decodeType(typ, data, ptr, 1)
474476
assert.Nil(t, err)
475477
assert.Equal(t, len(data), n)
476478
assert.Equal(t, "hello", result)
@@ -481,3 +483,56 @@ func TestDecodeStringShortBuffer(t *testing.T) {
481483
assert.Equal(t, len(data), n)
482484
assert.Equal(t, "hello", result)
483485
}
486+
487+
// assertSizeLimit asserts err is a thrift SIZE_LIMIT protocol exception, i.e. a
488+
// decoded size that exceeds the remaining buffer (likely corrupted data).
489+
func assertSizeLimit(t *testing.T, err error) {
490+
t.Helper()
491+
var pe *thrift.ProtocolException
492+
if !errors.As(err, &pe) {
493+
t.Fatalf("expected *thrift.ProtocolException, got %v", err)
494+
}
495+
assert.Equal(t, int32(thrift.SIZE_LIMIT), pe.TypeID())
496+
}
497+
498+
func TestDecodeSizeExceedsBuffer(t *testing.T) {
499+
// string / binary: declared length exceeds the remaining buffer
500+
decoder := &tDecoder{}
501+
typ := &tType{T: tSTRING, Tag: defs.T_string}
502+
var s string
503+
data := []byte{0x00, 0x00, 0x00, 0x05, 'h', 'e', 'l', 'l'} // length=5, only 4 bytes
504+
_, err := decoder.decodeType(typ, data, unsafe.Pointer(&s), 1)
505+
assertSizeLimit(t, err)
506+
_, err = decodeStringNoCopy(typ, data, unsafe.Pointer(&s))
507+
assertSizeLimit(t, err)
508+
509+
// list: element count far exceeds the remaining buffer
510+
{
511+
type Msg struct {
512+
L []int32 `frugal:"1,default,list<i32>"`
513+
}
514+
b := []byte{
515+
byte(tLIST), 0x00, 0x01, // field: type=LIST id=1
516+
byte(tI32), // element type
517+
0x7f, 0xff, 0xff, 0xff, // count = 2147483647
518+
byte(tSTOP),
519+
}
520+
_, err := Decode(b, &Msg{})
521+
assertSizeLimit(t, err)
522+
}
523+
524+
// map: entry count far exceeds the remaining buffer
525+
{
526+
type Msg struct {
527+
M map[int32]int32 `frugal:"1,default,map<i32:i32>"`
528+
}
529+
b := []byte{
530+
byte(tMAP), 0x00, 0x01, // field: type=MAP id=1
531+
byte(tI32), byte(tI32), // key type, value type
532+
0x7f, 0xff, 0xff, 0xff, // count = 2147483647
533+
byte(tSTOP),
534+
}
535+
_, err := Decode(b, &Msg{})
536+
assertSizeLimit(t, err)
537+
}
538+
}

internal/reflect/exception.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@ func newRequiredFieldNotSetException(name string) error {
3434
)
3535
}
3636

37+
// newSizeExceedsBufferException is returned when a decoded length or element
38+
// count is larger than the remaining buffer can possibly hold. Unlike
39+
// io.ErrShortBuffer (which only means more bytes are needed), it signals the
40+
// size field itself is bogus, so the data is most likely corrupted.
41+
func newSizeExceedsBufferException(size, remain int) error {
42+
return thrift.NewProtocolException(
43+
thrift.SIZE_LIMIT,
44+
fmt.Sprintf("decoded size %d exceeds remaining buffer %d, data may be corrupted", size, remain),
45+
)
46+
}
47+
3748
func newTypeMismatch(expect, got ttype) error {
3849
return thrift.NewProtocolException(
3950
thrift.INVALID_DATA,

0 commit comments

Comments
 (0)