Skip to content

Commit 9179445

Browse files
committed
Merge ThreadSafeReader into ReadWaiter interface
1 parent 86c131f commit 9179445

File tree

11 files changed

+209
-177
lines changed

11 files changed

+209
-177
lines changed

common/buf/buffer.go

+22-4
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@ import (
88
"sync/atomic"
99

1010
"github.com/sagernet/sing/common"
11+
"github.com/sagernet/sing/common/debug"
1112
E "github.com/sagernet/sing/common/exceptions"
13+
F "github.com/sagernet/sing/common/format"
1214
)
1315

1416
type Buffer struct {
1517
data []byte
1618
start int
1719
end int
18-
refs int32
20+
refs atomic.Int32
1921
managed bool
2022
closed bool
2123
}
@@ -281,24 +283,40 @@ func (b *Buffer) FullReset() {
281283
}
282284

283285
func (b *Buffer) IncRef() {
284-
atomic.AddInt32(&b.refs, 1)
286+
b.refs.Add(1)
285287
}
286288

287289
func (b *Buffer) DecRef() {
288-
atomic.AddInt32(&b.refs, -1)
290+
b.refs.Add(-1)
289291
}
290292

291293
func (b *Buffer) Release() {
292294
if b == nil || b.closed || !b.managed {
293295
return
294296
}
295-
if atomic.LoadInt32(&b.refs) > 0 {
297+
if b.refs.Load() > 0 {
296298
return
297299
}
298300
common.Must(Put(b.data))
299301
*b = Buffer{closed: true}
300302
}
301303

304+
func (b *Buffer) Leak() {
305+
if debug.Enabled {
306+
if b == nil || b.closed || !b.managed {
307+
return
308+
}
309+
refs := b.refs.Load()
310+
if refs == 0 {
311+
panic("leaking buffer")
312+
} else {
313+
panic(F.ToString("leaking buffer with ", refs, " references"))
314+
}
315+
} else {
316+
b.Release()
317+
}
318+
}
319+
302320
func (b *Buffer) Cut(start int, end int) *Buffer {
303321
b.start += start
304322
b.end = len(b.data) - end

common/bufio/bind_wait.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ type BindPacketReadWaiter struct {
1212
readWaiter N.PacketReadWaiter
1313
}
1414

15-
func (w *BindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
16-
w.readWaiter.InitializeReadWaiter(newBuffer)
15+
func (w *BindPacketReadWaiter) InitializeReadWaiter(options *N.ReadWaitOptions) (needCopy bool) {
16+
return w.readWaiter.InitializeReadWaiter(options)
1717
}
1818

1919
func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
@@ -28,8 +28,8 @@ type UnbindPacketReadWaiter struct {
2828
addr M.Socksaddr
2929
}
3030

31-
func (w *UnbindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
32-
w.readWaiter.InitializeReadWaiter(newBuffer)
31+
func (w *UnbindPacketReadWaiter) InitializeReadWaiter(options *N.ReadWaitOptions) (needCopy bool) {
32+
return w.readWaiter.InitializeReadWaiter(options)
3333
}
3434

3535
func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {

common/bufio/copy.go

+31-30
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,21 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
5757
}
5858

5959
func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
60-
safeSrc := N.IsSafeReader(source)
61-
headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination)
62-
if safeSrc != nil {
63-
if headroom == 0 {
64-
return CopyExtendedWithSrcBuffer(originSource, destination, safeSrc, readCounters, writeCounters)
65-
}
66-
}
60+
frontHeadroom := N.CalculateFrontHeadroom(destination)
61+
rearHeadroom := N.CalculateRearHeadroom(destination)
6762
readWaiter, isReadWaiter := CreateReadWaiter(source)
6863
if isReadWaiter {
69-
var handled bool
70-
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
71-
if handled {
72-
return
64+
needCopy := readWaiter.InitializeReadWaiter(&N.ReadWaitOptions{
65+
FrontHeadroom: frontHeadroom,
66+
RearHeadroom: rearHeadroom,
67+
MTU: N.CalculateMTU(source, destination),
68+
})
69+
if !needCopy || common.LowMemory {
70+
var handled bool
71+
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
72+
if handled {
73+
return
74+
}
7375
}
7476
}
7577
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
@@ -113,6 +115,7 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
113115
}
114116
}
115117

118+
// Deprecated: Use ReadWaiter interface instead.
116119
func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
117120
var notFirstTime bool
118121
for {
@@ -128,7 +131,7 @@ func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWri
128131
dataLen := buffer.Len()
129132
err = destination.WriteBuffer(buffer)
130133
if err != nil {
131-
buffer.Release()
134+
buffer.Leak()
132135
if !notFirstTime {
133136
err = N.ReportHandshakeFailure(originSource, err)
134137
}
@@ -173,7 +176,7 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
173176
buffer.Resize(readBuffer.Start(), dataLen)
174177
err = destination.WriteBuffer(buffer)
175178
if err != nil {
176-
buffer.Release()
179+
buffer.Leak()
177180
if !notFirstTime {
178181
err = N.ReportHandshakeFailure(originSource, err)
179182
}
@@ -256,35 +259,33 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
256259
return
257260
}
258261
}
259-
safeSrc := N.IsSafePacketReader(source)
260262
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
261263
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
262-
headroom := frontHeadroom + rearHeadroom
263-
if safeSrc != nil {
264-
if headroom == 0 {
265-
var copyN int64
266-
copyN, err = CopyPacketWithSrcBuffer(originSource, destinationConn, safeSrc, readCounters, writeCounters, n > 0)
267-
n += copyN
268-
return
269-
}
270-
}
271264
var (
272265
handled bool
273266
copeN int64
274267
)
275268
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
276269
if isReadWaiter {
277-
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
278-
if handled {
279-
n += copeN
280-
return
270+
needCopy := readWaiter.InitializeReadWaiter(&N.ReadWaitOptions{
271+
FrontHeadroom: frontHeadroom,
272+
RearHeadroom: rearHeadroom,
273+
MTU: N.CalculateMTU(source, destinationConn),
274+
})
275+
if !needCopy || common.LowMemory {
276+
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
277+
if handled {
278+
n += copeN
279+
return
280+
}
281281
}
282282
}
283283
copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0)
284284
n += copeN
285285
return
286286
}
287287

288+
// Deprecated: Use PacketReadWaiter interface instead.
288289
func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
289290
var buffer *buf.Buffer
290291
var destination M.Socksaddr
@@ -302,7 +303,7 @@ func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.Pack
302303
}
303304
err = destinationConn.WritePacket(buffer, destination)
304305
if err != nil {
305-
buffer.Release()
306+
buffer.Leak()
306307
if !notFirstTime {
307308
err = N.ReportHandshakeFailure(originSource, err)
308309
}
@@ -343,7 +344,7 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri
343344
buffer.Resize(readBuffer.Start(), dataLen)
344345
err = destinationConn.WritePacket(buffer, destination)
345346
if err != nil {
346-
buffer.Release()
347+
buffer.Leak()
347348
if !notFirstTime {
348349
err = N.ReportHandshakeFailure(originSource, err)
349350
}
@@ -379,7 +380,7 @@ func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWr
379380
buffer.Resize(readBuffer.Start(), dataLen)
380381
err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
381382
if err != nil {
382-
buffer.Release()
383+
buffer.Leak()
383384
if !notFirstTime {
384385
err = N.ReportHandshakeFailure(originSource, err)
385386
}

0 commit comments

Comments
 (0)