Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 7389afb

Browse files
committedNov 18, 2022
Compatibility with ReadDeadline
1 parent 4e7640f commit 7389afb

File tree

3 files changed

+154
-12
lines changed

3 files changed

+154
-12
lines changed
 

‎compatibility_read_deadline.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package httpproxy
2+
3+
import (
4+
"net"
5+
"time"
6+
)
7+
8+
// aLongTimeAgo is a non-zero time, far in the past, used for
9+
// immediate cancellation of network operations.
10+
// copies from http
11+
var aLongTimeAgo = time.Unix(1, 0)
12+
13+
// NewListenerCompatibilityReadDeadline this is a wrapper used to be compatible with
14+
// the contents of ServerConn after wrapping it so that it can be hijacked properly.
15+
// there is no effect if the content is not manipulated.
16+
func NewListenerCompatibilityReadDeadline(listener net.Listener) net.Listener {
17+
return listenerCompatibilityReadDeadline{listener}
18+
}
19+
20+
type listenerCompatibilityReadDeadline struct {
21+
net.Listener
22+
}
23+
24+
func (w listenerCompatibilityReadDeadline) Accept() (net.Conn, error) {
25+
c, err := w.Listener.Accept()
26+
if err != nil {
27+
return nil, err
28+
}
29+
return connCompatibilityReadDeadline{c}, nil
30+
}
31+
32+
type connCompatibilityReadDeadline struct {
33+
net.Conn
34+
}
35+
36+
func (d connCompatibilityReadDeadline) SetReadDeadline(t time.Time) error {
37+
if aLongTimeAgo == t {
38+
t = time.Now().Add(time.Second)
39+
}
40+
return d.Conn.SetReadDeadline(t)
41+
}

‎compatibility_read_deadline_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package httpproxy
2+
3+
import (
4+
"context"
5+
"encoding/hex"
6+
"fmt"
7+
"io"
8+
"net"
9+
"net/http"
10+
"net/http/httptest"
11+
"testing"
12+
)
13+
14+
func TestNewListenerCompatibilityReadDeadline(t *testing.T) {
15+
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
16+
fmt.Fprint(w, "check", r.RequestURI)
17+
}))
18+
19+
listener, err := net.Listen("tcp", ":0")
20+
if err != nil {
21+
t.Fatal(err)
22+
}
23+
listener = newHexListener(listener)
24+
listener = NewListenerCompatibilityReadDeadline(listener)
25+
26+
s, err := NewSimpleServer("http://u:p@:0")
27+
if err != nil {
28+
t.Fatal(err)
29+
}
30+
31+
s.Listener = listener
32+
s.Start(context.Background())
33+
defer s.Close()
34+
35+
dial, err := NewDialer(s.ProxyURL())
36+
if err != nil {
37+
t.Fatal(err)
38+
}
39+
dial.ProxyDial = func(ctx context.Context, network, address string) (net.Conn, error) {
40+
conn, err := net.Dial(network, address)
41+
if err != nil {
42+
return nil, err
43+
}
44+
conn = newHexConn(conn)
45+
return conn, nil
46+
}
47+
cli := testServer.Client()
48+
cli.Transport = &http.Transport{
49+
DialContext: dial.DialContext,
50+
}
51+
52+
resp, err := cli.Get(testServer.URL)
53+
if err != nil {
54+
t.Fatal(err)
55+
}
56+
resp.Body.Close()
57+
}
58+
59+
func newHexListener(listener net.Listener) net.Listener {
60+
return hexListener{
61+
Listener: listener,
62+
}
63+
}
64+
65+
type hexListener struct {
66+
net.Listener
67+
}
68+
69+
func (h hexListener) Accept() (net.Conn, error) {
70+
conn, err := h.Listener.Accept()
71+
if err != nil {
72+
return nil, err
73+
}
74+
return newHexConn(conn), nil
75+
}
76+
77+
func newHexConn(conn net.Conn) net.Conn {
78+
return hexConn{
79+
Conn: conn,
80+
r: hex.NewDecoder(conn),
81+
w: hex.NewEncoder(conn),
82+
}
83+
}
84+
85+
type hexConn struct {
86+
net.Conn
87+
r io.Reader
88+
w io.Writer
89+
}
90+
91+
func (h hexConn) Read(p []byte) (n int, err error) {
92+
return h.r.Read(p)
93+
}
94+
95+
func (h hexConn) Write(p []byte) (n int, err error) {
96+
return h.w.Write(p)
97+
}

‎simple_server.go

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,25 +52,29 @@ func NewSimpleServer(addr string) (*SimpleServer, error) {
5252
// Run the server
5353
func (s *SimpleServer) Run(ctx context.Context) error {
5454
var listenConfig net.ListenConfig
55-
listener, err := listenConfig.Listen(ctx, s.Network, s.Address)
56-
if err != nil {
57-
return err
55+
if s.Listener == nil {
56+
listener, err := listenConfig.Listen(ctx, s.Network, s.Address)
57+
if err != nil {
58+
return err
59+
}
60+
s.Listener = NewListenerCompatibilityReadDeadline(listener)
5861
}
59-
s.Listener = listener
60-
s.Address = listener.Addr().String()
61-
return s.Server.Serve(listener)
62+
s.Address = s.Listener.Addr().String()
63+
return s.Server.Serve(s.Listener)
6264
}
6365

6466
// Start the server
6567
func (s *SimpleServer) Start(ctx context.Context) error {
6668
var listenConfig net.ListenConfig
67-
listener, err := listenConfig.Listen(ctx, s.Network, s.Address)
68-
if err != nil {
69-
return err
69+
if s.Listener == nil {
70+
listener, err := listenConfig.Listen(ctx, s.Network, s.Address)
71+
if err != nil {
72+
return err
73+
}
74+
s.Listener = NewListenerCompatibilityReadDeadline(listener)
7075
}
71-
s.Listener = listener
72-
s.Address = listener.Addr().String()
73-
go s.Server.Serve(listener)
76+
s.Address = s.Listener.Addr().String()
77+
go s.Server.Serve(s.Listener)
7478
return nil
7579
}
7680

0 commit comments

Comments
 (0)
Please sign in to comment.