Skip to content

Commit f5d9633

Browse files
committed
Merge ThreadSafeReader into ReadWaiter interface
1 parent 86c131f commit f5d9633

File tree

13 files changed

+257
-284
lines changed

13 files changed

+257
-284
lines changed

common/buf/buffer.go

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@ import (
88
"sync/atomic"
99

1010
"github.com/sagernet/sing/common"
11+
"github.com/sagernet/sing/common/debug"
1112
E "github.com/sagernet/sing/common/exceptions"
13+
F "github.com/sagernet/sing/common/format"
1214
)
1315

1416
type Buffer struct {
1517
data []byte
1618
start int
1719
end int
18-
refs int32
20+
refs atomic.Int32
1921
managed bool
2022
closed bool
2123
}
@@ -281,24 +283,40 @@ func (b *Buffer) FullReset() {
281283
}
282284

283285
func (b *Buffer) IncRef() {
284-
atomic.AddInt32(&b.refs, 1)
286+
b.refs.Add(1)
285287
}
286288

287289
func (b *Buffer) DecRef() {
288-
atomic.AddInt32(&b.refs, -1)
290+
b.refs.Add(-1)
289291
}
290292

291293
func (b *Buffer) Release() {
292294
if b == nil || b.closed || !b.managed {
293295
return
294296
}
295-
if atomic.LoadInt32(&b.refs) > 0 {
297+
if b.refs.Load() > 0 {
296298
return
297299
}
298300
common.Must(Put(b.data))
299301
*b = Buffer{closed: true}
300302
}
301303

304+
func (b *Buffer) Leak() {
305+
if debug.Enabled {
306+
if b == nil || b.closed || !b.managed {
307+
return
308+
}
309+
refs := b.refs.Load()
310+
if refs == 0 {
311+
panic("leaking buffer")
312+
} else {
313+
panic(F.ToString("leaking buffer with ", refs, " references"))
314+
}
315+
} else {
316+
b.Release()
317+
}
318+
}
319+
302320
func (b *Buffer) Cut(start int, end int) *Buffer {
303321
b.start += start
304322
b.end = len(b.data) - end

common/bufio/bind_wait.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ type BindPacketReadWaiter struct {
1212
readWaiter N.PacketReadWaiter
1313
}
1414

15-
func (w *BindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
16-
w.readWaiter.InitializeReadWaiter(newBuffer)
15+
func (w *BindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
16+
return w.readWaiter.InitializeReadWaiter(options)
1717
}
1818

1919
func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
@@ -28,8 +28,8 @@ type UnbindPacketReadWaiter struct {
2828
addr M.Socksaddr
2929
}
3030

31-
func (w *UnbindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
32-
w.readWaiter.InitializeReadWaiter(newBuffer)
31+
func (w *UnbindPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
32+
return w.readWaiter.InitializeReadWaiter(options)
3333
}
3434

3535
func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {

common/bufio/copy.go

Lines changed: 27 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"errors"
66
"io"
77
"net"
8-
"reflect"
98
"syscall"
109

1110
"github.com/sagernet/sing/common"
@@ -57,19 +56,21 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
5756
}
5857

5958
func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
60-
safeSrc := N.IsSafeReader(source)
61-
headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination)
62-
if safeSrc != nil {
63-
if headroom == 0 {
64-
return CopyExtendedWithSrcBuffer(originSource, destination, safeSrc, readCounters, writeCounters)
65-
}
66-
}
59+
frontHeadroom := N.CalculateFrontHeadroom(destination)
60+
rearHeadroom := N.CalculateRearHeadroom(destination)
6761
readWaiter, isReadWaiter := CreateReadWaiter(source)
6862
if isReadWaiter {
69-
var handled bool
70-
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
71-
if handled {
72-
return
63+
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
64+
FrontHeadroom: frontHeadroom,
65+
RearHeadroom: rearHeadroom,
66+
MTU: N.CalculateMTU(source, destination),
67+
})
68+
if !needCopy || common.LowMemory {
69+
var handled bool
70+
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
71+
if handled {
72+
return
73+
}
7374
}
7475
}
7576
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
@@ -113,38 +114,6 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
113114
}
114115
}
115116

