Skip to content

Commit d9a5120

Browse files
committed
Remove bad rw usages
1 parent 01b7350 commit d9a5120

File tree

7 files changed

+108
-74
lines changed

7 files changed

+108
-74
lines changed

common/cond.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,10 +363,12 @@ func Close(closers ...any) error {
363363
return retErr
364364
}
365365

366+
// Deprecated: wtf is this?
366367
type Starter interface {
367368
Start() error
368369
}
369370

371+
// Deprecated: wtf is this?
370372
func Start(starters ...any) error {
371373
for _, rawStarter := range starters {
372374
if rawStarter == nil {

common/metadata/serializer.go

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
"github.com/sagernet/sing/common"
99
"github.com/sagernet/sing/common/buf"
1010
E "github.com/sagernet/sing/common/exceptions"
11-
"github.com/sagernet/sing/common/rw"
11+
"github.com/sagernet/sing/common/varbin"
1212
)
1313

1414
const (
@@ -116,7 +116,7 @@ func (s *Serializer) WriteAddrPort(writer io.Writer, destination Socksaddr) erro
116116
return err
117117
}
118118
if !isBuffer {
119-
err = rw.WriteBytes(writer, buffer.Bytes())
119+
err = common.Error(writer.Write(buffer.Bytes()))
120120
}
121121
return err
122122
}
@@ -129,8 +129,9 @@ func (s *Serializer) AddrPortLen(destination Socksaddr) int {
129129
}
130130
}
131131

