Skip to content

Commit b671451

Browse files
committed
Improve read waiter interface
1 parent ab3e469 commit b671451

File tree

2 files changed

+52
-24
lines changed

2 files changed

+52
-24
lines changed

common/bufio/copy_direct_posix.go

+48-22
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"errors"
77
"io"
88
"net/netip"
9+
"os"
910
"syscall"
1011

1112
"github.com/sagernet/sing/common/buf"
@@ -25,10 +26,11 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
2526
bufferSize = buf.BufferSize
2627
}
2728
var (
28-
buffer *buf.Buffer
29-
readBuffer *buf.Buffer
29+
buffer *buf.Buffer
30+
readBuffer *buf.Buffer
31+
notFirstTime bool
3032
)
31-
newBuffer := func() *buf.Buffer {
33+
source.InitializeReadWaiter(func() *buf.Buffer {
3234
if buffer != nil {
3335
buffer.Release()
3436
}
@@ -37,10 +39,10 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
3739
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
3840
readBuffer.Resize(frontHeadroom, 0)
3941
return readBuffer
40-
}
41-
var notFirstTime bool
42+
})
43+
defer source.InitializeReadWaiter(nil)
4244
for {
43-
err = source.WaitReadBuffer(newBuffer)
45+
err = source.WaitReadBuffer()
4446
if err != nil {
4547
buffer.Release()
4648
if errors.Is(err, io.EOF) {
@@ -55,10 +57,8 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
5557
dataLen := readBuffer.Len()
5658
buffer.Resize(readBuffer.Start(), dataLen)
5759
err = destination.WriteBuffer(buffer)
60+
buffer.Release()
5861
if err != nil {
59-
if buffer != nil {
60-
buffer.Release()
61-
}
6262
return
6363
}
6464
n += int64(dataLen)
@@ -83,10 +83,12 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
8383
bufferSize = buf.UDPBufferSize
8484
}
8585
var (
86-
buffer *buf.Buffer
87-
readBuffer *buf.Buffer
86+
buffer *buf.Buffer
87+
readBuffer *buf.Buffer
88+
destination M.Socksaddr
89+
notFirstTime bool
8890
)
89-
newBuffer := func() *buf.Buffer {
91+
source.InitializeReadWaiter(func() *buf.Buffer {
9092
if buffer != nil {
9193
buffer.Release()
9294
}
@@ -95,11 +97,10 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
9597
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
9698
readBuffer.Resize(frontHeadroom, 0)
9799
return readBuffer
98-
}
99-
var destination M.Socksaddr
100-
var notFirstTime bool
100+
})
101+
defer source.InitializeReadWaiter(nil)
101102
for {
102-
destination, err = source.WaitReadPacket(newBuffer)
103+
destination, err = source.WaitReadPacket()
103104
if err != nil {
104105
buffer.Release()
105106
if !notFirstTime {
@@ -113,9 +114,8 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
113114
if err != nil {
114115
buffer.Release()
115116
return
116-
} else {
117-
buffer = nil
118117
}
118+
buffer = nil
119119
n += int64(dataLen)
120120
for _, counter := range readCounters {
121121
counter(int64(dataLen))
@@ -127,6 +127,8 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
127127
}
128128
}
129129

130+
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
131+
130132
type syscallReadWaiter struct {
131133
rawConn syscall.RawConn
132134
readErr error
@@ -143,8 +145,11 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
143145
return nil, false
144146
}
145147

146-
func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
147-
if w.readFunc == nil {
148+
func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
149+
w.readErr = nil
150+
if newBuffer == nil {
151+
w.readFunc = nil
152+
} else {
148153
w.readFunc = func(fd uintptr) (done bool) {
149154
buffer := newBuffer()
150155
var readN int
@@ -164,16 +169,27 @@ func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
164169
return true
165170
}
166171
}
172+
}
173+
174+
func (w *syscallReadWaiter) WaitReadBuffer() error {
175+
if w.readFunc == nil {
176+
return os.ErrInvalid
177+
}
167178
err := w.rawConn.Read(w.readFunc)
168179
if err != nil {
169180
return err
170181
}
171182
if w.readErr != nil {
183+
if w.readErr == io.EOF {
184+
return io.EOF
185+
}
172186
return E.Cause(w.readErr, "raw read")
173187
}
174188
return nil
175189
}
176190

191+
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
192+
177193
type syscallPacketReadWaiter struct {
178194
rawConn syscall.RawConn
179195
readErr error
@@ -191,8 +207,12 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
191207
return nil, false
192208
}
193209

194-
func (w *syscallPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
195-
if w.readFunc == nil {
210+
func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
211+
w.readErr = nil
212+
w.readFrom = M.Socksaddr{}
213+
if newBuffer == nil {
214+
w.readFunc = nil
215+
} else {
196216
w.readFunc = func(fd uintptr) (done bool) {
197217
buffer := newBuffer()
198218
var readN int
@@ -221,6 +241,12 @@ func (w *syscallPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (
221241
return true
222242
}
223243
}
244+
}
245+
246+
func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) {
247+
if w.readFunc == nil {
248+
return M.Socksaddr{}, os.ErrInvalid
249+
}
224250
err = w.rawConn.Read(w.readFunc)
225251
if err != nil {
226252
return

common/network/direct.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@ import (
66
)
77

88
type ReadWaiter interface {
9-
WaitReadBuffer(newBuffer func() *buf.Buffer) error
9+
InitializeReadWaiter(newBuffer func() *buf.Buffer)
10+
WaitReadBuffer() error
1011
}
1112

1213
type ReadWaitCreator interface {
1314
CreateReadWaiter() (ReadWaiter, bool)
1415
}
1516

1617
type PacketReadWaiter interface {
17-
WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error)
18+
InitializeReadWaiter(newBuffer func() *buf.Buffer)
19+
WaitReadPacket() (destination M.Socksaddr, err error)
1820
}
1921

2022
type PacketReadWaitCreator interface {

0 commit comments

Comments
 (0)