diff --git a/cmd/shadowsocks-local/local.go b/cmd/shadowsocks-local/local.go index 604bf9f7..a5a9cc1d 100644 --- a/cmd/shadowsocks-local/local.go +++ b/cmd/shadowsocks-local/local.go @@ -77,7 +77,7 @@ func handShake(conn net.Conn) (err error) { return } -func getRequest(conn net.Conn) (rawaddr []byte, host string, err error) { +func getRequest(conn net.Conn) (header []byte, host string, err error) { const ( idVer = 0 idCmd = 1 @@ -94,8 +94,11 @@ func getRequest(conn net.Conn) (rawaddr []byte, host string, err error) { lenIPv6 = 3 + 1 + net.IPv6len + 2 // 3(ver+cmd+rsv) + 1addrType + ipv6 + 2port lenDmBase = 3 + 1 + 1 + 2 // 3 + 1addrType + 1addrLen + 2port, plus addrLen ) - // refer to getRequest in server.go for why set buffer size to 263 - buf := make([]byte, 263) + + // buf size should at least have the same size with the largest possible + // request size (when addrType is 3, domain name has at most 256 bytes) + // 1(addrType) + 1(lenByte) + 255(max length address) + 2(port) + 10(hmac-sha1) + [1~128](random length data buffer) + buf := make([]byte, 270+rand.Int()%128) var n int ss.SetReadTimeout(conn) // read till we get possible domain length field @@ -136,8 +139,6 @@ func getRequest(conn net.Conn) (rawaddr []byte, host string, err error) { return } - rawaddr = buf[idType:reqLen] - if debug { switch buf[idType] { case typeIPv4: @@ -151,6 +152,21 @@ func getRequest(conn net.Conn) (rawaddr []byte, host string, err error) { host = net.JoinHostPort(host, strconv.Itoa(int(port))) } + // Sending connection established message immediately to client. + // This some round trip time for creating socks connection with the client. + // But if connection failed, the client will get connection reset error. + _, err = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x08, 0x43}) + if err != nil { + debug.Println("send connection confirmation:", err) + return + } + + // read until the client sends at least 1 byte of data + if n, err = io.ReadAtLeast(conn, buf[reqLen:], 1); err != nil { + return + } + + header = buf[idType : reqLen+n] return } @@ -235,9 +251,9 @@ func parseServerConfig(config *ss.Config) { return } -func connectToServer(serverId int, rawaddr []byte, addr string) (remote *ss.Conn, err error) { +func connectToServer(serverId int, header []byte, addr string) (remote *ss.Conn, err error) { se := servers.srvCipher[serverId] - remote, err = ss.DialWithRawAddr(rawaddr, se.server, se.cipher.Copy()) + remote, err = ss.DiaAndWriteData(header, se.server, se.cipher.Copy()) if err != nil { log.Println("error connecting to shadowsocks server:", err) const maxFailCnt = 30 @@ -255,7 +271,7 @@ func connectToServer(serverId int, rawaddr []byte, addr string) (remote *ss.Conn // connection failure, try the next server. A failed server will be tried with // some probability according to its fail count, so we can discover recovered // servers. -func createServerConn(rawaddr []byte, addr string) (remote *ss.Conn, err error) { +func createServerConn(header []byte, addr string) (remote *ss.Conn, err error) { const baseFailCnt = 20 n := len(servers.srvCipher) skipped := make([]int, 0) @@ -265,14 +281,14 @@ func createServerConn(rawaddr []byte, addr string) (remote *ss.Conn, err error) skipped = append(skipped, i) continue } - remote, err = connectToServer(i, rawaddr, addr) + remote, err = connectToServer(i, header, addr) if err == nil { return } } // last resort, try skipped servers, not likely to succeed for _, i := range skipped { - remote, err = connectToServer(i, rawaddr, addr) + remote, err = connectToServer(i, header, addr) if err == nil { return } @@ -296,21 +312,13 @@ func handleConnection(conn net.Conn) { log.Println("socks handshake:", err) return } - rawaddr, addr, err := getRequest(conn) + header, addr, err := getRequest(conn) if err != nil { log.Println("error getting request:", err) return } - // Sending connection established message immediately to client. - // This some round trip time for creating socks connection with the client. - // But if connection failed, the client will get connection reset error. - _, err = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x08, 0x43}) - if err != nil { - debug.Println("send connection confirmation:", err) - return - } - remote, err := createServerConn(rawaddr, addr) + remote, err := createServerConn(header, addr) if err != nil { if len(servers.srvCipher) > 1 { log.Println("Failed connect to all available shadowsocks server") diff --git a/shadowsocks/conn.go b/shadowsocks/conn.go index 5d264b74..1edd621c 100644 --- a/shadowsocks/conn.go +++ b/shadowsocks/conn.go @@ -9,7 +9,7 @@ import ( ) const ( - AddrMask byte = 0xf + AddrMask byte = 0xf ) type Conn struct { @@ -69,6 +69,22 @@ func DialWithRawAddr(rawaddr []byte, server string, cipher *Cipher) (c *Conn, er return } +// DiaAndWriteData is intended for use by users implementing a local socks proxy. +// rawaddr shoud contain part of the data in socks request and forwarding data, +// starting from the ATYP field. (Refer to rfc1928 for more information.) +func DiaAndWriteData(data []byte, server string, cipher *Cipher) (c *Conn, err error) { + conn, err := net.Dial("tcp", server) + if err != nil { + return + } + c = NewConn(conn, cipher) + if _, err = c.Write(data); err != nil { + c.Close() + return nil, err + } + return +} + // Dial: addr should be in the form of host:port func Dial(addr, server string, cipher *Cipher) (c *Conn, err error) { ra, err := RawAddr(addr)