132-
func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) {
133-
af, err := rw.ReadByte(reader)
132+
func (s *Serializer) ReadAddress(rawRedaer io.Reader) (Socksaddr, error) {
133+
reader := varbin.NewReader(rawRedaer)
134+
af, err := reader.ReadByte()
134135
if err != nil {
135136
return Socksaddr{}, err
136137
}
@@ -164,11 +165,12 @@ func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) {
164165
}
165166

166167
func (s *Serializer) ReadPort(reader io.Reader) (uint16, error) {
167-
port, err := rw.ReadBytes(reader, 2)
168+
var port uint16
169+
err := binary.Read(reader, binary.BigEndian, &port)
168170
if err != nil {
169171
return 0, E.Cause(err, "read port")
170172
}
171-
return binary.BigEndian.Uint16(port), nil
173+
return port, nil
172174
}
173175

174176
func (s *Serializer) ReadAddrPort(reader io.Reader) (destination Socksaddr, err error) {
@@ -194,12 +196,17 @@ func (s *Serializer) ReadAddrPort(reader io.Reader) (destination Socksaddr, err
194196
return addr, nil
195197
}
196198

197-
func ReadSockString(reader io.Reader) (string, error) {
198-
strLen, err := rw.ReadByte(reader)
199+
func ReadSockString(reader varbin.Reader) (string, error) {
200+
strLen, err := reader.ReadByte()
201+
if err != nil {
202+
return "", err
203+
}
204+
strBytes := make([]byte, strLen)
205+
_, err = io.ReadFull(reader, strBytes)
199206
if err != nil {
200207
return "", err
201208
}
202-
return rw.ReadString(reader, int(strLen))
209+
return string(strBytes), nil
203210
}
204211

205212
func WriteSocksString(buffer *buf.Buffer, str string) error {

protocol/socks/client.go

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
package socks
22

33
import (
4+
std_bufio "bufio"
45
"context"
56
"net"
67
"net/url"
78
"os"
89
"strings"
910

11+
"github.com/sagernet/sing/common/buf"
12+
"github.com/sagernet/sing/common/bufio"
1013
E "github.com/sagernet/sing/common/exceptions"
1114
M "github.com/sagernet/sing/common/metadata"
1215
N "github.com/sagernet/sing/common/network"
@@ -118,31 +121,53 @@ func (c *Client) DialContext(ctx context.Context, network string, address M.Sock
118121
return nil, err
119122
}
120123
if c.version == Version4 && address.IsFqdn() {
121-
tcpAddr, err := net.ResolveTCPAddr(network, address.String())
124+
var tcpAddr *net.TCPAddr
125+
tcpAddr, err = net.ResolveTCPAddr(network, address.String())
122126
if err != nil {
123127
tcpConn.Close()
124128
return nil, err
125129
}
126130
address = M.SocksaddrFromNet(tcpAddr)
127131
}
132+
reader := std_bufio.NewReader(tcpConn)
128133
switch c.version {
129134
case Version4, Version4A:
130-
_, err = ClientHandshake4(tcpConn, command, address, c.username)
135+
_, err = ClientHandshake4(reader, tcpConn, command, address, c.username)
131136
if err != nil {
132137
tcpConn.Close()
133138
return nil, err
134139
}
140+
if reader.Buffered() > 0 {
141+
buffer := buf.NewSize(reader.Buffered())
142+
_, err = buffer.ReadFullFrom(reader, reader.Buffered())
143+
if err != nil {
144+
tcpConn.Close()
145+
return nil, err
146+
}
147+
return bufio.NewCachedConn(tcpConn, buffer), nil
148+
}
135149
return tcpConn, nil
136150
case Version5:
137-
response, err := ClientHandshake5(tcpConn, command, address, c.username, c.password)
151+
var response socks5.Response
152+
response, err = ClientHandshake5(reader, tcpConn, command, address, c.username, c.password)
138153
if err != nil {
139154
tcpConn.Close()
140155
return nil, err
141156
}
142157
if command == socks5.CommandConnect {
158+
if reader.Buffered() > 0 {
159+
buffer := buf.NewSize(reader.Buffered())
160+
_, err = buffer.ReadFullFrom(reader, reader.Buffered())
161+
if err != nil {
162+
tcpConn.Close()
163+
return nil, err
164+
}
165+
return bufio.NewCachedConn(tcpConn, buffer), nil
166+
}
143167
return tcpConn, nil
144168
}
145-
udpConn, err := c.dialer.DialContext(ctx, N.NetworkUDP, response.Bind)
169+
var udpConn net.Conn
170+
udpConn, err = c.dialer.DialContext(ctx, N.NetworkUDP, response.Bind)
146171
if err != nil {
147172
tcpConn.Close()
148173
return nil, err
@@ -166,16 +191,17 @@ func (c *Client) BindContext(ctx context.Context, address M.Socksaddr) (net.Conn
166191
if err != nil {
167192
return nil, err
168193
}
194+
reader := std_bufio.NewReader(tcpConn)
169195
switch c.version {
170196
case Version4, Version4A:
171-
_, err = ClientHandshake4(tcpConn, socks4.CommandBind, address, c.username)
197+
_, err = ClientHandshake4(reader, tcpConn, socks4.CommandBind, address, c.username)
172198
if err != nil {
173199
tcpConn.Close()
174200
return nil, err
175201
}
176202
return tcpConn, nil
177203
case Version5:
178-
_, err = ClientHandshake5(tcpConn, socks5.CommandBind, address, c.username, c.password)
204+
_, err = ClientHandshake5(reader, tcpConn, socks5.CommandBind, address, c.username, c.password)
179205
if err != nil {
180206
tcpConn.Close()
181207
return nil, err

protocol/socks/handshake.go

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package socks
22

33
import (
4+
std_bufio "bufio"
45
"context"
56
"io"
67
"net"
@@ -13,7 +14,7 @@ import (
1314
E "github.com/sagernet/sing/common/exceptions"
1415
M "github.com/sagernet/sing/common/metadata"
1516
N "github.com/sagernet/sing/common/network"
16-
"github.com/sagernet/sing/common/rw"
17+
"github.com/sagernet/sing/common/varbin"
1718
"github.com/sagernet/sing/protocol/socks/socks4"
1819
"github.com/sagernet/sing/protocol/socks/socks5"
1920
)
@@ -23,16 +24,16 @@ type Handler interface {
2324
N.UDPConnectionHandler
2425
}
2526

26-
func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr, username string) (socks4.Response, error) {
27-
err := socks4.WriteRequest(conn, socks4.Request{
27+
func ClientHandshake4(reader varbin.Reader, writer io.Writer, command byte, destination M.Socksaddr, username string) (socks4.Response, error) {
28+
err := socks4.WriteRequest(writer, socks4.Request{
2829
Command: command,
2930
Destination: destination,
3031
Username: username,
3132
})
3233
if err != nil {
3334
return socks4.Response{}, err
3435
}
35-
response, err := socks4.ReadResponse(conn)
36+
response, err := socks4.ReadResponse(reader)
3637
if err != nil {
3738
return socks4.Response{}, err
3839
}
@@ -42,32 +43,32 @@ func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr,
4243
return response, err
4344
}
4445

45-
func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr, username string, password string) (socks5.Response, error) {
46+
func ClientHandshake5(reader varbin.Reader, writer io.Writer, command byte, destination M.Socksaddr, username string, password string) (socks5.Response, error) {
4647
var method byte
4748
if username == "" {
4849
method = socks5.AuthTypeNotRequired
4950
} else {
5051
method = socks5.AuthTypeUsernamePassword
5152
}
52-
err := socks5.WriteAuthRequest(conn, socks5.AuthRequest{
53+
err := socks5.WriteAuthRequest(writer, socks5.AuthRequest{
5354
Methods: []byte{method},
5455
})
5556
if err != nil {
5657
return socks5.Response{}, err
5758
}
58-
authResponse, err := socks5.ReadAuthResponse(conn)
59+
authResponse, err := socks5.ReadAuthResponse(reader)
5960
if err != nil {
6061
return socks5.Response{}, err
6162
}
6263
if authResponse.Method == socks5.AuthTypeUsernamePassword {
63-
err = socks5.WriteUsernamePasswordAuthRequest(conn, socks5.UsernamePasswordAuthRequest{
64+
err = socks5.WriteUsernamePasswordAuthRequest(writer, socks5.UsernamePasswordAuthRequest{
6465
Username: username,
6566
Password: password,
6667
})
6768
if err != nil {
6869
return socks5.Response{}, err
6970
}
70-
usernamePasswordResponse, err := socks5.ReadUsernamePasswordAuthResponse(conn)
71+
usernamePasswordResponse, err := socks5.ReadUsernamePasswordAuthResponse(reader)
7172
if err != nil {
7273
return socks5.Response{}, err
7374
}
@@ -77,14 +78,14 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr,
7778
} else if authResponse.Method != socks5.AuthTypeNotRequired {
7879
return socks5.Response{}, E.New("socks5: unsupported auth method: ", authResponse.Method)
7980
}
80-
err = socks5.WriteRequest(conn, socks5.Request{
81+
err = socks5.WriteRequest(writer, socks5.Request{
8182
Command: command,
8283
Destination: destination,
8384
})
8485
if err != nil {
8586
return socks5.Response{}, err
8687
}
87-
response, err := socks5.ReadResponse(conn)
88+
response, err := socks5.ReadResponse(reader)
8889
if err != nil {
8990
return socks5.Response{}, err
9091
}
@@ -94,18 +95,14 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr,
9495
return response, err
9596
}
9697

97-
func HandleConnection(ctx context.Context, conn net.Conn, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error {
98-
version, err := rw.ReadByte(conn)
98+
func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error {
99+
version, err := reader.ReadByte()
99100
if err != nil {
100101
return err
101102
}
102-
return HandleConnection0(ctx, conn, version, authenticator, handler, metadata)
103-
}
104-
105-
func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error {
106103
switch version {
107104
case socks4.Version:
108-
request, err := socks4.ReadRequest0(conn)
105+
request, err := socks4.ReadRequest0(reader)
109106
if err != nil {
110107
return err
111108
}
@@ -142,7 +139,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent
142139
return E.New("socks4: unsupported command ", request.Command)
143140
}
144141
case socks5.Version:
145-
authRequest, err := socks5.ReadAuthRequest0(conn)
142+
authRequest, err := socks5.ReadAuthRequest0(reader)
146143
if err != nil {
147144
return err
148145
}
@@ -167,7 +164,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent
167164
return err
168165
}
169166
if authMethod == socks5.AuthTypeUsernamePassword {
170-
usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(conn)
167+
usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(reader)
171168
if err != nil {
172169
return err
173170
}
@@ -186,7 +183,7 @@ func HandleConnection0(ctx context.Context, conn net.Conn, version byte, authent
186183
return E.New("socks5: authentication failed, username=", usernamePasswordAuthRequest.Username, ", password=", usernamePasswordAuthRequest.Password)
187184
}
188185
}
189-
request, err := socks5.ReadRequest(conn)
186+
request, err := socks5.ReadRequest(reader)
190187
if err != nil {
191188
return err
192189
}

protocol/socks/socks4/protocol.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
"github.com/sagernet/sing/common/buf"
1111
E "github.com/sagernet/sing/common/exceptions"
1212
M "github.com/sagernet/sing/common/metadata"
13-
"github.com/sagernet/sing/common/rw"
13+
"github.com/sagernet/sing/common/varbin"
1414
)
1515

1616
const (
@@ -31,8 +31,8 @@ type Request struct {
3131
Username string
3232
}
3333

34-
func ReadRequest(reader io.Reader) (request Request, err error) {
35-
version, err := rw.ReadByte(reader)
34+
func ReadRequest(reader varbin.Reader) (request Request, err error) {
35+
version, err := reader.ReadByte()
3636
if err != nil {
3737
return
3838
}
@@ -43,8 +43,8 @@ func ReadRequest(reader io.Reader) (request Request, err error) {
4343
return ReadRequest0(reader)
4444
}
4545

46-
func ReadRequest0(reader io.Reader) (request Request, err error) {
47-
request.Command, err = rw.ReadByte(reader)
46+
func ReadRequest0(reader varbin.Reader) (request Request, err error) {
47+
request.Command, err = reader.ReadByte()
4848
if err != nil {
4949
return
5050
}
@@ -108,24 +108,24 @@ func WriteRequest(writer io.Writer, request Request) error {
108108
common.Must1(buffer.WriteString(request.Destination.AddrString()))
109109
common.Must(buffer.WriteZero())
110110
}
111-
return rw.WriteBytes(writer, buffer.Bytes())
111+
return common.Error(writer.Write(buffer.Bytes()))
112112
}
113113

114114
type Response struct {
115115
ReplyCode byte
116116
Destination M.Socksaddr
117117
}
118118

119-
func ReadResponse(reader io.Reader) (response Response, err error) {
120-
version, err := rw.ReadByte(reader)
119+
func ReadResponse(reader varbin.Reader) (response Response, err error) {
120+
version, err := reader.ReadByte()
121121
if err != nil {
122122
return
123123
}
124124
if version != 0 {
125125
err = E.New("excepted socks4 response version 0, got ", version)
126126
return
127127
}
128-
response.ReplyCode, err = rw.ReadByte(reader)
128+
response.ReplyCode, err = reader.ReadByte()
129129
if err != nil {
130130
return
131131
}
@@ -151,13 +151,13 @@ func WriteResponse(writer io.Writer, response Response) error {
151151
binary.Write(buffer, binary.BigEndian, response.Destination.Port),
152152
common.Error(buffer.Write(response.Destination.Addr.AsSlice())),
153153
)
154-
return rw.WriteBytes(writer, buffer.Bytes())
154+
return common.Error(writer.Write(buffer.Bytes()))
155155
}
156156

157-
func readString(reader io.Reader) (string, error) {
157+
func readString(reader varbin.Reader) (string, error) {
158158
buffer := bytes.Buffer{}
159159
for {
160-
b, err := rw.ReadByte(reader)
160+
b, err := reader.ReadByte()
161161
if err != nil {
162162
return "", err
163163
}

0 commit comments

Comments
 (0)