Skip to content

Commit b8dfeb3

Browse files
committed
Improve the record layer fuzz tests
1 parent 7b68bd9 commit b8dfeb3

File tree

2 files changed

+253
-45
lines changed

2 files changed

+253
-45
lines changed

pkg/protocol/recordlayer/fuzz_test.go

Lines changed: 0 additions & 36 deletions
This file was deleted.

pkg/protocol/recordlayer/recordlayer_test.go

Lines changed: 253 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
package recordlayer
55

66
import (
7+
"encoding/binary"
78
"testing"
89

910
"github.com/pion/dtls/v3/pkg/protocol"
1011
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
1113
)
1214

1315
func TestUDPDecode(t *testing.T) {
@@ -19,20 +21,20 @@ func TestUDPDecode(t *testing.T) {
1921
}{
2022
{
2123
Name: "Change Cipher Spec, single packet",
22-
Data: []byte{0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01},
24+
Data: []byte{0x14, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01},
2325
Want: [][]byte{
24-
{0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01},
26+
{0x14, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01},
2527
},
2628
},
2729
{
2830
Name: "Change Cipher Spec, multi packet",
2931
Data: []byte{
30-
0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01,
31-
0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x01, 0x01,
32+
0x14, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01,
33+
0x14, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x01, 0x01,
3234
},
3335
Want: [][]byte{
34-
{0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01},
35-
{0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x01, 0x01},
36+
{0x14, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01},
37+
{0x14, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x01, 0x01},
3638
},
3739
},
3840
{
@@ -42,7 +44,7 @@ func TestUDPDecode(t *testing.T) {
4244
},
4345
{
4446
Name: "Packet declared invalid length",
45-
Data: []byte{0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0xFF, 0x01},
47+
Data: []byte{0x14, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0xFF, 0x01},
4648
WantError: ErrInvalidPacketLength,
4749
},
4850
} {
@@ -62,11 +64,11 @@ func TestRecordLayerRoundTrip(t *testing.T) {
6264
}{
6365
{
6466
Name: "Change Cipher Spec, single packet",
65-
Data: []byte{0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01},
67+
Data: []byte{0x14, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01},
6668
Want: &RecordLayer{
6769
Header: Header{
6870
ContentType: protocol.ContentTypeChangeCipherSpec,
69-
Version: protocol.Version{Major: 0xfe, Minor: 0xff},
71+
Version: protocol.Version1_2,
7072
Epoch: 0,
7173
SequenceNumber: 18,
7274
},
@@ -83,3 +85,245 @@ func TestRecordLayerRoundTrip(t *testing.T) {
8385
assert.Equal(t, test.Data, data, "RecordLayer should match expected value after marshal")
8486
}
8587
}
88+
89+
func FuzzRecordLayer_Unmarshal_No_Panics(f *testing.F) {
90+
f.Add([]byte{
91+
0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01,
92+
})
93+
94+
f.Fuzz(func(_ *testing.T, data []byte) {
95+
var r RecordLayer
96+
_ = r.Unmarshal(data)
97+
})
98+
}
99+
100+
func FuzzUnpackDatagram_No_Panics(f *testing.F) {
101+
Datasingle := []byte{
102+
0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01,
103+
}
104+
Datamulti := []byte{
105+
0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01,
106+
0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x01, 0x01,
107+
}
108+
f.Add(Datasingle)
109+
f.Add(Datamulti)
110+
111+
f.Fuzz(func(_ *testing.T, data []byte) {
112+
_, _ = UnpackDatagram(data)
113+
})
114+
}
115+
116+
func FuzzRecordLayer_MarshalUnmarshal_RoundTrip(f *testing.F) {
117+
f.Add([]byte{}, uint16(0), uint64(0))
118+
f.Add([]byte{1, 2, 3}, uint16(1), uint64(5))
119+
120+
f.Fuzz(func(t *testing.T, payload []byte, epoch uint16, seq uint64) {
121+
if len(payload) > 1<<14 {
122+
payload = payload[:1<<14]
123+
}
124+
125+
recordLayer := &RecordLayer{
126+
Header: Header{
127+
ContentType: protocol.ContentTypeApplicationData,
128+
Version: protocol.Version1_2,
129+
Epoch: epoch,
130+
SequenceNumber: seq,
131+
},
132+
Content: &protocol.ApplicationData{Data: payload},
133+
}
134+
135+
raw, err := recordLayer.Marshal()
136+
require.NoError(t, err)
137+
138+
var back RecordLayer
139+
require.NoError(t, back.Unmarshal(raw))
140+
141+
require.Equal(t, recordLayer.Header.ContentType, back.Header.ContentType)
142+
require.Equal(t, recordLayer.Header.Version, back.Header.Version)
143+
require.Equal(t, recordLayer.Header.Epoch, back.Header.Epoch)
144+
require.Equal(t, recordLayer.Header.SequenceNumber, back.Header.SequenceNumber)
145+
146+
bodyLen := len(raw) - back.Header.Size()
147+
appData, ok := back.Content.(*protocol.ApplicationData)
148+
require.True(t, ok)
149+
require.Equal(t, bodyLen, len(appData.Data))
150+
151+
require.Equal(t, payload, appData.Data)
152+
153+
raw2, err := back.Marshal()
154+
require.NoError(t, err)
155+
require.Equal(t, raw, raw2)
156+
})
157+
}
158+
159+
func FuzzRecordLayer_UnpackDatagram_RoundTrip(f *testing.F) {
160+
f.Add(uint8(1), []byte("a"), []byte{}, []byte{}, []byte{})
161+
f.Add(uint8(3), []byte("one"), []byte("two"), []byte("three"), []byte(""))
162+
163+
f.Fuzz(func(t *testing.T, n uint8, p1, p2, p3, p4 []byte) {
164+
count := int(n%4) + 1
165+
all := [][]byte{p1, p2, p3, p4}
166+
all = all[:count]
167+
168+
for i := range all {
169+
if len(all[i]) > 1<<14 {
170+
all[i] = all[i][:1<<14]
171+
}
172+
if len(all[i]) == 0 {
173+
all[i] = []byte{0} // ensure a non-empty record
174+
}
175+
}
176+
177+
var dat []byte
178+
want := make([][]byte, 0, count)
179+
for i := 0; i < count; i++ {
180+
rl := &RecordLayer{
181+
Header: Header{
182+
ContentType: protocol.ContentTypeApplicationData,
183+
Version: protocol.Version1_2,
184+
Epoch: uint16(i), //nolint:gosec // G115: i is bounded (<= 4)
185+
SequenceNumber: uint64(1000) + uint64(i), //nolint:gosec // G115: i is bounded (<= 4)
186+
},
187+
Content: &protocol.ApplicationData{Data: all[i]},
188+
}
189+
raw, err := rl.Marshal()
190+
require.NoError(t, err)
191+
dat = append(dat, raw...)
192+
want = append(want, raw)
193+
}
194+
195+
chunks, err := UnpackDatagram(dat)
196+
require.NoError(t, err)
197+
require.Equal(t, len(want), len(chunks))
198+
199+
for i := range chunks {
200+
require.Equal(t, want[i], chunks[i])
201+
202+
require.True(t, len(chunks[i]) >= FixedHeaderSize+1)
203+
ln := int(binary.BigEndian.Uint16(chunks[i][11:]))
204+
require.Equal(t, ln, len(chunks[i])-FixedHeaderSize)
205+
206+
var rl RecordLayer
207+
require.NoError(t, rl.Unmarshal(chunks[i]))
208+
}
209+
210+
if len(dat) >= FixedHeaderSize+2 {
211+
bad := append([]byte{}, dat...)
212+
orig := binary.BigEndian.Uint16(bad[11:])
213+
binary.BigEndian.PutUint16(bad[11:], orig+1)
214+
_, err = UnpackDatagram(bad)
215+
require.ErrorIs(t, err, ErrInvalidPacketLength)
216+
}
217+
218+
if len(dat) > 0 {
219+
_, err = UnpackDatagram(dat[:len(dat)-1])
220+
require.ErrorIs(t, err, ErrInvalidPacketLength)
221+
}
222+
})
223+
}
224+
225+
func FuzzRecordLayer_ContentAwareUnpackDatagram_RoundTrip(f *testing.F) {
226+
f.Add(uint8(5), []byte("hello"), []byte("world"))
227+
f.Add(uint8(0), []byte{}, []byte("x"))
228+
229+
f.Fuzz(func(t *testing.T, cidLen uint8, p1, p2 []byte) {
230+
cl := int(cidLen % 8)
231+
232+
bound := func(b []byte) []byte {
233+
if len(b) > 1<<14 {
234+
b = b[:1<<14]
235+
}
236+
if len(b) == 0 {
237+
b = []byte{0}
238+
}
239+
240+
return b
241+
}
242+
p1, p2 = bound(p1), bound(p2)
243+
244+
cid := make([]byte, cl)
245+
for i := range cid {
246+
cid[i] = byte(i)
247+
}
248+
249+
makeCIDRecord := func(epoch uint16, seq uint64, payload []byte) []byte {
250+
header := make([]byte, FixedHeaderSize-2) // 11 bytes before len
251+
if cl > 0 {
252+
header[0] = byte(protocol.ContentTypeConnectionID)
253+
} else {
254+
header[0] = byte(protocol.ContentTypeApplicationData)
255+
}
256+
257+
header[1], header[2] = protocol.Version1_2.Major, protocol.Version1_2.Minor
258+
binary.BigEndian.PutUint16(header[3:], epoch)
259+
260+
// 48-bit sequence number
261+
seq48 := seq & 0x0000ffffffffffff
262+
header[5] = byte((seq48 >> 40) & 0xff)
263+
header[6] = byte((seq48 >> 32) & 0xff)
264+
header[7] = byte((seq48 >> 24) & 0xff)
265+
header[8] = byte((seq48 >> 16) & 0xff)
266+
header[9] = byte((seq48 >> 8) & 0xff)
267+
header[10] = byte(seq48 & 0xff)
268+
269+
out := make([]byte, 0, len(header)+cl+2+len(payload))
270+
out = append(out, header...)
271+
if cl > 0 {
272+
out = append(out, cid...)
273+
}
274+
275+
//nolint:gosec // G115: payload <= 1<<14
276+
binary.BigEndian.PutUint16(out[len(out):len(out)+2], uint16(len(payload)))
277+
out = out[:len(out)+2]
278+
out = append(out, payload...)
279+
280+
return out
281+
}
282+
283+
raw1 := makeCIDRecord(0, 10, p1)
284+
raw2 := makeCIDRecord(1, 11, p2)
285+
data := append(append([]byte{}, raw1...), raw2...)
286+
287+
parts, err := ContentAwareUnpackDatagram(data, cl)
288+
require.NoError(t, err)
289+
require.Equal(t, 2, len(parts))
290+
require.Equal(t, raw1, parts[0])
291+
require.Equal(t, raw2, parts[1])
292+
293+
// Validate length field and header size per record.
294+
for _, part := range parts {
295+
hdrExtra := 0
296+
if protocol.ContentType(part[0]) == protocol.ContentTypeConnectionID {
297+
hdrExtra = cl
298+
}
299+
300+
require.GreaterOrEqual(t, len(part), FixedHeaderSize+hdrExtra)
301+
302+
lenIdx := fixedHeaderLenIdx + hdrExtra
303+
require.GreaterOrEqual(t, len(part), lenIdx+2)
304+
305+
decl := int(binary.BigEndian.Uint16(part[lenIdx:]))
306+
require.Equal(t, decl, len(part)-(FixedHeaderSize+hdrExtra))
307+
}
308+
309+
// Negative: corrupt the first record's length.
310+
{
311+
bad := append([]byte{}, data...)
312+
hdrExtra := 0
313+
if protocol.ContentType(bad[0]) == protocol.ContentTypeConnectionID {
314+
hdrExtra = cl
315+
}
316+
lenIdx := fixedHeaderLenIdx + hdrExtra
317+
orig := binary.BigEndian.Uint16(bad[lenIdx:])
318+
binary.BigEndian.PutUint16(bad[lenIdx:], orig+1)
319+
_, err = ContentAwareUnpackDatagram(bad, cl)
320+
require.ErrorIs(t, err, ErrInvalidPacketLength)
321+
}
322+
323+
// Negative: truncate the datagram.
324+
if len(data) > 0 {
325+
_, err = ContentAwareUnpackDatagram(data[:len(data)-1], cl)
326+
require.ErrorIs(t, err, ErrInvalidPacketLength)
327+
}
328+
})
329+
}

0 commit comments

Comments
 (0)