Skip to content

Commit f30d881

Browse files
committed
fix(reader): 修复读取数据时的边界检查和错误处理
在Reader的多个方法中添加了对缓冲区边界和读取错误的检查,确保在数据不足或读取失败时返回安全值。同时,新增了测试文件reader_test.go,验证了空缓冲区、不完整数据、错误Reader等情况下的行为。
1 parent 25727bd commit f30d881

File tree

2 files changed

+201
-3
lines changed

2 files changed

+201
-3
lines changed

utils/binary/reader.go

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,19 @@ func (r *Reader) ReadAll() []byte {
8181

8282
func (r *Reader) ReadU8() (v uint8) {
8383
if r.reader != nil {
84-
_, _ = r.reader.Read(unsafe.Slice(&v, 1))
84+
buf := make([]byte, 1)
85+
n, err := r.reader.Read(buf)
86+
if err != nil || n < 1 {
87+
// 读取失败或读取的数据不足,返回零值
88+
return 0
89+
}
90+
v = buf[0]
8591
return
8692
}
93+
// 确保缓冲区有足够的数据
94+
if r.pos >= len(r.buffer) {
95+
return 0
96+
}
8797
v = r.buffer[r.pos]
8898
r.pos++
8999
return
@@ -93,8 +103,16 @@ func readint[T ~uint16 | ~uint32 | ~uint64](r *Reader) (v T) {
93103
sz := unsafe.Sizeof(v)
94104
buf := make([]byte, 8)
95105
if r.reader != nil {
96-
_, _ = r.reader.Read(buf[8-sz:])
106+
n, err := r.reader.Read(buf[8-sz:])
107+
if err != nil || n < int(sz) {
108+
// 读取失败或读取的数据不足,返回零值
109+
return 0
110+
}
97111
} else {
112+
// 确保缓冲区有足够的数据
113+
if r.pos+int(sz) > len(r.buffer) {
114+
return 0
115+
}
98116
copy(buf[8-sz:], r.buffer[r.pos:r.pos+int(sz)])
99117
r.pos += int(sz)
100118
}
@@ -129,6 +147,10 @@ func (r *Reader) ReadBytesNoCopy(length int) (v []byte) {
129147
if r.reader != nil {
130148
return r.ReadBytes(length)
131149
}
150+
// 确保缓冲区有足够的数据
151+
if r.pos+length > len(r.buffer) {
152+
return make([]byte, 0)
153+
}
132154
v = r.buffer[r.pos : r.pos+length]
133155
r.pos += length
134156
return
@@ -138,8 +160,16 @@ func (r *Reader) ReadBytes(length int) (v []byte) {
138160
// 返回一个全新的数组罢
139161
v = make([]byte, length)
140162
if r.reader != nil {
141-
_, _ = r.reader.Read(v)
163+
n, err := io.ReadFull(r.reader, v)
164+
if err != nil || n < length {
165+
// 读取失败或读取的数据不足,返回空数组
166+
return make([]byte, 0)
167+
}
142168
} else {
169+
// 确保缓冲区有足够的数据
170+
if r.pos+length > len(r.buffer) {
171+
return make([]byte, 0)
172+
}
143173
copy(v, r.buffer[r.pos:r.pos+length])
144174
r.pos += length
145175
}

utils/binary/reader_test.go

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
package binary
2+
3+
import (
4+
"io"
5+
"testing"
6+
)
7+
8+
// TestReaderEmptyBuffer 测试空缓冲区的情况
9+
func TestReaderEmptyBuffer(t *testing.T) {
10+
// 测试空缓冲区
11+
r := NewReader([]byte{})
12+
13+
// 测试ReadU8
14+
if v := r.ReadU8(); v != 0 {
15+
t.Errorf("ReadU8 with empty buffer should return 0, got %d", v)
16+
}
17+
18+
// 测试ReadU16
19+
if v := r.ReadU16(); v != 0 {
20+
t.Errorf("ReadU16 with empty buffer should return 0, got %d", v)
21+
}
22+
23+
// 测试ReadU32
24+
if v := r.ReadU32(); v != 0 {
25+
t.Errorf("ReadU32 with empty buffer should return 0, got %d", v)
26+
}
27+
28+
// 测试ReadU64
29+
if v := r.ReadU64(); v != 0 {
30+
t.Errorf("ReadU64 with empty buffer should return 0, got %d", v)
31+
}
32+
33+
// 测试ReadBytes
34+
if bytes := r.ReadBytes(10); len(bytes) != 0 {
35+
t.Errorf("ReadBytes with empty buffer should return empty slice, got %v", bytes)
36+
}
37+
38+
// 测试ReadBytesNoCopy
39+
if bytes := r.ReadBytesNoCopy(10); len(bytes) != 0 {
40+
t.Errorf("ReadBytesNoCopy with empty buffer should return empty slice, got %v", bytes)
41+
}
42+
}
43+
44+
// TestReaderIncompleteData 测试不完整数据的情况
45+
func TestReaderIncompleteData(t *testing.T) {
46+
// 测试不完整数据 - 只有1个字节
47+
r := NewReader([]byte{0x01})
48+
49+
// 测试ReadU16 (需要2字节)
50+
if v := r.ReadU16(); v != 0 {
51+
t.Errorf("ReadU16 with incomplete data should return 0, got %d", v)
52+
}
53+
54+
// 重置Reader
55+
r = NewReader([]byte{0x01, 0x02})
56+
57+
// 测试ReadU32 (需要4字节)
58+
if v := r.ReadU32(); v != 0 {
59+
t.Errorf("ReadU32 with incomplete data should return 0, got %d", v)
60+
}
61+
62+
// 测试ReadBytes超出可用长度
63+
r = NewReader([]byte{0x01, 0x02, 0x03})
64+
if bytes := r.ReadBytes(10); len(bytes) != 0 {
65+
t.Errorf("ReadBytes with insufficient data should return empty slice, got %v", bytes)
66+
}
67+
}
68+
69+
// TestReaderWithIOReader 测试使用io.Reader的情况
70+
func TestReaderWithIOReader(t *testing.T) {
71+
// 创建一个会返回错误的Reader
72+
errReader := &errorReader{}
73+
r := ParseReader(errReader)
74+
75+
// 测试ReadU8
76+
if v := r.ReadU8(); v != 0 {
77+
t.Errorf("ReadU8 with error reader should return 0, got %d", v)
78+
}
79+
80+
// 测试ReadU16
81+
if v := r.ReadU16(); v != 0 {
82+
t.Errorf("ReadU16 with error reader should return 0, got %d", v)
83+
}
84+
85+
// 测试ReadU32
86+
if v := r.ReadU32(); v != 0 {
87+
t.Errorf("ReadU32 with error reader should return 0, got %d", v)
88+
}
89+
90+
// 测试ReadBytes
91+
if bytes := r.ReadBytes(10); len(bytes) != 0 {
92+
t.Errorf("ReadBytes with error reader should return empty slice, got %v", bytes)
93+
}
94+
95+
// 测试ReadAll
96+
if data := r.ReadAll(); data != nil {
97+
t.Errorf("ReadAll with error reader should return nil, got %v", data)
98+
}
99+
}
100+
101+
// TestReaderNormalData 测试正常数据的情况
102+
func TestReaderNormalData(t *testing.T) {
103+
// 准备测试数据
104+
data := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}
105+
r := NewReader(data)
106+
107+
// 测试ReadU8
108+
if v := r.ReadU8(); v != 0x01 {
109+
t.Errorf("ReadU8 should return 0x01, got 0x%02x", v)
110+
}
111+
112+
// 测试ReadU16
113+
if v := r.ReadU16(); v != 0x0203 {
114+
t.Errorf("ReadU16 should return 0x0203, got 0x%04x", v)
115+
}
116+
117+
// 测试ReadU32
118+
if v := r.ReadU32(); v != 0x04050607 {
119+
t.Errorf("ReadU32 should return 0x04050607, got 0x%08x", v)
120+
}
121+
122+
// 测试ReadByte
123+
if b, err := r.ReadByte(); err != nil || b != 0x08 {
124+
t.Errorf("ReadByte should return 0x08, got 0x%02x, err: %v", b, err)
125+
}
126+
127+
// 测试读取完所有数据后的ReadByte
128+
if _, err := r.ReadByte(); err != io.EOF {
129+
t.Errorf("ReadByte after end should return EOF, got %v", err)
130+
}
131+
}
132+
133+
// TestReaderShortRead 测试短读的情况
134+
func TestReaderShortRead(t *testing.T) {
135+
// 创建一个会返回短读的Reader
136+
shortReader := &shortReader{data: []byte{0x01, 0x02, 0x03, 0x04}}
137+
r := ParseReader(shortReader)
138+
139+
// 测试ReadBytes
140+
if bytes := r.ReadBytes(10); len(bytes) != 0 {
141+
t.Errorf("ReadBytes with short reader should return empty slice, got %v", bytes)
142+
}
143+
}
144+
145+
// 辅助测试结构
146+
147+
// errorReader 总是返回错误的Reader
148+
type errorReader struct{}
149+
150+
func (r *errorReader) Read(p []byte) (n int, err error) {
151+
return 0, io.ErrUnexpectedEOF
152+
}
153+
154+
// shortReader 总是返回短读的Reader
155+
type shortReader struct {
156+
data []byte
157+
pos int
158+
}
159+
160+
func (r *shortReader) Read(p []byte) (n int, err error) {
161+
if r.pos >= len(r.data) {
162+
return 0, io.EOF
163+
}
164+
// 只读取一个字节,模拟短读
165+
p[0] = r.data[r.pos]
166+
r.pos++
167+
return 1, nil
168+
}

0 commit comments

Comments
 (0)