Skip to content

Commit dd0be0d

Browse files
dyhkwongnekohasekai
authored andcommitted
Fix socks5 packet conn
1 parent 8fb1634 commit dd0be0d

File tree

3 files changed

+38
-40
lines changed

3 files changed

+38
-40
lines changed

common/bufio/nat.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ func (c *unidirectionalNATPacketConn) UpdateDestination(destinationAddress netip
6363
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
6464
}
6565

66+
func (c *unidirectionalNATPacketConn) RemoteAddr() net.Addr {
67+
return c.destination.UDPAddr()
68+
}
69+
6670
func (c *unidirectionalNATPacketConn) Upstream() any {
6771
return c.NetPacketConn
6872
}
@@ -136,6 +140,10 @@ func (c *bidirectionalNATPacketConn) Upstream() any {
136140
return c.NetPacketConn
137141
}
138142

143+
func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr {
144+
return c.destination.UDPAddr()
145+
}
146+
139147
func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
140148
destination.Port = 0
141149
return destination

protocol/socks/client.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"os"
88
"strings"
99

10-
"github.com/sagernet/sing/common/bufio"
1110
E "github.com/sagernet/sing/common/exceptions"
1211
M "github.com/sagernet/sing/common/metadata"
1312
N "github.com/sagernet/sing/common/network"
@@ -148,7 +147,7 @@ func (c *Client) DialContext(ctx context.Context, network string, address M.Sock
148147
tcpConn.Close()
149148
return nil, err
150149
}
151-
return NewAssociatePacketConn(bufio.NewUnbindPacketConn(udpConn), address, tcpConn), nil
150+
return NewAssociatePacketConn(udpConn, address, tcpConn), nil
152151
}
153152
return nil, os.ErrInvalid
154153
}

protocol/socks/packet.go

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,54 +21,41 @@ import (
2121
var ErrInvalidPacket = E.New("socks5: invalid packet")
2222

2323
type AssociatePacketConn struct {
24-
N.NetPacketConn
24+
N.AbstractConn
25+
conn N.ExtendedConn
2526
remoteAddr M.Socksaddr
2627
underlying net.Conn
2728
}
2829

29-
func NewAssociatePacketConn(conn net.PacketConn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn {
30+
func NewAssociatePacketConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn {
3031
return &AssociatePacketConn{
31-
NetPacketConn: bufio.NewPacketConn(conn),
32-
remoteAddr: remoteAddr,
33-
underlying: underlying,
32+
AbstractConn: conn,
33+
conn: bufio.NewExtendedConn(conn),
34+
remoteAddr: remoteAddr,
35+
underlying: underlying,
3436
}
3537
}
3638

37-
// Deprecated: NewAssociatePacketConn(bufio.NewUnbindPacketConn(conn), remoteAddr, underlying) instead.
38-
func NewAssociateConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn {
39-
return &AssociatePacketConn{
40-
NetPacketConn: bufio.NewUnbindPacketConn(conn),
41-
remoteAddr: remoteAddr,
42-
underlying: underlying,
43-
}
44-
}
45-
46-
func (c *AssociatePacketConn) RemoteAddr() net.Addr {
47-
return c.remoteAddr.UDPAddr()
48-
}
49-
50-
//warn:unsafe
5139
func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
52-
n, addr, err = c.NetPacketConn.ReadFrom(p)
40+
n, err = c.conn.Read(p)
5341
if err != nil {
5442
return
5543
}
5644
if n < 3 {
5745
return 0, nil, ErrInvalidPacket
5846
}
59-
c.remoteAddr = M.SocksaddrFromNet(addr)
6047
reader := bytes.NewReader(p[3:n])
6148
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
6249
if err != nil {
6350
return
6451
}
52+
c.remoteAddr = destination
6553
addr = destination.UDPAddr()
6654
index := 3 + int(reader.Size()) - reader.Len()
6755
n = copy(p, p[index:n])
6856
return
6957
}
7058

71-
//warn:unsafe
7259
func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
7360
destination := M.SocksaddrFromNet(addr)
7461
buffer := buf.NewSize(3 + M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
@@ -82,32 +69,23 @@ func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error
8269
if err != nil {
8370
return
8471
}
85-
return bufio.WritePacketBuffer(c.NetPacketConn, buffer, c.remoteAddr)
86-
}
87-
88-
func (c *AssociatePacketConn) Read(b []byte) (n int, err error) {
89-
n, _, err = c.ReadFrom(b)
90-
return
91-
}
92-
93-
func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
94-
return c.WriteTo(b, c.remoteAddr)
72+
return c.conn.Write(buffer.Bytes())
9573
}
9674

9775
func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
98-
destination, err = c.NetPacketConn.ReadPacket(buffer)
76+
err = c.conn.ReadBuffer(buffer)
9977
if err != nil {
100-
return M.Socksaddr{}, err
78+
return
10179
}
10280
if buffer.Len() < 3 {
10381
return M.Socksaddr{}, ErrInvalidPacket
10482
}
105-
c.remoteAddr = destination
10683
buffer.Advance(3)
10784
destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
10885
if err != nil {
10986
return
11087
}
88+
c.remoteAddr = destination
11189
return destination.Unwrap(), nil
11290
}
11391

@@ -118,11 +96,24 @@ func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Sock
11896
if err != nil {
11997
return err
12098
}
121-
return common.Error(bufio.WritePacketBuffer(c.NetPacketConn, buffer, c.remoteAddr))
99+
return c.conn.WriteBuffer(buffer)
100+
}
101+
102+
func (c *AssociatePacketConn) Read(b []byte) (n int, err error) {
103+
n, _, err = c.ReadFrom(b)
104+
return
105+
}
106+
107+
func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
108+
return c.WriteTo(b, c.remoteAddr)
109+
}
110+
111+
func (c *AssociatePacketConn) RemoteAddr() net.Addr {
112+
return c.remoteAddr.UDPAddr()
122113
}
123114

124115
func (c *AssociatePacketConn) Upstream() any {
125-
return c.NetPacketConn
116+
return c.conn
126117
}
127118

128119
func (c *AssociatePacketConn) FrontHeadroom() int {
@@ -131,7 +122,7 @@ func (c *AssociatePacketConn) FrontHeadroom() int {
131122

132123
func (c *AssociatePacketConn) Close() error {
133124
return common.Close(
134-
c.NetPacketConn,
125+
c.conn,
135126
c.underlying,
136127
)
137128
}

0 commit comments

Comments
 (0)