6
6
"errors"
7
7
"io"
8
8
"net/netip"
9
+ "os"
9
10
"syscall"
10
11
11
12
"github.com/sagernet/sing/common/buf"
@@ -25,10 +26,11 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
25
26
bufferSize = buf .BufferSize
26
27
}
27
28
var (
28
- buffer * buf.Buffer
29
- readBuffer * buf.Buffer
29
+ buffer * buf.Buffer
30
+ readBuffer * buf.Buffer
31
+ notFirstTime bool
30
32
)
31
- newBuffer := func () * buf.Buffer {
33
+ source . InitializeReadWaiter ( func () * buf.Buffer {
32
34
if buffer != nil {
33
35
buffer .Release ()
34
36
}
@@ -37,10 +39,10 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
37
39
readBuffer = buf .With (readBufferRaw [:len (readBufferRaw )- rearHeadroom ])
38
40
readBuffer .Resize (frontHeadroom , 0 )
39
41
return readBuffer
40
- }
41
- var notFirstTime bool
42
+ })
43
+ defer source . InitializeReadWaiter ( nil )
42
44
for {
43
- err = source .WaitReadBuffer (newBuffer )
45
+ err = source .WaitReadBuffer ()
44
46
if err != nil {
45
47
buffer .Release ()
46
48
if errors .Is (err , io .EOF ) {
@@ -55,10 +57,8 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
55
57
dataLen := readBuffer .Len ()
56
58
buffer .Resize (readBuffer .Start (), dataLen )
57
59
err = destination .WriteBuffer (buffer )
60
+ buffer .Release ()
58
61
if err != nil {
59
- if buffer != nil {
60
- buffer .Release ()
61
- }
62
62
return
63
63
}
64
64
n += int64 (dataLen )
@@ -83,10 +83,12 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
83
83
bufferSize = buf .UDPBufferSize
84
84
}
85
85
var (
86
- buffer * buf.Buffer
87
- readBuffer * buf.Buffer
86
+ buffer * buf.Buffer
87
+ readBuffer * buf.Buffer
88
+ destination M.Socksaddr
89
+ notFirstTime bool
88
90
)
89
- newBuffer := func () * buf.Buffer {
91
+ source . InitializeReadWaiter ( func () * buf.Buffer {
90
92
if buffer != nil {
91
93
buffer .Release ()
92
94
}
@@ -95,11 +97,10 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
95
97
readBuffer = buf .With (readBufferRaw [:len (readBufferRaw )- rearHeadroom ])
96
98
readBuffer .Resize (frontHeadroom , 0 )
97
99
return readBuffer
98
- }
99
- var destination M.Socksaddr
100
- var notFirstTime bool
100
+ })
101
+ defer source .InitializeReadWaiter (nil )
101
102
for {
102
- destination , err = source .WaitReadPacket (newBuffer )
103
+ destination , err = source .WaitReadPacket ()
103
104
if err != nil {
104
105
buffer .Release ()
105
106
if ! notFirstTime {
@@ -113,9 +114,8 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
113
114
if err != nil {
114
115
buffer .Release ()
115
116
return
116
- } else {
117
- buffer = nil
118
117
}
118
+ buffer = nil
119
119
n += int64 (dataLen )
120
120
for _ , counter := range readCounters {
121
121
counter (int64 (dataLen ))
@@ -127,6 +127,8 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
127
127
}
128
128
}
129
129
130
+ var _ N.ReadWaiter = (* syscallReadWaiter )(nil )
131
+
130
132
type syscallReadWaiter struct {
131
133
rawConn syscall.RawConn
132
134
readErr error
@@ -143,8 +145,11 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
143
145
return nil , false
144
146
}
145
147
146
- func (w * syscallReadWaiter ) WaitReadBuffer (newBuffer func () * buf.Buffer ) error {
147
- if w .readFunc == nil {
148
+ func (w * syscallReadWaiter ) InitializeReadWaiter (newBuffer func () * buf.Buffer ) {
149
+ w .readErr = nil
150
+ if newBuffer == nil {
151
+ w .readFunc = nil
152
+ } else {
148
153
w .readFunc = func (fd uintptr ) (done bool ) {
149
154
buffer := newBuffer ()
150
155
var readN int
@@ -164,16 +169,27 @@ func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
164
169
return true
165
170
}
166
171
}
172
+ }
173
+
174
+ func (w * syscallReadWaiter ) WaitReadBuffer () error {
175
+ if w .readFunc == nil {
176
+ return os .ErrInvalid
177
+ }
167
178
err := w .rawConn .Read (w .readFunc )
168
179
if err != nil {
169
180
return err
170
181
}
171
182
if w .readErr != nil {
183
+ if w .readErr == io .EOF {
184
+ return io .EOF
185
+ }
172
186
return E .Cause (w .readErr , "raw read" )
173
187
}
174
188
return nil
175
189
}
176
190
191
+ var _ N.PacketReadWaiter = (* syscallPacketReadWaiter )(nil )
192
+
177
193
type syscallPacketReadWaiter struct {
178
194
rawConn syscall.RawConn
179
195
readErr error
@@ -191,8 +207,12 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
191
207
return nil , false
192
208
}
193
209
194
- func (w * syscallPacketReadWaiter ) WaitReadPacket (newBuffer func () * buf.Buffer ) (destination M.Socksaddr , err error ) {
195
- if w .readFunc == nil {
210
+ func (w * syscallPacketReadWaiter ) InitializeReadWaiter (newBuffer func () * buf.Buffer ) {
211
+ w .readErr = nil
212
+ w .readFrom = M.Socksaddr {}
213
+ if newBuffer == nil {
214
+ w .readFunc = nil
215
+ } else {
196
216
w .readFunc = func (fd uintptr ) (done bool ) {
197
217
buffer := newBuffer ()
198
218
var readN int
@@ -221,6 +241,12 @@ func (w *syscallPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (
221
241
return true
222
242
}
223
243
}
244
+ }
245
+
246
+ func (w * syscallPacketReadWaiter ) WaitReadPacket () (destination M.Socksaddr , err error ) {
247
+ if w .readFunc == nil {
248
+ return M.Socksaddr {}, os .ErrInvalid
249
+ }
224
250
err = w .rawConn .Read (w .readFunc )
225
251
if err != nil {
226
252
return
0 commit comments