116-
func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
117-
var notFirstTime bool
118-
for {
119-
var buffer *buf.Buffer
120-
buffer, err = source.ReadBufferThreadSafe()
121-
if err != nil {
122-
if errors.Is(err, io.EOF) {
123-
err = nil
124-
return
125-
}
126-
return
127-
}
128-
dataLen := buffer.Len()
129-
err = destination.WriteBuffer(buffer)
130-
if err != nil {
131-
buffer.Release()
132-
if !notFirstTime {
133-
err = N.ReportHandshakeFailure(originSource, err)
134-
}
135-
return
136-
}
137-
n += int64(dataLen)
138-
for _, counter := range readCounters {
139-
counter(int64(dataLen))
140-
}
141-
for _, counter := range writeCounters {
142-
counter(int64(dataLen))
143-
}
144-
notFirstTime = true
145-
}
146-
}
147-
148117
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
149118
frontHeadroom := N.CalculateFrontHeadroom(destination)
150119
rearHeadroom := N.CalculateRearHeadroom(destination)
@@ -173,7 +142,7 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
173142
buffer.Resize(readBuffer.Start(), dataLen)
174143
err = destination.WriteBuffer(buffer)
175144
if err != nil {
176-
buffer.Release()
145+
buffer.Leak()
177146
if !notFirstTime {
178147
err = N.ReportHandshakeFailure(originSource, err)
179148
}
@@ -256,69 +225,32 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
256225
return
257226
}
258227
}
259-
safeSrc := N.IsSafePacketReader(source)
260228
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
261229
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
262-
headroom := frontHeadroom + rearHeadroom
263-
if safeSrc != nil {
264-
if headroom == 0 {
265-
var copyN int64
266-
copyN, err = CopyPacketWithSrcBuffer(originSource, destinationConn, safeSrc, readCounters, writeCounters, n > 0)
267-
n += copyN
268-
return
269-
}
270-
}
271230
var (
272231
handled bool
273232
copeN int64
274233
)
275234
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
276235
if isReadWaiter {
277-
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
278-
if handled {
279-
n += copeN
280-
return
236+
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
237+
FrontHeadroom: frontHeadroom,
238+
RearHeadroom: rearHeadroom,
239+
MTU: N.CalculateMTU(source, destinationConn),
240+
})
241+
if !needCopy || common.LowMemory {
242+
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
243+
if handled {
244+
n += copeN
245+
return
246+
}
281247
}
282248
}
283249
copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0)
284250
n += copeN
285251
return
286252
}
287253

288-
func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
289-
var buffer *buf.Buffer
290-
var destination M.Socksaddr
291-
for {
292-
buffer, destination, err = source.ReadPacketThreadSafe()
293-
if err != nil {
294-
return
295-
}
296-
if buffer == nil {
297-
panic("nil buffer returned from " + reflect.TypeOf(source).String())
298-
}
299-
dataLen := buffer.Len()
300-
if dataLen == 0 {
301-
continue
302-
}
303-
err = destinationConn.WritePacket(buffer, destination)
304-
if err != nil {
305-
buffer.Release()
306-
if !notFirstTime {
307-
err = N.ReportHandshakeFailure(originSource, err)
308-
}
309-
return
310-
}
311-
n += int64(dataLen)
312-
for _, counter := range readCounters {
313-
counter(int64(dataLen))
314-
}
315-
for _, counter := range writeCounters {
316-
counter(int64(dataLen))
317-
}
318-
notFirstTime = true
319-
}
320-
}
321-
322254
func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
323255
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
324256
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
@@ -343,7 +275,7 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri
343275
buffer.Resize(readBuffer.Start(), dataLen)
344276
err = destinationConn.WritePacket(buffer, destination)
345277
if err != nil {
346-
buffer.Release()
278+
buffer.Leak()
347279
if !notFirstTime {
348280
err = N.ReportHandshakeFailure(originSource, err)
349281
}
@@ -379,7 +311,7 @@ func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWr
379311
buffer.Resize(readBuffer.Start(), dataLen)
380312
err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
381313
if err != nil {
382-
buffer.Release()
314+
buffer.Leak()
383315
if !notFirstTime {
384316
err = N.ReportHandshakeFailure(originSource, err)
385317
}

0 commit comments

Comments
 (0)