Skip to content

Commit b2e302f

Browse files
committed
Fix compatibility
1 parent 629c912 commit b2e302f

File tree

5 files changed

+59
-16
lines changed

5 files changed

+59
-16
lines changed

compatibility_read_deadline.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,20 @@ func (w listenerCompatibilityReadDeadline) Accept() (net.Conn, error) {
2626
if err != nil {
2727
return nil, err
2828
}
29-
return connCompatibilityReadDeadline{c}, nil
29+
return NewConnCompatibilityReadDeadline(c), nil
30+
}
31+
32+
// NewConnCompatibilityReadDeadline this is a wrapper used to be compatible with
33+
// the net.Conn after wrapping it so that it can be hijacked properly.
34+
// there is no effect if the content is not manipulated.
35+
func NewConnCompatibilityReadDeadline(conn net.Conn) net.Conn {
36+
if conn == nil {
37+
return nil
38+
}
39+
if conn, ok := conn.(connCompatibilityReadDeadline); ok {
40+
return conn
41+
}
42+
return connCompatibilityReadDeadline{conn}
3043
}
3144

3245
type connCompatibilityReadDeadline struct {
@@ -35,7 +48,7 @@ type connCompatibilityReadDeadline struct {
3548

3649
func (d connCompatibilityReadDeadline) SetReadDeadline(t time.Time) error {
3750
if aLongTimeAgo == t {
38-
t = time.Now().Add(time.Second)
51+
t = time.Now().Add(1 * time.Second)
3952
}
4053
return d.Conn.SetReadDeadline(t)
4154
}

go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ module github.com/wzshiming/httpproxy
22

33
go 1.18
44

5-
require golang.org/x/net v0.2.0
5+
require golang.org/x/net v0.5.0
66

7-
require golang.org/x/text v0.4.0 // indirect
7+
require golang.org/x/text v0.6.0 // indirect

go.sum

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU=
2-
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
3-
golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg=
4-
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
1+
golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw=
2+
golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws=
3+
golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k=
4+
golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=

proxy.go

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package httpproxy
22

33
import (
4+
"bufio"
45
"context"
56
"fmt"
67
"io"
@@ -96,23 +97,22 @@ func (p *ProxyHandler) proxyConnect(w http.ResponseWriter, r *http.Request) {
9697
http.Error(w, e, http.StatusInternalServerError)
9798
return
9899
}
100+
defer targetConn.Close()
99101

100-
if flusher, ok := w.(http.Flusher); ok {
101-
flusher.Flush()
102-
} else {
103-
w.WriteHeader(http.StatusOK)
104-
}
102+
w.WriteHeader(http.StatusOK)
105103

106-
clientConn, _, err := hijacker.Hijack()
104+
conn, rw, err := hijacker.Hijack()
107105
if err != nil {
108-
e := err.Error()
106+
e := fmt.Sprintf("hijack failed: %v", err)
109107
if p.Logger != nil {
110108
p.Logger.Println(e)
111109
}
112110
http.Error(w, e, http.StatusInternalServerError)
113111
return
114112
}
115113

114+
clientConn := newBufConn(conn, rw)
115+
116116
var buf1, buf2 []byte
117117
if p.BytesPool != nil {
118118
buf1 = p.BytesPool.Get()
@@ -151,3 +151,34 @@ func (p *ProxyHandler) proxyDial(ctx context.Context, network, address string) (
151151
}
152152
return proxyDial(ctx, network, address)
153153
}
154+
155+
func newBufConn(conn net.Conn, rw *bufio.ReadWriter) net.Conn {
156+
rw.Flush()
157+
if rw.Reader.Buffered() == 0 {
158+
// If there's no buffered data to be read,
159+
// we can just discard the bufio.ReadWriter.
160+
return conn
161+
}
162+
return &bufConn{conn, rw.Reader}
163+
}
164+
165+
// bufConn wraps a net.Conn, but reads drain the bufio.Reader first.
166+
type bufConn struct {
167+
net.Conn
168+
*bufio.Reader
169+
}
170+
171+
func (c *bufConn) Read(p []byte) (int, error) {
172+
if c.Reader == nil {
173+
return c.Conn.Read(p)
174+
}
175+
n := c.Reader.Buffered()
176+
if n == 0 {
177+
c.Reader = nil
178+
return c.Conn.Read(p)
179+
}
180+
if n < len(p) {
181+
p = p[:n]
182+
}
183+
return c.Reader.Read(p)
184+
}

tunnel.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"io"
66
)
77

8-
98
// tunnel create tunnels for two io.ReadWriteCloser
109
func tunnel(ctx context.Context, c1, c2 io.ReadWriteCloser, buf1, buf2 []byte) error {
1110
ctx, cancel := context.WithCancel(ctx)

0 commit comments

Comments
 (0)